"""PPO agent: clipped surrogate objective + value loss + entropy bonus. Implements the PPO-Clip algorithm (Schulman et al. 2017) on top of our shared-CNN ActorCritic network and a vectorised rollout buffer. Includes production-grade refinements catalogued in *The 37 Implementation Details of PPO* (Huang et al. 2022) and *RAD* (Laskin et al. 2020): - Clipped value-function loss (SB3 standard) - KL early stopping within update epochs (target_kl) - Linear schedule for clip range (clip_init -> clip_floor) - Random-shift data augmentation on observations during the update - Linear annealing of learning rate and entropy coefficient with floors - AMP mixed precision training for GPU acceleration Public API: - PPOAgent.act(obs) -> (action, log_prob, value) - PPOAgent.act_batch(obs_batch) -> batched act for n_envs obs - PPOAgent.evaluate_value_batch(obs) -> bootstrap value for GAE - PPOAgent.update_vec(buffer) -> PPO update over vectorised rollout - PPOAgent.step_schedule(progress) -> linear LR/entropy/clip annealing - PPOAgent.save / load -> state_dict checkpoints """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from src.networks import ActorCritic class PPOAgent: """PPO-Clip agent for discrete action spaces.""" def __init__( self, n_actions=5, lr=2.5e-4, clip=0.2, vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5, n_epochs=6, batch_size=128, device=None, anneal_lr=False, anneal_ent=False, clip_floor=None, target_kl=None, use_data_aug=False, aug_pad=4, use_amp=True, ): dev = device or ("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device(dev) if isinstance(dev, str) else dev self.net = ActorCritic(n_actions=n_actions).to(self.device) self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5) # AMP 混合精度训练 self.use_amp = use_amp and self.device.type == 'cuda' self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp) # Save initial values for scheduling self.lr_init = lr self.clip_init = clip self.ent_coef_init = ent_coef self.clip = clip self.vf_coef = vf_coef self.ent_coef = ent_coef self.max_grad_norm = max_grad_norm self.n_epochs = n_epochs self.batch_size = batch_size # Schedule and refinement flags self.anneal_lr = anneal_lr self.anneal_ent = anneal_ent self.clip_floor = clip_floor # None = no clip annealing self.target_kl = target_kl # None = no KL early stopping self.use_data_aug = use_data_aug self.aug_pad = aug_pad @torch.no_grad() def act(self, obs): """Sample one action for the rollout phase.""" obs_t = torch.as_tensor(obs, device=self.device).unsqueeze(0) action, log_prob, _, value = self.net.get_action_and_value(obs_t) return action.item(), log_prob.item(), value.item() @torch.no_grad() def act_batch(self, obs_batch): """Vectorised act for n_envs obs at once.""" obs_t = torch.as_tensor(obs_batch, device=self.device) with torch.amp.autocast('cuda', enabled=self.use_amp): action, log_prob, _, value = self.net.get_action_and_value(obs_t) return ( action.cpu().numpy(), log_prob.cpu().numpy(), value.cpu().numpy(), ) @torch.no_grad() def evaluate_value_batch(self, obs_batch): obs_t = torch.as_tensor(obs_batch, device=self.device) with torch.amp.autocast('cuda', enabled=self.use_amp): _, value = self.net(obs_t) return value.cpu().numpy() def _random_shift(self, obs): """Random-shift data augmentation (DrQ / RAD style), fully vectorised. Pads the (B, C, H, W) image by ``aug_pad`` on each side using replicate padding, then crops a random H x W window per sample. Uses a single grid_sample call instead of a Python loop. """ n, c, h, w = obs.shape pad = self.aug_pad x = F.pad(obs.float(), (pad, pad, pad, pad), mode="replicate") # Per-sample random integer offsets in [0, 2*pad] h_off = torch.randint(0, 2 * pad + 1, (n,), device=obs.device) w_off = torch.randint(0, 2 * pad + 1, (n,), device=obs.device) # Build per-sample affine grid that translates by the offset. # Padded image is (h + 2*pad) x (w + 2*pad). To crop a (h x w) window # starting at (h_off, w_off), we sample at normalized coords: # y_norm in [(h_off + 0.5)/H' * 2 - 1, (h_off + h - 0.5)/H' * 2 - 1] # where H' = h + 2*pad. We build that grid in one shot. Hp = h + 2 * pad Wp = w + 2 * pad # Base grid for an (h, w) crop in normalized [-1, 1] coords on the padded image ys = torch.arange(h, device=obs.device, dtype=torch.float32) xs = torch.arange(w, device=obs.device, dtype=torch.float32) # (n, h) y_indices = h_off.unsqueeze(1).float() + ys.unsqueeze(0) # (n, w) x_indices = w_off.unsqueeze(1).float() + xs.unsqueeze(0) # Convert to normalized coords on the padded image, [-1, 1] y_norm = (y_indices + 0.5) / Hp * 2.0 - 1.0 # (n, h) x_norm = (x_indices + 0.5) / Wp * 2.0 - 1.0 # (n, w) # Build (n, h, w, 2) grid: [..., 0] = x, [..., 1] = y per grid_sample API grid = torch.stack( [x_norm.unsqueeze(1).expand(n, h, w), y_norm.unsqueeze(2).expand(n, h, w)], dim=-1, ) out = F.grid_sample(x, grid, mode="nearest", align_corners=False, padding_mode="border") return out.to(obs.dtype) def update_vec(self, vec_buffer): """PPO update for a vectorised buffer (flattens n_steps * n_envs).""" obs_shape = vec_buffer.obs_shape b_obs_flat = vec_buffer.obs.reshape(-1, *obs_shape) b_actions_flat = vec_buffer.actions.reshape(-1) b_old_logp_flat = vec_buffer.log_probs.reshape(-1) b_old_values_flat = vec_buffer.values.reshape(-1) b_adv_flat = vec_buffer.advantages.reshape(-1) b_ret_flat = vec_buffer.returns.reshape(-1) pg_losses, v_losses, ent_losses, approx_kls, clip_fracs = [], [], [], [], [] epochs_completed = 0 early_stopped = False for epoch in range(self.n_epochs): epoch_kls = [] for idx in vec_buffer.get_minibatches(self.batch_size): b_obs = b_obs_flat[idx] b_actions = b_actions_flat[idx] b_old_logp = b_old_logp_flat[idx] b_old_values = b_old_values_flat[idx] b_adv = b_adv_flat[idx] b_ret = b_ret_flat[idx] # Random shift data augmentation (refinement #4) if self.use_data_aug: b_obs = self._random_shift(b_obs) # AMP 前向传播 with torch.amp.autocast('cuda', enabled=self.use_amp): _, new_logp, entropy, value = self.net.get_action_and_value( b_obs, b_actions ) log_ratio = new_logp - b_old_logp ratio = log_ratio.exp() # Clipped policy loss surr1 = ratio * b_adv surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv policy_loss = -torch.min(surr1, surr2).mean() # Clipped value loss (refinement #1, SB3 standard) v_clipped = b_old_values + torch.clamp( value - b_old_values, -self.clip, self.clip ) v_loss_unclipped = (value - b_ret).pow(2) v_loss_clipped = (v_clipped - b_ret).pow(2) value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean() entropy_loss = entropy.mean() loss = ( policy_loss + self.vf_coef * value_loss - self.ent_coef * entropy_loss ) # AMP 反向传播 self.optim.zero_grad() self.scaler.scale(loss).backward() self.scaler.unscale_(self.optim) nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm) self.scaler.step(self.optim) self.scaler.update() with torch.no_grad(): approx_kl = ((ratio - 1) - log_ratio).mean().item() clip_frac = ((ratio - 1.0).abs() > self.clip).float().mean().item() pg_losses.append(policy_loss.item()) v_losses.append(value_loss.item()) ent_losses.append(entropy_loss.item()) approx_kls.append(approx_kl) clip_fracs.append(clip_frac) epoch_kls.append(approx_kl) epochs_completed += 1 # KL early stopping (refinement #2): stop epochs if mean KL exceeds 1.5 * target if self.target_kl is not None and len(epoch_kls) > 0: if sum(epoch_kls) / len(epoch_kls) > 1.5 * self.target_kl: early_stopped = True break return { "policy_loss": sum(pg_losses) / len(pg_losses), "value_loss": sum(v_losses) / len(v_losses), "entropy": sum(ent_losses) / len(ent_losses), "approx_kl": sum(approx_kls) / len(approx_kls), "clip_frac": sum(clip_fracs) / len(clip_fracs), "epochs_completed": epochs_completed, "early_stopped": float(early_stopped), "current_clip": self.clip, } def save(self, path): torch.save(self.net.state_dict(), path) def load(self, path): self.net.load_state_dict(torch.load(path, map_location=self.device)) def step_schedule(self, progress, ent_floor=0.0, lr_floor=0.0): """Linearly decay lr / ent_coef / clip toward floors over training. - LR: lr_init -> lr_floor - Entropy coefficient: ent_coef_init -> ent_floor (preserves exploration) - Clip range: clip_init -> clip_floor (only if clip_floor is set) """ progress = min(max(progress, 0.0), 1.0) if self.anneal_lr: target_lr = self.lr_init * (1.0 - progress) for g in self.optim.param_groups: g["lr"] = max(target_lr, lr_floor) if self.anneal_ent: target_ent = self.ent_coef_init * (1.0 - progress) self.ent_coef = max(target_ent, ent_floor) # Clip range schedule (refinement #6) if self.clip_floor is not None: target_clip = self.clip_init * (1.0 - progress) + self.clip_floor * progress self.clip = max(target_clip, self.clip_floor)