feat: 重构项目结构并添加向量化PPO训练与评估脚本
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率 - 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计 - 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py) - 提供详细的文档和开发日志,包含问题解决记录和实验分析 - 移除旧版项目文件,统一项目结构到CW1_id_name目录下
This commit is contained in:
@@ -0,0 +1,96 @@
|
||||
"""Environment wrappers for CarRacing-v3.
|
||||
|
||||
We stack four standard wrappers on top of the raw env:
|
||||
- SkipFrame: repeat each action k times to reduce decision frequency
|
||||
- GrayScaleResize: RGB(96, 96, 3) -> Gray(84, 84) to shrink the input
|
||||
- FrameStack: stack the last k frames so the agent can perceive motion
|
||||
- make_env: factory that returns a fully wrapped environment
|
||||
|
||||
After wrapping, an observation has shape (4, 84, 84) uint8.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import cv2
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SkipFrame(gym.Wrapper):
|
||||
"""Repeat the action ``k`` times and accumulate the rewards."""
|
||||
|
||||
def __init__(self, env: gym.Env, k: int = 4):
|
||||
super().__init__(env)
|
||||
self.k = k
|
||||
|
||||
def step(self, action):
|
||||
total_reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
info = {}
|
||||
obs = None
|
||||
for _ in range(self.k):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
total_reward += reward
|
||||
if terminated or truncated:
|
||||
break
|
||||
return obs, total_reward, terminated, truncated, info
|
||||
|
||||
|
||||
class GrayScaleResize(gym.ObservationWrapper):
|
||||
"""Convert RGB frames to grayscale and resize to ``size`` x ``size``."""
|
||||
|
||||
def __init__(self, env: gym.Env, size: int = 84):
|
||||
super().__init__(env)
|
||||
self.size = size
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=(size, size), dtype=np.uint8
|
||||
)
|
||||
|
||||
def observation(self, obs):
|
||||
gray = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
|
||||
resized = cv2.resize(gray, (self.size, self.size), interpolation=cv2.INTER_AREA)
|
||||
return resized
|
||||
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
"""Stack the most recent ``k`` frames along a new leading axis."""
|
||||
|
||||
def __init__(self, env: gym.Env, k: int = 4):
|
||||
super().__init__(env)
|
||||
self.k = k
|
||||
self.frames = deque(maxlen=k)
|
||||
h, w = env.observation_space.shape
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=(k, h, w), dtype=np.uint8
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
self.frames.clear()
|
||||
for _ in range(self.k):
|
||||
self.frames.append(obs)
|
||||
return self._get_obs(), info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
self.frames.append(obs)
|
||||
return self._get_obs(), reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self):
|
||||
return np.stack(self.frames, axis=0)
|
||||
|
||||
|
||||
def make_env(seed: int = 0, skip: int = 4, size: int = 84, stack: int = 4) -> gym.Env:
|
||||
"""Create a CarRacing-v3 env with our standard preprocessing stack.
|
||||
|
||||
Returns an environment whose observations are uint8 arrays of shape
|
||||
(stack, size, size), ready to feed into a CNN backbone.
|
||||
"""
|
||||
env = gym.make("CarRacing-v3", continuous=False, render_mode="rgb_array")
|
||||
env = SkipFrame(env, k=skip)
|
||||
env = GrayScaleResize(env, size=size)
|
||||
env = FrameStack(env, k=stack)
|
||||
env.action_space.seed(seed)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Evaluation helpers for PPO on CarRacing-v3.
|
||||
|
||||
Functions:
|
||||
- evaluate_agent(agent, n_episodes, seed_start) -> list[float] returns
|
||||
- record_demo_video(agent, out_path, seed) -> save mp4 of one episode
|
||||
- load_tb_scalars(run_dir, tag) -> (steps, values) from TensorBoard
|
||||
- plot_eval_bar(returns, baseline, save_path) -> bar chart of eval returns
|
||||
- plot_training_curves(run_dirs, labels, save_path) -> 6-panel curves figure
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import gymnasium as gym
|
||||
import imageio
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from src.env_wrappers import FrameStack, GrayScaleResize, SkipFrame, make_env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Numerical evaluation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def evaluate_agent(agent, n_episodes=20, seed_start=1000, deterministic=False):
|
||||
"""Roll the agent for n_episodes on freshly-seeded envs.
|
||||
|
||||
Args:
|
||||
agent: a PPOAgent (loaded with a checkpoint).
|
||||
n_episodes: how many evaluation episodes.
|
||||
seed_start: starting seed; each ep uses seed_start + ep.
|
||||
deterministic: if True, take argmax over policy logits instead of
|
||||
sampling (slightly higher mean, slightly lower variance).
|
||||
|
||||
Returns:
|
||||
list[float] of per-episode returns.
|
||||
"""
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
|
||||
returns = []
|
||||
env = make_env(seed=seed_start)
|
||||
for ep in range(n_episodes):
|
||||
obs, _ = env.reset(seed=seed_start + ep)
|
||||
ep_return, done = 0.0, False
|
||||
while not done:
|
||||
if deterministic:
|
||||
obs_t = torch.as_tensor(obs, device=agent.device).unsqueeze(0)
|
||||
logits, _ = agent.net(obs_t)
|
||||
action = int(logits.argmax(dim=-1).item())
|
||||
else:
|
||||
action, _, _ = agent.act(obs)
|
||||
obs, r, term, trunc, _ = env.step(action)
|
||||
ep_return += r
|
||||
done = term or trunc
|
||||
returns.append(ep_return)
|
||||
env.close()
|
||||
return returns
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Video recording
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _DualEnv:
|
||||
"""Step a wrapped env (for the agent) and a raw env (for nice video) in lockstep."""
|
||||
|
||||
def __init__(self, seed):
|
||||
self.wrapped = make_env(seed=seed)
|
||||
self.raw = gym.make("CarRacing-v3", continuous=False, render_mode="rgb_array")
|
||||
# Don't pre-reset raw env here; reset() below will do it once cleanly.
|
||||
|
||||
def reset(self, seed):
|
||||
# Re-create wrapped to ensure both envs see the EXACT same first reset
|
||||
# under the given seed. This mirrors how the eval pipeline scores seeds.
|
||||
self.wrapped = make_env(seed=seed)
|
||||
wrapped_obs, _ = self.wrapped.reset(seed=seed)
|
||||
raw_obs, _ = self.raw.reset(seed=seed)
|
||||
return wrapped_obs, raw_obs
|
||||
|
||||
def step(self, action):
|
||||
wrapped_obs, _, term, trunc, _ = self.wrapped.step(action)
|
||||
# SkipFrame inside wrapped env runs 4 raw frames per call; mirror that.
|
||||
raw_frames = []
|
||||
raw_done = False
|
||||
for _ in range(4):
|
||||
if not raw_done:
|
||||
raw_obs, _, t, tr, _ = self.raw.step(action)
|
||||
raw_done = t or tr
|
||||
raw_frames.append(raw_obs.copy())
|
||||
return wrapped_obs, raw_frames, (term or trunc) or raw_done
|
||||
|
||||
def close(self):
|
||||
self.wrapped.close()
|
||||
self.raw.close()
|
||||
|
||||
|
||||
def record_demo_video(agent, out_path, seed=42, fps=30, max_steps=600):
|
||||
"""Record one evaluation episode as an mp4 using the original RGB renderer."""
|
||||
out_path = Path(out_path)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
denv = _DualEnv(seed=seed)
|
||||
wrapped_obs, _ = denv.reset(seed=seed)
|
||||
|
||||
frames = []
|
||||
total_r = 0.0
|
||||
for _ in range(max_steps):
|
||||
action, _, _ = agent.act(wrapped_obs)
|
||||
wrapped_obs, raw_frames, done = denv.step(action)
|
||||
frames.extend(raw_frames)
|
||||
if done:
|
||||
break
|
||||
|
||||
denv.close()
|
||||
imageio.mimsave(str(out_path), frames, fps=fps)
|
||||
return len(frames), out_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TensorBoard data extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_tb_scalars(run_dir, tag):
|
||||
"""Read a single scalar tag from a TensorBoard run directory.
|
||||
|
||||
Returns (steps_list, values_list); empty if tag is absent.
|
||||
"""
|
||||
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
||||
|
||||
ea = EventAccumulator(str(run_dir))
|
||||
ea.Reload()
|
||||
if tag not in ea.Tags()["scalars"]:
|
||||
return [], []
|
||||
events = ea.Scalars(tag)
|
||||
return [e.step for e in events], [e.value for e in events]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plotting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def plot_eval_bar(returns, baseline, save_path, title=None):
|
||||
"""Bar chart of per-episode returns + baseline reference line."""
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mean_r = float(np.mean(returns))
|
||||
std_r = float(np.std(returns))
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
xs = np.arange(len(returns))
|
||||
ax.bar(xs, returns, color="steelblue", edgecolor="black", alpha=0.7)
|
||||
ax.axhline(y=mean_r, color="red", linestyle="--",
|
||||
label=f"Mean = {mean_r:.1f} ± {std_r:.1f}")
|
||||
ax.axhline(y=baseline, color="gray", linestyle=":",
|
||||
label=f"Random baseline = {baseline:.1f}")
|
||||
ax.set_xlabel("Evaluation episode")
|
||||
ax.set_ylabel("Episode return")
|
||||
ax.set_title(title or f"Evaluation returns over {len(returns)} unseen seeds")
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=150)
|
||||
plt.close(fig)
|
||||
return save_path
|
||||
|
||||
|
||||
def plot_training_curves(run_dirs, labels, save_path,
|
||||
tags=None, smooth_window=10):
|
||||
"""Multi-run multi-panel training curves (one panel per tag).
|
||||
|
||||
Args:
|
||||
run_dirs: list[Path] — TensorBoard run directories
|
||||
labels: list[str] — label per run for the legend
|
||||
save_path: where to write the PNG
|
||||
tags: list[str] of TensorBoard scalar tags to plot. Defaults to a
|
||||
standard 6-panel set.
|
||||
smooth_window: rolling-mean window for visual smoothing (1 = none).
|
||||
"""
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if tags is None:
|
||||
tags = [
|
||||
"episode/avg_return_100",
|
||||
"losses/value_loss",
|
||||
"losses/entropy",
|
||||
"losses/approx_kl",
|
||||
"losses/clip_frac",
|
||||
"episode/length",
|
||||
]
|
||||
|
||||
n = len(tags)
|
||||
rows = (n + 2) // 3
|
||||
fig, axes = plt.subplots(rows, 3, figsize=(15, 4 * rows))
|
||||
axes = axes.flatten()
|
||||
|
||||
for ax, tag in zip(axes, tags):
|
||||
for run_dir, label in zip(run_dirs, labels):
|
||||
steps, values = load_tb_scalars(run_dir, tag)
|
||||
if not steps:
|
||||
continue
|
||||
if smooth_window > 1 and len(values) > smooth_window:
|
||||
values = np.convolve(
|
||||
values, np.ones(smooth_window) / smooth_window, mode="valid"
|
||||
)
|
||||
steps = steps[smooth_window - 1:]
|
||||
ax.plot(steps, values, label=label, alpha=0.85)
|
||||
ax.set_title(tag)
|
||||
ax.set_xlabel("Env steps")
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.legend(fontsize=8)
|
||||
|
||||
# Hide unused axes
|
||||
for ax in axes[n:]:
|
||||
ax.set_visible(False)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=150)
|
||||
plt.close(fig)
|
||||
return save_path
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Shared-CNN Actor-Critic network for discrete CarRacing-v3 PPO.
|
||||
|
||||
Input : uint8 tensor (B, 4, 84, 84), values in [0, 255]
|
||||
Output :
|
||||
- logits (B, n_actions) for a Categorical policy
|
||||
- value (B,) scalar state-value V(s)
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Categorical
|
||||
|
||||
|
||||
def layer_init(layer, std=math.sqrt(2), bias=0.0):
|
||||
"""Orthogonal init with configurable gain (PPO best practice)."""
|
||||
nn.init.orthogonal_(layer.weight, std)
|
||||
nn.init.constant_(layer.bias, bias)
|
||||
return layer
|
||||
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
"""Shared-CNN actor-critic for discrete visual control."""
|
||||
|
||||
def __init__(self, n_actions=5):
|
||||
super().__init__()
|
||||
self.cnn = nn.Sequential(
|
||||
layer_init(nn.Conv2d(4, 32, kernel_size=8, stride=4)),
|
||||
nn.ReLU(),
|
||||
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
|
||||
nn.ReLU(),
|
||||
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
layer_init(nn.Linear(64 * 7 * 7, 512)),
|
||||
nn.ReLU(),
|
||||
)
|
||||
# Small std on the actor head -> initial policy is nearly uniform
|
||||
self.actor = layer_init(nn.Linear(512, n_actions), std=0.01)
|
||||
# Standard std on the critic head
|
||||
self.critic = layer_init(nn.Linear(512, 1), std=1.0)
|
||||
|
||||
def forward(self, x):
|
||||
# uint8 [0, 255] -> float32 [0, 1]
|
||||
x = x.float() / 255.0
|
||||
feat = self.cnn(x)
|
||||
logits = self.actor(feat)
|
||||
value = self.critic(feat).squeeze(-1)
|
||||
return logits, value
|
||||
|
||||
def get_action_and_value(self, x, action=None):
|
||||
logits, value = self(x)
|
||||
dist = Categorical(logits=logits)
|
||||
if action is None:
|
||||
action = dist.sample()
|
||||
log_prob = dist.log_prob(action)
|
||||
entropy = dist.entropy()
|
||||
return action, log_prob, entropy, value
|
||||
@@ -0,0 +1,255 @@
|
||||
"""PPO agent: clipped surrogate objective + value loss + entropy bonus.
|
||||
|
||||
Implements the PPO-Clip algorithm (Schulman et al. 2017) on top of our
|
||||
shared-CNN ActorCritic network and a vectorised rollout buffer. Includes
|
||||
production-grade refinements catalogued in *The 37 Implementation Details
|
||||
of PPO* (Huang et al. 2022) and *RAD* (Laskin et al. 2020):
|
||||
|
||||
- Clipped value-function loss (SB3 standard)
|
||||
- KL early stopping within update epochs (target_kl)
|
||||
- Linear schedule for clip range (clip_init -> clip_floor)
|
||||
- Random-shift data augmentation on observations during the update
|
||||
- Linear annealing of learning rate and entropy coefficient with floors
|
||||
|
||||
Public API:
|
||||
- PPOAgent.act(obs) -> (action, log_prob, value)
|
||||
- PPOAgent.act_batch(obs_batch) -> batched act for n_envs obs
|
||||
- PPOAgent.evaluate_value_batch(obs) -> bootstrap value for GAE
|
||||
- PPOAgent.update_vec(buffer) -> PPO update over vectorised rollout
|
||||
- PPOAgent.step_schedule(progress) -> linear LR/entropy/clip annealing
|
||||
- PPOAgent.save / load -> state_dict checkpoints
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
from src.networks import ActorCritic
|
||||
|
||||
|
||||
class PPOAgent:
|
||||
"""PPO-Clip agent for discrete action spaces."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_actions=5,
|
||||
lr=2.5e-4,
|
||||
clip=0.2,
|
||||
vf_coef=0.5,
|
||||
ent_coef=0.01,
|
||||
max_grad_norm=0.5,
|
||||
n_epochs=6,
|
||||
batch_size=128,
|
||||
device=None,
|
||||
anneal_lr=False,
|
||||
anneal_ent=False,
|
||||
clip_floor=None,
|
||||
target_kl=None,
|
||||
use_data_aug=False,
|
||||
aug_pad=4,
|
||||
):
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.net = ActorCritic(n_actions=n_actions).to(self.device)
|
||||
self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5)
|
||||
|
||||
# Save initial values for scheduling
|
||||
self.lr_init = lr
|
||||
self.clip_init = clip
|
||||
self.ent_coef_init = ent_coef
|
||||
|
||||
self.clip = clip
|
||||
self.vf_coef = vf_coef
|
||||
self.ent_coef = ent_coef
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.n_epochs = n_epochs
|
||||
self.batch_size = batch_size
|
||||
|
||||
# Schedule and refinement flags
|
||||
self.anneal_lr = anneal_lr
|
||||
self.anneal_ent = anneal_ent
|
||||
self.clip_floor = clip_floor # None = no clip annealing
|
||||
self.target_kl = target_kl # None = no KL early stopping
|
||||
self.use_data_aug = use_data_aug
|
||||
self.aug_pad = aug_pad
|
||||
|
||||
@torch.no_grad()
|
||||
def act(self, obs):
|
||||
"""Sample one action for the rollout phase."""
|
||||
obs_t = torch.as_tensor(obs, device=self.device).unsqueeze(0)
|
||||
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
|
||||
return action.item(), log_prob.item(), value.item()
|
||||
|
||||
@torch.no_grad()
|
||||
def act_batch(self, obs_batch):
|
||||
"""Vectorised act for n_envs obs at once."""
|
||||
obs_t = torch.as_tensor(obs_batch, device=self.device)
|
||||
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
|
||||
return (
|
||||
action.cpu().numpy(),
|
||||
log_prob.cpu().numpy(),
|
||||
value.cpu().numpy(),
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_value_batch(self, obs_batch):
|
||||
obs_t = torch.as_tensor(obs_batch, device=self.device)
|
||||
_, value = self.net(obs_t)
|
||||
return value.cpu().numpy()
|
||||
|
||||
def _random_shift(self, obs):
|
||||
"""Random-shift data augmentation (DrQ / RAD style), fully vectorised.
|
||||
|
||||
Pads the (B, C, H, W) image by ``aug_pad`` on each side using
|
||||
replicate padding, then crops a random H x W window per sample.
|
||||
Uses a single grid_sample call instead of a Python loop.
|
||||
"""
|
||||
n, c, h, w = obs.shape
|
||||
pad = self.aug_pad
|
||||
x = F.pad(obs.float(), (pad, pad, pad, pad), mode="replicate")
|
||||
# Per-sample random integer offsets in [0, 2*pad]
|
||||
h_off = torch.randint(0, 2 * pad + 1, (n,), device=obs.device)
|
||||
w_off = torch.randint(0, 2 * pad + 1, (n,), device=obs.device)
|
||||
|
||||
# Build per-sample affine grid that translates by the offset.
|
||||
# Padded image is (h + 2*pad) x (w + 2*pad). To crop a (h x w) window
|
||||
# starting at (h_off, w_off), we sample at normalized coords:
|
||||
# y_norm in [(h_off + 0.5)/H' * 2 - 1, (h_off + h - 0.5)/H' * 2 - 1]
|
||||
# where H' = h + 2*pad. We build that grid in one shot.
|
||||
Hp = h + 2 * pad
|
||||
Wp = w + 2 * pad
|
||||
# Base grid for an (h, w) crop in normalized [-1, 1] coords on the padded image
|
||||
ys = torch.arange(h, device=obs.device, dtype=torch.float32)
|
||||
xs = torch.arange(w, device=obs.device, dtype=torch.float32)
|
||||
# (n, h)
|
||||
y_indices = h_off.unsqueeze(1).float() + ys.unsqueeze(0)
|
||||
# (n, w)
|
||||
x_indices = w_off.unsqueeze(1).float() + xs.unsqueeze(0)
|
||||
# Convert to normalized coords on the padded image, [-1, 1]
|
||||
y_norm = (y_indices + 0.5) / Hp * 2.0 - 1.0 # (n, h)
|
||||
x_norm = (x_indices + 0.5) / Wp * 2.0 - 1.0 # (n, w)
|
||||
# Build (n, h, w, 2) grid: [..., 0] = x, [..., 1] = y per grid_sample API
|
||||
grid = torch.stack(
|
||||
[x_norm.unsqueeze(1).expand(n, h, w),
|
||||
y_norm.unsqueeze(2).expand(n, h, w)],
|
||||
dim=-1,
|
||||
)
|
||||
out = F.grid_sample(x, grid, mode="nearest", align_corners=False,
|
||||
padding_mode="border")
|
||||
return out.to(obs.dtype)
|
||||
|
||||
def update_vec(self, vec_buffer):
|
||||
"""PPO update for a vectorised buffer (flattens n_steps * n_envs)."""
|
||||
obs_shape = vec_buffer.obs_shape
|
||||
b_obs_flat = vec_buffer.obs.reshape(-1, *obs_shape)
|
||||
b_actions_flat = vec_buffer.actions.reshape(-1)
|
||||
b_old_logp_flat = vec_buffer.log_probs.reshape(-1)
|
||||
b_old_values_flat = vec_buffer.values.reshape(-1)
|
||||
b_adv_flat = vec_buffer.advantages.reshape(-1)
|
||||
b_ret_flat = vec_buffer.returns.reshape(-1)
|
||||
|
||||
pg_losses, v_losses, ent_losses, approx_kls, clip_fracs = [], [], [], [], []
|
||||
epochs_completed = 0
|
||||
early_stopped = False
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
epoch_kls = []
|
||||
for idx in vec_buffer.get_minibatches(self.batch_size):
|
||||
b_obs = b_obs_flat[idx]
|
||||
b_actions = b_actions_flat[idx]
|
||||
b_old_logp = b_old_logp_flat[idx]
|
||||
b_old_values = b_old_values_flat[idx]
|
||||
b_adv = b_adv_flat[idx]
|
||||
b_ret = b_ret_flat[idx]
|
||||
|
||||
# Random shift data augmentation (refinement #4)
|
||||
if self.use_data_aug:
|
||||
b_obs = self._random_shift(b_obs)
|
||||
|
||||
_, new_logp, entropy, value = self.net.get_action_and_value(
|
||||
b_obs, b_actions
|
||||
)
|
||||
|
||||
log_ratio = new_logp - b_old_logp
|
||||
ratio = log_ratio.exp()
|
||||
|
||||
# Clipped policy loss
|
||||
surr1 = ratio * b_adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
|
||||
policy_loss = -torch.min(surr1, surr2).mean()
|
||||
|
||||
# Clipped value loss (refinement #1, SB3 standard)
|
||||
v_clipped = b_old_values + torch.clamp(
|
||||
value - b_old_values, -self.clip, self.clip
|
||||
)
|
||||
v_loss_unclipped = (value - b_ret).pow(2)
|
||||
v_loss_clipped = (v_clipped - b_ret).pow(2)
|
||||
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
||||
|
||||
entropy_loss = entropy.mean()
|
||||
|
||||
loss = (
|
||||
policy_loss
|
||||
+ self.vf_coef * value_loss
|
||||
- self.ent_coef * entropy_loss
|
||||
)
|
||||
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
|
||||
self.optim.step()
|
||||
|
||||
with torch.no_grad():
|
||||
approx_kl = ((ratio - 1) - log_ratio).mean().item()
|
||||
clip_frac = ((ratio - 1.0).abs() > self.clip).float().mean().item()
|
||||
|
||||
pg_losses.append(policy_loss.item())
|
||||
v_losses.append(value_loss.item())
|
||||
ent_losses.append(entropy_loss.item())
|
||||
approx_kls.append(approx_kl)
|
||||
clip_fracs.append(clip_frac)
|
||||
epoch_kls.append(approx_kl)
|
||||
|
||||
epochs_completed += 1
|
||||
# KL early stopping (refinement #2): stop epochs if mean KL exceeds 1.5 * target
|
||||
if self.target_kl is not None and len(epoch_kls) > 0:
|
||||
if sum(epoch_kls) / len(epoch_kls) > 1.5 * self.target_kl:
|
||||
early_stopped = True
|
||||
break
|
||||
|
||||
return {
|
||||
"policy_loss": sum(pg_losses) / len(pg_losses),
|
||||
"value_loss": sum(v_losses) / len(v_losses),
|
||||
"entropy": sum(ent_losses) / len(ent_losses),
|
||||
"approx_kl": sum(approx_kls) / len(approx_kls),
|
||||
"clip_frac": sum(clip_fracs) / len(clip_fracs),
|
||||
"epochs_completed": epochs_completed,
|
||||
"early_stopped": float(early_stopped),
|
||||
"current_clip": self.clip,
|
||||
}
|
||||
|
||||
def save(self, path):
|
||||
torch.save(self.net.state_dict(), path)
|
||||
|
||||
def load(self, path):
|
||||
self.net.load_state_dict(torch.load(path, map_location=self.device))
|
||||
|
||||
def step_schedule(self, progress, ent_floor=0.0, lr_floor=0.0):
|
||||
"""Linearly decay lr / ent_coef / clip toward floors over training.
|
||||
|
||||
- LR: lr_init -> lr_floor
|
||||
- Entropy coefficient: ent_coef_init -> ent_floor (preserves exploration)
|
||||
- Clip range: clip_init -> clip_floor (only if clip_floor is set)
|
||||
"""
|
||||
progress = min(max(progress, 0.0), 1.0)
|
||||
if self.anneal_lr:
|
||||
target_lr = self.lr_init * (1.0 - progress)
|
||||
for g in self.optim.param_groups:
|
||||
g["lr"] = max(target_lr, lr_floor)
|
||||
if self.anneal_ent:
|
||||
target_ent = self.ent_coef_init * (1.0 - progress)
|
||||
self.ent_coef = max(target_ent, ent_floor)
|
||||
# Clip range schedule (refinement #6)
|
||||
if self.clip_floor is not None:
|
||||
target_clip = self.clip_init * (1.0 - progress) + self.clip_floor * progress
|
||||
self.clip = max(target_clip, self.clip_floor)
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Small helpers used across the training and evaluation code."""
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def set_seed(seed: int = 42):
|
||||
"""Make Python / NumPy / PyTorch / CUDA randomness reproducible."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def format_seconds(s: float) -> str:
|
||||
"""Pretty-print a duration in seconds as HH:MM:SS."""
|
||||
h = int(s // 3600)
|
||||
m = int((s % 3600) // 60)
|
||||
sec = int(s % 60)
|
||||
return f"{h:02d}:{m:02d}:{sec:02d}"
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Vectorised CarRacing-v3 factory.
|
||||
|
||||
Wraps the existing single-env preprocessing stack inside a Gymnasium
|
||||
vector env so n_envs copies can step in parallel processes (async mode).
|
||||
"""
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from src.env_wrappers import FrameStack, GrayScaleResize, SkipFrame
|
||||
|
||||
|
||||
def _make_one(rank: int, seed: int):
|
||||
def _init():
|
||||
env = gym.make("CarRacing-v3", continuous=False)
|
||||
env = SkipFrame(env, k=4)
|
||||
env = GrayScaleResize(env, size=84)
|
||||
env = FrameStack(env, k=4)
|
||||
env.action_space.seed(seed + rank)
|
||||
return env
|
||||
return _init
|
||||
|
||||
|
||||
def make_vec_env(n_envs: int = 4, seed: int = 0, async_mode: bool = True):
|
||||
"""Build a vectorised CarRacing-v3 env.
|
||||
|
||||
Returns a gym.vector env whose obs has shape (n_envs, 4, 84, 84) uint8.
|
||||
"""
|
||||
fns = [_make_one(i, seed) for i in range(n_envs)]
|
||||
if async_mode:
|
||||
return gym.vector.AsyncVectorEnv(fns)
|
||||
return gym.vector.SyncVectorEnv(fns)
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Vectorised rollout buffer (n_steps, n_envs, ...) with GAE.
|
||||
|
||||
Uses CleanRL's indexing convention:
|
||||
dones[t] flags whether obs[t] is the FIRST obs of a fresh episode
|
||||
(i.e., the previous action terminated). GAE then uses dones[t+1]
|
||||
as the mask for V(s_{t+1}) at time t.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class VecRolloutBuffer:
|
||||
def __init__(self, n_steps, n_envs, obs_shape, device):
|
||||
self.n_steps = n_steps
|
||||
self.n_envs = n_envs
|
||||
self.obs_shape = obs_shape
|
||||
self.device = device
|
||||
|
||||
self.obs = torch.zeros(
|
||||
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, device=device
|
||||
)
|
||||
self.actions = torch.zeros((n_steps, n_envs), dtype=torch.long, device=device)
|
||||
self.log_probs = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
|
||||
self.rewards = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
|
||||
self.values = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
|
||||
self.dones = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
|
||||
|
||||
self.advantages = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
|
||||
self.returns = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
|
||||
|
||||
self.ptr = 0
|
||||
|
||||
def add(self, obs, action, log_prob, reward, value, done):
|
||||
i = self.ptr
|
||||
self.obs[i] = torch.as_tensor(obs, device=self.device)
|
||||
self.actions[i] = torch.as_tensor(action, device=self.device, dtype=torch.long)
|
||||
self.log_probs[i] = torch.as_tensor(log_prob, device=self.device, dtype=torch.float32)
|
||||
self.rewards[i] = torch.as_tensor(reward, device=self.device, dtype=torch.float32)
|
||||
self.values[i] = torch.as_tensor(value, device=self.device, dtype=torch.float32)
|
||||
self.dones[i] = torch.as_tensor(done, device=self.device, dtype=torch.float32)
|
||||
self.ptr += 1
|
||||
|
||||
def compute_gae(self, last_value, last_done, gamma=0.99, lam=0.95):
|
||||
last_value = torch.as_tensor(last_value, device=self.device, dtype=torch.float32)
|
||||
last_done = torch.as_tensor(last_done, device=self.device, dtype=torch.float32)
|
||||
|
||||
next_values = torch.cat([self.values[1:], last_value.unsqueeze(0)], dim=0)
|
||||
next_non_terminal = 1.0 - torch.cat([self.dones[1:], last_done.unsqueeze(0)], dim=0)
|
||||
|
||||
deltas = self.rewards + gamma * next_values * next_non_terminal - self.values
|
||||
|
||||
adv = torch.zeros((self.n_envs,), device=self.device)
|
||||
for t in reversed(range(self.n_steps)):
|
||||
adv = deltas[t] + gamma * lam * next_non_terminal[t] * adv
|
||||
self.advantages[t] = adv
|
||||
|
||||
self.returns = self.advantages + self.values
|
||||
|
||||
flat = self.advantages.reshape(-1)
|
||||
flat = (flat - flat.mean()) / (flat.std() + 1e-8)
|
||||
self.advantages = flat.reshape(self.n_steps, self.n_envs)
|
||||
|
||||
def get_minibatches(self, batch_size):
|
||||
total = self.n_steps * self.n_envs
|
||||
idx = torch.randperm(total, device=self.device)
|
||||
for start in range(0, total, batch_size):
|
||||
yield idx[start: start + batch_size]
|
||||
|
||||
def reset(self):
|
||||
self.ptr = 0
|
||||
Reference in New Issue
Block a user