Files
rl-atari/CW1_id_name/train_vec.py
T
Serendipity d5c9baffe6 perf: 为PPO和DQN添加GPU优化——AMP混合精度、pinned memory、torch.compile
- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090
- DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环
- train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-05 00:50:16 +08:00

244 lines
8.9 KiB
Python

"""Train PPO on CarRacing-v3 with vectorised envs (parallel rollout).
Usage (Windows):
python train_vec.py --n-envs 4 --total-steps 10000 --run-name vec_smoke
python train_vec.py --n-envs 4 --total-steps 500000 --run-name vec_main \
--anneal-lr --anneal-ent --reward-clip 1.0
Usage (Linux server with RTX 5090):
python train_vec.py --n-envs 16 --total-steps 2000000 --run-name vec_main \
--n-steps 512 --batch-size 512 --n-epochs 10 \
--anneal-lr --anneal-ent --reward-clip 1.0 --use-amp
The ``if __name__ == "__main__"`` guard at the bottom is mandatory on
Windows for AsyncVectorEnv (otherwise child processes infinite-spawn).
"""
import argparse
import time
from collections import deque
from pathlib import Path
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from src.ppo_agent import PPOAgent
from src.utils import format_seconds, set_seed
from src.vec_env_wrappers import make_vec_env
from src.vec_rollout_buffer import VecRolloutBuffer
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--total-steps", type=int, default=2_000_000)
p.add_argument("--n-envs", type=int, default=16)
p.add_argument("--n-steps", type=int, default=512)
p.add_argument("--n-epochs", type=int, default=10)
p.add_argument("--batch-size", type=int, default=512)
p.add_argument("--lr", type=float, default=2.5e-4)
p.add_argument("--gamma", type=float, default=0.99)
p.add_argument("--lam", type=float, default=0.95)
p.add_argument("--clip", type=float, default=0.2)
p.add_argument("--ent-coef", type=float, default=0.01)
p.add_argument("--vf-coef", type=float, default=0.5)
p.add_argument("--max-grad-norm", type=float, default=0.5)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--run-name", type=str, default="ppo_vec_main")
p.add_argument("--save-every-iters", type=int, default=20)
p.add_argument("--anneal-lr", action="store_true")
p.add_argument("--anneal-ent", action="store_true")
p.add_argument("--reward-clip", type=float, default=None)
p.add_argument("--ent-floor", type=float, default=0.0,
help="Lower bound on ent_coef when --anneal-ent is on")
p.add_argument("--clip-floor", type=float, default=None,
help="Linearly anneal clip range to this floor (e.g. 0.05). "
"None disables clip annealing.")
p.add_argument("--target-kl", type=float, default=None,
help="Stop the current update epoch early if mean approx_kl "
"exceeds 1.5 * target_kl. None disables. SB3 default 0.015.")
p.add_argument("--use-data-aug", action="store_true",
help="Apply random-shift augmentation to obs during PPO update")
p.add_argument("--sync-mode", action="store_true",
help="Use SyncVectorEnv (debug mode)")
p.add_argument("--use-amp", action="store_true",
help="Use AMP mixed precision training for GPU acceleration")
p.add_argument("--use-compile", action="store_true",
help="Use torch.compile for JIT compilation acceleration")
return p.parse_args()
def main():
args = parse_args()
project_root = Path(__file__).resolve().parent
run_dir = project_root / "runs" / args.run_name
ckpt_dir = project_root / "models" / args.run_name
run_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir.mkdir(parents=True, exist_ok=True)
set_seed(args.seed)
# Throughput tweak: let cuDNN auto-pick the fastest conv algorithm
# for our fixed (B, 4, 84, 84) input shape.
torch.backends.cudnn.benchmark = True
vec_env = make_vec_env(
n_envs=args.n_envs,
seed=args.seed,
async_mode=not args.sync_mode,
)
agent = PPOAgent(
n_actions=5,
lr=args.lr,
clip=args.clip,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm,
n_epochs=args.n_epochs,
batch_size=args.batch_size,
anneal_lr=args.anneal_lr,
anneal_ent=args.anneal_ent,
clip_floor=args.clip_floor,
target_kl=args.target_kl,
use_data_aug=args.use_data_aug,
use_amp=args.use_amp,
)
# torch.compile JIT 编译加速
if args.use_compile and hasattr(torch, 'compile'):
print("应用 torch.compile 加速...")
agent.net = torch.compile(agent.net)
print("torch.compile 完成")
buffer = VecRolloutBuffer(
n_steps=args.n_steps,
n_envs=args.n_envs,
obs_shape=(4, 84, 84),
device=agent.device,
)
writer = SummaryWriter(str(run_dir))
samples_per_iter = args.n_steps * args.n_envs
print("=" * 60)
print(f"Run: {args.run_name}")
mode_str = "sync" if args.sync_mode else "async"
print(f"Mode: {mode_str} vec env, n_envs={args.n_envs}")
print(f"Total steps: {args.total_steps:,}")
print(f"Per-iter samples: {samples_per_iter} (n_steps={args.n_steps} x n_envs={args.n_envs})")
print(f"lr={args.lr} gamma={args.gamma} lam={args.lam} clip={args.clip}")
print(f"anneal_lr={args.anneal_lr} anneal_ent={args.anneal_ent} "
f"ent_floor={args.ent_floor} reward_clip={args.reward_clip}")
print(f"clip_floor={args.clip_floor} target_kl={args.target_kl} "
f"use_data_aug={args.use_data_aug}")
print(f"n_epochs={args.n_epochs} batch_size={args.batch_size}")
print(f"AMP: {args.use_amp}")
print(f"Compile: {args.use_compile}")
print(f"Device: {agent.device}")
print(f"Logs: {run_dir}")
print(f"Ckpts: {ckpt_dir}")
print("=" * 60)
obs, _ = vec_env.reset(seed=args.seed)
next_done = np.zeros(args.n_envs, dtype=np.float32)
global_step = 0
iteration = 0
episode_returns = deque(maxlen=100)
cur_ep_returns = np.zeros(args.n_envs, dtype=np.float32)
cur_ep_lens = np.zeros(args.n_envs, dtype=np.int64)
start_time = time.time()
while global_step < args.total_steps:
iteration += 1
agent.step_schedule(
global_step / args.total_steps,
ent_floor=args.ent_floor,
)
# Rollout (n_steps per env, total samples = n_steps * n_envs)
for step in range(args.n_steps):
actions, log_probs, values = agent.act_batch(obs)
next_obs, rewards, terms, truncs, _ = vec_env.step(actions)
done = np.logical_or(terms, truncs).astype(np.float32)
train_rewards = (
np.maximum(rewards, -args.reward_clip)
if args.reward_clip is not None
else rewards
)
# Use CleanRL convention: dones[step] = was obs[step] a fresh start
buffer.add(obs, actions, log_probs, train_rewards, values, next_done)
cur_ep_returns += rewards
cur_ep_lens += 1
for i in range(args.n_envs):
if done[i]:
episode_returns.append(float(cur_ep_returns[i]))
writer.add_scalar("episode/return", cur_ep_returns[i], global_step)
writer.add_scalar("episode/length", cur_ep_lens[i], global_step)
cur_ep_returns[i] = 0.0
cur_ep_lens[i] = 0
obs = next_obs
next_done = done
global_step += args.n_envs
# GAE
last_value = agent.evaluate_value_batch(obs)
buffer.compute_gae(
last_value=last_value,
last_done=next_done,
gamma=args.gamma,
lam=args.lam,
)
# Update
losses = agent.update_vec(buffer)
for k, v in losses.items():
writer.add_scalar(f"losses/{k}", v, global_step)
elapsed = time.time() - start_time
steps_per_sec = global_step / max(elapsed, 1e-6)
avg_ret = sum(episode_returns) / len(episode_returns) if episode_returns else 0.0
writer.add_scalar("perf/steps_per_sec", steps_per_sec, global_step)
writer.add_scalar("episode/avg_return_100", avg_ret, global_step)
writer.add_scalar("hp/lr", agent.optim.param_groups[0]["lr"], global_step)
writer.add_scalar("hp/ent_coef", agent.ent_coef, global_step)
writer.add_scalar("hp/clip", agent.clip, global_step)
epochs_done = int(losses.get("epochs_completed", args.n_epochs))
early = losses.get("early_stopped", 0.0) > 0.5
mark = "*" if early else " "
print(
f"iter {iteration:4d} | step {global_step:>9,} | "
f"avg_ret(100) {avg_ret:7.2f} | "
f"pg {losses['policy_loss']:+.4f} | "
f"v {losses['value_loss']:7.3f} | "
f"ent {losses['entropy']:.3f} | "
f"kl {losses['approx_kl']:.4f} | "
f"clip {agent.clip:.3f} | "
f"clip% {losses['clip_frac']:.2%} | "
f"ep {epochs_done}{mark}/{args.n_epochs} | "
f"sps {steps_per_sec:5.0f} | "
f"{format_seconds(elapsed)}"
)
if iteration % args.save_every_iters == 0:
agent.save(str(ckpt_dir / f"iter_{iteration:04d}.pt"))
buffer.reset()
final_path = ckpt_dir / "final.pt"
agent.save(str(final_path))
print(f"\nTraining done. Final model: {final_path}")
writer.close()
vec_env.close()
if __name__ == "__main__":
main()