perf: 为PPO和DQN添加GPU优化——AMP混合精度、pinned memory、torch.compile

- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090
- DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环
- train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-05 00:50:16 +08:00
parent ed0822966b
commit d5c9baffe6
7 changed files with 495 additions and 883 deletions
+41 -27
View File
@@ -10,6 +10,7 @@ of PPO* (Huang et al. 2022) and *RAD* (Laskin et al. 2020):
- Linear schedule for clip range (clip_init -> clip_floor)
- Random-shift data augmentation on observations during the update
- Linear annealing of learning rate and entropy coefficient with floors
- AMP mixed precision training for GPU acceleration
Public API:
- PPOAgent.act(obs) -> (action, log_prob, value)
@@ -48,11 +49,17 @@ class PPOAgent:
target_kl=None,
use_data_aug=False,
aug_pad=4,
use_amp=True,
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device(dev) if isinstance(dev, str) else dev
self.net = ActorCritic(n_actions=n_actions).to(self.device)
self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5)
# AMP 混合精度训练
self.use_amp = use_amp and self.device.type == 'cuda'
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
# Save initial values for scheduling
self.lr_init = lr
self.clip_init = clip
@@ -84,7 +91,8 @@ class PPOAgent:
def act_batch(self, obs_batch):
"""Vectorised act for n_envs obs at once."""
obs_t = torch.as_tensor(obs_batch, device=self.device)
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
with torch.amp.autocast('cuda', enabled=self.use_amp):
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
return (
action.cpu().numpy(),
log_prob.cpu().numpy(),
@@ -94,7 +102,8 @@ class PPOAgent:
@torch.no_grad()
def evaluate_value_batch(self, obs_batch):
obs_t = torch.as_tensor(obs_batch, device=self.device)
_, value = self.net(obs_t)
with torch.amp.autocast('cuda', enabled=self.use_amp):
_, value = self.net(obs_t)
return value.cpu().numpy()
def _random_shift(self, obs):
@@ -166,38 +175,43 @@ class PPOAgent:
if self.use_data_aug:
b_obs = self._random_shift(b_obs)
_, new_logp, entropy, value = self.net.get_action_and_value(
b_obs, b_actions
)
# AMP 前向传播
with torch.amp.autocast('cuda', enabled=self.use_amp):
_, new_logp, entropy, value = self.net.get_action_and_value(
b_obs, b_actions
)
log_ratio = new_logp - b_old_logp
ratio = log_ratio.exp()
log_ratio = new_logp - b_old_logp
ratio = log_ratio.exp()
# Clipped policy loss
surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
policy_loss = -torch.min(surr1, surr2).mean()
# Clipped policy loss
surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
policy_loss = -torch.min(surr1, surr2).mean()
# Clipped value loss (refinement #1, SB3 standard)
v_clipped = b_old_values + torch.clamp(
value - b_old_values, -self.clip, self.clip
)
v_loss_unclipped = (value - b_ret).pow(2)
v_loss_clipped = (v_clipped - b_ret).pow(2)
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
# Clipped value loss (refinement #1, SB3 standard)
v_clipped = b_old_values + torch.clamp(
value - b_old_values, -self.clip, self.clip
)
v_loss_unclipped = (value - b_ret).pow(2)
v_loss_clipped = (v_clipped - b_ret).pow(2)
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
entropy_loss = entropy.mean()
entropy_loss = entropy.mean()
loss = (
policy_loss
+ self.vf_coef * value_loss
- self.ent_coef * entropy_loss
)
loss = (
policy_loss
+ self.vf_coef * value_loss
- self.ent_coef * entropy_loss
)
# AMP 反向传播
self.optim.zero_grad()
loss.backward()
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optim)
nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
self.optim.step()
self.scaler.step(self.optim)
self.scaler.update()
with torch.no_grad():
approx_kl = ((ratio - 1) - log_ratio).mean().item()
+28 -6
View File
@@ -4,6 +4,8 @@ Uses CleanRL's indexing convention:
dones[t] flags whether obs[t] is the FIRST obs of a fresh episode
(i.e., the previous action terminated). GAE then uses dones[t+1]
as the mask for V(s_{t+1}) at time t.
Supports pinned memory for faster CPU→GPU transfer.
"""
import torch
@@ -16,6 +18,7 @@ class VecRolloutBuffer:
self.obs_shape = obs_shape
self.device = device
# 主存储在 GPU 上
self.obs = torch.zeros(
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, device=device
)
@@ -28,16 +31,35 @@ class VecRolloutBuffer:
self.advantages = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
self.returns = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
# Pinned memory 缓冲区(加速 CPU→GPU 传输)
self._obs_pin = torch.zeros(
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, pin_memory=True
)
self._actions_pin = torch.zeros((n_steps, n_envs), dtype=torch.long, pin_memory=True)
self._log_probs_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self._rewards_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self._values_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self._dones_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self.ptr = 0
def add(self, obs, action, log_prob, reward, value, done):
i = self.ptr
self.obs[i] = torch.as_tensor(obs, device=self.device)
self.actions[i] = torch.as_tensor(action, device=self.device, dtype=torch.long)
self.log_probs[i] = torch.as_tensor(log_prob, device=self.device, dtype=torch.float32)
self.rewards[i] = torch.as_tensor(reward, device=self.device, dtype=torch.float32)
self.values[i] = torch.as_tensor(value, device=self.device, dtype=torch.float32)
self.dones[i] = torch.as_tensor(done, device=self.device, dtype=torch.float32)
# 先写入 pinned memory,再 non-blocking 传输到 GPU
self._obs_pin[i] = torch.as_tensor(obs)
self._actions_pin[i] = torch.as_tensor(action, dtype=torch.long)
self._log_probs_pin[i] = torch.as_tensor(log_prob, dtype=torch.float32)
self._rewards_pin[i] = torch.as_tensor(reward, dtype=torch.float32)
self._values_pin[i] = torch.as_tensor(value, dtype=torch.float32)
self._dones_pin[i] = torch.as_tensor(done, dtype=torch.float32)
# non_blocking 传输到 GPU
self.obs[i] = self._obs_pin[i].to(self.device, non_blocking=True)
self.actions[i] = self._actions_pin[i].to(self.device, non_blocking=True)
self.log_probs[i] = self._log_probs_pin[i].to(self.device, non_blocking=True)
self.rewards[i] = self._rewards_pin[i].to(self.device, non_blocking=True)
self.values[i] = self._values_pin[i].to(self.device, non_blocking=True)
self.dones[i] = self._dones_pin[i].to(self.device, non_blocking=True)
self.ptr += 1
def compute_gae(self, last_value, last_done, gamma=0.99, lam=0.95):