perf: 为PPO和DQN添加GPU优化——AMP混合精度、pinned memory、torch.compile

- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090
- DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环
- train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-05 00:50:16 +08:00
parent ed0822966b
commit d5c9baffe6
7 changed files with 495 additions and 883 deletions
+41 -27
View File
@@ -10,6 +10,7 @@ of PPO* (Huang et al. 2022) and *RAD* (Laskin et al. 2020):
- Linear schedule for clip range (clip_init -> clip_floor)
- Random-shift data augmentation on observations during the update
- Linear annealing of learning rate and entropy coefficient with floors
- AMP mixed precision training for GPU acceleration
Public API:
- PPOAgent.act(obs) -> (action, log_prob, value)
@@ -48,11 +49,17 @@ class PPOAgent:
target_kl=None,
use_data_aug=False,
aug_pad=4,
use_amp=True,
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device(dev) if isinstance(dev, str) else dev
self.net = ActorCritic(n_actions=n_actions).to(self.device)
self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5)
# AMP 混合精度训练
self.use_amp = use_amp and self.device.type == 'cuda'
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
# Save initial values for scheduling
self.lr_init = lr
self.clip_init = clip
@@ -84,7 +91,8 @@ class PPOAgent:
def act_batch(self, obs_batch):
"""Vectorised act for n_envs obs at once."""
obs_t = torch.as_tensor(obs_batch, device=self.device)
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
with torch.amp.autocast('cuda', enabled=self.use_amp):
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
return (
action.cpu().numpy(),
log_prob.cpu().numpy(),
@@ -94,7 +102,8 @@ class PPOAgent:
@torch.no_grad()
def evaluate_value_batch(self, obs_batch):
obs_t = torch.as_tensor(obs_batch, device=self.device)
_, value = self.net(obs_t)
with torch.amp.autocast('cuda', enabled=self.use_amp):
_, value = self.net(obs_t)
return value.cpu().numpy()
def _random_shift(self, obs):
@@ -166,38 +175,43 @@ class PPOAgent:
if self.use_data_aug:
b_obs = self._random_shift(b_obs)
_, new_logp, entropy, value = self.net.get_action_and_value(
b_obs, b_actions
)
# AMP 前向传播
with torch.amp.autocast('cuda', enabled=self.use_amp):
_, new_logp, entropy, value = self.net.get_action_and_value(
b_obs, b_actions
)
log_ratio = new_logp - b_old_logp
ratio = log_ratio.exp()
log_ratio = new_logp - b_old_logp
ratio = log_ratio.exp()
# Clipped policy loss
surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
policy_loss = -torch.min(surr1, surr2).mean()
# Clipped policy loss
surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
policy_loss = -torch.min(surr1, surr2).mean()
# Clipped value loss (refinement #1, SB3 standard)
v_clipped = b_old_values + torch.clamp(
value - b_old_values, -self.clip, self.clip
)
v_loss_unclipped = (value - b_ret).pow(2)
v_loss_clipped = (v_clipped - b_ret).pow(2)
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
# Clipped value loss (refinement #1, SB3 standard)
v_clipped = b_old_values + torch.clamp(
value - b_old_values, -self.clip, self.clip
)
v_loss_unclipped = (value - b_ret).pow(2)
v_loss_clipped = (v_clipped - b_ret).pow(2)
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
entropy_loss = entropy.mean()
entropy_loss = entropy.mean()
loss = (
policy_loss
+ self.vf_coef * value_loss
- self.ent_coef * entropy_loss
)
loss = (
policy_loss
+ self.vf_coef * value_loss
- self.ent_coef * entropy_loss
)
# AMP 反向传播
self.optim.zero_grad()
loss.backward()
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optim)
nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
self.optim.step()
self.scaler.step(self.optim)
self.scaler.update()
with torch.no_grad():
approx_kl = ((ratio - 1) - log_ratio).mean().item()
+28 -6
View File
@@ -4,6 +4,8 @@ Uses CleanRL's indexing convention:
dones[t] flags whether obs[t] is the FIRST obs of a fresh episode
(i.e., the previous action terminated). GAE then uses dones[t+1]
as the mask for V(s_{t+1}) at time t.
Supports pinned memory for faster CPU→GPU transfer.
"""
import torch
@@ -16,6 +18,7 @@ class VecRolloutBuffer:
self.obs_shape = obs_shape
self.device = device
# 主存储在 GPU 上
self.obs = torch.zeros(
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, device=device
)
@@ -28,16 +31,35 @@ class VecRolloutBuffer:
self.advantages = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
self.returns = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
# Pinned memory 缓冲区(加速 CPU→GPU 传输)
self._obs_pin = torch.zeros(
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, pin_memory=True
)
self._actions_pin = torch.zeros((n_steps, n_envs), dtype=torch.long, pin_memory=True)
self._log_probs_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self._rewards_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self._values_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self._dones_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self.ptr = 0
def add(self, obs, action, log_prob, reward, value, done):
i = self.ptr
self.obs[i] = torch.as_tensor(obs, device=self.device)
self.actions[i] = torch.as_tensor(action, device=self.device, dtype=torch.long)
self.log_probs[i] = torch.as_tensor(log_prob, device=self.device, dtype=torch.float32)
self.rewards[i] = torch.as_tensor(reward, device=self.device, dtype=torch.float32)
self.values[i] = torch.as_tensor(value, device=self.device, dtype=torch.float32)
self.dones[i] = torch.as_tensor(done, device=self.device, dtype=torch.float32)
# 先写入 pinned memory,再 non-blocking 传输到 GPU
self._obs_pin[i] = torch.as_tensor(obs)
self._actions_pin[i] = torch.as_tensor(action, dtype=torch.long)
self._log_probs_pin[i] = torch.as_tensor(log_prob, dtype=torch.float32)
self._rewards_pin[i] = torch.as_tensor(reward, dtype=torch.float32)
self._values_pin[i] = torch.as_tensor(value, dtype=torch.float32)
self._dones_pin[i] = torch.as_tensor(done, dtype=torch.float32)
# non_blocking 传输到 GPU
self.obs[i] = self._obs_pin[i].to(self.device, non_blocking=True)
self.actions[i] = self._actions_pin[i].to(self.device, non_blocking=True)
self.log_probs[i] = self._log_probs_pin[i].to(self.device, non_blocking=True)
self.rewards[i] = self._rewards_pin[i].to(self.device, non_blocking=True)
self.values[i] = self._values_pin[i].to(self.device, non_blocking=True)
self.dones[i] = self._dones_pin[i].to(self.device, non_blocking=True)
self.ptr += 1
def compute_gae(self, last_value, last_done, gamma=0.99, lam=0.95):
+23 -5
View File
@@ -5,6 +5,11 @@ Usage (Windows):
python train_vec.py --n-envs 4 --total-steps 500000 --run-name vec_main \
--anneal-lr --anneal-ent --reward-clip 1.0
Usage (Linux server with RTX 5090):
python train_vec.py --n-envs 16 --total-steps 2000000 --run-name vec_main \
--n-steps 512 --batch-size 512 --n-epochs 10 \
--anneal-lr --anneal-ent --reward-clip 1.0 --use-amp
The ``if __name__ == "__main__"`` guard at the bottom is mandatory on
Windows for AsyncVectorEnv (otherwise child processes infinite-spawn).
"""
@@ -26,11 +31,11 @@ from src.vec_rollout_buffer import VecRolloutBuffer
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--total-steps", type=int, default=3_000_000)
p.add_argument("--n-envs", type=int, default=8)
p.add_argument("--n-steps", type=int, default=256)
p.add_argument("--n-epochs", type=int, default=6)
p.add_argument("--batch-size", type=int, default=128)
p.add_argument("--total-steps", type=int, default=2_000_000)
p.add_argument("--n-envs", type=int, default=16)
p.add_argument("--n-steps", type=int, default=512)
p.add_argument("--n-epochs", type=int, default=10)
p.add_argument("--batch-size", type=int, default=512)
p.add_argument("--lr", type=float, default=2.5e-4)
p.add_argument("--gamma", type=float, default=0.99)
p.add_argument("--lam", type=float, default=0.95)
@@ -56,6 +61,10 @@ def parse_args():
help="Apply random-shift augmentation to obs during PPO update")
p.add_argument("--sync-mode", action="store_true",
help="Use SyncVectorEnv (debug mode)")
p.add_argument("--use-amp", action="store_true",
help="Use AMP mixed precision training for GPU acceleration")
p.add_argument("--use-compile", action="store_true",
help="Use torch.compile for JIT compilation acceleration")
return p.parse_args()
@@ -94,7 +103,14 @@ def main():
clip_floor=args.clip_floor,
target_kl=args.target_kl,
use_data_aug=args.use_data_aug,
use_amp=args.use_amp,
)
# torch.compile JIT 编译加速
if args.use_compile and hasattr(torch, 'compile'):
print("应用 torch.compile 加速...")
agent.net = torch.compile(agent.net)
print("torch.compile 完成")
buffer = VecRolloutBuffer(
n_steps=args.n_steps,
n_envs=args.n_envs,
@@ -117,6 +133,8 @@ def main():
print(f"clip_floor={args.clip_floor} target_kl={args.target_kl} "
f"use_data_aug={args.use_data_aug}")
print(f"n_epochs={args.n_epochs} batch_size={args.batch_size}")
print(f"AMP: {args.use_amp}")
print(f"Compile: {args.use_compile}")
print(f"Device: {agent.device}")
print(f"Logs: {run_dir}")
print(f"Ckpts: {ckpt_dir}")