perf: 为PPO和DQN添加GPU优化——AMP混合精度、pinned memory、torch.compile
- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090 - DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环 - train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,7 @@ of PPO* (Huang et al. 2022) and *RAD* (Laskin et al. 2020):
|
||||
- Linear schedule for clip range (clip_init -> clip_floor)
|
||||
- Random-shift data augmentation on observations during the update
|
||||
- Linear annealing of learning rate and entropy coefficient with floors
|
||||
- AMP mixed precision training for GPU acceleration
|
||||
|
||||
Public API:
|
||||
- PPOAgent.act(obs) -> (action, log_prob, value)
|
||||
@@ -48,11 +49,17 @@ class PPOAgent:
|
||||
target_kl=None,
|
||||
use_data_aug=False,
|
||||
aug_pad=4,
|
||||
use_amp=True,
|
||||
):
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device(dev) if isinstance(dev, str) else dev
|
||||
self.net = ActorCritic(n_actions=n_actions).to(self.device)
|
||||
self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5)
|
||||
|
||||
# AMP 混合精度训练
|
||||
self.use_amp = use_amp and self.device.type == 'cuda'
|
||||
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
|
||||
|
||||
# Save initial values for scheduling
|
||||
self.lr_init = lr
|
||||
self.clip_init = clip
|
||||
@@ -84,7 +91,8 @@ class PPOAgent:
|
||||
def act_batch(self, obs_batch):
|
||||
"""Vectorised act for n_envs obs at once."""
|
||||
obs_t = torch.as_tensor(obs_batch, device=self.device)
|
||||
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
|
||||
return (
|
||||
action.cpu().numpy(),
|
||||
log_prob.cpu().numpy(),
|
||||
@@ -94,7 +102,8 @@ class PPOAgent:
|
||||
@torch.no_grad()
|
||||
def evaluate_value_batch(self, obs_batch):
|
||||
obs_t = torch.as_tensor(obs_batch, device=self.device)
|
||||
_, value = self.net(obs_t)
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
_, value = self.net(obs_t)
|
||||
return value.cpu().numpy()
|
||||
|
||||
def _random_shift(self, obs):
|
||||
@@ -166,38 +175,43 @@ class PPOAgent:
|
||||
if self.use_data_aug:
|
||||
b_obs = self._random_shift(b_obs)
|
||||
|
||||
_, new_logp, entropy, value = self.net.get_action_and_value(
|
||||
b_obs, b_actions
|
||||
)
|
||||
# AMP 前向传播
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
_, new_logp, entropy, value = self.net.get_action_and_value(
|
||||
b_obs, b_actions
|
||||
)
|
||||
|
||||
log_ratio = new_logp - b_old_logp
|
||||
ratio = log_ratio.exp()
|
||||
log_ratio = new_logp - b_old_logp
|
||||
ratio = log_ratio.exp()
|
||||
|
||||
# Clipped policy loss
|
||||
surr1 = ratio * b_adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
|
||||
policy_loss = -torch.min(surr1, surr2).mean()
|
||||
# Clipped policy loss
|
||||
surr1 = ratio * b_adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
|
||||
policy_loss = -torch.min(surr1, surr2).mean()
|
||||
|
||||
# Clipped value loss (refinement #1, SB3 standard)
|
||||
v_clipped = b_old_values + torch.clamp(
|
||||
value - b_old_values, -self.clip, self.clip
|
||||
)
|
||||
v_loss_unclipped = (value - b_ret).pow(2)
|
||||
v_loss_clipped = (v_clipped - b_ret).pow(2)
|
||||
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
||||
# Clipped value loss (refinement #1, SB3 standard)
|
||||
v_clipped = b_old_values + torch.clamp(
|
||||
value - b_old_values, -self.clip, self.clip
|
||||
)
|
||||
v_loss_unclipped = (value - b_ret).pow(2)
|
||||
v_loss_clipped = (v_clipped - b_ret).pow(2)
|
||||
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
||||
|
||||
entropy_loss = entropy.mean()
|
||||
entropy_loss = entropy.mean()
|
||||
|
||||
loss = (
|
||||
policy_loss
|
||||
+ self.vf_coef * value_loss
|
||||
- self.ent_coef * entropy_loss
|
||||
)
|
||||
loss = (
|
||||
policy_loss
|
||||
+ self.vf_coef * value_loss
|
||||
- self.ent_coef * entropy_loss
|
||||
)
|
||||
|
||||
# AMP 反向传播
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optim)
|
||||
nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
|
||||
self.optim.step()
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
|
||||
with torch.no_grad():
|
||||
approx_kl = ((ratio - 1) - log_ratio).mean().item()
|
||||
|
||||
Reference in New Issue
Block a user