Files
rl-atari/CW1_id_name/train_sb3_baseline.py
T
Serendipity fb09e66d09 feat: 重构项目结构并添加向量化PPO训练与评估脚本
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率
- 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计
- 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py)
- 提供详细的文档和开发日志,包含问题解决记录和实验分析
- 移除旧版项目文件,统一项目结构到CW1_id_name目录下
2026-05-02 13:44:08 +08:00

104 lines
3.0 KiB
Python

"""Stable-Baselines3 PPO baseline for fair comparison.
We compare against SB3's default CNN PPO under the same observation
preprocessing (4-frame stacked grayscale 84x84) and similar core
hyperparameters. Note: SB3 has additional optimisations on by default
(orthogonal init, reward normalisation in some versions, vectorised env);
this baseline is intentionally a "production" reference, not a like-for-like
comparison. The report should discuss this honestly.
Usage:
python train_sb3_baseline.py --total-steps 500000 --run-name sb3_baseline
This script is ONLY for the evaluation/comparison phase. The main PPO
implementation in src/ uses no SB3 code.
"""
import argparse
from pathlib import Path
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from src.env_wrappers import FrameStack, GrayScaleResize, SkipFrame
def _make_one(rank, seed):
def _init():
env = gym.make("CarRacing-v3", continuous=False)
env = SkipFrame(env, k=4)
env = GrayScaleResize(env, size=84)
env = FrameStack(env, k=4)
env.reset(seed=seed + rank)
return env
return _init
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--total-steps", type=int, default=500_000)
p.add_argument("--n-envs", type=int, default=4)
p.add_argument("--run-name", type=str, default="sb3_baseline")
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def main():
args = parse_args()
project_root = Path(__file__).resolve().parent
log_dir = project_root / "runs" / args.run_name
ckpt_dir = project_root / "models" / args.run_name
log_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir.mkdir(parents=True, exist_ok=True)
print("=" * 60)
print(f"SB3 PPO baseline: {args.run_name}")
print(f"Total steps: {args.total_steps:,} n_envs: {args.n_envs}")
print(f"Logs: {log_dir}")
print(f"Ckpts: {ckpt_dir}")
print("=" * 60)
fns = [_make_one(i, args.seed) for i in range(args.n_envs)]
vec_env = SubprocVecEnv(fns) if args.n_envs > 1 else DummyVecEnv(fns)
# Match our hyperparameters as closely as possible
model = PPO(
policy="CnnPolicy",
env=vec_env,
learning_rate=2.5e-4,
n_steps=512,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
vf_coef=0.5,
ent_coef=0.01,
max_grad_norm=0.5,
tensorboard_log=str(log_dir),
seed=args.seed,
verbose=1,
)
ckpt_cb = CheckpointCallback(
save_freq=max(50_000 // args.n_envs, 1),
save_path=str(ckpt_dir),
name_prefix="sb3",
)
model.learn(
total_timesteps=args.total_steps,
callback=ckpt_cb,
tb_log_name="run",
)
model.save(str(ckpt_dir / "final.zip"))
print(f"\nSaved final SB3 model to {ckpt_dir / 'final.zip'}")
vec_env.close()
if __name__ == "__main__":
main()