d5c9baffe6
- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090 - DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环 - train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
270 lines
11 KiB
Python
270 lines
11 KiB
Python
"""PPO agent: clipped surrogate objective + value loss + entropy bonus.
|
|
|
|
Implements the PPO-Clip algorithm (Schulman et al. 2017) on top of our
|
|
shared-CNN ActorCritic network and a vectorised rollout buffer. Includes
|
|
production-grade refinements catalogued in *The 37 Implementation Details
|
|
of PPO* (Huang et al. 2022) and *RAD* (Laskin et al. 2020):
|
|
|
|
- Clipped value-function loss (SB3 standard)
|
|
- KL early stopping within update epochs (target_kl)
|
|
- Linear schedule for clip range (clip_init -> clip_floor)
|
|
- Random-shift data augmentation on observations during the update
|
|
- Linear annealing of learning rate and entropy coefficient with floors
|
|
- AMP mixed precision training for GPU acceleration
|
|
|
|
Public API:
|
|
- PPOAgent.act(obs) -> (action, log_prob, value)
|
|
- PPOAgent.act_batch(obs_batch) -> batched act for n_envs obs
|
|
- PPOAgent.evaluate_value_batch(obs) -> bootstrap value for GAE
|
|
- PPOAgent.update_vec(buffer) -> PPO update over vectorised rollout
|
|
- PPOAgent.step_schedule(progress) -> linear LR/entropy/clip annealing
|
|
- PPOAgent.save / load -> state_dict checkpoints
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
|
|
from src.networks import ActorCritic
|
|
|
|
|
|
class PPOAgent:
|
|
"""PPO-Clip agent for discrete action spaces."""
|
|
|
|
def __init__(
|
|
self,
|
|
n_actions=5,
|
|
lr=2.5e-4,
|
|
clip=0.2,
|
|
vf_coef=0.5,
|
|
ent_coef=0.01,
|
|
max_grad_norm=0.5,
|
|
n_epochs=6,
|
|
batch_size=128,
|
|
device=None,
|
|
anneal_lr=False,
|
|
anneal_ent=False,
|
|
clip_floor=None,
|
|
target_kl=None,
|
|
use_data_aug=False,
|
|
aug_pad=4,
|
|
use_amp=True,
|
|
):
|
|
dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.device = torch.device(dev) if isinstance(dev, str) else dev
|
|
self.net = ActorCritic(n_actions=n_actions).to(self.device)
|
|
self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5)
|
|
|
|
# AMP 混合精度训练
|
|
self.use_amp = use_amp and self.device.type == 'cuda'
|
|
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
|
|
|
|
# Save initial values for scheduling
|
|
self.lr_init = lr
|
|
self.clip_init = clip
|
|
self.ent_coef_init = ent_coef
|
|
|
|
self.clip = clip
|
|
self.vf_coef = vf_coef
|
|
self.ent_coef = ent_coef
|
|
self.max_grad_norm = max_grad_norm
|
|
self.n_epochs = n_epochs
|
|
self.batch_size = batch_size
|
|
|
|
# Schedule and refinement flags
|
|
self.anneal_lr = anneal_lr
|
|
self.anneal_ent = anneal_ent
|
|
self.clip_floor = clip_floor # None = no clip annealing
|
|
self.target_kl = target_kl # None = no KL early stopping
|
|
self.use_data_aug = use_data_aug
|
|
self.aug_pad = aug_pad
|
|
|
|
@torch.no_grad()
|
|
def act(self, obs):
|
|
"""Sample one action for the rollout phase."""
|
|
obs_t = torch.as_tensor(obs, device=self.device).unsqueeze(0)
|
|
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
|
|
return action.item(), log_prob.item(), value.item()
|
|
|
|
@torch.no_grad()
|
|
def act_batch(self, obs_batch):
|
|
"""Vectorised act for n_envs obs at once."""
|
|
obs_t = torch.as_tensor(obs_batch, device=self.device)
|
|
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
|
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
|
|
return (
|
|
action.cpu().numpy(),
|
|
log_prob.cpu().numpy(),
|
|
value.cpu().numpy(),
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def evaluate_value_batch(self, obs_batch):
|
|
obs_t = torch.as_tensor(obs_batch, device=self.device)
|
|
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
|
_, value = self.net(obs_t)
|
|
return value.cpu().numpy()
|
|
|
|
def _random_shift(self, obs):
|
|
"""Random-shift data augmentation (DrQ / RAD style), fully vectorised.
|
|
|
|
Pads the (B, C, H, W) image by ``aug_pad`` on each side using
|
|
replicate padding, then crops a random H x W window per sample.
|
|
Uses a single grid_sample call instead of a Python loop.
|
|
"""
|
|
n, c, h, w = obs.shape
|
|
pad = self.aug_pad
|
|
x = F.pad(obs.float(), (pad, pad, pad, pad), mode="replicate")
|
|
# Per-sample random integer offsets in [0, 2*pad]
|
|
h_off = torch.randint(0, 2 * pad + 1, (n,), device=obs.device)
|
|
w_off = torch.randint(0, 2 * pad + 1, (n,), device=obs.device)
|
|
|
|
# Build per-sample affine grid that translates by the offset.
|
|
# Padded image is (h + 2*pad) x (w + 2*pad). To crop a (h x w) window
|
|
# starting at (h_off, w_off), we sample at normalized coords:
|
|
# y_norm in [(h_off + 0.5)/H' * 2 - 1, (h_off + h - 0.5)/H' * 2 - 1]
|
|
# where H' = h + 2*pad. We build that grid in one shot.
|
|
Hp = h + 2 * pad
|
|
Wp = w + 2 * pad
|
|
# Base grid for an (h, w) crop in normalized [-1, 1] coords on the padded image
|
|
ys = torch.arange(h, device=obs.device, dtype=torch.float32)
|
|
xs = torch.arange(w, device=obs.device, dtype=torch.float32)
|
|
# (n, h)
|
|
y_indices = h_off.unsqueeze(1).float() + ys.unsqueeze(0)
|
|
# (n, w)
|
|
x_indices = w_off.unsqueeze(1).float() + xs.unsqueeze(0)
|
|
# Convert to normalized coords on the padded image, [-1, 1]
|
|
y_norm = (y_indices + 0.5) / Hp * 2.0 - 1.0 # (n, h)
|
|
x_norm = (x_indices + 0.5) / Wp * 2.0 - 1.0 # (n, w)
|
|
# Build (n, h, w, 2) grid: [..., 0] = x, [..., 1] = y per grid_sample API
|
|
grid = torch.stack(
|
|
[x_norm.unsqueeze(1).expand(n, h, w),
|
|
y_norm.unsqueeze(2).expand(n, h, w)],
|
|
dim=-1,
|
|
)
|
|
out = F.grid_sample(x, grid, mode="nearest", align_corners=False,
|
|
padding_mode="border")
|
|
return out.to(obs.dtype)
|
|
|
|
def update_vec(self, vec_buffer):
|
|
"""PPO update for a vectorised buffer (flattens n_steps * n_envs)."""
|
|
obs_shape = vec_buffer.obs_shape
|
|
b_obs_flat = vec_buffer.obs.reshape(-1, *obs_shape)
|
|
b_actions_flat = vec_buffer.actions.reshape(-1)
|
|
b_old_logp_flat = vec_buffer.log_probs.reshape(-1)
|
|
b_old_values_flat = vec_buffer.values.reshape(-1)
|
|
b_adv_flat = vec_buffer.advantages.reshape(-1)
|
|
b_ret_flat = vec_buffer.returns.reshape(-1)
|
|
|
|
pg_losses, v_losses, ent_losses, approx_kls, clip_fracs = [], [], [], [], []
|
|
epochs_completed = 0
|
|
early_stopped = False
|
|
|
|
for epoch in range(self.n_epochs):
|
|
epoch_kls = []
|
|
for idx in vec_buffer.get_minibatches(self.batch_size):
|
|
b_obs = b_obs_flat[idx]
|
|
b_actions = b_actions_flat[idx]
|
|
b_old_logp = b_old_logp_flat[idx]
|
|
b_old_values = b_old_values_flat[idx]
|
|
b_adv = b_adv_flat[idx]
|
|
b_ret = b_ret_flat[idx]
|
|
|
|
# Random shift data augmentation (refinement #4)
|
|
if self.use_data_aug:
|
|
b_obs = self._random_shift(b_obs)
|
|
|
|
# AMP 前向传播
|
|
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
|
_, new_logp, entropy, value = self.net.get_action_and_value(
|
|
b_obs, b_actions
|
|
)
|
|
|
|
log_ratio = new_logp - b_old_logp
|
|
ratio = log_ratio.exp()
|
|
|
|
# Clipped policy loss
|
|
surr1 = ratio * b_adv
|
|
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
|
|
policy_loss = -torch.min(surr1, surr2).mean()
|
|
|
|
# Clipped value loss (refinement #1, SB3 standard)
|
|
v_clipped = b_old_values + torch.clamp(
|
|
value - b_old_values, -self.clip, self.clip
|
|
)
|
|
v_loss_unclipped = (value - b_ret).pow(2)
|
|
v_loss_clipped = (v_clipped - b_ret).pow(2)
|
|
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
|
|
|
entropy_loss = entropy.mean()
|
|
|
|
loss = (
|
|
policy_loss
|
|
+ self.vf_coef * value_loss
|
|
- self.ent_coef * entropy_loss
|
|
)
|
|
|
|
# AMP 反向传播
|
|
self.optim.zero_grad()
|
|
self.scaler.scale(loss).backward()
|
|
self.scaler.unscale_(self.optim)
|
|
nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
|
|
self.scaler.step(self.optim)
|
|
self.scaler.update()
|
|
|
|
with torch.no_grad():
|
|
approx_kl = ((ratio - 1) - log_ratio).mean().item()
|
|
clip_frac = ((ratio - 1.0).abs() > self.clip).float().mean().item()
|
|
|
|
pg_losses.append(policy_loss.item())
|
|
v_losses.append(value_loss.item())
|
|
ent_losses.append(entropy_loss.item())
|
|
approx_kls.append(approx_kl)
|
|
clip_fracs.append(clip_frac)
|
|
epoch_kls.append(approx_kl)
|
|
|
|
epochs_completed += 1
|
|
# KL early stopping (refinement #2): stop epochs if mean KL exceeds 1.5 * target
|
|
if self.target_kl is not None and len(epoch_kls) > 0:
|
|
if sum(epoch_kls) / len(epoch_kls) > 1.5 * self.target_kl:
|
|
early_stopped = True
|
|
break
|
|
|
|
return {
|
|
"policy_loss": sum(pg_losses) / len(pg_losses),
|
|
"value_loss": sum(v_losses) / len(v_losses),
|
|
"entropy": sum(ent_losses) / len(ent_losses),
|
|
"approx_kl": sum(approx_kls) / len(approx_kls),
|
|
"clip_frac": sum(clip_fracs) / len(clip_fracs),
|
|
"epochs_completed": epochs_completed,
|
|
"early_stopped": float(early_stopped),
|
|
"current_clip": self.clip,
|
|
}
|
|
|
|
def save(self, path):
|
|
torch.save(self.net.state_dict(), path)
|
|
|
|
def load(self, path):
|
|
self.net.load_state_dict(torch.load(path, map_location=self.device))
|
|
|
|
def step_schedule(self, progress, ent_floor=0.0, lr_floor=0.0):
|
|
"""Linearly decay lr / ent_coef / clip toward floors over training.
|
|
|
|
- LR: lr_init -> lr_floor
|
|
- Entropy coefficient: ent_coef_init -> ent_floor (preserves exploration)
|
|
- Clip range: clip_init -> clip_floor (only if clip_floor is set)
|
|
"""
|
|
progress = min(max(progress, 0.0), 1.0)
|
|
if self.anneal_lr:
|
|
target_lr = self.lr_init * (1.0 - progress)
|
|
for g in self.optim.param_groups:
|
|
g["lr"] = max(target_lr, lr_floor)
|
|
if self.anneal_ent:
|
|
target_ent = self.ent_coef_init * (1.0 - progress)
|
|
self.ent_coef = max(target_ent, ent_floor)
|
|
# Clip range schedule (refinement #6)
|
|
if self.clip_floor is not None:
|
|
target_clip = self.clip_init * (1.0 - progress) + self.clip_floor * progress
|
|
self.clip = max(target_clip, self.clip_floor)
|