perf: 为PPO和DQN添加GPU优化——AMP混合精度、pinned memory、torch.compile

- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090
- DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环
- train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-05 00:50:16 +08:00
parent ed0822966b
commit d5c9baffe6
7 changed files with 495 additions and 883 deletions
+41 -27
View File
@@ -10,6 +10,7 @@ of PPO* (Huang et al. 2022) and *RAD* (Laskin et al. 2020):
- Linear schedule for clip range (clip_init -> clip_floor)
- Random-shift data augmentation on observations during the update
- Linear annealing of learning rate and entropy coefficient with floors
- AMP mixed precision training for GPU acceleration
Public API:
- PPOAgent.act(obs) -> (action, log_prob, value)
@@ -48,11 +49,17 @@ class PPOAgent:
target_kl=None,
use_data_aug=False,
aug_pad=4,
use_amp=True,
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device(dev) if isinstance(dev, str) else dev
self.net = ActorCritic(n_actions=n_actions).to(self.device)
self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5)
# AMP 混合精度训练
self.use_amp = use_amp and self.device.type == 'cuda'
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
# Save initial values for scheduling
self.lr_init = lr
self.clip_init = clip
@@ -84,7 +91,8 @@ class PPOAgent:
def act_batch(self, obs_batch):
"""Vectorised act for n_envs obs at once."""
obs_t = torch.as_tensor(obs_batch, device=self.device)
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
with torch.amp.autocast('cuda', enabled=self.use_amp):
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
return (
action.cpu().numpy(),
log_prob.cpu().numpy(),
@@ -94,7 +102,8 @@ class PPOAgent:
@torch.no_grad()
def evaluate_value_batch(self, obs_batch):
obs_t = torch.as_tensor(obs_batch, device=self.device)
_, value = self.net(obs_t)
with torch.amp.autocast('cuda', enabled=self.use_amp):
_, value = self.net(obs_t)
return value.cpu().numpy()
def _random_shift(self, obs):
@@ -166,38 +175,43 @@ class PPOAgent:
if self.use_data_aug:
b_obs = self._random_shift(b_obs)
_, new_logp, entropy, value = self.net.get_action_and_value(
b_obs, b_actions
)
# AMP 前向传播
with torch.amp.autocast('cuda', enabled=self.use_amp):
_, new_logp, entropy, value = self.net.get_action_and_value(
b_obs, b_actions
)
log_ratio = new_logp - b_old_logp
ratio = log_ratio.exp()
log_ratio = new_logp - b_old_logp
ratio = log_ratio.exp()
# Clipped policy loss
surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
policy_loss = -torch.min(surr1, surr2).mean()
# Clipped policy loss
surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
policy_loss = -torch.min(surr1, surr2).mean()
# Clipped value loss (refinement #1, SB3 standard)
v_clipped = b_old_values + torch.clamp(
value - b_old_values, -self.clip, self.clip
)
v_loss_unclipped = (value - b_ret).pow(2)
v_loss_clipped = (v_clipped - b_ret).pow(2)
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
# Clipped value loss (refinement #1, SB3 standard)
v_clipped = b_old_values + torch.clamp(
value - b_old_values, -self.clip, self.clip
)
v_loss_unclipped = (value - b_ret).pow(2)
v_loss_clipped = (v_clipped - b_ret).pow(2)
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
entropy_loss = entropy.mean()
entropy_loss = entropy.mean()
loss = (
policy_loss
+ self.vf_coef * value_loss
- self.ent_coef * entropy_loss
)
loss = (
policy_loss
+ self.vf_coef * value_loss
- self.ent_coef * entropy_loss
)
# AMP 反向传播
self.optim.zero_grad()
loss.backward()
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optim)
nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
self.optim.step()
self.scaler.step(self.optim)
self.scaler.update()
with torch.no_grad():
approx_kl = ((ratio - 1) - log_ratio).mean().item()