perf: 为PPO和DQN添加GPU优化——AMP混合精度、pinned memory、torch.compile

- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090
- DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环
- train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-05 00:50:16 +08:00
parent ed0822966b
commit d5c9baffe6
7 changed files with 495 additions and 883 deletions
+41 -27
View File
@@ -10,6 +10,7 @@ of PPO* (Huang et al. 2022) and *RAD* (Laskin et al. 2020):
- Linear schedule for clip range (clip_init -> clip_floor) - Linear schedule for clip range (clip_init -> clip_floor)
- Random-shift data augmentation on observations during the update - Random-shift data augmentation on observations during the update
- Linear annealing of learning rate and entropy coefficient with floors - Linear annealing of learning rate and entropy coefficient with floors
- AMP mixed precision training for GPU acceleration
Public API: Public API:
- PPOAgent.act(obs) -> (action, log_prob, value) - PPOAgent.act(obs) -> (action, log_prob, value)
@@ -48,11 +49,17 @@ class PPOAgent:
target_kl=None, target_kl=None,
use_data_aug=False, use_data_aug=False,
aug_pad=4, aug_pad=4,
use_amp=True,
): ):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device(dev) if isinstance(dev, str) else dev
self.net = ActorCritic(n_actions=n_actions).to(self.device) self.net = ActorCritic(n_actions=n_actions).to(self.device)
self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5) self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5)
# AMP 混合精度训练
self.use_amp = use_amp and self.device.type == 'cuda'
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
# Save initial values for scheduling # Save initial values for scheduling
self.lr_init = lr self.lr_init = lr
self.clip_init = clip self.clip_init = clip
@@ -84,7 +91,8 @@ class PPOAgent:
def act_batch(self, obs_batch): def act_batch(self, obs_batch):
"""Vectorised act for n_envs obs at once.""" """Vectorised act for n_envs obs at once."""
obs_t = torch.as_tensor(obs_batch, device=self.device) obs_t = torch.as_tensor(obs_batch, device=self.device)
action, log_prob, _, value = self.net.get_action_and_value(obs_t) with torch.amp.autocast('cuda', enabled=self.use_amp):
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
return ( return (
action.cpu().numpy(), action.cpu().numpy(),
log_prob.cpu().numpy(), log_prob.cpu().numpy(),
@@ -94,7 +102,8 @@ class PPOAgent:
@torch.no_grad() @torch.no_grad()
def evaluate_value_batch(self, obs_batch): def evaluate_value_batch(self, obs_batch):
obs_t = torch.as_tensor(obs_batch, device=self.device) obs_t = torch.as_tensor(obs_batch, device=self.device)
_, value = self.net(obs_t) with torch.amp.autocast('cuda', enabled=self.use_amp):
_, value = self.net(obs_t)
return value.cpu().numpy() return value.cpu().numpy()
def _random_shift(self, obs): def _random_shift(self, obs):
@@ -166,38 +175,43 @@ class PPOAgent:
if self.use_data_aug: if self.use_data_aug:
b_obs = self._random_shift(b_obs) b_obs = self._random_shift(b_obs)
_, new_logp, entropy, value = self.net.get_action_and_value( # AMP 前向传播
b_obs, b_actions with torch.amp.autocast('cuda', enabled=self.use_amp):
) _, new_logp, entropy, value = self.net.get_action_and_value(
b_obs, b_actions
)
log_ratio = new_logp - b_old_logp log_ratio = new_logp - b_old_logp
ratio = log_ratio.exp() ratio = log_ratio.exp()
# Clipped policy loss # Clipped policy loss
surr1 = ratio * b_adv surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
policy_loss = -torch.min(surr1, surr2).mean() policy_loss = -torch.min(surr1, surr2).mean()
# Clipped value loss (refinement #1, SB3 standard) # Clipped value loss (refinement #1, SB3 standard)
v_clipped = b_old_values + torch.clamp( v_clipped = b_old_values + torch.clamp(
value - b_old_values, -self.clip, self.clip value - b_old_values, -self.clip, self.clip
) )
v_loss_unclipped = (value - b_ret).pow(2) v_loss_unclipped = (value - b_ret).pow(2)
v_loss_clipped = (v_clipped - b_ret).pow(2) v_loss_clipped = (v_clipped - b_ret).pow(2)
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean() value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
entropy_loss = entropy.mean() entropy_loss = entropy.mean()
loss = ( loss = (
policy_loss policy_loss
+ self.vf_coef * value_loss + self.vf_coef * value_loss
- self.ent_coef * entropy_loss - self.ent_coef * entropy_loss
) )
# AMP 反向传播
self.optim.zero_grad() self.optim.zero_grad()
loss.backward() self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optim)
nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm) nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
self.optim.step() self.scaler.step(self.optim)
self.scaler.update()
with torch.no_grad(): with torch.no_grad():
approx_kl = ((ratio - 1) - log_ratio).mean().item() approx_kl = ((ratio - 1) - log_ratio).mean().item()
+28 -6
View File
@@ -4,6 +4,8 @@ Uses CleanRL's indexing convention:
dones[t] flags whether obs[t] is the FIRST obs of a fresh episode dones[t] flags whether obs[t] is the FIRST obs of a fresh episode
(i.e., the previous action terminated). GAE then uses dones[t+1] (i.e., the previous action terminated). GAE then uses dones[t+1]
as the mask for V(s_{t+1}) at time t. as the mask for V(s_{t+1}) at time t.
Supports pinned memory for faster CPU→GPU transfer.
""" """
import torch import torch
@@ -16,6 +18,7 @@ class VecRolloutBuffer:
self.obs_shape = obs_shape self.obs_shape = obs_shape
self.device = device self.device = device
# 主存储在 GPU 上
self.obs = torch.zeros( self.obs = torch.zeros(
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, device=device (n_steps, n_envs, *obs_shape), dtype=torch.uint8, device=device
) )
@@ -28,16 +31,35 @@ class VecRolloutBuffer:
self.advantages = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device) self.advantages = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
self.returns = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device) self.returns = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
# Pinned memory 缓冲区(加速 CPU→GPU 传输)
self._obs_pin = torch.zeros(
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, pin_memory=True
)
self._actions_pin = torch.zeros((n_steps, n_envs), dtype=torch.long, pin_memory=True)
self._log_probs_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self._rewards_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self._values_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self._dones_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
self.ptr = 0 self.ptr = 0
def add(self, obs, action, log_prob, reward, value, done): def add(self, obs, action, log_prob, reward, value, done):
i = self.ptr i = self.ptr
self.obs[i] = torch.as_tensor(obs, device=self.device) # 先写入 pinned memory,再 non-blocking 传输到 GPU
self.actions[i] = torch.as_tensor(action, device=self.device, dtype=torch.long) self._obs_pin[i] = torch.as_tensor(obs)
self.log_probs[i] = torch.as_tensor(log_prob, device=self.device, dtype=torch.float32) self._actions_pin[i] = torch.as_tensor(action, dtype=torch.long)
self.rewards[i] = torch.as_tensor(reward, device=self.device, dtype=torch.float32) self._log_probs_pin[i] = torch.as_tensor(log_prob, dtype=torch.float32)
self.values[i] = torch.as_tensor(value, device=self.device, dtype=torch.float32) self._rewards_pin[i] = torch.as_tensor(reward, dtype=torch.float32)
self.dones[i] = torch.as_tensor(done, device=self.device, dtype=torch.float32) self._values_pin[i] = torch.as_tensor(value, dtype=torch.float32)
self._dones_pin[i] = torch.as_tensor(done, dtype=torch.float32)
# non_blocking 传输到 GPU
self.obs[i] = self._obs_pin[i].to(self.device, non_blocking=True)
self.actions[i] = self._actions_pin[i].to(self.device, non_blocking=True)
self.log_probs[i] = self._log_probs_pin[i].to(self.device, non_blocking=True)
self.rewards[i] = self._rewards_pin[i].to(self.device, non_blocking=True)
self.values[i] = self._values_pin[i].to(self.device, non_blocking=True)
self.dones[i] = self._dones_pin[i].to(self.device, non_blocking=True)
self.ptr += 1 self.ptr += 1
def compute_gae(self, last_value, last_done, gamma=0.99, lam=0.95): def compute_gae(self, last_value, last_done, gamma=0.99, lam=0.95):
+23 -5
View File
@@ -5,6 +5,11 @@ Usage (Windows):
python train_vec.py --n-envs 4 --total-steps 500000 --run-name vec_main \ python train_vec.py --n-envs 4 --total-steps 500000 --run-name vec_main \
--anneal-lr --anneal-ent --reward-clip 1.0 --anneal-lr --anneal-ent --reward-clip 1.0
Usage (Linux server with RTX 5090):
python train_vec.py --n-envs 16 --total-steps 2000000 --run-name vec_main \
--n-steps 512 --batch-size 512 --n-epochs 10 \
--anneal-lr --anneal-ent --reward-clip 1.0 --use-amp
The ``if __name__ == "__main__"`` guard at the bottom is mandatory on The ``if __name__ == "__main__"`` guard at the bottom is mandatory on
Windows for AsyncVectorEnv (otherwise child processes infinite-spawn). Windows for AsyncVectorEnv (otherwise child processes infinite-spawn).
""" """
@@ -26,11 +31,11 @@ from src.vec_rollout_buffer import VecRolloutBuffer
def parse_args(): def parse_args():
p = argparse.ArgumentParser() p = argparse.ArgumentParser()
p.add_argument("--total-steps", type=int, default=3_000_000) p.add_argument("--total-steps", type=int, default=2_000_000)
p.add_argument("--n-envs", type=int, default=8) p.add_argument("--n-envs", type=int, default=16)
p.add_argument("--n-steps", type=int, default=256) p.add_argument("--n-steps", type=int, default=512)
p.add_argument("--n-epochs", type=int, default=6) p.add_argument("--n-epochs", type=int, default=10)
p.add_argument("--batch-size", type=int, default=128) p.add_argument("--batch-size", type=int, default=512)
p.add_argument("--lr", type=float, default=2.5e-4) p.add_argument("--lr", type=float, default=2.5e-4)
p.add_argument("--gamma", type=float, default=0.99) p.add_argument("--gamma", type=float, default=0.99)
p.add_argument("--lam", type=float, default=0.95) p.add_argument("--lam", type=float, default=0.95)
@@ -56,6 +61,10 @@ def parse_args():
help="Apply random-shift augmentation to obs during PPO update") help="Apply random-shift augmentation to obs during PPO update")
p.add_argument("--sync-mode", action="store_true", p.add_argument("--sync-mode", action="store_true",
help="Use SyncVectorEnv (debug mode)") help="Use SyncVectorEnv (debug mode)")
p.add_argument("--use-amp", action="store_true",
help="Use AMP mixed precision training for GPU acceleration")
p.add_argument("--use-compile", action="store_true",
help="Use torch.compile for JIT compilation acceleration")
return p.parse_args() return p.parse_args()
@@ -94,7 +103,14 @@ def main():
clip_floor=args.clip_floor, clip_floor=args.clip_floor,
target_kl=args.target_kl, target_kl=args.target_kl,
use_data_aug=args.use_data_aug, use_data_aug=args.use_data_aug,
use_amp=args.use_amp,
) )
# torch.compile JIT 编译加速
if args.use_compile and hasattr(torch, 'compile'):
print("应用 torch.compile 加速...")
agent.net = torch.compile(agent.net)
print("torch.compile 完成")
buffer = VecRolloutBuffer( buffer = VecRolloutBuffer(
n_steps=args.n_steps, n_steps=args.n_steps,
n_envs=args.n_envs, n_envs=args.n_envs,
@@ -117,6 +133,8 @@ def main():
print(f"clip_floor={args.clip_floor} target_kl={args.target_kl} " print(f"clip_floor={args.clip_floor} target_kl={args.target_kl} "
f"use_data_aug={args.use_data_aug}") f"use_data_aug={args.use_data_aug}")
print(f"n_epochs={args.n_epochs} batch_size={args.batch_size}") print(f"n_epochs={args.n_epochs} batch_size={args.batch_size}")
print(f"AMP: {args.use_amp}")
print(f"Compile: {args.use_compile}")
print(f"Device: {agent.device}") print(f"Device: {agent.device}")
print(f"Logs: {run_dir}") print(f"Logs: {run_dir}")
print(f"Ckpts: {ckpt_dir}") print(f"Ckpts: {ckpt_dir}")
@@ -12,36 +12,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{ "source": "import sys\nimport os\nimport time\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom collections import deque\nfrom multiprocessing import Process, Queue\n\n# notebooks/ 的上级目录即项目根目录\nsys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), \"..\")))\n\nfrom src.network import QNetwork, DuelingQNetwork\nfrom src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer\nfrom src.agent import DQNAgent\nfrom src.utils import make_env, get_device\n\nprint(\"导入完成\")"
"name": "stdout",
"output_type": "stream",
"text": [
"导入完成\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"import time\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from collections import deque\n",
"\n",
"# notebooks/ 的上级目录即项目根目录\n",
"sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), \"..\")))\n",
"\n",
"from src.network import QNetwork, DuelingQNetwork\n",
"from src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer\n",
"from src.agent import DQNAgent\n",
"from src.utils import make_env, get_device\n",
"\n",
"print(\"导入完成\")"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
@@ -69,134 +43,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{ "source": "# ── 并行训练器(优化版) ──\nclass ParallelTrainer:\n def __init__(\n self, agent, envs, eval_env, num_envs,\n save_dir=\"models\", eval_freq=10000, save_freq=50000,\n num_eval_episodes=10, warmup_steps=10000,\n train_steps_per_update=1,\n ):\n self.agent = agent\n self.envs = envs\n self.eval_env = eval_env\n self.num_envs = num_envs\n self.save_dir = save_dir\n self.eval_freq = eval_freq\n self.save_freq = save_freq\n self.num_eval_episodes = num_eval_episodes\n self.warmup_steps = warmup_steps\n self.train_steps_per_update = train_steps_per_update\n self.episode_rewards = deque(maxlen=100)\n self.eval_rewards = []\n self.best_eval_reward = -float(\"inf\")\n\n def evaluate(self):\n \"\"\"评估智能体\"\"\"\n rewards = []\n for _ in range(self.num_eval_episodes):\n state, _ = self.eval_env.reset()\n ep_reward = 0\n done = False\n while not done:\n action = self.agent.select_action(state, evaluate=True)\n state, reward, terminated, truncated, _ = self.eval_env.step(action)\n done = terminated or truncated\n ep_reward += reward\n rewards.append(ep_reward)\n return np.mean(rewards)\n\n def train(self, total_steps):\n n = self.num_envs\n device = self.agent.device\n envs = self.envs\n\n print(f\"开始训练: {total_steps:,} 步, {n} 并行环境, 每步训练 {self.train_steps_per_update} 次\")\n print(\"=\" * 60)\n\n states, _ = envs.reset()\n ep_rewards = np.zeros(n, dtype=np.float32)\n ep_count = 0\n start_time = time.time()\n step = 0\n\n while step < total_steps:\n # ── 动作选择(向量化) ──\n if step < self.warmup_steps:\n actions = np.array([envs.single_action_space.sample() for _ in range(n)])\n else:\n actions = self.agent.batch_select_actions(states, self.agent.epsilon)\n\n # ── 环境步进 ──\n next_states, rewards, terminateds, truncateds, _ = envs.step(actions)\n dones = np.logical_or(terminateds, truncateds)\n\n # ── 向量化批量添加经验(消除 Python 循环) ──\n self.agent.replay_buffer.add_batch(states, actions, rewards, next_states, dones)\n\n ep_rewards += rewards\n\n # ── 处理结束的 episode ──\n for i in range(n):\n if dones[i]:\n self.episode_rewards.append(ep_rewards[i])\n ep_count += 1\n ep_rewards[i] = 0\n\n step += n\n states = next_states\n\n # ── 每步训练多次(提升 GPU 利用率) ──\n if step >= self.warmup_steps:\n for _ in range(self.train_steps_per_update):\n self.agent.train_step()\n\n # ── 进度打印 ──\n if ep_count > 0 and ep_count % 20 == 0:\n avg_r = np.mean(self.episode_rewards) if self.episode_rewards else 0\n elapsed = time.time() - start_time\n fps = step / elapsed\n lr = self.agent.optimizer.param_groups[0][\"lr\"]\n print(f\"Step:{step:>10,} | Ep:{ep_count:>5} | AvgR:{avg_r:>7.1f} | \"\n f\"Eps:{self.agent.epsilon:.3f} | LR:{lr:.2e} | FPS:{fps:.0f}\")\n\n # ── 定期评估 ──\n if step % self.eval_freq == 0 and step > 0:\n eval_r = self.evaluate()\n self.eval_rewards.append((step, eval_r))\n print(f\"\\n[评估] Step:{step:>10,} | 平均回报:{eval_r:.1f}\\n\")\n if eval_r > self.best_eval_reward:\n self.best_eval_reward = eval_r\n self.agent.save(f\"{self.save_dir}/dqn_best.pt\")\n\n # ── 定期保存 ──\n if step % self.save_freq == 0 and step > 0:\n self.agent.save(f\"{self.save_dir}/dqn_step_{step}.pt\")\n\n total_time = time.time() - start_time\n print(\"\\n\" + \"=\" * 60)\n print(f\"训练完成!总时间: {total_time:.1f} 秒 | FPS: {total_steps/total_time:.0f}\")\n print(f\"最佳评估回报: {self.best_eval_reward:.1f}\")\n self.agent.save(f\"{self.save_dir}/dqn_final.pt\")\n\nprint(\"训练器就绪\")"
"name": "stdout",
"output_type": "stream",
"text": [
"训练器就绪\n"
]
}
],
"source": [
"# ── 并行训练器 ──\n",
"class ParallelTrainer:\n",
" def __init__(\n",
" self, agent, envs, eval_env, num_envs,\n",
" save_dir=\"models\", eval_freq=10000, save_freq=50000,\n",
" num_eval_episodes=10, warmup_steps=10000,\n",
" ):\n",
" self.agent = agent\n",
" self.envs = envs\n",
" self.eval_env = eval_env\n",
" self.num_envs = num_envs\n",
" self.save_dir = save_dir\n",
" self.eval_freq = eval_freq\n",
" self.save_freq = save_freq\n",
" self.num_eval_episodes = num_eval_episodes\n",
" self.warmup_steps = warmup_steps\n",
" self.episode_rewards = deque(maxlen=100)\n",
" self.eval_rewards = []\n",
" self.best_eval_reward = -float(\"inf\")\n",
"\n",
" def _batch_select_actions(self, states):\n",
" epsilon = self.agent.epsilon\n",
" n = len(states)\n",
" random_mask = np.random.random(n) < epsilon\n",
" actions = np.zeros(n, dtype=np.int64)\n",
" non_random = ~random_mask\n",
" if non_random.any():\n",
" state_tensor = torch.from_numpy(states[non_random]).float().to(self.agent.device)\n",
" with torch.no_grad():\n",
" q_values = self.agent.q_network(state_tensor)\n",
" actions[non_random] = q_values.argmax(dim=1).cpu().numpy()\n",
" if random_mask.any():\n",
" actions[random_mask] = np.random.randint(0, self.agent.num_actions, size=random_mask.sum())\n",
" return actions\n",
"\n",
" def evaluate(self):\n",
" rewards = []\n",
" for _ in range(self.num_eval_episodes):\n",
" state, _ = self.eval_env.reset()\n",
" ep_reward = 0\n",
" done = False\n",
" while not done:\n",
" action = self.agent.select_action(state, evaluate=True)\n",
" state, reward, terminated, truncated, _ = self.eval_env.step(action)\n",
" done = terminated or truncated\n",
" ep_reward += reward\n",
" rewards.append(ep_reward)\n",
" return np.mean(rewards)\n",
"\n",
" def train(self, total_steps):\n",
" n = self.num_envs\n",
" device = self.agent.device\n",
" envs = self.envs\n",
"\n",
" print(f\"开始训练: {total_steps:,} 步, {n} 并行环境\")\n",
" print(\"=\" * 60)\n",
"\n",
" states, _ = envs.reset()\n",
" ep_rewards = np.zeros(n, dtype=np.float32)\n",
" ep_count = 0\n",
" start_time = time.time()\n",
" step = 0\n",
"\n",
" while step < total_steps:\n",
" if step < self.warmup_steps:\n",
" actions = np.array([envs.single_action_space.sample() for _ in range(n)])\n",
" else:\n",
" actions = self._batch_select_actions(states)\n",
"\n",
" next_states, rewards, terminateds, truncateds, _ = envs.step(actions)\n",
" dones = np.logical_or(terminateds, truncateds)\n",
"\n",
" for i in range(n):\n",
" self.agent.replay_buffer.add(states[i], actions[i], rewards[i], next_states[i], dones[i])\n",
"\n",
" ep_rewards += rewards\n",
"\n",
" for i in range(n):\n",
" if dones[i]:\n",
" self.episode_rewards.append(ep_rewards[i])\n",
" ep_count += 1\n",
" ep_rewards[i] = 0\n",
"\n",
" step += n\n",
" states = next_states\n",
"\n",
" if step >= self.warmup_steps:\n",
" self.agent.train_step()\n",
"\n",
" if ep_count > 0 and ep_count % 20 == 0:\n",
" avg_r = np.mean(self.episode_rewards) if self.episode_rewards else 0\n",
" elapsed = time.time() - start_time\n",
" fps = step / elapsed\n",
" lr = self.agent.optimizer.param_groups[0][\"lr\"]\n",
" print(f\"Step:{step:>10,} | Ep:{ep_count:>5} | AvgR:{avg_r:>7.1f} | \"\n",
" f\"Eps:{self.agent.epsilon:.3f} | LR:{lr:.2e} | FPS:{fps:.0f}\")\n",
"\n",
" if step % self.eval_freq == 0 and step > 0:\n",
" eval_r = self.evaluate()\n",
" self.eval_rewards.append((step, eval_r))\n",
" print(f\"\\n[评估] Step:{step:>10,} | 平均回报:{eval_r:.1f}\\n\")\n",
" if eval_r > self.best_eval_reward:\n",
" self.best_eval_reward = eval_r\n",
" self.agent.save(f\"{self.save_dir}/dqn_best.pt\")\n",
"\n",
" if step % self.save_freq == 0:\n",
" self.agent.save(f\"{self.save_dir}/dqn_step_{step}.pt\")\n",
"\n",
" total_time = time.time() - start_time\n",
" print(\"\\n\" + \"=\" * 60)\n",
" print(f\"训练完成!总时间: {total_time:.1f} 秒 | FPS: {total_steps/total_time:.0f}\")\n",
" print(f\"最佳评估回报: {self.best_eval_reward:.1f}\")\n",
" self.agent.save(f\"{self.save_dir}/dqn_final.pt\")\n",
"\n",
"print(\"训练器就绪\")"
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@@ -212,37 +62,7 @@
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": "# ── 可修改的超参数 ──\n\nENV_ID = \"ALE/SpaceInvaders-v5\"\nN_ENVS = 64 # 超配:64 个并行环境,最大化 CPU 利用率\nTOTAL_STEPS = 10_000_000 # 总步数\nLR = 1e-4 # 大 batch 配合稍高 lr\nGAMMA = 0.99 # 折扣因子\nBATCH_SIZE = 2048 # 大 batch 充分利用 RTX 5090\nBUFFER_SIZE = 1_000_000 # 回放缓冲区\nEPSILON_START = 1.0\nEPSILON_END = 0.01\nEPSILON_DECAY = 4_000_000 # ε衰减步数\nTARGET_UPDATE = 5000 # 降低目标网络更新频率\nLR_DECAY_STEPS = 5_000_000\nLR_DECAY_FACTOR = 0.5\nWARMUP_STEPS = 50_000\nEVAL_FREQ = 200000 # 评估频率降低,减少中断\nEVAL_EPISODES = 10\nSAVE_FREQ = 500000\nSEED = 42\nSAVE_DIR = os.path.join(os.path.abspath(os.path.join(os.getcwd(), \"..\")), \"models\")\n\nTRAIN_STEPS_PER_UPDATE = 4 # 每步训练 4 次,提升 GPU 利用率\nUSE_AMP = True # AMP 混合精度训练\nUSE_COMPILE = True # torch.compile 编译加速\n\nUSE_DUELING = True\nUSE_DOUBLE = True\nUSE_PER = True\n\nos.makedirs(SAVE_DIR, exist_ok=True)\n\nprint(f\"配置: {TOTAL_STEPS/1e6:.0f}M 步, {N_ENVS} 并行环境\")\nprint(f\"每步训练 {TRAIN_STEPS_PER_UPDATE} 次, Batch {BATCH_SIZE}\")\nprint(f\"AMP: {USE_AMP}, torch.compile: {USE_COMPILE}\")\nprint(f\"模型保存: {SAVE_DIR}\")"
"# ── 可修改的超参数 ──\n",
"\n",
"ENV_ID = \"ALE/SpaceInvaders-v5\"\n",
"N_ENVS = 24 # 25 核 CPU,留 1 核给主进程\n",
"TOTAL_STEPS = 10_000_000 # 总步数\n",
"LR = 3e-5 # 学习率(大 batch 配低 lr 更稳定)\n",
"GAMMA = 0.99 # 折扣因子\n",
"BATCH_SIZE = 512 # RTX 5090 跑大 batch 才不浪费\n",
"BUFFER_SIZE = 1_000_000 # 回放缓冲区\n",
"EPSILON_START = 1.0\n",
"EPSILON_END = 0.01\n",
"EPSILON_DECAY = 4_000_000 # ε衰减步数(24 环境探索效率高,延长探索期)\n",
"TARGET_UPDATE = 2000\n",
"LR_DECAY_STEPS = 5_000_000\n",
"LR_DECAY_FACTOR = 0.5\n",
"WARMUP_STEPS = 50_000\n",
"EVAL_FREQ = 50000\n",
"EVAL_EPISODES = 10\n",
"SAVE_FREQ = 200000\n",
"SEED = 42\n",
"SAVE_DIR = \"models\"\n",
"\n",
"USE_DUELING = True\n",
"USE_DOUBLE = True\n",
"USE_PER = True # 优先经验回放\n",
"\n",
"print(f\"配置: {TOTAL_STEPS/1e6:.0f}M 步, {N_ENVS} 并行环境\")\n",
"print(f\"预计环境交互: {TOTAL_STEPS * 4 / 1e6:.0f}M frames\")\n",
"print(f\"预计时间 (AutoDL 5090): ~{TOTAL_STEPS / 1000 / N_ENVS / 3600:.1f} 小时\")"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
@@ -288,203 +108,24 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{ "source": "# ── 网络 + torch.compile ──\nstate_shape = (4, 84, 84)\n\nif USE_DUELING:\n q_network = DuelingQNetwork(state_shape, num_actions).to(device)\n target_network = DuelingQNetwork(state_shape, num_actions).to(device)\n print(f\"Dueling DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\nelse:\n q_network = QNetwork(state_shape, num_actions).to(device)\n target_network = QNetwork(state_shape, num_actions).to(device)\n print(f\"标准 DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\n\n# torch.compile 编译加速(PyTorch 2.x\nif USE_COMPILE and hasattr(torch, 'compile'):\n print(\"应用 torch.compile 加速...\")\n q_network = torch.compile(q_network)\n target_network = torch.compile(target_network)\n print(\"torch.compile 完成\")\n\ntarget_network.load_state_dict(q_network.state_dict())\ntarget_network.eval()"
"name": "stdout",
"output_type": "stream",
"text": [
"Dueling DQN: 3,293,863 参数\n"
]
},
{
"data": {
"text/plain": [
"DuelingQNetwork(\n",
" (conv): Sequential(\n",
" (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))\n",
" (1): ReLU()\n",
" (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
" (3): ReLU()\n",
" (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
" (5): ReLU()\n",
" )\n",
" (value_stream): Sequential(\n",
" (0): Linear(in_features=3136, out_features=512, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=512, out_features=1, bias=True)\n",
" )\n",
" (advantage_stream): Sequential(\n",
" (0): Linear(in_features=3136, out_features=512, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=512, out_features=6, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# ── 网络 ──\n",
"state_shape = (4, 84, 84)\n",
"\n",
"if USE_DUELING:\n",
" q_network = DuelingQNetwork(state_shape, num_actions).to(device)\n",
" target_network = DuelingQNetwork(state_shape, num_actions).to(device)\n",
" print(f\"Dueling DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\n",
"else:\n",
" q_network = QNetwork(state_shape, num_actions).to(device)\n",
" target_network = QNetwork(state_shape, num_actions).to(device)\n",
" print(f\"标准 DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\n",
"\n",
"target_network.load_state_dict(q_network.state_dict())\n",
"target_network.eval()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"优先经验回放\n",
"Agent 创建完成\n"
]
}
],
"source": [
"# ── 回放缓冲区 + Agent ──\n",
"if USE_PER:\n",
" replay_buffer = PrioritizedReplayBuffer(BUFFER_SIZE, state_shape, device)\n",
" print(\"优先经验回放\")\n",
"else:\n",
" replay_buffer = ReplayBuffer(BUFFER_SIZE, state_shape, device)\n",
" print(\"标准经验回放\")\n",
"\n",
"agent = DQNAgent(\n",
" q_network=q_network,\n",
" target_network=target_network,\n",
" replay_buffer=replay_buffer,\n",
" device=device,\n",
" num_actions=num_actions,\n",
" gamma=GAMMA,\n",
" lr=LR,\n",
" epsilon_start=EPSILON_START,\n",
" epsilon_end=EPSILON_END,\n",
" epsilon_decay_steps=EPSILON_DECAY,\n",
" target_update_freq=TARGET_UPDATE,\n",
" batch_size=BATCH_SIZE,\n",
" double_dqn=USE_DOUBLE,\n",
" lr_decay_steps=LR_DECAY_STEPS,\n",
" lr_decay_factor=LR_DECAY_FACTOR,\n",
" warmup_steps=WARMUP_STEPS,\n",
")\n",
"print(\"Agent 创建完成\")"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{ "source": "# ── 回放缓冲区 + Agent ──\nif USE_PER:\n replay_buffer = PrioritizedReplayBuffer(BUFFER_SIZE, state_shape, device)\n print(\"优先经验回放 (Pinned Memory)\")\nelse:\n replay_buffer = ReplayBuffer(BUFFER_SIZE, state_shape, device)\n print(\"标准经验回放 (Pinned Memory)\")\n\nagent = DQNAgent(\n q_network=q_network,\n target_network=target_network,\n replay_buffer=replay_buffer,\n device=device,\n num_actions=num_actions,\n gamma=GAMMA,\n lr=LR,\n epsilon_start=EPSILON_START,\n epsilon_end=EPSILON_END,\n epsilon_decay_steps=EPSILON_DECAY,\n target_update_freq=TARGET_UPDATE,\n batch_size=BATCH_SIZE,\n double_dqn=USE_DOUBLE,\n lr_decay_steps=LR_DECAY_STEPS,\n lr_decay_factor=LR_DECAY_FACTOR,\n warmup_steps=WARMUP_STEPS,\n use_amp=USE_AMP,\n)\nprint(f\"Agent 创建完成 (AMP: {USE_AMP})\")"
"name": "stdout", },
"output_type": "stream", {
"text": [ "cell_type": "code",
"\n", "execution_count": null,
"============================================================\n", "metadata": {},
"开始 10M 步并行训练\n", "outputs": [],
" GPU: cuda\n", "source": "# ── 开始训练 ──\ntrainer = ParallelTrainer(\n agent=agent,\n envs=envs,\n eval_env=eval_env,\n num_envs=N_ENVS,\n save_dir=SAVE_DIR,\n eval_freq=EVAL_FREQ,\n save_freq=SAVE_FREQ,\n num_eval_episodes=EVAL_EPISODES,\n warmup_steps=WARMUP_STEPS,\n train_steps_per_update=TRAIN_STEPS_PER_UPDATE,\n)\n\nprint(\"\\n\" + \"=\" * 60)\nprint(f\"开始 10M 步并行训练(全优化版)\")\nprint(f\" GPU: {device}\")\nprint(f\" 并行环境: {N_ENVS}\")\nprint(f\" Batch Size: {BATCH_SIZE}\")\nprint(f\" 每步训练: {TRAIN_STEPS_PER_UPDATE} 次\")\nprint(f\" AMP 混合精度: {USE_AMP}\")\nprint(f\" torch.compile: {USE_COMPILE}\")\nprint(f\" Dueling: {USE_DUELING}\")\nprint(f\" Double DQN: {USE_DOUBLE}\")\nprint(f\" PER: {USE_PER}\")\nprint(f\" Pinned Memory: 是\")\nprint(f\" 向量化批量添加: 是\")\nprint(\"=\" * 60 + \"\\n\")\n\ntrainer.train(TOTAL_STEPS)"
" 并行环境: 16\n",
" Dueling: True\n",
" Double DQN: True\n",
" PER: True\n",
"============================================================\n",
"\n",
"开始训练: 10,000,000 步, 16 并行环境\n",
"============================================================\n",
"Step: 3,232 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:651\n",
"Step: 3,248 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
"Step: 3,264 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
"Step: 3,280 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
"Step: 3,296 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
"Step: 3,312 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
"Step: 3,328 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
"Step: 5,776 | Ep: 40 | AvgR: 12.3 | Eps:1.000 | LR:5.00e-05 | FPS:596\n",
"Step: 5,792 | Ep: 40 | AvgR: 12.3 | Eps:1.000 | LR:5.00e-05 | FPS:596\n",
"Step: 8,144 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
"Step: 8,160 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 8,176 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 8,192 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 8,208 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
"Step: 8,224 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
"Step: 8,240 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
"Step: 8,256 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
"Step: 8,272 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
"Step: 8,288 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
"Step: 8,304 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
"Step: 8,320 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 8,336 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 8,352 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 8,368 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 8,384 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 8,400 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 8,416 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 8,432 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
"Step: 10,912 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:2.90e-07 | FPS:520\n",
"Step: 10,928 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:2.95e-07 | FPS:520\n",
"Step: 10,944 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:3.00e-07 | FPS:519\n",
"Step: 10,960 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:3.05e-07 | FPS:519\n",
"Step: 13,280 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.03e-06 | FPS:481\n",
"Step: 13,296 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
"Step: 13,312 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
"Step: 13,328 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
"Step: 15,648 | Ep: 120 | AvgR: 12.8 | Eps:1.000 | LR:1.77e-06 | FPS:454\n",
"Step: 21,184 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.50e-06 | FPS:434\n",
"Step: 21,200 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.50e-06 | FPS:434\n",
"Step: 21,216 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.51e-06 | FPS:434\n",
"Step: 21,232 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
"Step: 21,248 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
"Step: 21,264 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
"Step: 21,280 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.53e-06 | FPS:434\n",
"Step: 21,296 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.54e-06 | FPS:434\n",
"Step: 23,824 | Ep: 180 | AvgR: 14.4 | Eps:1.000 | LR:4.33e-06 | FPS:432\n",
"Step: 26,144 | Ep: 200 | AvgR: 14.3 | Eps:1.000 | LR:5.05e-06 | FPS:422\n"
]
}
],
"source": [
"# ── 开始训练 ──\n",
"trainer = ParallelTrainer(\n",
" agent=agent,\n",
" envs=envs,\n",
" eval_env=eval_env,\n",
" num_envs=N_ENVS,\n",
" save_dir=SAVE_DIR,\n",
" eval_freq=EVAL_FREQ,\n",
" save_freq=SAVE_FREQ,\n",
" num_eval_episodes=EVAL_EPISODES,\n",
" warmup_steps=WARMUP_STEPS,\n",
")\n",
"\n",
"print(\"\\n\" + \"=\" * 60)\n",
"print(f\"开始 10M 步并行训练\")\n",
"print(f\" GPU: {device}\")\n",
"print(f\" 并行环境: {N_ENVS}\")\n",
"print(f\" Dueling: {USE_DUELING}\")\n",
"print(f\" Double DQN: {USE_DOUBLE}\")\n",
"print(f\" PER: {USE_PER}\")\n",
"print(\"=\" * 60 + \"\\n\")\n",
"\n",
"trainer.train(TOTAL_STEPS)"
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@@ -495,63 +136,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{ "source": "# ── 评估最佳模型 ──\nprint(\"加载最佳模型...\")\nagent.load(f\"{SAVE_DIR}/dqn_best.pt\")\n\nprint(\"\\n评估中...\")\nall_rewards = []\nfor i in range(20):\n state, _ = eval_env.reset()\n ep_r = 0\n done = False\n while not done:\n action = agent.select_action(state, evaluate=True)\n state, reward, terminated, truncated, _ = eval_env.step(action)\n done = terminated or truncated\n ep_r += reward\n all_rewards.append(ep_r)\n print(f\" Episode {i+1:>2}: {ep_r:.1f}\")\n\nprint(f\"\\n结果: 平均 {np.mean(all_rewards):.2f} ± {np.std(all_rewards):.2f}\")\nprint(f\"最佳: {max(all_rewards):.1f} | 最差: {min(all_rewards):.1f}\")\nprint(f\"中位数: {np.median(all_rewards):.1f}\")"
"name": "stdout",
"output_type": "stream",
"text": [
"加载最佳模型...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\Code\\doing_exercises\\programs\\外教作业外快\\强化学习个人项目报告(Atari 游戏方向)\\src\\agent.py:219: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" checkpoint = torch.load(path, map_location=self.device)\n"
]
},
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'models/dqn_best.pt'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[9], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# ── 评估最佳模型 ──\u001b[39;00m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m加载最佳模型...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 3\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mSAVE_DIR\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m/dqn_best.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m评估中...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 6\u001b[0m all_rewards \u001b[38;5;241m=\u001b[39m []\n",
"File \u001b[1;32md:\\Code\\doing_exercises\\programs\\外教作业外快\\强化学习个人项目报告(Atari 游戏方向)\\src\\agent.py:219\u001b[0m, in \u001b[0;36mDQNAgent.load\u001b[1;34m(self, path)\u001b[0m\n\u001b[0;32m 217\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mload\u001b[39m(\u001b[38;5;28mself\u001b[39m, path):\n\u001b[0;32m 218\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"加载模型\"\"\"\u001b[39;00m\n\u001b[1;32m--> 219\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 220\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mq_network\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mq_network\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[0;32m 221\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_network\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtarget_network\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:1065\u001b[0m, in \u001b[0;36mload\u001b[1;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[0;32m 1062\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m pickle_load_args\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m 1063\u001b[0m pickle_load_args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m-> 1065\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_file_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_file:\n\u001b[0;32m 1066\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[0;32m 1067\u001b[0m \u001b[38;5;66;03m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[0;32m 1068\u001b[0m \u001b[38;5;66;03m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[0;32m 1069\u001b[0m \u001b[38;5;66;03m# reset back to the original position.\u001b[39;00m\n\u001b[0;32m 1070\u001b[0m orig_position \u001b[38;5;241m=\u001b[39m opened_file\u001b[38;5;241m.\u001b[39mtell()\n",
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:468\u001b[0m, in \u001b[0;36m_open_file_like\u001b[1;34m(name_or_buffer, mode)\u001b[0m\n\u001b[0;32m 466\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[0;32m 467\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[1;32m--> 468\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_open_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 469\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 470\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m mode:\n",
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:449\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[1;34m(self, name, mode)\u001b[0m\n\u001b[0;32m 448\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, mode):\n\u001b[1;32m--> 449\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m)\n",
"\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'models/dqn_best.pt'"
]
}
],
"source": [
"# ── 评估最佳模型 ──\n",
"print(\"加载最佳模型...\")\n",
"agent.load(f\"{SAVE_DIR}/dqn_best.pt\")\n",
"\n",
"print(\"\\n评估中...\")\n",
"all_rewards = []\n",
"for i in range(20):\n",
" state, _ = eval_env.reset()\n",
" ep_r = 0\n",
" done = False\n",
" while not done:\n",
" action = agent.select_action(state, evaluate=True)\n",
" state, reward, terminated, truncated, _ = eval_env.step(action)\n",
" done = terminated or truncated\n",
" ep_r += reward\n",
" all_rewards.append(ep_r)\n",
" print(f\" Episode {i+1:>2}: {ep_r:.1f}\")\n",
"\n",
"print(f\"\\n结果: 平均 {np.mean(all_rewards):.2f} ± {np.std(all_rewards):.2f}\")\n",
"print(f\"最佳: {max(all_rewards):.1f} | 最差: {min(all_rewards):.1f}\")\n",
"print(f\"中位数: {np.median(all_rewards):.1f}\")"
]
} }
], ],
"metadata": { "metadata": {
@@ -10,6 +10,7 @@ class DQNAgent:
"""DQN智能体 """DQN智能体
实现ε-greedy探索策略和Q-learning更新 实现ε-greedy探索策略和Q-learning更新
支持 AMP 混合精度训练
""" """
def __init__( def __init__(
@@ -30,26 +31,8 @@ class DQNAgent:
lr_decay_steps=1_000_000, lr_decay_steps=1_000_000,
lr_decay_factor=0.5, lr_decay_factor=0.5,
warmup_steps=10_000, warmup_steps=10_000,
use_amp=True,
): ):
"""
Args:
q_network: Q网络
target_network: 目标网络
replay_buffer: 经验回放缓冲区
device: 设备
num_actions: 动作数量
gamma: 折扣因子
lr: 学习率
epsilon_start: ε初始值
epsilon_end: ε最终值
epsilon_decay_steps: ε衰减步数
target_update_freq: 目标网络更新频率
batch_size: 批次大小
double_dqn: 是否使用Double DQN
lr_decay_steps: 学习率衰减步数
lr_decay_factor: 学习率衰减因子
warmup_steps: 预热步数
"""
self.q_network = q_network self.q_network = q_network
self.target_network = target_network self.target_network = target_network
self.replay_buffer = replay_buffer self.replay_buffer = replay_buffer
@@ -72,40 +55,52 @@ class DQNAgent:
self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr) self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr)
self.step_count = 0 # AMP 混合精度
self.use_amp = use_amp and device.type == 'cuda'
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
self.step_count = 0
self.loss_history = [] self.loss_history = []
self.q_value_history = [] self.q_value_history = []
def select_action(self, state, evaluate=False): def select_action(self, state, evaluate=False):
"""选择动作 """选择动作"""
epsilon = 0.0 if evaluate else self.epsilon
Args:
state: 当前状态 (channels, height, width)
evaluate: 是否为评估模式(不使用ε-greedy)
Returns:
action: 选择的动作
"""
if evaluate:
# 评估模式:纯贪心
epsilon = 0.0
else:
# 训练模式:ε-greedy
epsilon = self.epsilon
if np.random.random() < epsilon: if np.random.random() < epsilon:
# 随机探索
return np.random.randint(self.num_actions) return np.random.randint(self.num_actions)
else: else:
# 贪心选择
with torch.no_grad(): with torch.no_grad():
state_tensor = ( state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
torch.from_numpy(state).float().unsqueeze(0).to(self.device)
)
q_values = self.q_network(state_tensor) q_values = self.q_network(state_tensor)
return q_values.argmax(dim=1).item() return q_values.argmax(dim=1).item()
def batch_select_actions(self, states, epsilon):
"""批量选择动作(向量化,用于并行训练)
Args:
states: (n, C, H, W) numpy 数组
epsilon: 当前 ε 值
Returns:
actions: (n,) numpy 数组
"""
n = len(states)
random_mask = np.random.random(n) < epsilon
actions = np.zeros(n, dtype=np.int64)
non_random = ~random_mask
if non_random.any():
state_tensor = torch.from_numpy(states[non_random]).float().to(self.device)
with torch.no_grad(), torch.amp.autocast('cuda', enabled=self.use_amp):
q_values = self.q_network(state_tensor)
actions[non_random] = q_values.argmax(dim=1).cpu().numpy()
if random_mask.any():
actions[random_mask] = np.random.randint(0, self.num_actions, size=random_mask.sum())
return actions
def update_epsilon(self): def update_epsilon(self):
"""更新ε值(线性衰减)""" """更新ε值(线性衰减)"""
if self.step_count < self.epsilon_decay_steps: if self.step_count < self.epsilon_decay_steps:
@@ -129,17 +124,10 @@ class DQNAgent:
param_group["lr"] *= self.lr_decay_factor param_group["lr"] *= self.lr_decay_factor
def train_step(self): def train_step(self):
"""执行一步训练 """执行一步训练(支持 AMP 混合精度)"""
Returns:
loss: 损失值
avg_q: 平均Q值
"""
# 检查是否有足够样本
if len(self.replay_buffer) < self.batch_size: if len(self.replay_buffer) < self.batch_size:
return None, None return None, None
# 采样(兼容标准和优先经验回放)
sample_result = self.replay_buffer.sample(self.batch_size) sample_result = self.replay_buffer.sample(self.batch_size)
if len(sample_result) == 7: if len(sample_result) == 7:
states, actions, rewards, next_states, dones, indices, weights = sample_result states, actions, rewards, next_states, dones, indices, weights = sample_result
@@ -147,52 +135,46 @@ class DQNAgent:
states, actions, rewards, next_states, dones = sample_result states, actions, rewards, next_states, dones = sample_result
indices, weights = None, None indices, weights = None, None
# 计算当前Q值 # AMP 前向传播
q_values = self.q_network(states) with torch.amp.autocast('cuda', enabled=self.use_amp):
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1) q_values = self.q_network(states)
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# 计算目标Q值 with torch.no_grad():
with torch.no_grad(): if self.double_dqn:
if self.double_dqn: next_actions = self.q_network(next_states).argmax(dim=1)
next_actions = self.q_network(next_states).argmax(dim=1) next_q_values = self.target_network(next_states)
next_q_values = self.target_network(next_states) next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
next_q_values = next_q_values.gather( else:
1, next_actions.unsqueeze(1) next_q_values = self.target_network(next_states).max(dim=1)[0]
).squeeze(1) target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
if weights is not None:
loss = (weights * F.mse_loss(q_values, target_q_values, reduction='none')).mean()
else: else:
next_q_values = self.target_network(next_states).max(dim=1)[0] loss = F.mse_loss(q_values, target_q_values)
target_q_values = rewards + self.gamma * next_q_values * (1 - dones) # AMP 反向传播
# 计算TD误差
td_errors = (q_values - target_q_values).detach()
# 计算损失(优先经验回放使用重要性采样权重)
if weights is not None:
loss = (weights * F.mse_loss(q_values, target_q_values, reduction='none')).mean()
else:
loss = F.mse_loss(q_values, target_q_values)
# 反向传播
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10) torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10)
self.optimizer.step() self.scaler.step(self.optimizer)
self.scaler.update()
# 更新优先级 # 更新优先级
if indices is not None and hasattr(self.replay_buffer, 'update_priorities'): if indices is not None and hasattr(self.replay_buffer, 'update_priorities'):
self.replay_buffer.update_priorities(indices, td_errors.cpu().numpy()) td_errors = (q_values - target_q_values).detach()
self.replay_buffer.update_priorities(indices, td_errors.cpu().float().numpy())
# 更新目标网络 # 更新目标网络
self.step_count += 1 self.step_count += 1
if self.step_count % self.target_update_freq == 0: if self.step_count % self.target_update_freq == 0:
self.target_network.load_state_dict(self.q_network.state_dict()) self.target_network.load_state_dict(self.q_network.state_dict())
# 更新ε和学习率
self.update_epsilon() self.update_epsilon()
self.update_learning_rate() self.update_learning_rate()
# 记录统计
avg_q = q_values.mean().item() avg_q = q_values.mean().item()
self.loss_history.append(loss.item()) self.loss_history.append(loss.item())
self.q_value_history.append(avg_q) self.q_value_history.append(avg_q)
@@ -216,7 +198,7 @@ class DQNAgent:
def load(self, path): def load(self, path):
"""加载模型""" """加载模型"""
checkpoint = torch.load(path, map_location=self.device) checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.q_network.load_state_dict(checkpoint["q_network"]) self.q_network.load_state_dict(checkpoint["q_network"])
self.target_network.load_state_dict(checkpoint["target_network"]) self.target_network.load_state_dict(checkpoint["target_network"])
self.optimizer.load_state_dict(checkpoint["optimizer"]) self.optimizer.load_state_dict(checkpoint["optimizer"])
@@ -7,63 +7,94 @@ class ReplayBuffer:
"""经验回放缓冲区 """经验回放缓冲区
存储转移 (s, a, r, s', done),随机采样打破数据相关性 存储转移 (s, a, r, s', done),随机采样打破数据相关性
支持批量添加和 Pinned Memory 加速传输
""" """
def __init__(self, capacity, state_shape=(4, 84, 84), device='cpu'): def __init__(self, capacity, state_shape=(4, 84, 84), device='cpu'):
"""
Args:
capacity: 缓冲区容量
state_shape: 状态形状 (channels, height, width)
device: 设备 (cpu/cuda)
"""
self.capacity = capacity self.capacity = capacity
self.device = device self.device = device
self.ptr = 0 self.ptr = 0
self.size = 0 self.size = 0
# 预分配内存 # 预分配 numpy 内存
self.states = np.zeros((capacity, *state_shape), dtype=np.uint8) self.states = np.zeros((capacity, *state_shape), dtype=np.uint8)
self.actions = np.zeros(capacity, dtype=np.int64) self.actions = np.zeros(capacity, dtype=np.int64)
self.rewards = np.zeros(capacity, dtype=np.float32) self.rewards = np.zeros(capacity, dtype=np.float32)
self.next_states = np.zeros((capacity, *state_shape), dtype=np.uint8) self.next_states = np.zeros((capacity, *state_shape), dtype=np.uint8)
self.dones = np.zeros(capacity, dtype=np.bool_) self.dones = np.zeros(capacity, dtype=np.bool_)
def add(self, state, action, reward, next_state, done): # 预分配 pinned memory 张量(加速 CPU→GPU 传输)
"""添加一个转移 pin_device = 'cpu'
self._states_pin = torch.empty((capacity, *state_shape), dtype=torch.uint8, pin_memory=True)
self._actions_pin = torch.empty(capacity, dtype=torch.int64, pin_memory=True)
self._rewards_pin = torch.empty(capacity, dtype=torch.float32, pin_memory=True)
self._next_states_pin = torch.empty((capacity, *state_shape), dtype=torch.uint8, pin_memory=True)
self._dones_pin = torch.empty(capacity, dtype=torch.float32, pin_memory=True)
Args: def add(self, state, action, reward, next_state, done):
state: 当前状态 """添加一个转移"""
action: 执行的动作
reward: 获得的奖励
next_state: 下一个状态
done: 是否结束
"""
self.states[self.ptr] = state self.states[self.ptr] = state
self.actions[self.ptr] = action self.actions[self.ptr] = action
self.rewards[self.ptr] = reward self.rewards[self.ptr] = reward
self.next_states[self.ptr] = next_state self.next_states[self.ptr] = next_state
self.dones[self.ptr] = done self.dones[self.ptr] = done
# 循环缓冲区
self.ptr = (self.ptr + 1) % self.capacity self.ptr = (self.ptr + 1) % self.capacity
self.size = min(self.size + 1, self.capacity) self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size): def add_batch(self, states, actions, rewards, next_states, dones):
"""随机采样一个批次 """批量添加转移(向量化,消除 Python 循环)
Args: Args:
batch_size: 批次大小 states: (n, C, H, W) 数组
actions: (n,) 数组
Returns: rewards: (n,) 数组
states, actions, rewards, next_states, dones next_states: (n, C, H, W) 数组
dones: (n,) 数组
""" """
n = len(states)
end = self.ptr + n
if end <= self.capacity:
self.states[self.ptr:end] = states
self.actions[self.ptr:end] = actions
self.rewards[self.ptr:end] = rewards
self.next_states[self.ptr:end] = next_states
self.dones[self.ptr:end] = dones
else:
# 循环缓冲区回绕
overflow = end - self.capacity
first = n - overflow
self.states[self.ptr:] = states[:first]
self.actions[self.ptr:] = actions[:first]
self.rewards[self.ptr:] = rewards[:first]
self.next_states[self.ptr:] = next_states[:first]
self.dones[self.ptr:] = dones[:first]
self.states[:overflow] = states[first:]
self.actions[:overflow] = actions[first:]
self.rewards[:overflow] = rewards[first:]
self.next_states[:overflow] = next_states[first:]
self.dones[:overflow] = dones[first:]
self.ptr = end % self.capacity
self.size = min(self.size + n, self.capacity)
def sample(self, batch_size):
"""随机采样一个批次,使用 pinned memory 加速传输"""
indices = np.random.randint(0, self.size, size=batch_size) indices = np.random.randint(0, self.size, size=batch_size)
states = torch.from_numpy(self.states[indices]).float().to(self.device) # 先写入 pinned memory,再 non-blocking 传输到 GPU
actions = torch.from_numpy(self.actions[indices]).long().to(self.device) self._states_pin[:batch_size].copy_(torch.from_numpy(self.states[indices]))
rewards = torch.from_numpy(self.rewards[indices]).float().to(self.device) self._actions_pin[:batch_size].copy_(torch.from_numpy(self.actions[indices]))
next_states = torch.from_numpy(self.next_states[indices]).float().to(self.device) self._rewards_pin[:batch_size].copy_(torch.from_numpy(self.rewards[indices]))
dones = torch.from_numpy(self.dones[indices]).float().to(self.device) self._next_states_pin[:batch_size].copy_(torch.from_numpy(self.next_states[indices]))
self._dones_pin[:batch_size].copy_(torch.from_numpy(self.dones[indices].astype(np.float32)))
states = self._states_pin[:batch_size].float().to(self.device, non_blocking=True)
actions = self._actions_pin[:batch_size].to(self.device, non_blocking=True)
rewards = self._rewards_pin[:batch_size].to(self.device, non_blocking=True)
next_states = self._next_states_pin[:batch_size].float().to(self.device, non_blocking=True)
dones = self._dones_pin[:batch_size].to(self.device, non_blocking=True)
return states, actions, rewards, next_states, dones return states, actions, rewards, next_states, dones
@@ -75,16 +106,10 @@ class PrioritizedReplayBuffer:
"""优先经验回放缓冲区 """优先经验回放缓冲区
根据TD误差优先采样,提高样本效率 根据TD误差优先采样,提高样本效率
支持批量添加和 Pinned Memory 加速传输
""" """
def __init__(self, capacity, state_shape=(4, 84, 84), device='cpu', alpha=0.6): def __init__(self, capacity, state_shape=(4, 84, 84), device='cpu', alpha=0.6):
"""
Args:
capacity: 缓冲区容量
state_shape: 状态形状
device: 设备
alpha: 优先级指数 (0=均匀采样, 1=完全按优先级采样)
"""
self.capacity = capacity self.capacity = capacity
self.device = device self.device = device
self.alpha = alpha self.alpha = alpha
@@ -102,6 +127,14 @@ class PrioritizedReplayBuffer:
# 优先级存储 # 优先级存储
self.priorities = np.zeros(capacity, dtype=np.float32) self.priorities = np.zeros(capacity, dtype=np.float32)
# 预分配 pinned memory
self._states_pin = torch.empty((capacity, *state_shape), dtype=torch.uint8, pin_memory=True)
self._actions_pin = torch.empty(capacity, dtype=torch.int64, pin_memory=True)
self._rewards_pin = torch.empty(capacity, dtype=torch.float32, pin_memory=True)
self._next_states_pin = torch.empty((capacity, *state_shape), dtype=torch.uint8, pin_memory=True)
self._dones_pin = torch.empty(capacity, dtype=torch.float32, pin_memory=True)
self._weights_pin = torch.empty(capacity, dtype=torch.float32, pin_memory=True)
def add(self, state, action, reward, next_state, done): def add(self, state, action, reward, next_state, done):
"""添加转移,使用最大优先级""" """添加转移,使用最大优先级"""
self.states[self.ptr] = state self.states[self.ptr] = state
@@ -109,51 +142,71 @@ class PrioritizedReplayBuffer:
self.rewards[self.ptr] = reward self.rewards[self.ptr] = reward
self.next_states[self.ptr] = next_state self.next_states[self.ptr] = next_state
self.dones[self.ptr] = done self.dones[self.ptr] = done
# 新样本使用最大优先级
self.priorities[self.ptr] = self.max_priority self.priorities[self.ptr] = self.max_priority
self.ptr = (self.ptr + 1) % self.capacity self.ptr = (self.ptr + 1) % self.capacity
self.size = min(self.size + 1, self.capacity) self.size = min(self.size + 1, self.capacity)
def add_batch(self, states, actions, rewards, next_states, dones):
"""批量添加转移(向量化)"""
n = len(states)
end = self.ptr + n
if end <= self.capacity:
self.states[self.ptr:end] = states
self.actions[self.ptr:end] = actions
self.rewards[self.ptr:end] = rewards
self.next_states[self.ptr:end] = next_states
self.dones[self.ptr:end] = dones
self.priorities[self.ptr:end] = self.max_priority
else:
overflow = end - self.capacity
first = n - overflow
self.states[self.ptr:] = states[:first]
self.actions[self.ptr:] = actions[:first]
self.rewards[self.ptr:] = rewards[:first]
self.next_states[self.ptr:] = next_states[:first]
self.dones[self.ptr:] = dones[:first]
self.priorities[self.ptr:] = self.max_priority
self.states[:overflow] = states[first:]
self.actions[:overflow] = actions[first:]
self.rewards[:overflow] = rewards[first:]
self.next_states[:overflow] = next_states[first:]
self.dones[:overflow] = dones[first:]
self.priorities[:overflow] = self.max_priority
self.ptr = end % self.capacity
self.size = min(self.size + n, self.capacity)
def sample(self, batch_size, beta=0.4): def sample(self, batch_size, beta=0.4):
"""按优先级采样 """按优先级采样,使用 pinned memory 加速传输"""
Args:
batch_size: 批次大小
beta: 重要性采样指数
Returns:
states, actions, rewards, next_states, dones, indices, weights
"""
# 计算采样概率
priorities = self.priorities[:self.size] ** self.alpha priorities = self.priorities[:self.size] ** self.alpha
probs = priorities / priorities.sum() probs = priorities / priorities.sum()
# 按概率采样
indices = np.random.choice(self.size, size=batch_size, p=probs) indices = np.random.choice(self.size, size=batch_size, p=probs)
# 计算重要性采样权重
weights = (self.size * probs[indices]) ** (-beta) weights = (self.size * probs[indices]) ** (-beta)
weights = weights / weights.max() weights = weights / weights.max()
# 获取数据 # pinned memory 传输
states = torch.from_numpy(self.states[indices]).float().to(self.device) self._states_pin[:batch_size].copy_(torch.from_numpy(self.states[indices]))
actions = torch.from_numpy(self.actions[indices]).long().to(self.device) self._actions_pin[:batch_size].copy_(torch.from_numpy(self.actions[indices]))
rewards = torch.from_numpy(self.rewards[indices]).float().to(self.device) self._rewards_pin[:batch_size].copy_(torch.from_numpy(self.rewards[indices]))
next_states = torch.from_numpy(self.next_states[indices]).float().to(self.device) self._next_states_pin[:batch_size].copy_(torch.from_numpy(self.next_states[indices]))
dones = torch.from_numpy(self.dones[indices]).float().to(self.device) self._dones_pin[:batch_size].copy_(torch.from_numpy(self.dones[indices].astype(np.float32)))
weights = torch.from_numpy(weights).float().to(self.device) self._weights_pin[:batch_size].copy_(torch.from_numpy(weights))
return states, actions, rewards, next_states, dones, indices, weights states = self._states_pin[:batch_size].float().to(self.device, non_blocking=True)
actions = self._actions_pin[:batch_size].to(self.device, non_blocking=True)
rewards = self._rewards_pin[:batch_size].to(self.device, non_blocking=True)
next_states = self._next_states_pin[:batch_size].float().to(self.device, non_blocking=True)
dones = self._dones_pin[:batch_size].to(self.device, non_blocking=True)
ws = self._weights_pin[:batch_size].to(self.device, non_blocking=True)
return states, actions, rewards, next_states, dones, indices, ws
def update_priorities(self, indices, td_errors): def update_priorities(self, indices, td_errors):
"""更新优先级 """更新优先级"""
Args:
indices: 样本索引
td_errors: TD误差
"""
priorities = np.abs(td_errors) + 1e-6 priorities = np.abs(td_errors) + 1e-6
self.priorities[indices] = priorities self.priorities[indices] = priorities
self.max_priority = max(self.max_priority, priorities.max()) self.max_priority = max(self.max_priority, priorities.max())
@@ -1,31 +1,36 @@
"""并行环境 DQN 训练脚本 - 使用 AsyncVectorEnv 加速数据收集. """Dueling Double DQN - Space Invaders 并行训练脚本
每个训练迭代并行采集 N 个环境的转移,批量 GPU 推理,显著提升 FPS 使用 AsyncVectorEnv 并行运行多个 Atari 环境,GPU 批量推理加速
适合在 AutoDL 等多核服务器+GPU 环境运行。 适合在 AutoDL 等多核服务器环境运行。
与 notebooks/train_parallel.ipynb 内容一致,但使用 Python 脚本直接运行,
确保 stdout 实时输出(无缓冲)。
""" """
import sys import sys
import os import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import argparse
import time import time
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from collections import deque from collections import deque
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, PROJECT_ROOT)
from src.network import QNetwork, DuelingQNetwork from src.network import QNetwork, DuelingQNetwork
from src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer from src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from src.agent import DQNAgent
from src.utils import make_env, get_device from src.utils import make_env, get_device
# 强制无缓冲输出
sys.stdout.reconfigure(line_buffering=True) if hasattr(sys.stdout, 'reconfigure') else None
os.environ['PYTHONUNBUFFERED'] = '1'
# ── 环境工厂函数(供 AsyncVectorEnv 子进程使用)── print("导入完成", flush=True)
# ── 环境工厂 ──
def _make_env_fn(env_id): def _make_env_fn(env_id):
"""环境工厂 - 必须在模块级别以便 multiprocessing pickle."""
# AsyncVectorEnv 子进程需要独立注册 ALE
try: try:
import ale_py import ale_py
import gymnasium as gym import gymnasium as gym
@@ -37,28 +42,16 @@ def _make_env_fn(env_id):
return make_env(env_id, gray_scale=True, resize=True, frame_stack=4) return make_env(env_id, gray_scale=True, resize=True, frame_stack=4)
return _make return _make
print("环境工厂就绪", flush=True)
# ── 并行训练器 ── # ── 并行训练器 ──
class ParallelTrainer: class ParallelTrainer:
"""并行环境 DQN 训练器.
使用 AsyncVectorEnv 并行运行 N 个环境,
同时收集转移 + 批量推理,大幅提升训练速度。
"""
def __init__( def __init__(
self, self, agent, envs, eval_env, num_envs,
agent, save_dir="models", eval_freq=10000, save_freq=50000,
envs, num_eval_episodes=10, warmup_steps=10000,
eval_env, train_steps_per_update=1,
num_envs,
save_dir="models",
eval_freq=10000,
save_freq=50000,
num_eval_episodes=10,
warmup_steps=10000,
n_steps_per_env=1,
): ):
self.agent = agent self.agent = agent
self.envs = envs self.envs = envs
@@ -69,256 +62,173 @@ class ParallelTrainer:
self.save_freq = save_freq self.save_freq = save_freq
self.num_eval_episodes = num_eval_episodes self.num_eval_episodes = num_eval_episodes
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.n_steps_per_env = n_steps_per_env self.train_steps_per_update = train_steps_per_update
self.episode_rewards = deque(maxlen=100) self.episode_rewards = deque(maxlen=100)
self.eval_rewards = [] self.eval_rewards = []
self.best_eval_reward = -float("inf") self.best_eval_reward = -float("inf")
def train(self, total_steps):
"""主并行训练循环.
Args:
total_steps: 总环境交互步数
"""
num_envs = self.num_envs
device = self.agent.device
envs = self.envs
print(f"开始并行训练,总步数: {total_steps:,}")
print(f"并行环境数: {num_envs}")
print(f"预热步数: {self.warmup_steps:,}")
print("=" * 60)
# 重置所有环境
states, _ = envs.reset()
episode_rewards = np.zeros(num_envs, dtype=np.float32)
episode_lengths = np.zeros(num_envs, dtype=np.int32)
episode_count = 0
start_time = time.time()
step = 0
while step < total_steps:
# ── 动作选择 ──
if step < self.warmup_steps:
actions = np.array([envs.single_action_space.sample() for _ in range(num_envs)])
else:
actions = self._batch_select_actions(states)
# ── 环境步进(N 个环境并行)──
next_states, rewards, terminateds, truncateds, _ = envs.step(actions)
dones = np.logical_or(terminateds, truncateds)
# ── 存储转移 ──
for i in range(num_envs):
self.agent.replay_buffer.add(
states[i], actions[i], rewards[i], next_states[i], dones[i]
)
# ── 统计 ──
episode_rewards += rewards
episode_lengths += 1
# 处理结束的 episode
for i in range(num_envs):
if dones[i]:
self.episode_rewards.append(episode_rewards[i])
episode_count += 1
episode_rewards[i] = 0
episode_lengths[i] = 0
step += num_envs
states = next_states
# ── 训练(环境每步一个 mini-batch)──
if step >= self.warmup_steps:
self.agent.train_step()
# ── 进度打印 ──
if episode_count > 0 and episode_count % 10 == 0:
avg_reward = np.mean(self.episode_rewards) if self.episode_rewards else 0
elapsed = time.time() - start_time
fps = step / elapsed
current_lr = self.agent.optimizer.param_groups[0]["lr"]
print(
f"Step: {step:>10,} | "
f"Ep: {episode_count:>5} | "
f"AvgReward: {avg_reward:>7.1f} | "
f"Epsilon: {self.agent.epsilon:.3f} | "
f"LR: {current_lr:.2e} | "
f"FPS: {fps:.0f}"
)
# ── 定期评估 ──
if step % self.eval_freq == 0 and step > 0:
eval_reward = self.evaluate()
self.eval_rewards.append((step, eval_reward))
print(f"\n[Eval] Step: {step:>10,} | AvgReward: {eval_reward:.1f}\n" + "-" * 60)
if eval_reward > self.best_eval_reward:
self.best_eval_reward = eval_reward
self.agent.save(f"{self.save_dir}/dqn_best.pt")
# ── 定期保存 ──
if step % self.save_freq == 0:
self.agent.save(f"{self.save_dir}/dqn_step_{step}.pt")
# 训练结束
total_time = time.time() - start_time
print("\n" + "=" * 60)
print(f"训练完成!总时间: {total_time:.1f}")
print(f"平均 FPS: {total_steps / total_time:.0f}")
print(f"最佳评估回报: {self.best_eval_reward:.1f}")
self.agent.save(f"{self.save_dir}/dqn_final.pt")
def _batch_select_actions(self, states):
"""批量选择动作(使用 GPU 批量推理)."""
epsilon = self.agent.epsilon
num_envs = len(states)
# 随机探索
random_mask = np.random.random(num_envs) < epsilon
actions = np.zeros(num_envs, dtype=np.int64)
# 对非随机的环境做批量推理
non_random = ~random_mask
if non_random.any():
state_tensor = (
torch.from_numpy(states[non_random]).float().to(self.agent.device)
)
with torch.no_grad():
q_values = self.agent.q_network(state_tensor)
actions[non_random] = q_values.argmax(dim=1).cpu().numpy()
# 随机的环境
if random_mask.any():
actions[random_mask] = np.random.randint(
0, self.agent.num_actions, size=random_mask.sum()
)
return actions
def evaluate(self): def evaluate(self):
"""评估智能体."""
rewards = [] rewards = []
for _ in range(self.num_eval_episodes): for _ in range(self.num_eval_episodes):
state, _ = self.eval_env.reset() state, _ = self.eval_env.reset()
episode_reward = 0 ep_reward = 0
done = False done = False
while not done: while not done:
action = self.agent.select_action(state, evaluate=True) action = self.agent.select_action(state, evaluate=True)
state, reward, terminated, truncated, _ = self.eval_env.step(action) state, reward, terminated, truncated, _ = self.eval_env.step(action)
done = terminated or truncated done = terminated or truncated
episode_reward += reward ep_reward += reward
rewards.append(ep_reward)
rewards.append(episode_reward)
return np.mean(rewards) return np.mean(rewards)
def train(self, total_steps):
n = self.num_envs
device = self.agent.device
envs = self.envs
print(f"开始训练: {total_steps:,} 步, {n} 并行环境, 每步训练 {self.train_steps_per_update}", flush=True)
print("=" * 60, flush=True)
states, _ = envs.reset()
ep_rewards = np.zeros(n, dtype=np.float32)
ep_count = 0
start_time = time.time()
step = 0
while step < total_steps:
if step < self.warmup_steps:
actions = np.array([envs.single_action_space.sample() for _ in range(n)])
else:
actions = self.agent.batch_select_actions(states, self.agent.epsilon)
next_states, rewards, terminateds, truncateds, _ = envs.step(actions)
dones = np.logical_or(terminateds, truncateds)
# 向量化批量添加经验
self.agent.replay_buffer.add_batch(states, actions, rewards, next_states, dones)
ep_rewards += rewards
for i in range(n):
if dones[i]:
self.episode_rewards.append(ep_rewards[i])
ep_count += 1
ep_rewards[i] = 0
step += n
states = next_states
if step >= self.warmup_steps:
for _ in range(self.train_steps_per_update):
self.agent.train_step()
if ep_count > 0 and ep_count % 20 == 0:
avg_r = np.mean(self.episode_rewards) if self.episode_rewards else 0
elapsed = time.time() - start_time
fps = step / elapsed
lr = self.agent.optimizer.param_groups[0]["lr"]
print(f"Step:{step:>10,} | Ep:{ep_count:>5} | AvgR:{avg_r:>7.1f} | "
f"Eps:{self.agent.epsilon:.3f} | LR:{lr:.2e} | FPS:{fps:.0f}", flush=True)
if step % self.eval_freq == 0 and step > 0:
eval_r = self.evaluate()
self.eval_rewards.append((step, eval_r))
print(f"\n[评估] Step:{step:>10,} | 平均回报:{eval_r:.1f}\n", flush=True)
if eval_r > self.best_eval_reward:
self.best_eval_reward = eval_r
self.agent.save(f"{self.save_dir}/dqn_best.pt")
print(f"最佳模型已更新 (回报: {eval_r:.1f})", flush=True)
if step % self.save_freq == 0 and step > 0:
self.agent.save(f"{self.save_dir}/dqn_step_{step}.pt")
total_time = time.time() - start_time
print("\n" + "=" * 60, flush=True)
print(f"训练完成!总时间: {total_time:.1f} 秒 | FPS: {total_steps/total_time:.0f}", flush=True)
print(f"最佳评估回报: {self.best_eval_reward:.1f}", flush=True)
self.agent.save(f"{self.save_dir}/dqn_final.pt")
print("训练器就绪", flush=True)
# ── 主入口 ──
def main(): def main():
parser = argparse.ArgumentParser(description="Parallel DQN for Space Invaders") # ── 超参数 ──
ENV_ID = "ALE/SpaceInvaders-v5"
N_ENVS = 64
TOTAL_STEPS = 2_000_000
LR = 1e-4
GAMMA = 0.99
BATCH_SIZE = 2048
BUFFER_SIZE = 1_000_000
EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 4_000_000
TARGET_UPDATE = 5000
LR_DECAY_STEPS = 5_000_000
LR_DECAY_FACTOR = 0.5
WARMUP_STEPS = 50_000
EVAL_FREQ = 200000
EVAL_EPISODES = 10
SAVE_FREQ = 500000
SEED = 42
SAVE_DIR = os.path.join(PROJECT_ROOT, "models")
# 并行参数 TRAIN_STEPS_PER_UPDATE = 4
parser.add_argument("--n-envs", type=int, default=8, help="并行环境数") USE_AMP = True
USE_COMPILE = True
USE_DUELING = True
USE_DOUBLE = True
USE_PER = True
# 训练参数 os.makedirs(SAVE_DIR, exist_ok=True)
parser.add_argument("--env", type=str, default="ALE/SpaceInvaders-v5")
parser.add_argument("--steps", type=int, default=10_000_000, help="总训练步数")
parser.add_argument("--lr", type=float, default=5e-5, help="学习率")
parser.add_argument("--gamma", type=float, default=0.99, help="折扣因子")
parser.add_argument("--batch-size", type=int, default=64, help="批次大小")
parser.add_argument("--buffer-size", type=int, default=500_000, help="回放缓冲区大小")
# ε-greedy print(f"配置: {TOTAL_STEPS/1e6:.0f}M 步, {N_ENVS} 并行环境", flush=True)
parser.add_argument("--epsilon-start", type=float, default=1.0) print(f"每步训练 {TRAIN_STEPS_PER_UPDATE} 次, Batch {BATCH_SIZE}", flush=True)
parser.add_argument("--epsilon-end", type=float, default=0.01) print(f"AMP: {USE_AMP}, torch.compile: {USE_COMPILE}", flush=True)
parser.add_argument("--epsilon-decay", type=int, default=2_000_000) print(f"模型保存: {SAVE_DIR}", flush=True)
# 网络 torch.manual_seed(SEED)
parser.add_argument("--target-update", type=int, default=1000) np.random.seed(SEED)
parser.add_argument("--double-dqn", action="store_true", default=True) import platform
parser.add_argument("--dueling", action="store_true", default=True)
# 学习率
parser.add_argument("--lr-decay-steps", type=int, default=5_000_000)
parser.add_argument("--lr-decay-factor", type=float, default=0.5)
parser.add_argument("--warmup-steps", type=int, default=10_000)
# 评估
parser.add_argument("--eval-freq", type=int, default=50000)
parser.add_argument("--eval-episodes", type=int, default=10)
parser.add_argument("--save-freq", type=int, default=100000)
# 优先回放
parser.add_argument("--prioritized", action="store_true", default=True)
# 其他
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--save-dir", type=str, default="models")
parser.add_argument("--log-dir", type=str, default="logs")
args = parser.parse_args()
# 随机种子
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# 设备
device = get_device() device = get_device()
print(f"使用设备: {device}", flush=True)
# 创建并行训练环境 from gymnasium.vector import SyncVectorEnv
print(f"创建 {args.n_envs} 个并行训练环境...") env_fns = [_make_env_fn(ENV_ID) for _ in range(N_ENVS)]
try: envs = SyncVectorEnv(env_fns)
from gymnasium.vector import AsyncVectorEnv print(f"SyncVectorEnv: {envs.num_envs} 个环境", flush=True)
env_fns = [_make_env_fn(args.env) for _ in range(args.n_envs)]
envs = AsyncVectorEnv(env_fns, shared_memory=True)
except ImportError:
print("AsyncVectorEnv 不可用,回退到 SyncVectorEnv")
from gymnasium.vector import SyncVectorEnv
env_fns = [_make_env_fn(args.env) for _ in range(args.n_envs)]
envs = SyncVectorEnv(env_fns)
# 创建评估环境(单环境)
eval_env = make_env(args.env, gray_scale=True, resize=True, frame_stack=4)
eval_env = make_env(ENV_ID, gray_scale=True, resize=True, frame_stack=4)
num_actions = envs.single_action_space.n num_actions = envs.single_action_space.n
print(f"动作空间: {num_actions}") print(f"动作空间: {num_actions}", flush=True)
print(f"实际环境数: {envs.num_envs}")
state_shape = (4, 84, 84) state_shape = (4, 84, 84)
# 创建网络 if USE_DUELING:
if args.dueling:
print("使用 Dueling Double DQN")
q_network = DuelingQNetwork(state_shape, num_actions).to(device) q_network = DuelingQNetwork(state_shape, num_actions).to(device)
target_network = DuelingQNetwork(state_shape, num_actions).to(device) target_network = DuelingQNetwork(state_shape, num_actions).to(device)
print(f"Dueling DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数", flush=True)
else: else:
print("使用标准 DQN")
q_network = QNetwork(state_shape, num_actions).to(device) q_network = QNetwork(state_shape, num_actions).to(device)
target_network = QNetwork(state_shape, num_actions).to(device) target_network = QNetwork(state_shape, num_actions).to(device)
print(f"标准 DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数", flush=True)
if USE_COMPILE and hasattr(torch, 'compile'):
print("应用 torch.compile 加速...", flush=True)
q_network = torch.compile(q_network)
target_network = torch.compile(target_network)
print("torch.compile 完成", flush=True)
target_network.load_state_dict(q_network.state_dict()) target_network.load_state_dict(q_network.state_dict())
target_network.eval() target_network.eval()
print(f"网络参数量: {sum(p.numel() for p in q_network.parameters()):,}") if USE_PER:
replay_buffer = PrioritizedReplayBuffer(BUFFER_SIZE, state_shape, device)
# 回放缓冲区 print("优先经验回放 (Pinned Memory)", flush=True)
if args.prioritized:
print("使用优先经验回放")
replay_buffer = PrioritizedReplayBuffer(args.buffer_size, state_shape, device)
else: else:
print("使用标准经验回放") replay_buffer = ReplayBuffer(BUFFER_SIZE, state_shape, device)
replay_buffer = ReplayBuffer(args.buffer_size, state_shape, device) print("标准经验回放 (Pinned Memory)", flush=True)
# 创建 Agent
from src.agent import DQNAgent
agent = DQNAgent( agent = DQNAgent(
q_network=q_network, q_network=q_network,
@@ -326,45 +236,70 @@ def main():
replay_buffer=replay_buffer, replay_buffer=replay_buffer,
device=device, device=device,
num_actions=num_actions, num_actions=num_actions,
gamma=args.gamma, gamma=GAMMA,
lr=args.lr, lr=LR,
epsilon_start=args.epsilon_start, epsilon_start=EPSILON_START,
epsilon_end=args.epsilon_end, epsilon_end=EPSILON_END,
epsilon_decay_steps=args.epsilon_decay, epsilon_decay_steps=EPSILON_DECAY,
target_update_freq=args.target_update, target_update_freq=TARGET_UPDATE,
batch_size=args.batch_size, batch_size=BATCH_SIZE,
double_dqn=args.double_dqn, double_dqn=USE_DOUBLE,
lr_decay_steps=args.lr_decay_steps, lr_decay_steps=LR_DECAY_STEPS,
lr_decay_factor=args.lr_decay_factor, lr_decay_factor=LR_DECAY_FACTOR,
warmup_steps=args.warmup_steps, warmup_steps=WARMUP_STEPS,
use_amp=USE_AMP,
) )
print(f"Agent 创建完成 (AMP: {USE_AMP})", flush=True)
# 创建训练器
trainer = ParallelTrainer( trainer = ParallelTrainer(
agent=agent, agent=agent,
envs=envs, envs=envs,
eval_env=eval_env, eval_env=eval_env,
num_envs=args.n_envs, num_envs=N_ENVS,
save_dir=args.save_dir, save_dir=SAVE_DIR,
eval_freq=args.eval_freq, eval_freq=EVAL_FREQ,
save_freq=args.save_freq, save_freq=SAVE_FREQ,
num_eval_episodes=args.eval_episodes, num_eval_episodes=EVAL_EPISODES,
warmup_steps=args.warmup_steps, warmup_steps=WARMUP_STEPS,
train_steps_per_update=TRAIN_STEPS_PER_UPDATE,
) )
# 打印配置 print("\n" + "=" * 60, flush=True)
print("\n训练配置:") print(f"开始 10M 步并行训练(全优化版)", flush=True)
print(f" 并行环境数: {args.n_envs}") print(f" GPU: {device}", flush=True)
print(f" 总步数: {args.steps:,}") print(f" 并行环境: {N_ENVS}", flush=True)
print(f" 学习率: {args.lr} (Warmup: {args.warmup_steps:,} 步)") print(f" Batch Size: {BATCH_SIZE}", flush=True)
print(f" ε衰减: {args.epsilon_start} -> {args.epsilon_end} ({args.epsilon_decay:,} 步)") print(f" 每步训练: {TRAIN_STEPS_PER_UPDATE} ", flush=True)
print(f" 批次大小: {args.batch_size}") print(f" AMP 混合精度: {USE_AMP}", flush=True)
print(f" 缓冲区大小: {args.buffer_size:,}") print(f" torch.compile: {USE_COMPILE}", flush=True)
print(f" Double DQN: {args.double_dqn}") print(f" Dueling: {USE_DUELING}", flush=True)
print(f" Dueling: {args.dueling}") print(f" Double DQN: {USE_DOUBLE}", flush=True)
print("=" * 60) print(f" PER: {USE_PER}", flush=True)
print("=" * 60 + "\n", flush=True)
trainer.train(args.steps) trainer.train(TOTAL_STEPS)
# ── 评估最佳模型 ──
print("\n加载最佳模型...", flush=True)
agent.load(f"{SAVE_DIR}/dqn_best.pt")
print("\n评估中...", flush=True)
all_rewards = []
for i in range(20):
state, _ = eval_env.reset()
ep_r = 0
done = False
while not done:
action = agent.select_action(state, evaluate=True)
state, reward, terminated, truncated, _ = eval_env.step(action)
done = terminated or truncated
ep_r += reward
all_rewards.append(ep_r)
print(f" Episode {i+1:>2}: {ep_r:.1f}", flush=True)
print(f"\n结果: 平均 {np.mean(all_rewards):.2f} +/- {np.std(all_rewards):.2f}", flush=True)
print(f"最佳: {max(all_rewards):.1f} | 最差: {min(all_rewards):.1f}", flush=True)
print(f"中位数: {np.median(all_rewards):.1f}", flush=True)
if __name__ == "__main__": if __name__ == "__main__":