perf: 为PPO和DQN添加GPU优化——AMP混合精度、pinned memory、torch.compile

- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090
- DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环
- train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-05 00:50:16 +08:00
parent ed0822966b
commit d5c9baffe6
7 changed files with 495 additions and 883 deletions
+23 -5
View File
@@ -5,6 +5,11 @@ Usage (Windows):
python train_vec.py --n-envs 4 --total-steps 500000 --run-name vec_main \
--anneal-lr --anneal-ent --reward-clip 1.0
Usage (Linux server with RTX 5090):
python train_vec.py --n-envs 16 --total-steps 2000000 --run-name vec_main \
--n-steps 512 --batch-size 512 --n-epochs 10 \
--anneal-lr --anneal-ent --reward-clip 1.0 --use-amp
The ``if __name__ == "__main__"`` guard at the bottom is mandatory on
Windows for AsyncVectorEnv (otherwise child processes infinite-spawn).
"""
@@ -26,11 +31,11 @@ from src.vec_rollout_buffer import VecRolloutBuffer
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--total-steps", type=int, default=3_000_000)
p.add_argument("--n-envs", type=int, default=8)
p.add_argument("--n-steps", type=int, default=256)
p.add_argument("--n-epochs", type=int, default=6)
p.add_argument("--batch-size", type=int, default=128)
p.add_argument("--total-steps", type=int, default=2_000_000)
p.add_argument("--n-envs", type=int, default=16)
p.add_argument("--n-steps", type=int, default=512)
p.add_argument("--n-epochs", type=int, default=10)
p.add_argument("--batch-size", type=int, default=512)
p.add_argument("--lr", type=float, default=2.5e-4)
p.add_argument("--gamma", type=float, default=0.99)
p.add_argument("--lam", type=float, default=0.95)
@@ -56,6 +61,10 @@ def parse_args():
help="Apply random-shift augmentation to obs during PPO update")
p.add_argument("--sync-mode", action="store_true",
help="Use SyncVectorEnv (debug mode)")
p.add_argument("--use-amp", action="store_true",
help="Use AMP mixed precision training for GPU acceleration")
p.add_argument("--use-compile", action="store_true",
help="Use torch.compile for JIT compilation acceleration")
return p.parse_args()
@@ -94,7 +103,14 @@ def main():
clip_floor=args.clip_floor,
target_kl=args.target_kl,
use_data_aug=args.use_data_aug,
use_amp=args.use_amp,
)
# torch.compile JIT 编译加速
if args.use_compile and hasattr(torch, 'compile'):
print("应用 torch.compile 加速...")
agent.net = torch.compile(agent.net)
print("torch.compile 完成")
buffer = VecRolloutBuffer(
n_steps=args.n_steps,
n_envs=args.n_envs,
@@ -117,6 +133,8 @@ def main():
print(f"clip_floor={args.clip_floor} target_kl={args.target_kl} "
f"use_data_aug={args.use_data_aug}")
print(f"n_epochs={args.n_epochs} batch_size={args.batch_size}")
print(f"AMP: {args.use_amp}")
print(f"Compile: {args.use_compile}")
print(f"Device: {agent.device}")
print(f"Logs: {run_dir}")
print(f"Ckpts: {ckpt_dir}")