perf: 为PPO和DQN添加GPU优化——AMP混合精度、pinned memory、torch.compile
- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090 - DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环 - train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,7 @@ of PPO* (Huang et al. 2022) and *RAD* (Laskin et al. 2020):
|
||||
- Linear schedule for clip range (clip_init -> clip_floor)
|
||||
- Random-shift data augmentation on observations during the update
|
||||
- Linear annealing of learning rate and entropy coefficient with floors
|
||||
- AMP mixed precision training for GPU acceleration
|
||||
|
||||
Public API:
|
||||
- PPOAgent.act(obs) -> (action, log_prob, value)
|
||||
@@ -48,11 +49,17 @@ class PPOAgent:
|
||||
target_kl=None,
|
||||
use_data_aug=False,
|
||||
aug_pad=4,
|
||||
use_amp=True,
|
||||
):
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device(dev) if isinstance(dev, str) else dev
|
||||
self.net = ActorCritic(n_actions=n_actions).to(self.device)
|
||||
self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5)
|
||||
|
||||
# AMP 混合精度训练
|
||||
self.use_amp = use_amp and self.device.type == 'cuda'
|
||||
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
|
||||
|
||||
# Save initial values for scheduling
|
||||
self.lr_init = lr
|
||||
self.clip_init = clip
|
||||
@@ -84,7 +91,8 @@ class PPOAgent:
|
||||
def act_batch(self, obs_batch):
|
||||
"""Vectorised act for n_envs obs at once."""
|
||||
obs_t = torch.as_tensor(obs_batch, device=self.device)
|
||||
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
|
||||
return (
|
||||
action.cpu().numpy(),
|
||||
log_prob.cpu().numpy(),
|
||||
@@ -94,7 +102,8 @@ class PPOAgent:
|
||||
@torch.no_grad()
|
||||
def evaluate_value_batch(self, obs_batch):
|
||||
obs_t = torch.as_tensor(obs_batch, device=self.device)
|
||||
_, value = self.net(obs_t)
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
_, value = self.net(obs_t)
|
||||
return value.cpu().numpy()
|
||||
|
||||
def _random_shift(self, obs):
|
||||
@@ -166,38 +175,43 @@ class PPOAgent:
|
||||
if self.use_data_aug:
|
||||
b_obs = self._random_shift(b_obs)
|
||||
|
||||
_, new_logp, entropy, value = self.net.get_action_and_value(
|
||||
b_obs, b_actions
|
||||
)
|
||||
# AMP 前向传播
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
_, new_logp, entropy, value = self.net.get_action_and_value(
|
||||
b_obs, b_actions
|
||||
)
|
||||
|
||||
log_ratio = new_logp - b_old_logp
|
||||
ratio = log_ratio.exp()
|
||||
log_ratio = new_logp - b_old_logp
|
||||
ratio = log_ratio.exp()
|
||||
|
||||
# Clipped policy loss
|
||||
surr1 = ratio * b_adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
|
||||
policy_loss = -torch.min(surr1, surr2).mean()
|
||||
# Clipped policy loss
|
||||
surr1 = ratio * b_adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
|
||||
policy_loss = -torch.min(surr1, surr2).mean()
|
||||
|
||||
# Clipped value loss (refinement #1, SB3 standard)
|
||||
v_clipped = b_old_values + torch.clamp(
|
||||
value - b_old_values, -self.clip, self.clip
|
||||
)
|
||||
v_loss_unclipped = (value - b_ret).pow(2)
|
||||
v_loss_clipped = (v_clipped - b_ret).pow(2)
|
||||
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
||||
# Clipped value loss (refinement #1, SB3 standard)
|
||||
v_clipped = b_old_values + torch.clamp(
|
||||
value - b_old_values, -self.clip, self.clip
|
||||
)
|
||||
v_loss_unclipped = (value - b_ret).pow(2)
|
||||
v_loss_clipped = (v_clipped - b_ret).pow(2)
|
||||
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
||||
|
||||
entropy_loss = entropy.mean()
|
||||
entropy_loss = entropy.mean()
|
||||
|
||||
loss = (
|
||||
policy_loss
|
||||
+ self.vf_coef * value_loss
|
||||
- self.ent_coef * entropy_loss
|
||||
)
|
||||
loss = (
|
||||
policy_loss
|
||||
+ self.vf_coef * value_loss
|
||||
- self.ent_coef * entropy_loss
|
||||
)
|
||||
|
||||
# AMP 反向传播
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optim)
|
||||
nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
|
||||
self.optim.step()
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
|
||||
with torch.no_grad():
|
||||
approx_kl = ((ratio - 1) - log_ratio).mean().item()
|
||||
|
||||
@@ -4,6 +4,8 @@ Uses CleanRL's indexing convention:
|
||||
dones[t] flags whether obs[t] is the FIRST obs of a fresh episode
|
||||
(i.e., the previous action terminated). GAE then uses dones[t+1]
|
||||
as the mask for V(s_{t+1}) at time t.
|
||||
|
||||
Supports pinned memory for faster CPU→GPU transfer.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -16,6 +18,7 @@ class VecRolloutBuffer:
|
||||
self.obs_shape = obs_shape
|
||||
self.device = device
|
||||
|
||||
# 主存储在 GPU 上
|
||||
self.obs = torch.zeros(
|
||||
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, device=device
|
||||
)
|
||||
@@ -28,16 +31,35 @@ class VecRolloutBuffer:
|
||||
self.advantages = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
|
||||
self.returns = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
|
||||
|
||||
# Pinned memory 缓冲区(加速 CPU→GPU 传输)
|
||||
self._obs_pin = torch.zeros(
|
||||
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, pin_memory=True
|
||||
)
|
||||
self._actions_pin = torch.zeros((n_steps, n_envs), dtype=torch.long, pin_memory=True)
|
||||
self._log_probs_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
|
||||
self._rewards_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
|
||||
self._values_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
|
||||
self._dones_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
|
||||
|
||||
self.ptr = 0
|
||||
|
||||
def add(self, obs, action, log_prob, reward, value, done):
|
||||
i = self.ptr
|
||||
self.obs[i] = torch.as_tensor(obs, device=self.device)
|
||||
self.actions[i] = torch.as_tensor(action, device=self.device, dtype=torch.long)
|
||||
self.log_probs[i] = torch.as_tensor(log_prob, device=self.device, dtype=torch.float32)
|
||||
self.rewards[i] = torch.as_tensor(reward, device=self.device, dtype=torch.float32)
|
||||
self.values[i] = torch.as_tensor(value, device=self.device, dtype=torch.float32)
|
||||
self.dones[i] = torch.as_tensor(done, device=self.device, dtype=torch.float32)
|
||||
# 先写入 pinned memory,再 non-blocking 传输到 GPU
|
||||
self._obs_pin[i] = torch.as_tensor(obs)
|
||||
self._actions_pin[i] = torch.as_tensor(action, dtype=torch.long)
|
||||
self._log_probs_pin[i] = torch.as_tensor(log_prob, dtype=torch.float32)
|
||||
self._rewards_pin[i] = torch.as_tensor(reward, dtype=torch.float32)
|
||||
self._values_pin[i] = torch.as_tensor(value, dtype=torch.float32)
|
||||
self._dones_pin[i] = torch.as_tensor(done, dtype=torch.float32)
|
||||
|
||||
# non_blocking 传输到 GPU
|
||||
self.obs[i] = self._obs_pin[i].to(self.device, non_blocking=True)
|
||||
self.actions[i] = self._actions_pin[i].to(self.device, non_blocking=True)
|
||||
self.log_probs[i] = self._log_probs_pin[i].to(self.device, non_blocking=True)
|
||||
self.rewards[i] = self._rewards_pin[i].to(self.device, non_blocking=True)
|
||||
self.values[i] = self._values_pin[i].to(self.device, non_blocking=True)
|
||||
self.dones[i] = self._dones_pin[i].to(self.device, non_blocking=True)
|
||||
self.ptr += 1
|
||||
|
||||
def compute_gae(self, last_value, last_done, gamma=0.99, lam=0.95):
|
||||
|
||||
Reference in New Issue
Block a user