fb09e66d09
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率 - 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计 - 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py) - 提供详细的文档和开发日志,包含问题解决记录和实验分析 - 移除旧版项目文件,统一项目结构到CW1_id_name目录下
104 lines
3.0 KiB
Python
104 lines
3.0 KiB
Python
"""Stable-Baselines3 PPO baseline for fair comparison.
|
|
|
|
We compare against SB3's default CNN PPO under the same observation
|
|
preprocessing (4-frame stacked grayscale 84x84) and similar core
|
|
hyperparameters. Note: SB3 has additional optimisations on by default
|
|
(orthogonal init, reward normalisation in some versions, vectorised env);
|
|
this baseline is intentionally a "production" reference, not a like-for-like
|
|
comparison. The report should discuss this honestly.
|
|
|
|
Usage:
|
|
python train_sb3_baseline.py --total-steps 500000 --run-name sb3_baseline
|
|
|
|
This script is ONLY for the evaluation/comparison phase. The main PPO
|
|
implementation in src/ uses no SB3 code.
|
|
"""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
import gymnasium as gym
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.callbacks import CheckpointCallback
|
|
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
|
|
|
from src.env_wrappers import FrameStack, GrayScaleResize, SkipFrame
|
|
|
|
|
|
def _make_one(rank, seed):
|
|
def _init():
|
|
env = gym.make("CarRacing-v3", continuous=False)
|
|
env = SkipFrame(env, k=4)
|
|
env = GrayScaleResize(env, size=84)
|
|
env = FrameStack(env, k=4)
|
|
env.reset(seed=seed + rank)
|
|
return env
|
|
return _init
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--total-steps", type=int, default=500_000)
|
|
p.add_argument("--n-envs", type=int, default=4)
|
|
p.add_argument("--run-name", type=str, default="sb3_baseline")
|
|
p.add_argument("--seed", type=int, default=42)
|
|
return p.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
project_root = Path(__file__).resolve().parent
|
|
|
|
log_dir = project_root / "runs" / args.run_name
|
|
ckpt_dir = project_root / "models" / args.run_name
|
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
print("=" * 60)
|
|
print(f"SB3 PPO baseline: {args.run_name}")
|
|
print(f"Total steps: {args.total_steps:,} n_envs: {args.n_envs}")
|
|
print(f"Logs: {log_dir}")
|
|
print(f"Ckpts: {ckpt_dir}")
|
|
print("=" * 60)
|
|
|
|
fns = [_make_one(i, args.seed) for i in range(args.n_envs)]
|
|
vec_env = SubprocVecEnv(fns) if args.n_envs > 1 else DummyVecEnv(fns)
|
|
|
|
# Match our hyperparameters as closely as possible
|
|
model = PPO(
|
|
policy="CnnPolicy",
|
|
env=vec_env,
|
|
learning_rate=2.5e-4,
|
|
n_steps=512,
|
|
batch_size=64,
|
|
n_epochs=10,
|
|
gamma=0.99,
|
|
gae_lambda=0.95,
|
|
clip_range=0.2,
|
|
vf_coef=0.5,
|
|
ent_coef=0.01,
|
|
max_grad_norm=0.5,
|
|
tensorboard_log=str(log_dir),
|
|
seed=args.seed,
|
|
verbose=1,
|
|
)
|
|
|
|
ckpt_cb = CheckpointCallback(
|
|
save_freq=max(50_000 // args.n_envs, 1),
|
|
save_path=str(ckpt_dir),
|
|
name_prefix="sb3",
|
|
)
|
|
model.learn(
|
|
total_timesteps=args.total_steps,
|
|
callback=ckpt_cb,
|
|
tb_log_name="run",
|
|
)
|
|
model.save(str(ckpt_dir / "final.zip"))
|
|
print(f"\nSaved final SB3 model to {ckpt_dir / 'final.zip'}")
|
|
|
|
vec_env.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|