perf: 为PPO和DQN添加GPU优化——AMP混合精度、pinned memory、torch.compile
- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090 - DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环 - train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -5,6 +5,11 @@ Usage (Windows):
|
||||
python train_vec.py --n-envs 4 --total-steps 500000 --run-name vec_main \
|
||||
--anneal-lr --anneal-ent --reward-clip 1.0
|
||||
|
||||
Usage (Linux server with RTX 5090):
|
||||
python train_vec.py --n-envs 16 --total-steps 2000000 --run-name vec_main \
|
||||
--n-steps 512 --batch-size 512 --n-epochs 10 \
|
||||
--anneal-lr --anneal-ent --reward-clip 1.0 --use-amp
|
||||
|
||||
The ``if __name__ == "__main__"`` guard at the bottom is mandatory on
|
||||
Windows for AsyncVectorEnv (otherwise child processes infinite-spawn).
|
||||
"""
|
||||
@@ -26,11 +31,11 @@ from src.vec_rollout_buffer import VecRolloutBuffer
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--total-steps", type=int, default=3_000_000)
|
||||
p.add_argument("--n-envs", type=int, default=8)
|
||||
p.add_argument("--n-steps", type=int, default=256)
|
||||
p.add_argument("--n-epochs", type=int, default=6)
|
||||
p.add_argument("--batch-size", type=int, default=128)
|
||||
p.add_argument("--total-steps", type=int, default=2_000_000)
|
||||
p.add_argument("--n-envs", type=int, default=16)
|
||||
p.add_argument("--n-steps", type=int, default=512)
|
||||
p.add_argument("--n-epochs", type=int, default=10)
|
||||
p.add_argument("--batch-size", type=int, default=512)
|
||||
p.add_argument("--lr", type=float, default=2.5e-4)
|
||||
p.add_argument("--gamma", type=float, default=0.99)
|
||||
p.add_argument("--lam", type=float, default=0.95)
|
||||
@@ -56,6 +61,10 @@ def parse_args():
|
||||
help="Apply random-shift augmentation to obs during PPO update")
|
||||
p.add_argument("--sync-mode", action="store_true",
|
||||
help="Use SyncVectorEnv (debug mode)")
|
||||
p.add_argument("--use-amp", action="store_true",
|
||||
help="Use AMP mixed precision training for GPU acceleration")
|
||||
p.add_argument("--use-compile", action="store_true",
|
||||
help="Use torch.compile for JIT compilation acceleration")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
@@ -94,7 +103,14 @@ def main():
|
||||
clip_floor=args.clip_floor,
|
||||
target_kl=args.target_kl,
|
||||
use_data_aug=args.use_data_aug,
|
||||
use_amp=args.use_amp,
|
||||
)
|
||||
|
||||
# torch.compile JIT 编译加速
|
||||
if args.use_compile and hasattr(torch, 'compile'):
|
||||
print("应用 torch.compile 加速...")
|
||||
agent.net = torch.compile(agent.net)
|
||||
print("torch.compile 完成")
|
||||
buffer = VecRolloutBuffer(
|
||||
n_steps=args.n_steps,
|
||||
n_envs=args.n_envs,
|
||||
@@ -117,6 +133,8 @@ def main():
|
||||
print(f"clip_floor={args.clip_floor} target_kl={args.target_kl} "
|
||||
f"use_data_aug={args.use_data_aug}")
|
||||
print(f"n_epochs={args.n_epochs} batch_size={args.batch_size}")
|
||||
print(f"AMP: {args.use_amp}")
|
||||
print(f"Compile: {args.use_compile}")
|
||||
print(f"Device: {agent.device}")
|
||||
print(f"Logs: {run_dir}")
|
||||
print(f"Ckpts: {ckpt_dir}")
|
||||
|
||||
Reference in New Issue
Block a user