"""Stable-Baselines3 PPO baseline for fair comparison. We compare against SB3's default CNN PPO under the same observation preprocessing (4-frame stacked grayscale 84x84) and similar core hyperparameters. Note: SB3 has additional optimisations on by default (orthogonal init, reward normalisation in some versions, vectorised env); this baseline is intentionally a "production" reference, not a like-for-like comparison. The report should discuss this honestly. Usage: python train_sb3_baseline.py --total-steps 500000 --run-name sb3_baseline This script is ONLY for the evaluation/comparison phase. The main PPO implementation in src/ uses no SB3 code. """ import argparse from pathlib import Path import gymnasium as gym from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv from src.env_wrappers import FrameStack, GrayScaleResize, SkipFrame def _make_one(rank, seed): def _init(): env = gym.make("CarRacing-v3", continuous=False) env = SkipFrame(env, k=4) env = GrayScaleResize(env, size=84) env = FrameStack(env, k=4) env.reset(seed=seed + rank) return env return _init def parse_args(): p = argparse.ArgumentParser() p.add_argument("--total-steps", type=int, default=500_000) p.add_argument("--n-envs", type=int, default=4) p.add_argument("--run-name", type=str, default="sb3_baseline") p.add_argument("--seed", type=int, default=42) return p.parse_args() def main(): args = parse_args() project_root = Path(__file__).resolve().parent log_dir = project_root / "runs" / args.run_name ckpt_dir = project_root / "models" / args.run_name log_dir.mkdir(parents=True, exist_ok=True) ckpt_dir.mkdir(parents=True, exist_ok=True) print("=" * 60) print(f"SB3 PPO baseline: {args.run_name}") print(f"Total steps: {args.total_steps:,} n_envs: {args.n_envs}") print(f"Logs: {log_dir}") print(f"Ckpts: {ckpt_dir}") print("=" * 60) fns = [_make_one(i, args.seed) for i in range(args.n_envs)] vec_env = SubprocVecEnv(fns) if args.n_envs > 1 else DummyVecEnv(fns) # Match our hyperparameters as closely as possible model = PPO( policy="CnnPolicy", env=vec_env, learning_rate=2.5e-4, n_steps=512, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5, tensorboard_log=str(log_dir), seed=args.seed, verbose=1, ) ckpt_cb = CheckpointCallback( save_freq=max(50_000 // args.n_envs, 1), save_path=str(ckpt_dir), name_prefix="sb3", ) model.learn( total_timesteps=args.total_steps, callback=ckpt_cb, tb_log_name="run", ) model.save(str(ckpt_dir / "final.zip")) print(f"\nSaved final SB3 model to {ckpt_dir / 'final.zip'}") vec_env.close() if __name__ == "__main__": main()