perf: 为PPO和DQN添加GPU优化——AMP混合精度、pinned memory、torch.compile
- PPO (CW1_id_name): 添加 AMP GradScaler + autocast 混合精度训练,pinned memory 加速 CPU→GPU 传输,torch.compile JIT 编译支持,调整默认超参适配 RTX 5090 - DQN (Atari): 添加 AMP 混合精度、pinned memory 回放缓冲区、向量化批量添加经验 (add_batch) 和批量动作选择 (batch_select_actions),消除 Python 循环 - train_parallel.py: 重写为无缓冲脚本,集成所有优化,64 并行环境 + 每步 4 次训练更新 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,7 @@ of PPO* (Huang et al. 2022) and *RAD* (Laskin et al. 2020):
|
||||
- Linear schedule for clip range (clip_init -> clip_floor)
|
||||
- Random-shift data augmentation on observations during the update
|
||||
- Linear annealing of learning rate and entropy coefficient with floors
|
||||
- AMP mixed precision training for GPU acceleration
|
||||
|
||||
Public API:
|
||||
- PPOAgent.act(obs) -> (action, log_prob, value)
|
||||
@@ -48,11 +49,17 @@ class PPOAgent:
|
||||
target_kl=None,
|
||||
use_data_aug=False,
|
||||
aug_pad=4,
|
||||
use_amp=True,
|
||||
):
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dev = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device(dev) if isinstance(dev, str) else dev
|
||||
self.net = ActorCritic(n_actions=n_actions).to(self.device)
|
||||
self.optim = optim.Adam(self.net.parameters(), lr=lr, eps=1e-5)
|
||||
|
||||
# AMP 混合精度训练
|
||||
self.use_amp = use_amp and self.device.type == 'cuda'
|
||||
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
|
||||
|
||||
# Save initial values for scheduling
|
||||
self.lr_init = lr
|
||||
self.clip_init = clip
|
||||
@@ -84,7 +91,8 @@ class PPOAgent:
|
||||
def act_batch(self, obs_batch):
|
||||
"""Vectorised act for n_envs obs at once."""
|
||||
obs_t = torch.as_tensor(obs_batch, device=self.device)
|
||||
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
action, log_prob, _, value = self.net.get_action_and_value(obs_t)
|
||||
return (
|
||||
action.cpu().numpy(),
|
||||
log_prob.cpu().numpy(),
|
||||
@@ -94,7 +102,8 @@ class PPOAgent:
|
||||
@torch.no_grad()
|
||||
def evaluate_value_batch(self, obs_batch):
|
||||
obs_t = torch.as_tensor(obs_batch, device=self.device)
|
||||
_, value = self.net(obs_t)
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
_, value = self.net(obs_t)
|
||||
return value.cpu().numpy()
|
||||
|
||||
def _random_shift(self, obs):
|
||||
@@ -166,38 +175,43 @@ class PPOAgent:
|
||||
if self.use_data_aug:
|
||||
b_obs = self._random_shift(b_obs)
|
||||
|
||||
_, new_logp, entropy, value = self.net.get_action_and_value(
|
||||
b_obs, b_actions
|
||||
)
|
||||
# AMP 前向传播
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
_, new_logp, entropy, value = self.net.get_action_and_value(
|
||||
b_obs, b_actions
|
||||
)
|
||||
|
||||
log_ratio = new_logp - b_old_logp
|
||||
ratio = log_ratio.exp()
|
||||
log_ratio = new_logp - b_old_logp
|
||||
ratio = log_ratio.exp()
|
||||
|
||||
# Clipped policy loss
|
||||
surr1 = ratio * b_adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
|
||||
policy_loss = -torch.min(surr1, surr2).mean()
|
||||
# Clipped policy loss
|
||||
surr1 = ratio * b_adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
|
||||
policy_loss = -torch.min(surr1, surr2).mean()
|
||||
|
||||
# Clipped value loss (refinement #1, SB3 standard)
|
||||
v_clipped = b_old_values + torch.clamp(
|
||||
value - b_old_values, -self.clip, self.clip
|
||||
)
|
||||
v_loss_unclipped = (value - b_ret).pow(2)
|
||||
v_loss_clipped = (v_clipped - b_ret).pow(2)
|
||||
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
||||
# Clipped value loss (refinement #1, SB3 standard)
|
||||
v_clipped = b_old_values + torch.clamp(
|
||||
value - b_old_values, -self.clip, self.clip
|
||||
)
|
||||
v_loss_unclipped = (value - b_ret).pow(2)
|
||||
v_loss_clipped = (v_clipped - b_ret).pow(2)
|
||||
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
||||
|
||||
entropy_loss = entropy.mean()
|
||||
entropy_loss = entropy.mean()
|
||||
|
||||
loss = (
|
||||
policy_loss
|
||||
+ self.vf_coef * value_loss
|
||||
- self.ent_coef * entropy_loss
|
||||
)
|
||||
loss = (
|
||||
policy_loss
|
||||
+ self.vf_coef * value_loss
|
||||
- self.ent_coef * entropy_loss
|
||||
)
|
||||
|
||||
# AMP 反向传播
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optim)
|
||||
nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
|
||||
self.optim.step()
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
|
||||
with torch.no_grad():
|
||||
approx_kl = ((ratio - 1) - log_ratio).mean().item()
|
||||
|
||||
@@ -4,6 +4,8 @@ Uses CleanRL's indexing convention:
|
||||
dones[t] flags whether obs[t] is the FIRST obs of a fresh episode
|
||||
(i.e., the previous action terminated). GAE then uses dones[t+1]
|
||||
as the mask for V(s_{t+1}) at time t.
|
||||
|
||||
Supports pinned memory for faster CPU→GPU transfer.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -16,6 +18,7 @@ class VecRolloutBuffer:
|
||||
self.obs_shape = obs_shape
|
||||
self.device = device
|
||||
|
||||
# 主存储在 GPU 上
|
||||
self.obs = torch.zeros(
|
||||
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, device=device
|
||||
)
|
||||
@@ -28,16 +31,35 @@ class VecRolloutBuffer:
|
||||
self.advantages = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
|
||||
self.returns = torch.zeros((n_steps, n_envs), dtype=torch.float32, device=device)
|
||||
|
||||
# Pinned memory 缓冲区(加速 CPU→GPU 传输)
|
||||
self._obs_pin = torch.zeros(
|
||||
(n_steps, n_envs, *obs_shape), dtype=torch.uint8, pin_memory=True
|
||||
)
|
||||
self._actions_pin = torch.zeros((n_steps, n_envs), dtype=torch.long, pin_memory=True)
|
||||
self._log_probs_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
|
||||
self._rewards_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
|
||||
self._values_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
|
||||
self._dones_pin = torch.zeros((n_steps, n_envs), dtype=torch.float32, pin_memory=True)
|
||||
|
||||
self.ptr = 0
|
||||
|
||||
def add(self, obs, action, log_prob, reward, value, done):
|
||||
i = self.ptr
|
||||
self.obs[i] = torch.as_tensor(obs, device=self.device)
|
||||
self.actions[i] = torch.as_tensor(action, device=self.device, dtype=torch.long)
|
||||
self.log_probs[i] = torch.as_tensor(log_prob, device=self.device, dtype=torch.float32)
|
||||
self.rewards[i] = torch.as_tensor(reward, device=self.device, dtype=torch.float32)
|
||||
self.values[i] = torch.as_tensor(value, device=self.device, dtype=torch.float32)
|
||||
self.dones[i] = torch.as_tensor(done, device=self.device, dtype=torch.float32)
|
||||
# 先写入 pinned memory,再 non-blocking 传输到 GPU
|
||||
self._obs_pin[i] = torch.as_tensor(obs)
|
||||
self._actions_pin[i] = torch.as_tensor(action, dtype=torch.long)
|
||||
self._log_probs_pin[i] = torch.as_tensor(log_prob, dtype=torch.float32)
|
||||
self._rewards_pin[i] = torch.as_tensor(reward, dtype=torch.float32)
|
||||
self._values_pin[i] = torch.as_tensor(value, dtype=torch.float32)
|
||||
self._dones_pin[i] = torch.as_tensor(done, dtype=torch.float32)
|
||||
|
||||
# non_blocking 传输到 GPU
|
||||
self.obs[i] = self._obs_pin[i].to(self.device, non_blocking=True)
|
||||
self.actions[i] = self._actions_pin[i].to(self.device, non_blocking=True)
|
||||
self.log_probs[i] = self._log_probs_pin[i].to(self.device, non_blocking=True)
|
||||
self.rewards[i] = self._rewards_pin[i].to(self.device, non_blocking=True)
|
||||
self.values[i] = self._values_pin[i].to(self.device, non_blocking=True)
|
||||
self.dones[i] = self._dones_pin[i].to(self.device, non_blocking=True)
|
||||
self.ptr += 1
|
||||
|
||||
def compute_gae(self, last_value, last_done, gamma=0.99, lam=0.95):
|
||||
|
||||
@@ -5,6 +5,11 @@ Usage (Windows):
|
||||
python train_vec.py --n-envs 4 --total-steps 500000 --run-name vec_main \
|
||||
--anneal-lr --anneal-ent --reward-clip 1.0
|
||||
|
||||
Usage (Linux server with RTX 5090):
|
||||
python train_vec.py --n-envs 16 --total-steps 2000000 --run-name vec_main \
|
||||
--n-steps 512 --batch-size 512 --n-epochs 10 \
|
||||
--anneal-lr --anneal-ent --reward-clip 1.0 --use-amp
|
||||
|
||||
The ``if __name__ == "__main__"`` guard at the bottom is mandatory on
|
||||
Windows for AsyncVectorEnv (otherwise child processes infinite-spawn).
|
||||
"""
|
||||
@@ -26,11 +31,11 @@ from src.vec_rollout_buffer import VecRolloutBuffer
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--total-steps", type=int, default=3_000_000)
|
||||
p.add_argument("--n-envs", type=int, default=8)
|
||||
p.add_argument("--n-steps", type=int, default=256)
|
||||
p.add_argument("--n-epochs", type=int, default=6)
|
||||
p.add_argument("--batch-size", type=int, default=128)
|
||||
p.add_argument("--total-steps", type=int, default=2_000_000)
|
||||
p.add_argument("--n-envs", type=int, default=16)
|
||||
p.add_argument("--n-steps", type=int, default=512)
|
||||
p.add_argument("--n-epochs", type=int, default=10)
|
||||
p.add_argument("--batch-size", type=int, default=512)
|
||||
p.add_argument("--lr", type=float, default=2.5e-4)
|
||||
p.add_argument("--gamma", type=float, default=0.99)
|
||||
p.add_argument("--lam", type=float, default=0.95)
|
||||
@@ -56,6 +61,10 @@ def parse_args():
|
||||
help="Apply random-shift augmentation to obs during PPO update")
|
||||
p.add_argument("--sync-mode", action="store_true",
|
||||
help="Use SyncVectorEnv (debug mode)")
|
||||
p.add_argument("--use-amp", action="store_true",
|
||||
help="Use AMP mixed precision training for GPU acceleration")
|
||||
p.add_argument("--use-compile", action="store_true",
|
||||
help="Use torch.compile for JIT compilation acceleration")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
@@ -94,7 +103,14 @@ def main():
|
||||
clip_floor=args.clip_floor,
|
||||
target_kl=args.target_kl,
|
||||
use_data_aug=args.use_data_aug,
|
||||
use_amp=args.use_amp,
|
||||
)
|
||||
|
||||
# torch.compile JIT 编译加速
|
||||
if args.use_compile and hasattr(torch, 'compile'):
|
||||
print("应用 torch.compile 加速...")
|
||||
agent.net = torch.compile(agent.net)
|
||||
print("torch.compile 完成")
|
||||
buffer = VecRolloutBuffer(
|
||||
n_steps=args.n_steps,
|
||||
n_envs=args.n_envs,
|
||||
@@ -117,6 +133,8 @@ def main():
|
||||
print(f"clip_floor={args.clip_floor} target_kl={args.target_kl} "
|
||||
f"use_data_aug={args.use_data_aug}")
|
||||
print(f"n_epochs={args.n_epochs} batch_size={args.batch_size}")
|
||||
print(f"AMP: {args.use_amp}")
|
||||
print(f"Compile: {args.use_compile}")
|
||||
print(f"Device: {agent.device}")
|
||||
print(f"Logs: {run_dir}")
|
||||
print(f"Ckpts: {ckpt_dir}")
|
||||
|
||||
@@ -12,36 +12,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"导入完成\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from collections import deque\n",
|
||||
"\n",
|
||||
"# notebooks/ 的上级目录即项目根目录\n",
|
||||
"sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), \"..\")))\n",
|
||||
"\n",
|
||||
"from src.network import QNetwork, DuelingQNetwork\n",
|
||||
"from src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer\n",
|
||||
"from src.agent import DQNAgent\n",
|
||||
"from src.utils import make_env, get_device\n",
|
||||
"\n",
|
||||
"print(\"导入完成\")"
|
||||
]
|
||||
"outputs": [],
|
||||
"source": "import sys\nimport os\nimport time\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom collections import deque\nfrom multiprocessing import Process, Queue\n\n# notebooks/ 的上级目录即项目根目录\nsys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), \"..\")))\n\nfrom src.network import QNetwork, DuelingQNetwork\nfrom src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer\nfrom src.agent import DQNAgent\nfrom src.utils import make_env, get_device\n\nprint(\"导入完成\")"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -69,134 +43,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"训练器就绪\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 并行训练器 ──\n",
|
||||
"class ParallelTrainer:\n",
|
||||
" def __init__(\n",
|
||||
" self, agent, envs, eval_env, num_envs,\n",
|
||||
" save_dir=\"models\", eval_freq=10000, save_freq=50000,\n",
|
||||
" num_eval_episodes=10, warmup_steps=10000,\n",
|
||||
" ):\n",
|
||||
" self.agent = agent\n",
|
||||
" self.envs = envs\n",
|
||||
" self.eval_env = eval_env\n",
|
||||
" self.num_envs = num_envs\n",
|
||||
" self.save_dir = save_dir\n",
|
||||
" self.eval_freq = eval_freq\n",
|
||||
" self.save_freq = save_freq\n",
|
||||
" self.num_eval_episodes = num_eval_episodes\n",
|
||||
" self.warmup_steps = warmup_steps\n",
|
||||
" self.episode_rewards = deque(maxlen=100)\n",
|
||||
" self.eval_rewards = []\n",
|
||||
" self.best_eval_reward = -float(\"inf\")\n",
|
||||
"\n",
|
||||
" def _batch_select_actions(self, states):\n",
|
||||
" epsilon = self.agent.epsilon\n",
|
||||
" n = len(states)\n",
|
||||
" random_mask = np.random.random(n) < epsilon\n",
|
||||
" actions = np.zeros(n, dtype=np.int64)\n",
|
||||
" non_random = ~random_mask\n",
|
||||
" if non_random.any():\n",
|
||||
" state_tensor = torch.from_numpy(states[non_random]).float().to(self.agent.device)\n",
|
||||
" with torch.no_grad():\n",
|
||||
" q_values = self.agent.q_network(state_tensor)\n",
|
||||
" actions[non_random] = q_values.argmax(dim=1).cpu().numpy()\n",
|
||||
" if random_mask.any():\n",
|
||||
" actions[random_mask] = np.random.randint(0, self.agent.num_actions, size=random_mask.sum())\n",
|
||||
" return actions\n",
|
||||
"\n",
|
||||
" def evaluate(self):\n",
|
||||
" rewards = []\n",
|
||||
" for _ in range(self.num_eval_episodes):\n",
|
||||
" state, _ = self.eval_env.reset()\n",
|
||||
" ep_reward = 0\n",
|
||||
" done = False\n",
|
||||
" while not done:\n",
|
||||
" action = self.agent.select_action(state, evaluate=True)\n",
|
||||
" state, reward, terminated, truncated, _ = self.eval_env.step(action)\n",
|
||||
" done = terminated or truncated\n",
|
||||
" ep_reward += reward\n",
|
||||
" rewards.append(ep_reward)\n",
|
||||
" return np.mean(rewards)\n",
|
||||
"\n",
|
||||
" def train(self, total_steps):\n",
|
||||
" n = self.num_envs\n",
|
||||
" device = self.agent.device\n",
|
||||
" envs = self.envs\n",
|
||||
"\n",
|
||||
" print(f\"开始训练: {total_steps:,} 步, {n} 并行环境\")\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
" states, _ = envs.reset()\n",
|
||||
" ep_rewards = np.zeros(n, dtype=np.float32)\n",
|
||||
" ep_count = 0\n",
|
||||
" start_time = time.time()\n",
|
||||
" step = 0\n",
|
||||
"\n",
|
||||
" while step < total_steps:\n",
|
||||
" if step < self.warmup_steps:\n",
|
||||
" actions = np.array([envs.single_action_space.sample() for _ in range(n)])\n",
|
||||
" else:\n",
|
||||
" actions = self._batch_select_actions(states)\n",
|
||||
"\n",
|
||||
" next_states, rewards, terminateds, truncateds, _ = envs.step(actions)\n",
|
||||
" dones = np.logical_or(terminateds, truncateds)\n",
|
||||
"\n",
|
||||
" for i in range(n):\n",
|
||||
" self.agent.replay_buffer.add(states[i], actions[i], rewards[i], next_states[i], dones[i])\n",
|
||||
"\n",
|
||||
" ep_rewards += rewards\n",
|
||||
"\n",
|
||||
" for i in range(n):\n",
|
||||
" if dones[i]:\n",
|
||||
" self.episode_rewards.append(ep_rewards[i])\n",
|
||||
" ep_count += 1\n",
|
||||
" ep_rewards[i] = 0\n",
|
||||
"\n",
|
||||
" step += n\n",
|
||||
" states = next_states\n",
|
||||
"\n",
|
||||
" if step >= self.warmup_steps:\n",
|
||||
" self.agent.train_step()\n",
|
||||
"\n",
|
||||
" if ep_count > 0 and ep_count % 20 == 0:\n",
|
||||
" avg_r = np.mean(self.episode_rewards) if self.episode_rewards else 0\n",
|
||||
" elapsed = time.time() - start_time\n",
|
||||
" fps = step / elapsed\n",
|
||||
" lr = self.agent.optimizer.param_groups[0][\"lr\"]\n",
|
||||
" print(f\"Step:{step:>10,} | Ep:{ep_count:>5} | AvgR:{avg_r:>7.1f} | \"\n",
|
||||
" f\"Eps:{self.agent.epsilon:.3f} | LR:{lr:.2e} | FPS:{fps:.0f}\")\n",
|
||||
"\n",
|
||||
" if step % self.eval_freq == 0 and step > 0:\n",
|
||||
" eval_r = self.evaluate()\n",
|
||||
" self.eval_rewards.append((step, eval_r))\n",
|
||||
" print(f\"\\n[评估] Step:{step:>10,} | 平均回报:{eval_r:.1f}\\n\")\n",
|
||||
" if eval_r > self.best_eval_reward:\n",
|
||||
" self.best_eval_reward = eval_r\n",
|
||||
" self.agent.save(f\"{self.save_dir}/dqn_best.pt\")\n",
|
||||
"\n",
|
||||
" if step % self.save_freq == 0:\n",
|
||||
" self.agent.save(f\"{self.save_dir}/dqn_step_{step}.pt\")\n",
|
||||
"\n",
|
||||
" total_time = time.time() - start_time\n",
|
||||
" print(\"\\n\" + \"=\" * 60)\n",
|
||||
" print(f\"训练完成!总时间: {total_time:.1f} 秒 | FPS: {total_steps/total_time:.0f}\")\n",
|
||||
" print(f\"最佳评估回报: {self.best_eval_reward:.1f}\")\n",
|
||||
" self.agent.save(f\"{self.save_dir}/dqn_final.pt\")\n",
|
||||
"\n",
|
||||
"print(\"训练器就绪\")"
|
||||
]
|
||||
"outputs": [],
|
||||
"source": "# ── 并行训练器(优化版) ──\nclass ParallelTrainer:\n def __init__(\n self, agent, envs, eval_env, num_envs,\n save_dir=\"models\", eval_freq=10000, save_freq=50000,\n num_eval_episodes=10, warmup_steps=10000,\n train_steps_per_update=1,\n ):\n self.agent = agent\n self.envs = envs\n self.eval_env = eval_env\n self.num_envs = num_envs\n self.save_dir = save_dir\n self.eval_freq = eval_freq\n self.save_freq = save_freq\n self.num_eval_episodes = num_eval_episodes\n self.warmup_steps = warmup_steps\n self.train_steps_per_update = train_steps_per_update\n self.episode_rewards = deque(maxlen=100)\n self.eval_rewards = []\n self.best_eval_reward = -float(\"inf\")\n\n def evaluate(self):\n \"\"\"评估智能体\"\"\"\n rewards = []\n for _ in range(self.num_eval_episodes):\n state, _ = self.eval_env.reset()\n ep_reward = 0\n done = False\n while not done:\n action = self.agent.select_action(state, evaluate=True)\n state, reward, terminated, truncated, _ = self.eval_env.step(action)\n done = terminated or truncated\n ep_reward += reward\n rewards.append(ep_reward)\n return np.mean(rewards)\n\n def train(self, total_steps):\n n = self.num_envs\n device = self.agent.device\n envs = self.envs\n\n print(f\"开始训练: {total_steps:,} 步, {n} 并行环境, 每步训练 {self.train_steps_per_update} 次\")\n print(\"=\" * 60)\n\n states, _ = envs.reset()\n ep_rewards = np.zeros(n, dtype=np.float32)\n ep_count = 0\n start_time = time.time()\n step = 0\n\n while step < total_steps:\n # ── 动作选择(向量化) ──\n if step < self.warmup_steps:\n actions = np.array([envs.single_action_space.sample() for _ in range(n)])\n else:\n actions = self.agent.batch_select_actions(states, self.agent.epsilon)\n\n # ── 环境步进 ──\n next_states, rewards, terminateds, truncateds, _ = envs.step(actions)\n dones = np.logical_or(terminateds, truncateds)\n\n # ── 向量化批量添加经验(消除 Python 循环) ──\n self.agent.replay_buffer.add_batch(states, actions, rewards, next_states, dones)\n\n ep_rewards += rewards\n\n # ── 处理结束的 episode ──\n for i in range(n):\n if dones[i]:\n self.episode_rewards.append(ep_rewards[i])\n ep_count += 1\n ep_rewards[i] = 0\n\n step += n\n states = next_states\n\n # ── 每步训练多次(提升 GPU 利用率) ──\n if step >= self.warmup_steps:\n for _ in range(self.train_steps_per_update):\n self.agent.train_step()\n\n # ── 进度打印 ──\n if ep_count > 0 and ep_count % 20 == 0:\n avg_r = np.mean(self.episode_rewards) if self.episode_rewards else 0\n elapsed = time.time() - start_time\n fps = step / elapsed\n lr = self.agent.optimizer.param_groups[0][\"lr\"]\n print(f\"Step:{step:>10,} | Ep:{ep_count:>5} | AvgR:{avg_r:>7.1f} | \"\n f\"Eps:{self.agent.epsilon:.3f} | LR:{lr:.2e} | FPS:{fps:.0f}\")\n\n # ── 定期评估 ──\n if step % self.eval_freq == 0 and step > 0:\n eval_r = self.evaluate()\n self.eval_rewards.append((step, eval_r))\n print(f\"\\n[评估] Step:{step:>10,} | 平均回报:{eval_r:.1f}\\n\")\n if eval_r > self.best_eval_reward:\n self.best_eval_reward = eval_r\n self.agent.save(f\"{self.save_dir}/dqn_best.pt\")\n\n # ── 定期保存 ──\n if step % self.save_freq == 0 and step > 0:\n self.agent.save(f\"{self.save_dir}/dqn_step_{step}.pt\")\n\n total_time = time.time() - start_time\n print(\"\\n\" + \"=\" * 60)\n print(f\"训练完成!总时间: {total_time:.1f} 秒 | FPS: {total_steps/total_time:.0f}\")\n print(f\"最佳评估回报: {self.best_eval_reward:.1f}\")\n self.agent.save(f\"{self.save_dir}/dqn_final.pt\")\n\nprint(\"训练器就绪\")"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@@ -212,37 +62,7 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ── 可修改的超参数 ──\n",
|
||||
"\n",
|
||||
"ENV_ID = \"ALE/SpaceInvaders-v5\"\n",
|
||||
"N_ENVS = 24 # 25 核 CPU,留 1 核给主进程\n",
|
||||
"TOTAL_STEPS = 10_000_000 # 总步数\n",
|
||||
"LR = 3e-5 # 学习率(大 batch 配低 lr 更稳定)\n",
|
||||
"GAMMA = 0.99 # 折扣因子\n",
|
||||
"BATCH_SIZE = 512 # RTX 5090 跑大 batch 才不浪费\n",
|
||||
"BUFFER_SIZE = 1_000_000 # 回放缓冲区\n",
|
||||
"EPSILON_START = 1.0\n",
|
||||
"EPSILON_END = 0.01\n",
|
||||
"EPSILON_DECAY = 4_000_000 # ε衰减步数(24 环境探索效率高,延长探索期)\n",
|
||||
"TARGET_UPDATE = 2000\n",
|
||||
"LR_DECAY_STEPS = 5_000_000\n",
|
||||
"LR_DECAY_FACTOR = 0.5\n",
|
||||
"WARMUP_STEPS = 50_000\n",
|
||||
"EVAL_FREQ = 50000\n",
|
||||
"EVAL_EPISODES = 10\n",
|
||||
"SAVE_FREQ = 200000\n",
|
||||
"SEED = 42\n",
|
||||
"SAVE_DIR = \"models\"\n",
|
||||
"\n",
|
||||
"USE_DUELING = True\n",
|
||||
"USE_DOUBLE = True\n",
|
||||
"USE_PER = True # 优先经验回放\n",
|
||||
"\n",
|
||||
"print(f\"配置: {TOTAL_STEPS/1e6:.0f}M 步, {N_ENVS} 并行环境\")\n",
|
||||
"print(f\"预计环境交互: {TOTAL_STEPS * 4 / 1e6:.0f}M frames\")\n",
|
||||
"print(f\"预计时间 (AutoDL 5090): ~{TOTAL_STEPS / 1000 / N_ENVS / 3600:.1f} 小时\")"
|
||||
]
|
||||
"source": "# ── 可修改的超参数 ──\n\nENV_ID = \"ALE/SpaceInvaders-v5\"\nN_ENVS = 64 # 超配:64 个并行环境,最大化 CPU 利用率\nTOTAL_STEPS = 10_000_000 # 总步数\nLR = 1e-4 # 大 batch 配合稍高 lr\nGAMMA = 0.99 # 折扣因子\nBATCH_SIZE = 2048 # 大 batch 充分利用 RTX 5090\nBUFFER_SIZE = 1_000_000 # 回放缓冲区\nEPSILON_START = 1.0\nEPSILON_END = 0.01\nEPSILON_DECAY = 4_000_000 # ε衰减步数\nTARGET_UPDATE = 5000 # 降低目标网络更新频率\nLR_DECAY_STEPS = 5_000_000\nLR_DECAY_FACTOR = 0.5\nWARMUP_STEPS = 50_000\nEVAL_FREQ = 200000 # 评估频率降低,减少中断\nEVAL_EPISODES = 10\nSAVE_FREQ = 500000\nSEED = 42\nSAVE_DIR = os.path.join(os.path.abspath(os.path.join(os.getcwd(), \"..\")), \"models\")\n\nTRAIN_STEPS_PER_UPDATE = 4 # 每步训练 4 次,提升 GPU 利用率\nUSE_AMP = True # AMP 混合精度训练\nUSE_COMPILE = True # torch.compile 编译加速\n\nUSE_DUELING = True\nUSE_DOUBLE = True\nUSE_PER = True\n\nos.makedirs(SAVE_DIR, exist_ok=True)\n\nprint(f\"配置: {TOTAL_STEPS/1e6:.0f}M 步, {N_ENVS} 并行环境\")\nprint(f\"每步训练 {TRAIN_STEPS_PER_UPDATE} 次, Batch {BATCH_SIZE}\")\nprint(f\"AMP: {USE_AMP}, torch.compile: {USE_COMPILE}\")\nprint(f\"模型保存: {SAVE_DIR}\")"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -288,203 +108,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Dueling DQN: 3,293,863 参数\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DuelingQNetwork(\n",
|
||||
" (conv): Sequential(\n",
|
||||
" (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))\n",
|
||||
" (1): ReLU()\n",
|
||||
" (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
|
||||
" (3): ReLU()\n",
|
||||
" (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
|
||||
" (5): ReLU()\n",
|
||||
" )\n",
|
||||
" (value_stream): Sequential(\n",
|
||||
" (0): Linear(in_features=3136, out_features=512, bias=True)\n",
|
||||
" (1): ReLU()\n",
|
||||
" (2): Linear(in_features=512, out_features=1, bias=True)\n",
|
||||
" )\n",
|
||||
" (advantage_stream): Sequential(\n",
|
||||
" (0): Linear(in_features=3136, out_features=512, bias=True)\n",
|
||||
" (1): ReLU()\n",
|
||||
" (2): Linear(in_features=512, out_features=6, bias=True)\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 网络 ──\n",
|
||||
"state_shape = (4, 84, 84)\n",
|
||||
"\n",
|
||||
"if USE_DUELING:\n",
|
||||
" q_network = DuelingQNetwork(state_shape, num_actions).to(device)\n",
|
||||
" target_network = DuelingQNetwork(state_shape, num_actions).to(device)\n",
|
||||
" print(f\"Dueling DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\n",
|
||||
"else:\n",
|
||||
" q_network = QNetwork(state_shape, num_actions).to(device)\n",
|
||||
" target_network = QNetwork(state_shape, num_actions).to(device)\n",
|
||||
" print(f\"标准 DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\n",
|
||||
"\n",
|
||||
"target_network.load_state_dict(q_network.state_dict())\n",
|
||||
"target_network.eval()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"优先经验回放\n",
|
||||
"Agent 创建完成\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 回放缓冲区 + Agent ──\n",
|
||||
"if USE_PER:\n",
|
||||
" replay_buffer = PrioritizedReplayBuffer(BUFFER_SIZE, state_shape, device)\n",
|
||||
" print(\"优先经验回放\")\n",
|
||||
"else:\n",
|
||||
" replay_buffer = ReplayBuffer(BUFFER_SIZE, state_shape, device)\n",
|
||||
" print(\"标准经验回放\")\n",
|
||||
"\n",
|
||||
"agent = DQNAgent(\n",
|
||||
" q_network=q_network,\n",
|
||||
" target_network=target_network,\n",
|
||||
" replay_buffer=replay_buffer,\n",
|
||||
" device=device,\n",
|
||||
" num_actions=num_actions,\n",
|
||||
" gamma=GAMMA,\n",
|
||||
" lr=LR,\n",
|
||||
" epsilon_start=EPSILON_START,\n",
|
||||
" epsilon_end=EPSILON_END,\n",
|
||||
" epsilon_decay_steps=EPSILON_DECAY,\n",
|
||||
" target_update_freq=TARGET_UPDATE,\n",
|
||||
" batch_size=BATCH_SIZE,\n",
|
||||
" double_dqn=USE_DOUBLE,\n",
|
||||
" lr_decay_steps=LR_DECAY_STEPS,\n",
|
||||
" lr_decay_factor=LR_DECAY_FACTOR,\n",
|
||||
" warmup_steps=WARMUP_STEPS,\n",
|
||||
")\n",
|
||||
"print(\"Agent 创建完成\")"
|
||||
]
|
||||
"outputs": [],
|
||||
"source": "# ── 网络 + torch.compile ──\nstate_shape = (4, 84, 84)\n\nif USE_DUELING:\n q_network = DuelingQNetwork(state_shape, num_actions).to(device)\n target_network = DuelingQNetwork(state_shape, num_actions).to(device)\n print(f\"Dueling DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\nelse:\n q_network = QNetwork(state_shape, num_actions).to(device)\n target_network = QNetwork(state_shape, num_actions).to(device)\n print(f\"标准 DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\n\n# torch.compile 编译加速(PyTorch 2.x)\nif USE_COMPILE and hasattr(torch, 'compile'):\n print(\"应用 torch.compile 加速...\")\n q_network = torch.compile(q_network)\n target_network = torch.compile(target_network)\n print(\"torch.compile 完成\")\n\ntarget_network.load_state_dict(q_network.state_dict())\ntarget_network.eval()"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"============================================================\n",
|
||||
"开始 10M 步并行训练\n",
|
||||
" GPU: cuda\n",
|
||||
" 并行环境: 16\n",
|
||||
" Dueling: True\n",
|
||||
" Double DQN: True\n",
|
||||
" PER: True\n",
|
||||
"============================================================\n",
|
||||
"\n",
|
||||
"开始训练: 10,000,000 步, 16 并行环境\n",
|
||||
"============================================================\n",
|
||||
"Step: 3,232 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:651\n",
|
||||
"Step: 3,248 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
|
||||
"Step: 3,264 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
|
||||
"Step: 3,280 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
|
||||
"Step: 3,296 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
|
||||
"Step: 3,312 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
|
||||
"Step: 3,328 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
|
||||
"Step: 5,776 | Ep: 40 | AvgR: 12.3 | Eps:1.000 | LR:5.00e-05 | FPS:596\n",
|
||||
"Step: 5,792 | Ep: 40 | AvgR: 12.3 | Eps:1.000 | LR:5.00e-05 | FPS:596\n",
|
||||
"Step: 8,144 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,160 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,176 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,192 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,208 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,224 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,240 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,256 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,272 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,288 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,304 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,320 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,336 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,352 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,368 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,384 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,400 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,416 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,432 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 10,912 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:2.90e-07 | FPS:520\n",
|
||||
"Step: 10,928 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:2.95e-07 | FPS:520\n",
|
||||
"Step: 10,944 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:3.00e-07 | FPS:519\n",
|
||||
"Step: 10,960 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:3.05e-07 | FPS:519\n",
|
||||
"Step: 13,280 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.03e-06 | FPS:481\n",
|
||||
"Step: 13,296 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
|
||||
"Step: 13,312 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
|
||||
"Step: 13,328 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
|
||||
"Step: 15,648 | Ep: 120 | AvgR: 12.8 | Eps:1.000 | LR:1.77e-06 | FPS:454\n",
|
||||
"Step: 21,184 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.50e-06 | FPS:434\n",
|
||||
"Step: 21,200 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.50e-06 | FPS:434\n",
|
||||
"Step: 21,216 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.51e-06 | FPS:434\n",
|
||||
"Step: 21,232 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
|
||||
"Step: 21,248 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
|
||||
"Step: 21,264 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
|
||||
"Step: 21,280 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.53e-06 | FPS:434\n",
|
||||
"Step: 21,296 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.54e-06 | FPS:434\n",
|
||||
"Step: 23,824 | Ep: 180 | AvgR: 14.4 | Eps:1.000 | LR:4.33e-06 | FPS:432\n",
|
||||
"Step: 26,144 | Ep: 200 | AvgR: 14.3 | Eps:1.000 | LR:5.05e-06 | FPS:422\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 开始训练 ──\n",
|
||||
"trainer = ParallelTrainer(\n",
|
||||
" agent=agent,\n",
|
||||
" envs=envs,\n",
|
||||
" eval_env=eval_env,\n",
|
||||
" num_envs=N_ENVS,\n",
|
||||
" save_dir=SAVE_DIR,\n",
|
||||
" eval_freq=EVAL_FREQ,\n",
|
||||
" save_freq=SAVE_FREQ,\n",
|
||||
" num_eval_episodes=EVAL_EPISODES,\n",
|
||||
" warmup_steps=WARMUP_STEPS,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\" * 60)\n",
|
||||
"print(f\"开始 10M 步并行训练\")\n",
|
||||
"print(f\" GPU: {device}\")\n",
|
||||
"print(f\" 并行环境: {N_ENVS}\")\n",
|
||||
"print(f\" Dueling: {USE_DUELING}\")\n",
|
||||
"print(f\" Double DQN: {USE_DOUBLE}\")\n",
|
||||
"print(f\" PER: {USE_PER}\")\n",
|
||||
"print(\"=\" * 60 + \"\\n\")\n",
|
||||
"\n",
|
||||
"trainer.train(TOTAL_STEPS)"
|
||||
]
|
||||
"outputs": [],
|
||||
"source": "# ── 回放缓冲区 + Agent ──\nif USE_PER:\n replay_buffer = PrioritizedReplayBuffer(BUFFER_SIZE, state_shape, device)\n print(\"优先经验回放 (Pinned Memory)\")\nelse:\n replay_buffer = ReplayBuffer(BUFFER_SIZE, state_shape, device)\n print(\"标准经验回放 (Pinned Memory)\")\n\nagent = DQNAgent(\n q_network=q_network,\n target_network=target_network,\n replay_buffer=replay_buffer,\n device=device,\n num_actions=num_actions,\n gamma=GAMMA,\n lr=LR,\n epsilon_start=EPSILON_START,\n epsilon_end=EPSILON_END,\n epsilon_decay_steps=EPSILON_DECAY,\n target_update_freq=TARGET_UPDATE,\n batch_size=BATCH_SIZE,\n double_dqn=USE_DOUBLE,\n lr_decay_steps=LR_DECAY_STEPS,\n lr_decay_factor=LR_DECAY_FACTOR,\n warmup_steps=WARMUP_STEPS,\n use_amp=USE_AMP,\n)\nprint(f\"Agent 创建完成 (AMP: {USE_AMP})\")"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": "# ── 开始训练 ──\ntrainer = ParallelTrainer(\n agent=agent,\n envs=envs,\n eval_env=eval_env,\n num_envs=N_ENVS,\n save_dir=SAVE_DIR,\n eval_freq=EVAL_FREQ,\n save_freq=SAVE_FREQ,\n num_eval_episodes=EVAL_EPISODES,\n warmup_steps=WARMUP_STEPS,\n train_steps_per_update=TRAIN_STEPS_PER_UPDATE,\n)\n\nprint(\"\\n\" + \"=\" * 60)\nprint(f\"开始 10M 步并行训练(全优化版)\")\nprint(f\" GPU: {device}\")\nprint(f\" 并行环境: {N_ENVS}\")\nprint(f\" Batch Size: {BATCH_SIZE}\")\nprint(f\" 每步训练: {TRAIN_STEPS_PER_UPDATE} 次\")\nprint(f\" AMP 混合精度: {USE_AMP}\")\nprint(f\" torch.compile: {USE_COMPILE}\")\nprint(f\" Dueling: {USE_DUELING}\")\nprint(f\" Double DQN: {USE_DOUBLE}\")\nprint(f\" PER: {USE_PER}\")\nprint(f\" Pinned Memory: 是\")\nprint(f\" 向量化批量添加: 是\")\nprint(\"=\" * 60 + \"\\n\")\n\ntrainer.train(TOTAL_STEPS)"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@@ -495,63 +136,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"加载最佳模型...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"d:\\Code\\doing_exercises\\programs\\外教作业外快\\强化学习个人项目报告(Atari 游戏方向)\\src\\agent.py:219: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
||||
" checkpoint = torch.load(path, map_location=self.device)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "FileNotFoundError",
|
||||
"evalue": "[Errno 2] No such file or directory: 'models/dqn_best.pt'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[1;32mIn[9], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# ── 评估最佳模型 ──\u001b[39;00m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m加载最佳模型...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 3\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mSAVE_DIR\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m/dqn_best.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m评估中...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 6\u001b[0m all_rewards \u001b[38;5;241m=\u001b[39m []\n",
|
||||
"File \u001b[1;32md:\\Code\\doing_exercises\\programs\\外教作业外快\\强化学习个人项目报告(Atari 游戏方向)\\src\\agent.py:219\u001b[0m, in \u001b[0;36mDQNAgent.load\u001b[1;34m(self, path)\u001b[0m\n\u001b[0;32m 217\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mload\u001b[39m(\u001b[38;5;28mself\u001b[39m, path):\n\u001b[0;32m 218\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"加载模型\"\"\"\u001b[39;00m\n\u001b[1;32m--> 219\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 220\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mq_network\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mq_network\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[0;32m 221\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_network\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtarget_network\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
|
||||
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:1065\u001b[0m, in \u001b[0;36mload\u001b[1;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[0;32m 1062\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m pickle_load_args\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m 1063\u001b[0m pickle_load_args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m-> 1065\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_file_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_file:\n\u001b[0;32m 1066\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[0;32m 1067\u001b[0m \u001b[38;5;66;03m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[0;32m 1068\u001b[0m \u001b[38;5;66;03m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[0;32m 1069\u001b[0m \u001b[38;5;66;03m# reset back to the original position.\u001b[39;00m\n\u001b[0;32m 1070\u001b[0m orig_position \u001b[38;5;241m=\u001b[39m opened_file\u001b[38;5;241m.\u001b[39mtell()\n",
|
||||
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:468\u001b[0m, in \u001b[0;36m_open_file_like\u001b[1;34m(name_or_buffer, mode)\u001b[0m\n\u001b[0;32m 466\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[0;32m 467\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[1;32m--> 468\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_open_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 469\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 470\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m mode:\n",
|
||||
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:449\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[1;34m(self, name, mode)\u001b[0m\n\u001b[0;32m 448\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, mode):\n\u001b[1;32m--> 449\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m)\n",
|
||||
"\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'models/dqn_best.pt'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 评估最佳模型 ──\n",
|
||||
"print(\"加载最佳模型...\")\n",
|
||||
"agent.load(f\"{SAVE_DIR}/dqn_best.pt\")\n",
|
||||
"\n",
|
||||
"print(\"\\n评估中...\")\n",
|
||||
"all_rewards = []\n",
|
||||
"for i in range(20):\n",
|
||||
" state, _ = eval_env.reset()\n",
|
||||
" ep_r = 0\n",
|
||||
" done = False\n",
|
||||
" while not done:\n",
|
||||
" action = agent.select_action(state, evaluate=True)\n",
|
||||
" state, reward, terminated, truncated, _ = eval_env.step(action)\n",
|
||||
" done = terminated or truncated\n",
|
||||
" ep_r += reward\n",
|
||||
" all_rewards.append(ep_r)\n",
|
||||
" print(f\" Episode {i+1:>2}: {ep_r:.1f}\")\n",
|
||||
"\n",
|
||||
"print(f\"\\n结果: 平均 {np.mean(all_rewards):.2f} ± {np.std(all_rewards):.2f}\")\n",
|
||||
"print(f\"最佳: {max(all_rewards):.1f} | 最差: {min(all_rewards):.1f}\")\n",
|
||||
"print(f\"中位数: {np.median(all_rewards):.1f}\")"
|
||||
]
|
||||
"outputs": [],
|
||||
"source": "# ── 评估最佳模型 ──\nprint(\"加载最佳模型...\")\nagent.load(f\"{SAVE_DIR}/dqn_best.pt\")\n\nprint(\"\\n评估中...\")\nall_rewards = []\nfor i in range(20):\n state, _ = eval_env.reset()\n ep_r = 0\n done = False\n while not done:\n action = agent.select_action(state, evaluate=True)\n state, reward, terminated, truncated, _ = eval_env.step(action)\n done = terminated or truncated\n ep_r += reward\n all_rewards.append(ep_r)\n print(f\" Episode {i+1:>2}: {ep_r:.1f}\")\n\nprint(f\"\\n结果: 平均 {np.mean(all_rewards):.2f} ± {np.std(all_rewards):.2f}\")\nprint(f\"最佳: {max(all_rewards):.1f} | 最差: {min(all_rewards):.1f}\")\nprint(f\"中位数: {np.median(all_rewards):.1f}\")"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -575,4 +163,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ class DQNAgent:
|
||||
"""DQN智能体
|
||||
|
||||
实现ε-greedy探索策略和Q-learning更新
|
||||
支持 AMP 混合精度训练
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -30,26 +31,8 @@ class DQNAgent:
|
||||
lr_decay_steps=1_000_000,
|
||||
lr_decay_factor=0.5,
|
||||
warmup_steps=10_000,
|
||||
use_amp=True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
q_network: Q网络
|
||||
target_network: 目标网络
|
||||
replay_buffer: 经验回放缓冲区
|
||||
device: 设备
|
||||
num_actions: 动作数量
|
||||
gamma: 折扣因子
|
||||
lr: 学习率
|
||||
epsilon_start: ε初始值
|
||||
epsilon_end: ε最终值
|
||||
epsilon_decay_steps: ε衰减步数
|
||||
target_update_freq: 目标网络更新频率
|
||||
batch_size: 批次大小
|
||||
double_dqn: 是否使用Double DQN
|
||||
lr_decay_steps: 学习率衰减步数
|
||||
lr_decay_factor: 学习率衰减因子
|
||||
warmup_steps: 预热步数
|
||||
"""
|
||||
self.q_network = q_network
|
||||
self.target_network = target_network
|
||||
self.replay_buffer = replay_buffer
|
||||
@@ -72,40 +55,52 @@ class DQNAgent:
|
||||
|
||||
self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr)
|
||||
|
||||
self.step_count = 0
|
||||
# AMP 混合精度
|
||||
self.use_amp = use_amp and device.type == 'cuda'
|
||||
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
|
||||
|
||||
self.step_count = 0
|
||||
self.loss_history = []
|
||||
self.q_value_history = []
|
||||
|
||||
def select_action(self, state, evaluate=False):
|
||||
"""选择动作
|
||||
|
||||
Args:
|
||||
state: 当前状态 (channels, height, width)
|
||||
evaluate: 是否为评估模式(不使用ε-greedy)
|
||||
|
||||
Returns:
|
||||
action: 选择的动作
|
||||
"""
|
||||
if evaluate:
|
||||
# 评估模式:纯贪心
|
||||
epsilon = 0.0
|
||||
else:
|
||||
# 训练模式:ε-greedy
|
||||
epsilon = self.epsilon
|
||||
"""选择动作"""
|
||||
epsilon = 0.0 if evaluate else self.epsilon
|
||||
|
||||
if np.random.random() < epsilon:
|
||||
# 随机探索
|
||||
return np.random.randint(self.num_actions)
|
||||
else:
|
||||
# 贪心选择
|
||||
with torch.no_grad():
|
||||
state_tensor = (
|
||||
torch.from_numpy(state).float().unsqueeze(0).to(self.device)
|
||||
)
|
||||
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
|
||||
q_values = self.q_network(state_tensor)
|
||||
return q_values.argmax(dim=1).item()
|
||||
|
||||
def batch_select_actions(self, states, epsilon):
|
||||
"""批量选择动作(向量化,用于并行训练)
|
||||
|
||||
Args:
|
||||
states: (n, C, H, W) numpy 数组
|
||||
epsilon: 当前 ε 值
|
||||
|
||||
Returns:
|
||||
actions: (n,) numpy 数组
|
||||
"""
|
||||
n = len(states)
|
||||
random_mask = np.random.random(n) < epsilon
|
||||
actions = np.zeros(n, dtype=np.int64)
|
||||
|
||||
non_random = ~random_mask
|
||||
if non_random.any():
|
||||
state_tensor = torch.from_numpy(states[non_random]).float().to(self.device)
|
||||
with torch.no_grad(), torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
q_values = self.q_network(state_tensor)
|
||||
actions[non_random] = q_values.argmax(dim=1).cpu().numpy()
|
||||
|
||||
if random_mask.any():
|
||||
actions[random_mask] = np.random.randint(0, self.num_actions, size=random_mask.sum())
|
||||
|
||||
return actions
|
||||
|
||||
def update_epsilon(self):
|
||||
"""更新ε值(线性衰减)"""
|
||||
if self.step_count < self.epsilon_decay_steps:
|
||||
@@ -129,17 +124,10 @@ class DQNAgent:
|
||||
param_group["lr"] *= self.lr_decay_factor
|
||||
|
||||
def train_step(self):
|
||||
"""执行一步训练
|
||||
|
||||
Returns:
|
||||
loss: 损失值
|
||||
avg_q: 平均Q值
|
||||
"""
|
||||
# 检查是否有足够样本
|
||||
"""执行一步训练(支持 AMP 混合精度)"""
|
||||
if len(self.replay_buffer) < self.batch_size:
|
||||
return None, None
|
||||
|
||||
# 采样(兼容标准和优先经验回放)
|
||||
sample_result = self.replay_buffer.sample(self.batch_size)
|
||||
if len(sample_result) == 7:
|
||||
states, actions, rewards, next_states, dones, indices, weights = sample_result
|
||||
@@ -147,52 +135,46 @@ class DQNAgent:
|
||||
states, actions, rewards, next_states, dones = sample_result
|
||||
indices, weights = None, None
|
||||
|
||||
# 计算当前Q值
|
||||
q_values = self.q_network(states)
|
||||
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
# AMP 前向传播
|
||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||
q_values = self.q_network(states)
|
||||
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# 计算目标Q值
|
||||
with torch.no_grad():
|
||||
if self.double_dqn:
|
||||
next_actions = self.q_network(next_states).argmax(dim=1)
|
||||
next_q_values = self.target_network(next_states)
|
||||
next_q_values = next_q_values.gather(
|
||||
1, next_actions.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
with torch.no_grad():
|
||||
if self.double_dqn:
|
||||
next_actions = self.q_network(next_states).argmax(dim=1)
|
||||
next_q_values = self.target_network(next_states)
|
||||
next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
next_q_values = self.target_network(next_states).max(dim=1)[0]
|
||||
target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
|
||||
|
||||
if weights is not None:
|
||||
loss = (weights * F.mse_loss(q_values, target_q_values, reduction='none')).mean()
|
||||
else:
|
||||
next_q_values = self.target_network(next_states).max(dim=1)[0]
|
||||
loss = F.mse_loss(q_values, target_q_values)
|
||||
|
||||
target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
|
||||
|
||||
# 计算TD误差
|
||||
td_errors = (q_values - target_q_values).detach()
|
||||
|
||||
# 计算损失(优先经验回放使用重要性采样权重)
|
||||
if weights is not None:
|
||||
loss = (weights * F.mse_loss(q_values, target_q_values, reduction='none')).mean()
|
||||
else:
|
||||
loss = F.mse_loss(q_values, target_q_values)
|
||||
|
||||
# 反向传播
|
||||
# AMP 反向传播
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10)
|
||||
self.optimizer.step()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
# 更新优先级
|
||||
if indices is not None and hasattr(self.replay_buffer, 'update_priorities'):
|
||||
self.replay_buffer.update_priorities(indices, td_errors.cpu().numpy())
|
||||
td_errors = (q_values - target_q_values).detach()
|
||||
self.replay_buffer.update_priorities(indices, td_errors.cpu().float().numpy())
|
||||
|
||||
# 更新目标网络
|
||||
self.step_count += 1
|
||||
if self.step_count % self.target_update_freq == 0:
|
||||
self.target_network.load_state_dict(self.q_network.state_dict())
|
||||
|
||||
# 更新ε和学习率
|
||||
self.update_epsilon()
|
||||
self.update_learning_rate()
|
||||
|
||||
# 记录统计
|
||||
avg_q = q_values.mean().item()
|
||||
self.loss_history.append(loss.item())
|
||||
self.q_value_history.append(avg_q)
|
||||
@@ -216,7 +198,7 @@ class DQNAgent:
|
||||
|
||||
def load(self, path):
|
||||
"""加载模型"""
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||||
self.q_network.load_state_dict(checkpoint["q_network"])
|
||||
self.target_network.load_state_dict(checkpoint["target_network"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
|
||||
@@ -7,63 +7,94 @@ class ReplayBuffer:
|
||||
"""经验回放缓冲区
|
||||
|
||||
存储转移 (s, a, r, s', done),随机采样打破数据相关性
|
||||
支持批量添加和 Pinned Memory 加速传输
|
||||
"""
|
||||
|
||||
def __init__(self, capacity, state_shape=(4, 84, 84), device='cpu'):
|
||||
"""
|
||||
Args:
|
||||
capacity: 缓冲区容量
|
||||
state_shape: 状态形状 (channels, height, width)
|
||||
device: 设备 (cpu/cuda)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.ptr = 0
|
||||
self.size = 0
|
||||
|
||||
# 预分配内存
|
||||
# 预分配 numpy 内存
|
||||
self.states = np.zeros((capacity, *state_shape), dtype=np.uint8)
|
||||
self.actions = np.zeros(capacity, dtype=np.int64)
|
||||
self.rewards = np.zeros(capacity, dtype=np.float32)
|
||||
self.next_states = np.zeros((capacity, *state_shape), dtype=np.uint8)
|
||||
self.dones = np.zeros(capacity, dtype=np.bool_)
|
||||
|
||||
def add(self, state, action, reward, next_state, done):
|
||||
"""添加一个转移
|
||||
# 预分配 pinned memory 张量(加速 CPU→GPU 传输)
|
||||
pin_device = 'cpu'
|
||||
self._states_pin = torch.empty((capacity, *state_shape), dtype=torch.uint8, pin_memory=True)
|
||||
self._actions_pin = torch.empty(capacity, dtype=torch.int64, pin_memory=True)
|
||||
self._rewards_pin = torch.empty(capacity, dtype=torch.float32, pin_memory=True)
|
||||
self._next_states_pin = torch.empty((capacity, *state_shape), dtype=torch.uint8, pin_memory=True)
|
||||
self._dones_pin = torch.empty(capacity, dtype=torch.float32, pin_memory=True)
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
action: 执行的动作
|
||||
reward: 获得的奖励
|
||||
next_state: 下一个状态
|
||||
done: 是否结束
|
||||
"""
|
||||
def add(self, state, action, reward, next_state, done):
|
||||
"""添加一个转移"""
|
||||
self.states[self.ptr] = state
|
||||
self.actions[self.ptr] = action
|
||||
self.rewards[self.ptr] = reward
|
||||
self.next_states[self.ptr] = next_state
|
||||
self.dones[self.ptr] = done
|
||||
|
||||
# 循环缓冲区
|
||||
self.ptr = (self.ptr + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""随机采样一个批次
|
||||
def add_batch(self, states, actions, rewards, next_states, dones):
|
||||
"""批量添加转移(向量化,消除 Python 循环)
|
||||
|
||||
Args:
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
states, actions, rewards, next_states, dones
|
||||
states: (n, C, H, W) 数组
|
||||
actions: (n,) 数组
|
||||
rewards: (n,) 数组
|
||||
next_states: (n, C, H, W) 数组
|
||||
dones: (n,) 数组
|
||||
"""
|
||||
n = len(states)
|
||||
end = self.ptr + n
|
||||
|
||||
if end <= self.capacity:
|
||||
self.states[self.ptr:end] = states
|
||||
self.actions[self.ptr:end] = actions
|
||||
self.rewards[self.ptr:end] = rewards
|
||||
self.next_states[self.ptr:end] = next_states
|
||||
self.dones[self.ptr:end] = dones
|
||||
else:
|
||||
# 循环缓冲区回绕
|
||||
overflow = end - self.capacity
|
||||
first = n - overflow
|
||||
self.states[self.ptr:] = states[:first]
|
||||
self.actions[self.ptr:] = actions[:first]
|
||||
self.rewards[self.ptr:] = rewards[:first]
|
||||
self.next_states[self.ptr:] = next_states[:first]
|
||||
self.dones[self.ptr:] = dones[:first]
|
||||
self.states[:overflow] = states[first:]
|
||||
self.actions[:overflow] = actions[first:]
|
||||
self.rewards[:overflow] = rewards[first:]
|
||||
self.next_states[:overflow] = next_states[first:]
|
||||
self.dones[:overflow] = dones[first:]
|
||||
|
||||
self.ptr = end % self.capacity
|
||||
self.size = min(self.size + n, self.capacity)
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""随机采样一个批次,使用 pinned memory 加速传输"""
|
||||
indices = np.random.randint(0, self.size, size=batch_size)
|
||||
|
||||
states = torch.from_numpy(self.states[indices]).float().to(self.device)
|
||||
actions = torch.from_numpy(self.actions[indices]).long().to(self.device)
|
||||
rewards = torch.from_numpy(self.rewards[indices]).float().to(self.device)
|
||||
next_states = torch.from_numpy(self.next_states[indices]).float().to(self.device)
|
||||
dones = torch.from_numpy(self.dones[indices]).float().to(self.device)
|
||||
# 先写入 pinned memory,再 non-blocking 传输到 GPU
|
||||
self._states_pin[:batch_size].copy_(torch.from_numpy(self.states[indices]))
|
||||
self._actions_pin[:batch_size].copy_(torch.from_numpy(self.actions[indices]))
|
||||
self._rewards_pin[:batch_size].copy_(torch.from_numpy(self.rewards[indices]))
|
||||
self._next_states_pin[:batch_size].copy_(torch.from_numpy(self.next_states[indices]))
|
||||
self._dones_pin[:batch_size].copy_(torch.from_numpy(self.dones[indices].astype(np.float32)))
|
||||
|
||||
states = self._states_pin[:batch_size].float().to(self.device, non_blocking=True)
|
||||
actions = self._actions_pin[:batch_size].to(self.device, non_blocking=True)
|
||||
rewards = self._rewards_pin[:batch_size].to(self.device, non_blocking=True)
|
||||
next_states = self._next_states_pin[:batch_size].float().to(self.device, non_blocking=True)
|
||||
dones = self._dones_pin[:batch_size].to(self.device, non_blocking=True)
|
||||
|
||||
return states, actions, rewards, next_states, dones
|
||||
|
||||
@@ -75,16 +106,10 @@ class PrioritizedReplayBuffer:
|
||||
"""优先经验回放缓冲区
|
||||
|
||||
根据TD误差优先采样,提高样本效率
|
||||
支持批量添加和 Pinned Memory 加速传输
|
||||
"""
|
||||
|
||||
def __init__(self, capacity, state_shape=(4, 84, 84), device='cpu', alpha=0.6):
|
||||
"""
|
||||
Args:
|
||||
capacity: 缓冲区容量
|
||||
state_shape: 状态形状
|
||||
device: 设备
|
||||
alpha: 优先级指数 (0=均匀采样, 1=完全按优先级采样)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.alpha = alpha
|
||||
@@ -102,6 +127,14 @@ class PrioritizedReplayBuffer:
|
||||
# 优先级存储
|
||||
self.priorities = np.zeros(capacity, dtype=np.float32)
|
||||
|
||||
# 预分配 pinned memory
|
||||
self._states_pin = torch.empty((capacity, *state_shape), dtype=torch.uint8, pin_memory=True)
|
||||
self._actions_pin = torch.empty(capacity, dtype=torch.int64, pin_memory=True)
|
||||
self._rewards_pin = torch.empty(capacity, dtype=torch.float32, pin_memory=True)
|
||||
self._next_states_pin = torch.empty((capacity, *state_shape), dtype=torch.uint8, pin_memory=True)
|
||||
self._dones_pin = torch.empty(capacity, dtype=torch.float32, pin_memory=True)
|
||||
self._weights_pin = torch.empty(capacity, dtype=torch.float32, pin_memory=True)
|
||||
|
||||
def add(self, state, action, reward, next_state, done):
|
||||
"""添加转移,使用最大优先级"""
|
||||
self.states[self.ptr] = state
|
||||
@@ -109,51 +142,71 @@ class PrioritizedReplayBuffer:
|
||||
self.rewards[self.ptr] = reward
|
||||
self.next_states[self.ptr] = next_state
|
||||
self.dones[self.ptr] = done
|
||||
|
||||
# 新样本使用最大优先级
|
||||
self.priorities[self.ptr] = self.max_priority
|
||||
|
||||
self.ptr = (self.ptr + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
def add_batch(self, states, actions, rewards, next_states, dones):
|
||||
"""批量添加转移(向量化)"""
|
||||
n = len(states)
|
||||
end = self.ptr + n
|
||||
|
||||
if end <= self.capacity:
|
||||
self.states[self.ptr:end] = states
|
||||
self.actions[self.ptr:end] = actions
|
||||
self.rewards[self.ptr:end] = rewards
|
||||
self.next_states[self.ptr:end] = next_states
|
||||
self.dones[self.ptr:end] = dones
|
||||
self.priorities[self.ptr:end] = self.max_priority
|
||||
else:
|
||||
overflow = end - self.capacity
|
||||
first = n - overflow
|
||||
self.states[self.ptr:] = states[:first]
|
||||
self.actions[self.ptr:] = actions[:first]
|
||||
self.rewards[self.ptr:] = rewards[:first]
|
||||
self.next_states[self.ptr:] = next_states[:first]
|
||||
self.dones[self.ptr:] = dones[:first]
|
||||
self.priorities[self.ptr:] = self.max_priority
|
||||
self.states[:overflow] = states[first:]
|
||||
self.actions[:overflow] = actions[first:]
|
||||
self.rewards[:overflow] = rewards[first:]
|
||||
self.next_states[:overflow] = next_states[first:]
|
||||
self.dones[:overflow] = dones[first:]
|
||||
self.priorities[:overflow] = self.max_priority
|
||||
|
||||
self.ptr = end % self.capacity
|
||||
self.size = min(self.size + n, self.capacity)
|
||||
|
||||
def sample(self, batch_size, beta=0.4):
|
||||
"""按优先级采样
|
||||
|
||||
Args:
|
||||
batch_size: 批次大小
|
||||
beta: 重要性采样指数
|
||||
|
||||
Returns:
|
||||
states, actions, rewards, next_states, dones, indices, weights
|
||||
"""
|
||||
# 计算采样概率
|
||||
"""按优先级采样,使用 pinned memory 加速传输"""
|
||||
priorities = self.priorities[:self.size] ** self.alpha
|
||||
probs = priorities / priorities.sum()
|
||||
|
||||
# 按概率采样
|
||||
indices = np.random.choice(self.size, size=batch_size, p=probs)
|
||||
|
||||
# 计算重要性采样权重
|
||||
weights = (self.size * probs[indices]) ** (-beta)
|
||||
weights = weights / weights.max()
|
||||
|
||||
# 获取数据
|
||||
states = torch.from_numpy(self.states[indices]).float().to(self.device)
|
||||
actions = torch.from_numpy(self.actions[indices]).long().to(self.device)
|
||||
rewards = torch.from_numpy(self.rewards[indices]).float().to(self.device)
|
||||
next_states = torch.from_numpy(self.next_states[indices]).float().to(self.device)
|
||||
dones = torch.from_numpy(self.dones[indices]).float().to(self.device)
|
||||
weights = torch.from_numpy(weights).float().to(self.device)
|
||||
# pinned memory 传输
|
||||
self._states_pin[:batch_size].copy_(torch.from_numpy(self.states[indices]))
|
||||
self._actions_pin[:batch_size].copy_(torch.from_numpy(self.actions[indices]))
|
||||
self._rewards_pin[:batch_size].copy_(torch.from_numpy(self.rewards[indices]))
|
||||
self._next_states_pin[:batch_size].copy_(torch.from_numpy(self.next_states[indices]))
|
||||
self._dones_pin[:batch_size].copy_(torch.from_numpy(self.dones[indices].astype(np.float32)))
|
||||
self._weights_pin[:batch_size].copy_(torch.from_numpy(weights))
|
||||
|
||||
return states, actions, rewards, next_states, dones, indices, weights
|
||||
states = self._states_pin[:batch_size].float().to(self.device, non_blocking=True)
|
||||
actions = self._actions_pin[:batch_size].to(self.device, non_blocking=True)
|
||||
rewards = self._rewards_pin[:batch_size].to(self.device, non_blocking=True)
|
||||
next_states = self._next_states_pin[:batch_size].float().to(self.device, non_blocking=True)
|
||||
dones = self._dones_pin[:batch_size].to(self.device, non_blocking=True)
|
||||
ws = self._weights_pin[:batch_size].to(self.device, non_blocking=True)
|
||||
|
||||
return states, actions, rewards, next_states, dones, indices, ws
|
||||
|
||||
def update_priorities(self, indices, td_errors):
|
||||
"""更新优先级
|
||||
|
||||
Args:
|
||||
indices: 样本索引
|
||||
td_errors: TD误差
|
||||
"""
|
||||
"""更新优先级"""
|
||||
priorities = np.abs(td_errors) + 1e-6
|
||||
self.priorities[indices] = priorities
|
||||
self.max_priority = max(self.max_priority, priorities.max())
|
||||
|
||||
@@ -1,31 +1,36 @@
|
||||
"""并行环境 DQN 训练脚本 - 使用 AsyncVectorEnv 加速数据收集.
|
||||
"""Dueling Double DQN - Space Invaders 并行训练脚本
|
||||
|
||||
每个训练迭代并行采集 N 个环境的转移,批量 GPU 推理,显著提升 FPS。
|
||||
适合在 AutoDL 等多核服务器+GPU 环境下运行。
|
||||
使用 AsyncVectorEnv 并行运行多个 Atari 环境,GPU 批量推理加速。
|
||||
适合在 AutoDL 等多核服务器环境运行。
|
||||
|
||||
与 notebooks/train_parallel.ipynb 内容一致,但使用 Python 脚本直接运行,
|
||||
确保 stdout 实时输出(无缓冲)。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from collections import deque
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
from src.network import QNetwork, DuelingQNetwork
|
||||
from src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
from src.agent import DQNAgent
|
||||
from src.utils import make_env, get_device
|
||||
|
||||
# 强制无缓冲输出
|
||||
sys.stdout.reconfigure(line_buffering=True) if hasattr(sys.stdout, 'reconfigure') else None
|
||||
os.environ['PYTHONUNBUFFERED'] = '1'
|
||||
|
||||
# ── 环境工厂函数(供 AsyncVectorEnv 子进程使用)──
|
||||
print("导入完成", flush=True)
|
||||
|
||||
|
||||
# ── 环境工厂 ──
|
||||
def _make_env_fn(env_id):
|
||||
"""环境工厂 - 必须在模块级别以便 multiprocessing pickle."""
|
||||
# AsyncVectorEnv 子进程需要独立注册 ALE
|
||||
try:
|
||||
import ale_py
|
||||
import gymnasium as gym
|
||||
@@ -37,28 +42,16 @@ def _make_env_fn(env_id):
|
||||
return make_env(env_id, gray_scale=True, resize=True, frame_stack=4)
|
||||
return _make
|
||||
|
||||
print("环境工厂就绪", flush=True)
|
||||
|
||||
|
||||
# ── 并行训练器 ──
|
||||
|
||||
class ParallelTrainer:
|
||||
"""并行环境 DQN 训练器.
|
||||
|
||||
使用 AsyncVectorEnv 并行运行 N 个环境,
|
||||
同时收集转移 + 批量推理,大幅提升训练速度。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent,
|
||||
envs,
|
||||
eval_env,
|
||||
num_envs,
|
||||
save_dir="models",
|
||||
eval_freq=10000,
|
||||
save_freq=50000,
|
||||
num_eval_episodes=10,
|
||||
warmup_steps=10000,
|
||||
n_steps_per_env=1,
|
||||
self, agent, envs, eval_env, num_envs,
|
||||
save_dir="models", eval_freq=10000, save_freq=50000,
|
||||
num_eval_episodes=10, warmup_steps=10000,
|
||||
train_steps_per_update=1,
|
||||
):
|
||||
self.agent = agent
|
||||
self.envs = envs
|
||||
@@ -69,256 +62,173 @@ class ParallelTrainer:
|
||||
self.save_freq = save_freq
|
||||
self.num_eval_episodes = num_eval_episodes
|
||||
self.warmup_steps = warmup_steps
|
||||
self.n_steps_per_env = n_steps_per_env
|
||||
|
||||
self.train_steps_per_update = train_steps_per_update
|
||||
self.episode_rewards = deque(maxlen=100)
|
||||
self.eval_rewards = []
|
||||
self.best_eval_reward = -float("inf")
|
||||
|
||||
def train(self, total_steps):
|
||||
"""主并行训练循环.
|
||||
|
||||
Args:
|
||||
total_steps: 总环境交互步数
|
||||
"""
|
||||
num_envs = self.num_envs
|
||||
device = self.agent.device
|
||||
envs = self.envs
|
||||
|
||||
print(f"开始并行训练,总步数: {total_steps:,}")
|
||||
print(f"并行环境数: {num_envs}")
|
||||
print(f"预热步数: {self.warmup_steps:,}")
|
||||
print("=" * 60)
|
||||
|
||||
# 重置所有环境
|
||||
states, _ = envs.reset()
|
||||
episode_rewards = np.zeros(num_envs, dtype=np.float32)
|
||||
episode_lengths = np.zeros(num_envs, dtype=np.int32)
|
||||
episode_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
step = 0
|
||||
while step < total_steps:
|
||||
# ── 动作选择 ──
|
||||
if step < self.warmup_steps:
|
||||
actions = np.array([envs.single_action_space.sample() for _ in range(num_envs)])
|
||||
else:
|
||||
actions = self._batch_select_actions(states)
|
||||
|
||||
# ── 环境步进(N 个环境并行)──
|
||||
next_states, rewards, terminateds, truncateds, _ = envs.step(actions)
|
||||
dones = np.logical_or(terminateds, truncateds)
|
||||
|
||||
# ── 存储转移 ──
|
||||
for i in range(num_envs):
|
||||
self.agent.replay_buffer.add(
|
||||
states[i], actions[i], rewards[i], next_states[i], dones[i]
|
||||
)
|
||||
|
||||
# ── 统计 ──
|
||||
episode_rewards += rewards
|
||||
episode_lengths += 1
|
||||
|
||||
# 处理结束的 episode
|
||||
for i in range(num_envs):
|
||||
if dones[i]:
|
||||
self.episode_rewards.append(episode_rewards[i])
|
||||
episode_count += 1
|
||||
episode_rewards[i] = 0
|
||||
episode_lengths[i] = 0
|
||||
|
||||
step += num_envs
|
||||
states = next_states
|
||||
|
||||
# ── 训练(环境每步一个 mini-batch)──
|
||||
if step >= self.warmup_steps:
|
||||
self.agent.train_step()
|
||||
|
||||
# ── 进度打印 ──
|
||||
if episode_count > 0 and episode_count % 10 == 0:
|
||||
avg_reward = np.mean(self.episode_rewards) if self.episode_rewards else 0
|
||||
elapsed = time.time() - start_time
|
||||
fps = step / elapsed
|
||||
|
||||
current_lr = self.agent.optimizer.param_groups[0]["lr"]
|
||||
print(
|
||||
f"Step: {step:>10,} | "
|
||||
f"Ep: {episode_count:>5} | "
|
||||
f"AvgReward: {avg_reward:>7.1f} | "
|
||||
f"Epsilon: {self.agent.epsilon:.3f} | "
|
||||
f"LR: {current_lr:.2e} | "
|
||||
f"FPS: {fps:.0f}"
|
||||
)
|
||||
|
||||
# ── 定期评估 ──
|
||||
if step % self.eval_freq == 0 and step > 0:
|
||||
eval_reward = self.evaluate()
|
||||
self.eval_rewards.append((step, eval_reward))
|
||||
print(f"\n[Eval] Step: {step:>10,} | AvgReward: {eval_reward:.1f}\n" + "-" * 60)
|
||||
|
||||
if eval_reward > self.best_eval_reward:
|
||||
self.best_eval_reward = eval_reward
|
||||
self.agent.save(f"{self.save_dir}/dqn_best.pt")
|
||||
|
||||
# ── 定期保存 ──
|
||||
if step % self.save_freq == 0:
|
||||
self.agent.save(f"{self.save_dir}/dqn_step_{step}.pt")
|
||||
|
||||
# 训练结束
|
||||
total_time = time.time() - start_time
|
||||
print("\n" + "=" * 60)
|
||||
print(f"训练完成!总时间: {total_time:.1f} 秒")
|
||||
print(f"平均 FPS: {total_steps / total_time:.0f}")
|
||||
print(f"最佳评估回报: {self.best_eval_reward:.1f}")
|
||||
|
||||
self.agent.save(f"{self.save_dir}/dqn_final.pt")
|
||||
|
||||
def _batch_select_actions(self, states):
|
||||
"""批量选择动作(使用 GPU 批量推理)."""
|
||||
epsilon = self.agent.epsilon
|
||||
num_envs = len(states)
|
||||
|
||||
# 随机探索
|
||||
random_mask = np.random.random(num_envs) < epsilon
|
||||
|
||||
actions = np.zeros(num_envs, dtype=np.int64)
|
||||
|
||||
# 对非随机的环境做批量推理
|
||||
non_random = ~random_mask
|
||||
if non_random.any():
|
||||
state_tensor = (
|
||||
torch.from_numpy(states[non_random]).float().to(self.agent.device)
|
||||
)
|
||||
with torch.no_grad():
|
||||
q_values = self.agent.q_network(state_tensor)
|
||||
actions[non_random] = q_values.argmax(dim=1).cpu().numpy()
|
||||
|
||||
# 随机的环境
|
||||
if random_mask.any():
|
||||
actions[random_mask] = np.random.randint(
|
||||
0, self.agent.num_actions, size=random_mask.sum()
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def evaluate(self):
|
||||
"""评估智能体."""
|
||||
rewards = []
|
||||
for _ in range(self.num_eval_episodes):
|
||||
state, _ = self.eval_env.reset()
|
||||
episode_reward = 0
|
||||
ep_reward = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
action = self.agent.select_action(state, evaluate=True)
|
||||
state, reward, terminated, truncated, _ = self.eval_env.step(action)
|
||||
done = terminated or truncated
|
||||
episode_reward += reward
|
||||
|
||||
rewards.append(episode_reward)
|
||||
|
||||
ep_reward += reward
|
||||
rewards.append(ep_reward)
|
||||
return np.mean(rewards)
|
||||
|
||||
def train(self, total_steps):
|
||||
n = self.num_envs
|
||||
device = self.agent.device
|
||||
envs = self.envs
|
||||
|
||||
print(f"开始训练: {total_steps:,} 步, {n} 并行环境, 每步训练 {self.train_steps_per_update} 次", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
states, _ = envs.reset()
|
||||
ep_rewards = np.zeros(n, dtype=np.float32)
|
||||
ep_count = 0
|
||||
start_time = time.time()
|
||||
step = 0
|
||||
|
||||
while step < total_steps:
|
||||
if step < self.warmup_steps:
|
||||
actions = np.array([envs.single_action_space.sample() for _ in range(n)])
|
||||
else:
|
||||
actions = self.agent.batch_select_actions(states, self.agent.epsilon)
|
||||
|
||||
next_states, rewards, terminateds, truncateds, _ = envs.step(actions)
|
||||
dones = np.logical_or(terminateds, truncateds)
|
||||
|
||||
# 向量化批量添加经验
|
||||
self.agent.replay_buffer.add_batch(states, actions, rewards, next_states, dones)
|
||||
|
||||
ep_rewards += rewards
|
||||
|
||||
for i in range(n):
|
||||
if dones[i]:
|
||||
self.episode_rewards.append(ep_rewards[i])
|
||||
ep_count += 1
|
||||
ep_rewards[i] = 0
|
||||
|
||||
step += n
|
||||
states = next_states
|
||||
|
||||
if step >= self.warmup_steps:
|
||||
for _ in range(self.train_steps_per_update):
|
||||
self.agent.train_step()
|
||||
|
||||
if ep_count > 0 and ep_count % 20 == 0:
|
||||
avg_r = np.mean(self.episode_rewards) if self.episode_rewards else 0
|
||||
elapsed = time.time() - start_time
|
||||
fps = step / elapsed
|
||||
lr = self.agent.optimizer.param_groups[0]["lr"]
|
||||
print(f"Step:{step:>10,} | Ep:{ep_count:>5} | AvgR:{avg_r:>7.1f} | "
|
||||
f"Eps:{self.agent.epsilon:.3f} | LR:{lr:.2e} | FPS:{fps:.0f}", flush=True)
|
||||
|
||||
if step % self.eval_freq == 0 and step > 0:
|
||||
eval_r = self.evaluate()
|
||||
self.eval_rewards.append((step, eval_r))
|
||||
print(f"\n[评估] Step:{step:>10,} | 平均回报:{eval_r:.1f}\n", flush=True)
|
||||
if eval_r > self.best_eval_reward:
|
||||
self.best_eval_reward = eval_r
|
||||
self.agent.save(f"{self.save_dir}/dqn_best.pt")
|
||||
print(f"最佳模型已更新 (回报: {eval_r:.1f})", flush=True)
|
||||
|
||||
if step % self.save_freq == 0 and step > 0:
|
||||
self.agent.save(f"{self.save_dir}/dqn_step_{step}.pt")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print("\n" + "=" * 60, flush=True)
|
||||
print(f"训练完成!总时间: {total_time:.1f} 秒 | FPS: {total_steps/total_time:.0f}", flush=True)
|
||||
print(f"最佳评估回报: {self.best_eval_reward:.1f}", flush=True)
|
||||
self.agent.save(f"{self.save_dir}/dqn_final.pt")
|
||||
|
||||
print("训练器就绪", flush=True)
|
||||
|
||||
# ── 主入口 ──
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Parallel DQN for Space Invaders")
|
||||
# ── 超参数 ──
|
||||
ENV_ID = "ALE/SpaceInvaders-v5"
|
||||
N_ENVS = 64
|
||||
TOTAL_STEPS = 2_000_000
|
||||
LR = 1e-4
|
||||
GAMMA = 0.99
|
||||
BATCH_SIZE = 2048
|
||||
BUFFER_SIZE = 1_000_000
|
||||
EPSILON_START = 1.0
|
||||
EPSILON_END = 0.01
|
||||
EPSILON_DECAY = 4_000_000
|
||||
TARGET_UPDATE = 5000
|
||||
LR_DECAY_STEPS = 5_000_000
|
||||
LR_DECAY_FACTOR = 0.5
|
||||
WARMUP_STEPS = 50_000
|
||||
EVAL_FREQ = 200000
|
||||
EVAL_EPISODES = 10
|
||||
SAVE_FREQ = 500000
|
||||
SEED = 42
|
||||
SAVE_DIR = os.path.join(PROJECT_ROOT, "models")
|
||||
|
||||
# 并行参数
|
||||
parser.add_argument("--n-envs", type=int, default=8, help="并行环境数")
|
||||
TRAIN_STEPS_PER_UPDATE = 4
|
||||
USE_AMP = True
|
||||
USE_COMPILE = True
|
||||
USE_DUELING = True
|
||||
USE_DOUBLE = True
|
||||
USE_PER = True
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument("--env", type=str, default="ALE/SpaceInvaders-v5")
|
||||
parser.add_argument("--steps", type=int, default=10_000_000, help="总训练步数")
|
||||
parser.add_argument("--lr", type=float, default=5e-5, help="学习率")
|
||||
parser.add_argument("--gamma", type=float, default=0.99, help="折扣因子")
|
||||
parser.add_argument("--batch-size", type=int, default=64, help="批次大小")
|
||||
parser.add_argument("--buffer-size", type=int, default=500_000, help="回放缓冲区大小")
|
||||
os.makedirs(SAVE_DIR, exist_ok=True)
|
||||
|
||||
# ε-greedy
|
||||
parser.add_argument("--epsilon-start", type=float, default=1.0)
|
||||
parser.add_argument("--epsilon-end", type=float, default=0.01)
|
||||
parser.add_argument("--epsilon-decay", type=int, default=2_000_000)
|
||||
print(f"配置: {TOTAL_STEPS/1e6:.0f}M 步, {N_ENVS} 并行环境", flush=True)
|
||||
print(f"每步训练 {TRAIN_STEPS_PER_UPDATE} 次, Batch {BATCH_SIZE}", flush=True)
|
||||
print(f"AMP: {USE_AMP}, torch.compile: {USE_COMPILE}", flush=True)
|
||||
print(f"模型保存: {SAVE_DIR}", flush=True)
|
||||
|
||||
# 网络
|
||||
parser.add_argument("--target-update", type=int, default=1000)
|
||||
parser.add_argument("--double-dqn", action="store_true", default=True)
|
||||
parser.add_argument("--dueling", action="store_true", default=True)
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
import platform
|
||||
|
||||
# 学习率
|
||||
parser.add_argument("--lr-decay-steps", type=int, default=5_000_000)
|
||||
parser.add_argument("--lr-decay-factor", type=float, default=0.5)
|
||||
parser.add_argument("--warmup-steps", type=int, default=10_000)
|
||||
|
||||
# 评估
|
||||
parser.add_argument("--eval-freq", type=int, default=50000)
|
||||
parser.add_argument("--eval-episodes", type=int, default=10)
|
||||
parser.add_argument("--save-freq", type=int, default=100000)
|
||||
|
||||
# 优先回放
|
||||
parser.add_argument("--prioritized", action="store_true", default=True)
|
||||
|
||||
# 其他
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--save-dir", type=str, default="models")
|
||||
parser.add_argument("--log-dir", type=str, default="logs")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# 设备
|
||||
device = get_device()
|
||||
print(f"使用设备: {device}", flush=True)
|
||||
|
||||
# 创建并行训练环境
|
||||
print(f"创建 {args.n_envs} 个并行训练环境...")
|
||||
try:
|
||||
from gymnasium.vector import AsyncVectorEnv
|
||||
env_fns = [_make_env_fn(args.env) for _ in range(args.n_envs)]
|
||||
envs = AsyncVectorEnv(env_fns, shared_memory=True)
|
||||
except ImportError:
|
||||
print("AsyncVectorEnv 不可用,回退到 SyncVectorEnv")
|
||||
from gymnasium.vector import SyncVectorEnv
|
||||
env_fns = [_make_env_fn(args.env) for _ in range(args.n_envs)]
|
||||
envs = SyncVectorEnv(env_fns)
|
||||
|
||||
# 创建评估环境(单环境)
|
||||
eval_env = make_env(args.env, gray_scale=True, resize=True, frame_stack=4)
|
||||
from gymnasium.vector import SyncVectorEnv
|
||||
env_fns = [_make_env_fn(ENV_ID) for _ in range(N_ENVS)]
|
||||
envs = SyncVectorEnv(env_fns)
|
||||
print(f"SyncVectorEnv: {envs.num_envs} 个环境", flush=True)
|
||||
|
||||
eval_env = make_env(ENV_ID, gray_scale=True, resize=True, frame_stack=4)
|
||||
num_actions = envs.single_action_space.n
|
||||
print(f"动作空间: {num_actions}")
|
||||
print(f"实际环境数: {envs.num_envs}")
|
||||
print(f"动作空间: {num_actions}", flush=True)
|
||||
|
||||
state_shape = (4, 84, 84)
|
||||
|
||||
# 创建网络
|
||||
if args.dueling:
|
||||
print("使用 Dueling Double DQN")
|
||||
if USE_DUELING:
|
||||
q_network = DuelingQNetwork(state_shape, num_actions).to(device)
|
||||
target_network = DuelingQNetwork(state_shape, num_actions).to(device)
|
||||
print(f"Dueling DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数", flush=True)
|
||||
else:
|
||||
print("使用标准 DQN")
|
||||
q_network = QNetwork(state_shape, num_actions).to(device)
|
||||
target_network = QNetwork(state_shape, num_actions).to(device)
|
||||
print(f"标准 DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数", flush=True)
|
||||
|
||||
if USE_COMPILE and hasattr(torch, 'compile'):
|
||||
print("应用 torch.compile 加速...", flush=True)
|
||||
q_network = torch.compile(q_network)
|
||||
target_network = torch.compile(target_network)
|
||||
print("torch.compile 完成", flush=True)
|
||||
|
||||
target_network.load_state_dict(q_network.state_dict())
|
||||
target_network.eval()
|
||||
|
||||
print(f"网络参数量: {sum(p.numel() for p in q_network.parameters()):,}")
|
||||
|
||||
# 回放缓冲区
|
||||
if args.prioritized:
|
||||
print("使用优先经验回放")
|
||||
replay_buffer = PrioritizedReplayBuffer(args.buffer_size, state_shape, device)
|
||||
if USE_PER:
|
||||
replay_buffer = PrioritizedReplayBuffer(BUFFER_SIZE, state_shape, device)
|
||||
print("优先经验回放 (Pinned Memory)", flush=True)
|
||||
else:
|
||||
print("使用标准经验回放")
|
||||
replay_buffer = ReplayBuffer(args.buffer_size, state_shape, device)
|
||||
|
||||
# 创建 Agent
|
||||
from src.agent import DQNAgent
|
||||
replay_buffer = ReplayBuffer(BUFFER_SIZE, state_shape, device)
|
||||
print("标准经验回放 (Pinned Memory)", flush=True)
|
||||
|
||||
agent = DQNAgent(
|
||||
q_network=q_network,
|
||||
@@ -326,45 +236,70 @@ def main():
|
||||
replay_buffer=replay_buffer,
|
||||
device=device,
|
||||
num_actions=num_actions,
|
||||
gamma=args.gamma,
|
||||
lr=args.lr,
|
||||
epsilon_start=args.epsilon_start,
|
||||
epsilon_end=args.epsilon_end,
|
||||
epsilon_decay_steps=args.epsilon_decay,
|
||||
target_update_freq=args.target_update,
|
||||
batch_size=args.batch_size,
|
||||
double_dqn=args.double_dqn,
|
||||
lr_decay_steps=args.lr_decay_steps,
|
||||
lr_decay_factor=args.lr_decay_factor,
|
||||
warmup_steps=args.warmup_steps,
|
||||
gamma=GAMMA,
|
||||
lr=LR,
|
||||
epsilon_start=EPSILON_START,
|
||||
epsilon_end=EPSILON_END,
|
||||
epsilon_decay_steps=EPSILON_DECAY,
|
||||
target_update_freq=TARGET_UPDATE,
|
||||
batch_size=BATCH_SIZE,
|
||||
double_dqn=USE_DOUBLE,
|
||||
lr_decay_steps=LR_DECAY_STEPS,
|
||||
lr_decay_factor=LR_DECAY_FACTOR,
|
||||
warmup_steps=WARMUP_STEPS,
|
||||
use_amp=USE_AMP,
|
||||
)
|
||||
print(f"Agent 创建完成 (AMP: {USE_AMP})", flush=True)
|
||||
|
||||
# 创建训练器
|
||||
trainer = ParallelTrainer(
|
||||
agent=agent,
|
||||
envs=envs,
|
||||
eval_env=eval_env,
|
||||
num_envs=args.n_envs,
|
||||
save_dir=args.save_dir,
|
||||
eval_freq=args.eval_freq,
|
||||
save_freq=args.save_freq,
|
||||
num_eval_episodes=args.eval_episodes,
|
||||
warmup_steps=args.warmup_steps,
|
||||
num_envs=N_ENVS,
|
||||
save_dir=SAVE_DIR,
|
||||
eval_freq=EVAL_FREQ,
|
||||
save_freq=SAVE_FREQ,
|
||||
num_eval_episodes=EVAL_EPISODES,
|
||||
warmup_steps=WARMUP_STEPS,
|
||||
train_steps_per_update=TRAIN_STEPS_PER_UPDATE,
|
||||
)
|
||||
|
||||
# 打印配置
|
||||
print("\n训练配置:")
|
||||
print(f" 并行环境数: {args.n_envs}")
|
||||
print(f" 总步数: {args.steps:,}")
|
||||
print(f" 学习率: {args.lr} (Warmup: {args.warmup_steps:,} 步)")
|
||||
print(f" ε衰减: {args.epsilon_start} -> {args.epsilon_end} ({args.epsilon_decay:,} 步)")
|
||||
print(f" 批次大小: {args.batch_size}")
|
||||
print(f" 缓冲区大小: {args.buffer_size:,}")
|
||||
print(f" Double DQN: {args.double_dqn}")
|
||||
print(f" Dueling: {args.dueling}")
|
||||
print("=" * 60)
|
||||
print("\n" + "=" * 60, flush=True)
|
||||
print(f"开始 10M 步并行训练(全优化版)", flush=True)
|
||||
print(f" GPU: {device}", flush=True)
|
||||
print(f" 并行环境: {N_ENVS}", flush=True)
|
||||
print(f" Batch Size: {BATCH_SIZE}", flush=True)
|
||||
print(f" 每步训练: {TRAIN_STEPS_PER_UPDATE} 次", flush=True)
|
||||
print(f" AMP 混合精度: {USE_AMP}", flush=True)
|
||||
print(f" torch.compile: {USE_COMPILE}", flush=True)
|
||||
print(f" Dueling: {USE_DUELING}", flush=True)
|
||||
print(f" Double DQN: {USE_DOUBLE}", flush=True)
|
||||
print(f" PER: {USE_PER}", flush=True)
|
||||
print("=" * 60 + "\n", flush=True)
|
||||
|
||||
trainer.train(args.steps)
|
||||
trainer.train(TOTAL_STEPS)
|
||||
|
||||
# ── 评估最佳模型 ──
|
||||
print("\n加载最佳模型...", flush=True)
|
||||
agent.load(f"{SAVE_DIR}/dqn_best.pt")
|
||||
|
||||
print("\n评估中...", flush=True)
|
||||
all_rewards = []
|
||||
for i in range(20):
|
||||
state, _ = eval_env.reset()
|
||||
ep_r = 0
|
||||
done = False
|
||||
while not done:
|
||||
action = agent.select_action(state, evaluate=True)
|
||||
state, reward, terminated, truncated, _ = eval_env.step(action)
|
||||
done = terminated or truncated
|
||||
ep_r += reward
|
||||
all_rewards.append(ep_r)
|
||||
print(f" Episode {i+1:>2}: {ep_r:.1f}", flush=True)
|
||||
|
||||
print(f"\n结果: 平均 {np.mean(all_rewards):.2f} +/- {np.std(all_rewards):.2f}", flush=True)
|
||||
print(f"最佳: {max(all_rewards):.1f} | 最差: {min(all_rewards):.1f}", flush=True)
|
||||
print(f"中位数: {np.median(all_rewards):.1f}", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user