feat(训练): 添加并行环境DQN训练脚本和Jupyter笔记本
- 新增 train_parallel.py 脚本,使用 AsyncVectorEnv 并行运行多个Atari环境 - 添加配套的 Jupyter 笔记本 train_parallel.ipynb 用于交互式训练 - 在 utils.py 的 wrapper 中修复 observation_space 定义,确保与预处理后的观测形状一致 - 删除旧的压缩文件 CW2_DQN_SpaceInvaders.zip - 新增图片文件 image.png 并行训练器通过批量GPU推理和异步环境步进显著提升数据收集速度,适合在多核服务器环境下运行。包含完整的超参数配置、进度监控和模型保存功能。
This commit is contained in:
Binary file not shown.
@@ -0,0 +1,578 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Dueling Double DQN - Space Invaders 并行训练\n",
|
||||
"\n",
|
||||
"使用 AsyncVectorEnv 并行运行多个 Atari 环境,GPU 批量推理加速。\n",
|
||||
"适合在 AutoDL 等多核服务器环境运行。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"导入完成\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from collections import deque\n",
|
||||
"\n",
|
||||
"# notebooks/ 的上级目录即项目根目录\n",
|
||||
"sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), \"..\")))\n",
|
||||
"\n",
|
||||
"from src.network import QNetwork, DuelingQNetwork\n",
|
||||
"from src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer\n",
|
||||
"from src.agent import DQNAgent\n",
|
||||
"from src.utils import make_env, get_device\n",
|
||||
"\n",
|
||||
"print(\"导入完成\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ── 环境工厂 ──\n",
|
||||
"def _make_env_fn(env_id):\n",
|
||||
" \"\"\"环境工厂 - 必须在模块级别以便 multiprocessing pickle.\"\"\"\n",
|
||||
" # AsyncVectorEnv 子进程需要独立注册 ALE\n",
|
||||
" try:\n",
|
||||
" import ale_py\n",
|
||||
" import gymnasium as gym\n",
|
||||
" gym.register_envs(ale_py)\n",
|
||||
" except ImportError:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def _make():\n",
|
||||
" return make_env(env_id, gray_scale=True, resize=True, frame_stack=4)\n",
|
||||
" return _make\n",
|
||||
"\n",
|
||||
"print(\"环境工厂就绪\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"训练器就绪\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 并行训练器 ──\n",
|
||||
"class ParallelTrainer:\n",
|
||||
" def __init__(\n",
|
||||
" self, agent, envs, eval_env, num_envs,\n",
|
||||
" save_dir=\"models\", eval_freq=10000, save_freq=50000,\n",
|
||||
" num_eval_episodes=10, warmup_steps=10000,\n",
|
||||
" ):\n",
|
||||
" self.agent = agent\n",
|
||||
" self.envs = envs\n",
|
||||
" self.eval_env = eval_env\n",
|
||||
" self.num_envs = num_envs\n",
|
||||
" self.save_dir = save_dir\n",
|
||||
" self.eval_freq = eval_freq\n",
|
||||
" self.save_freq = save_freq\n",
|
||||
" self.num_eval_episodes = num_eval_episodes\n",
|
||||
" self.warmup_steps = warmup_steps\n",
|
||||
" self.episode_rewards = deque(maxlen=100)\n",
|
||||
" self.eval_rewards = []\n",
|
||||
" self.best_eval_reward = -float(\"inf\")\n",
|
||||
"\n",
|
||||
" def _batch_select_actions(self, states):\n",
|
||||
" epsilon = self.agent.epsilon\n",
|
||||
" n = len(states)\n",
|
||||
" random_mask = np.random.random(n) < epsilon\n",
|
||||
" actions = np.zeros(n, dtype=np.int64)\n",
|
||||
" non_random = ~random_mask\n",
|
||||
" if non_random.any():\n",
|
||||
" state_tensor = torch.from_numpy(states[non_random]).float().to(self.agent.device)\n",
|
||||
" with torch.no_grad():\n",
|
||||
" q_values = self.agent.q_network(state_tensor)\n",
|
||||
" actions[non_random] = q_values.argmax(dim=1).cpu().numpy()\n",
|
||||
" if random_mask.any():\n",
|
||||
" actions[random_mask] = np.random.randint(0, self.agent.num_actions, size=random_mask.sum())\n",
|
||||
" return actions\n",
|
||||
"\n",
|
||||
" def evaluate(self):\n",
|
||||
" rewards = []\n",
|
||||
" for _ in range(self.num_eval_episodes):\n",
|
||||
" state, _ = self.eval_env.reset()\n",
|
||||
" ep_reward = 0\n",
|
||||
" done = False\n",
|
||||
" while not done:\n",
|
||||
" action = self.agent.select_action(state, evaluate=True)\n",
|
||||
" state, reward, terminated, truncated, _ = self.eval_env.step(action)\n",
|
||||
" done = terminated or truncated\n",
|
||||
" ep_reward += reward\n",
|
||||
" rewards.append(ep_reward)\n",
|
||||
" return np.mean(rewards)\n",
|
||||
"\n",
|
||||
" def train(self, total_steps):\n",
|
||||
" n = self.num_envs\n",
|
||||
" device = self.agent.device\n",
|
||||
" envs = self.envs\n",
|
||||
"\n",
|
||||
" print(f\"开始训练: {total_steps:,} 步, {n} 并行环境\")\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
" states, _ = envs.reset()\n",
|
||||
" ep_rewards = np.zeros(n, dtype=np.float32)\n",
|
||||
" ep_count = 0\n",
|
||||
" start_time = time.time()\n",
|
||||
" step = 0\n",
|
||||
"\n",
|
||||
" while step < total_steps:\n",
|
||||
" if step < self.warmup_steps:\n",
|
||||
" actions = np.array([envs.single_action_space.sample() for _ in range(n)])\n",
|
||||
" else:\n",
|
||||
" actions = self._batch_select_actions(states)\n",
|
||||
"\n",
|
||||
" next_states, rewards, terminateds, truncateds, _ = envs.step(actions)\n",
|
||||
" dones = np.logical_or(terminateds, truncateds)\n",
|
||||
"\n",
|
||||
" for i in range(n):\n",
|
||||
" self.agent.replay_buffer.add(states[i], actions[i], rewards[i], next_states[i], dones[i])\n",
|
||||
"\n",
|
||||
" ep_rewards += rewards\n",
|
||||
"\n",
|
||||
" for i in range(n):\n",
|
||||
" if dones[i]:\n",
|
||||
" self.episode_rewards.append(ep_rewards[i])\n",
|
||||
" ep_count += 1\n",
|
||||
" ep_rewards[i] = 0\n",
|
||||
"\n",
|
||||
" step += n\n",
|
||||
" states = next_states\n",
|
||||
"\n",
|
||||
" if step >= self.warmup_steps:\n",
|
||||
" self.agent.train_step()\n",
|
||||
"\n",
|
||||
" if ep_count > 0 and ep_count % 20 == 0:\n",
|
||||
" avg_r = np.mean(self.episode_rewards) if self.episode_rewards else 0\n",
|
||||
" elapsed = time.time() - start_time\n",
|
||||
" fps = step / elapsed\n",
|
||||
" lr = self.agent.optimizer.param_groups[0][\"lr\"]\n",
|
||||
" print(f\"Step:{step:>10,} | Ep:{ep_count:>5} | AvgR:{avg_r:>7.1f} | \"\n",
|
||||
" f\"Eps:{self.agent.epsilon:.3f} | LR:{lr:.2e} | FPS:{fps:.0f}\")\n",
|
||||
"\n",
|
||||
" if step % self.eval_freq == 0 and step > 0:\n",
|
||||
" eval_r = self.evaluate()\n",
|
||||
" self.eval_rewards.append((step, eval_r))\n",
|
||||
" print(f\"\\n[评估] Step:{step:>10,} | 平均回报:{eval_r:.1f}\\n\")\n",
|
||||
" if eval_r > self.best_eval_reward:\n",
|
||||
" self.best_eval_reward = eval_r\n",
|
||||
" self.agent.save(f\"{self.save_dir}/dqn_best.pt\")\n",
|
||||
"\n",
|
||||
" if step % self.save_freq == 0:\n",
|
||||
" self.agent.save(f\"{self.save_dir}/dqn_step_{step}.pt\")\n",
|
||||
"\n",
|
||||
" total_time = time.time() - start_time\n",
|
||||
" print(\"\\n\" + \"=\" * 60)\n",
|
||||
" print(f\"训练完成!总时间: {total_time:.1f} 秒 | FPS: {total_steps/total_time:.0f}\")\n",
|
||||
" print(f\"最佳评估回报: {self.best_eval_reward:.1f}\")\n",
|
||||
" self.agent.save(f\"{self.save_dir}/dqn_final.pt\")\n",
|
||||
"\n",
|
||||
"print(\"训练器就绪\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 配置参数\n",
|
||||
"\n",
|
||||
"根据 GPU 和预期训练时间调整以下参数。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ── 可修改的超参数 ──\n",
|
||||
"\n",
|
||||
"ENV_ID = \"ALE/SpaceInvaders-v5\"\n",
|
||||
"N_ENVS = 24 # 25 核 CPU,留 1 核给主进程\n",
|
||||
"TOTAL_STEPS = 10_000_000 # 总步数\n",
|
||||
"LR = 3e-5 # 学习率(大 batch 配低 lr 更稳定)\n",
|
||||
"GAMMA = 0.99 # 折扣因子\n",
|
||||
"BATCH_SIZE = 512 # RTX 5090 跑大 batch 才不浪费\n",
|
||||
"BUFFER_SIZE = 1_000_000 # 回放缓冲区\n",
|
||||
"EPSILON_START = 1.0\n",
|
||||
"EPSILON_END = 0.01\n",
|
||||
"EPSILON_DECAY = 4_000_000 # ε衰减步数(24 环境探索效率高,延长探索期)\n",
|
||||
"TARGET_UPDATE = 2000\n",
|
||||
"LR_DECAY_STEPS = 5_000_000\n",
|
||||
"LR_DECAY_FACTOR = 0.5\n",
|
||||
"WARMUP_STEPS = 50_000\n",
|
||||
"EVAL_FREQ = 50000\n",
|
||||
"EVAL_EPISODES = 10\n",
|
||||
"SAVE_FREQ = 200000\n",
|
||||
"SEED = 42\n",
|
||||
"SAVE_DIR = \"models\"\n",
|
||||
"\n",
|
||||
"USE_DUELING = True\n",
|
||||
"USE_DOUBLE = True\n",
|
||||
"USE_PER = True # 优先经验回放\n",
|
||||
"\n",
|
||||
"print(f\"配置: {TOTAL_STEPS/1e6:.0f}M 步, {N_ENVS} 并行环境\")\n",
|
||||
"print(f\"预计环境交互: {TOTAL_STEPS * 4 / 1e6:.0f}M frames\")\n",
|
||||
"print(f\"预计时间 (AutoDL 5090): ~{TOTAL_STEPS / 1000 / N_ENVS / 3600:.1f} 小时\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"使用GPU: NVIDIA GeForce RTX 4060 Laptop GPU\n",
|
||||
"SyncVectorEnv (Windows): 16 个环境\n",
|
||||
"动作空间: 6\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 环境设置 ──\n",
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"np.random.seed(SEED)\n",
|
||||
"import platform\n",
|
||||
"\n",
|
||||
"device = get_device()\n",
|
||||
"\n",
|
||||
"# Windows Jupyter 不支持 AsyncVectorEnv 子进程,用 SyncVectorEnv 替代\n",
|
||||
"if platform.system() == \"Linux\":\n",
|
||||
" from gymnasium.vector import AsyncVectorEnv\n",
|
||||
" env_fns = [_make_env_fn(ENV_ID) for _ in range(N_ENVS)]\n",
|
||||
" envs = AsyncVectorEnv(env_fns, shared_memory=True)\n",
|
||||
" print(f\"AsyncVectorEnv: {envs.num_envs} 个环境\")\n",
|
||||
"else:\n",
|
||||
" from gymnasium.vector import SyncVectorEnv\n",
|
||||
" env_fns = [_make_env_fn(ENV_ID) for _ in range(N_ENVS)]\n",
|
||||
" envs = SyncVectorEnv(env_fns)\n",
|
||||
" print(f\"SyncVectorEnv (Windows): {envs.num_envs} 个环境\")\n",
|
||||
"\n",
|
||||
"# 评估环境\n",
|
||||
"eval_env = make_env(ENV_ID, gray_scale=True, resize=True, frame_stack=4)\n",
|
||||
"\n",
|
||||
"num_actions = envs.single_action_space.n\n",
|
||||
"print(f\"动作空间: {num_actions}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Dueling DQN: 3,293,863 参数\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DuelingQNetwork(\n",
|
||||
" (conv): Sequential(\n",
|
||||
" (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))\n",
|
||||
" (1): ReLU()\n",
|
||||
" (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
|
||||
" (3): ReLU()\n",
|
||||
" (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
|
||||
" (5): ReLU()\n",
|
||||
" )\n",
|
||||
" (value_stream): Sequential(\n",
|
||||
" (0): Linear(in_features=3136, out_features=512, bias=True)\n",
|
||||
" (1): ReLU()\n",
|
||||
" (2): Linear(in_features=512, out_features=1, bias=True)\n",
|
||||
" )\n",
|
||||
" (advantage_stream): Sequential(\n",
|
||||
" (0): Linear(in_features=3136, out_features=512, bias=True)\n",
|
||||
" (1): ReLU()\n",
|
||||
" (2): Linear(in_features=512, out_features=6, bias=True)\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 网络 ──\n",
|
||||
"state_shape = (4, 84, 84)\n",
|
||||
"\n",
|
||||
"if USE_DUELING:\n",
|
||||
" q_network = DuelingQNetwork(state_shape, num_actions).to(device)\n",
|
||||
" target_network = DuelingQNetwork(state_shape, num_actions).to(device)\n",
|
||||
" print(f\"Dueling DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\n",
|
||||
"else:\n",
|
||||
" q_network = QNetwork(state_shape, num_actions).to(device)\n",
|
||||
" target_network = QNetwork(state_shape, num_actions).to(device)\n",
|
||||
" print(f\"标准 DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\n",
|
||||
"\n",
|
||||
"target_network.load_state_dict(q_network.state_dict())\n",
|
||||
"target_network.eval()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"优先经验回放\n",
|
||||
"Agent 创建完成\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 回放缓冲区 + Agent ──\n",
|
||||
"if USE_PER:\n",
|
||||
" replay_buffer = PrioritizedReplayBuffer(BUFFER_SIZE, state_shape, device)\n",
|
||||
" print(\"优先经验回放\")\n",
|
||||
"else:\n",
|
||||
" replay_buffer = ReplayBuffer(BUFFER_SIZE, state_shape, device)\n",
|
||||
" print(\"标准经验回放\")\n",
|
||||
"\n",
|
||||
"agent = DQNAgent(\n",
|
||||
" q_network=q_network,\n",
|
||||
" target_network=target_network,\n",
|
||||
" replay_buffer=replay_buffer,\n",
|
||||
" device=device,\n",
|
||||
" num_actions=num_actions,\n",
|
||||
" gamma=GAMMA,\n",
|
||||
" lr=LR,\n",
|
||||
" epsilon_start=EPSILON_START,\n",
|
||||
" epsilon_end=EPSILON_END,\n",
|
||||
" epsilon_decay_steps=EPSILON_DECAY,\n",
|
||||
" target_update_freq=TARGET_UPDATE,\n",
|
||||
" batch_size=BATCH_SIZE,\n",
|
||||
" double_dqn=USE_DOUBLE,\n",
|
||||
" lr_decay_steps=LR_DECAY_STEPS,\n",
|
||||
" lr_decay_factor=LR_DECAY_FACTOR,\n",
|
||||
" warmup_steps=WARMUP_STEPS,\n",
|
||||
")\n",
|
||||
"print(\"Agent 创建完成\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"============================================================\n",
|
||||
"开始 10M 步并行训练\n",
|
||||
" GPU: cuda\n",
|
||||
" 并行环境: 16\n",
|
||||
" Dueling: True\n",
|
||||
" Double DQN: True\n",
|
||||
" PER: True\n",
|
||||
"============================================================\n",
|
||||
"\n",
|
||||
"开始训练: 10,000,000 步, 16 并行环境\n",
|
||||
"============================================================\n",
|
||||
"Step: 3,232 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:651\n",
|
||||
"Step: 3,248 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
|
||||
"Step: 3,264 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
|
||||
"Step: 3,280 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
|
||||
"Step: 3,296 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
|
||||
"Step: 3,312 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
|
||||
"Step: 3,328 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
|
||||
"Step: 5,776 | Ep: 40 | AvgR: 12.3 | Eps:1.000 | LR:5.00e-05 | FPS:596\n",
|
||||
"Step: 5,792 | Ep: 40 | AvgR: 12.3 | Eps:1.000 | LR:5.00e-05 | FPS:596\n",
|
||||
"Step: 8,144 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,160 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,176 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,192 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,208 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,224 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,240 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,256 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,272 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,288 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,304 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
||||
"Step: 8,320 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,336 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,352 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,368 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,384 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,400 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,416 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 8,432 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
||||
"Step: 10,912 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:2.90e-07 | FPS:520\n",
|
||||
"Step: 10,928 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:2.95e-07 | FPS:520\n",
|
||||
"Step: 10,944 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:3.00e-07 | FPS:519\n",
|
||||
"Step: 10,960 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:3.05e-07 | FPS:519\n",
|
||||
"Step: 13,280 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.03e-06 | FPS:481\n",
|
||||
"Step: 13,296 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
|
||||
"Step: 13,312 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
|
||||
"Step: 13,328 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
|
||||
"Step: 15,648 | Ep: 120 | AvgR: 12.8 | Eps:1.000 | LR:1.77e-06 | FPS:454\n",
|
||||
"Step: 21,184 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.50e-06 | FPS:434\n",
|
||||
"Step: 21,200 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.50e-06 | FPS:434\n",
|
||||
"Step: 21,216 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.51e-06 | FPS:434\n",
|
||||
"Step: 21,232 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
|
||||
"Step: 21,248 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
|
||||
"Step: 21,264 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
|
||||
"Step: 21,280 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.53e-06 | FPS:434\n",
|
||||
"Step: 21,296 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.54e-06 | FPS:434\n",
|
||||
"Step: 23,824 | Ep: 180 | AvgR: 14.4 | Eps:1.000 | LR:4.33e-06 | FPS:432\n",
|
||||
"Step: 26,144 | Ep: 200 | AvgR: 14.3 | Eps:1.000 | LR:5.05e-06 | FPS:422\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 开始训练 ──\n",
|
||||
"trainer = ParallelTrainer(\n",
|
||||
" agent=agent,\n",
|
||||
" envs=envs,\n",
|
||||
" eval_env=eval_env,\n",
|
||||
" num_envs=N_ENVS,\n",
|
||||
" save_dir=SAVE_DIR,\n",
|
||||
" eval_freq=EVAL_FREQ,\n",
|
||||
" save_freq=SAVE_FREQ,\n",
|
||||
" num_eval_episodes=EVAL_EPISODES,\n",
|
||||
" warmup_steps=WARMUP_STEPS,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\" * 60)\n",
|
||||
"print(f\"开始 10M 步并行训练\")\n",
|
||||
"print(f\" GPU: {device}\")\n",
|
||||
"print(f\" 并行环境: {N_ENVS}\")\n",
|
||||
"print(f\" Dueling: {USE_DUELING}\")\n",
|
||||
"print(f\" Double DQN: {USE_DOUBLE}\")\n",
|
||||
"print(f\" PER: {USE_PER}\")\n",
|
||||
"print(\"=\" * 60 + \"\\n\")\n",
|
||||
"\n",
|
||||
"trainer.train(TOTAL_STEPS)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 训练完成后:评估最佳模型"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"加载最佳模型...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"d:\\Code\\doing_exercises\\programs\\外教作业外快\\强化学习个人项目报告(Atari 游戏方向)\\src\\agent.py:219: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
||||
" checkpoint = torch.load(path, map_location=self.device)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "FileNotFoundError",
|
||||
"evalue": "[Errno 2] No such file or directory: 'models/dqn_best.pt'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[1;32mIn[9], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# ── 评估最佳模型 ──\u001b[39;00m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m加载最佳模型...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 3\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mSAVE_DIR\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m/dqn_best.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m评估中...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 6\u001b[0m all_rewards \u001b[38;5;241m=\u001b[39m []\n",
|
||||
"File \u001b[1;32md:\\Code\\doing_exercises\\programs\\外教作业外快\\强化学习个人项目报告(Atari 游戏方向)\\src\\agent.py:219\u001b[0m, in \u001b[0;36mDQNAgent.load\u001b[1;34m(self, path)\u001b[0m\n\u001b[0;32m 217\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mload\u001b[39m(\u001b[38;5;28mself\u001b[39m, path):\n\u001b[0;32m 218\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"加载模型\"\"\"\u001b[39;00m\n\u001b[1;32m--> 219\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 220\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mq_network\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mq_network\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[0;32m 221\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_network\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtarget_network\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
|
||||
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:1065\u001b[0m, in \u001b[0;36mload\u001b[1;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[0;32m 1062\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m pickle_load_args\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m 1063\u001b[0m pickle_load_args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m-> 1065\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_file_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_file:\n\u001b[0;32m 1066\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[0;32m 1067\u001b[0m \u001b[38;5;66;03m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[0;32m 1068\u001b[0m \u001b[38;5;66;03m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[0;32m 1069\u001b[0m \u001b[38;5;66;03m# reset back to the original position.\u001b[39;00m\n\u001b[0;32m 1070\u001b[0m orig_position \u001b[38;5;241m=\u001b[39m opened_file\u001b[38;5;241m.\u001b[39mtell()\n",
|
||||
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:468\u001b[0m, in \u001b[0;36m_open_file_like\u001b[1;34m(name_or_buffer, mode)\u001b[0m\n\u001b[0;32m 466\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[0;32m 467\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[1;32m--> 468\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_open_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 469\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 470\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m mode:\n",
|
||||
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:449\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[1;34m(self, name, mode)\u001b[0m\n\u001b[0;32m 448\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, mode):\n\u001b[1;32m--> 449\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m)\n",
|
||||
"\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'models/dqn_best.pt'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ── 评估最佳模型 ──\n",
|
||||
"print(\"加载最佳模型...\")\n",
|
||||
"agent.load(f\"{SAVE_DIR}/dqn_best.pt\")\n",
|
||||
"\n",
|
||||
"print(\"\\n评估中...\")\n",
|
||||
"all_rewards = []\n",
|
||||
"for i in range(20):\n",
|
||||
" state, _ = eval_env.reset()\n",
|
||||
" ep_r = 0\n",
|
||||
" done = False\n",
|
||||
" while not done:\n",
|
||||
" action = agent.select_action(state, evaluate=True)\n",
|
||||
" state, reward, terminated, truncated, _ = eval_env.step(action)\n",
|
||||
" done = terminated or truncated\n",
|
||||
" ep_r += reward\n",
|
||||
" all_rewards.append(ep_r)\n",
|
||||
" print(f\" Episode {i+1:>2}: {ep_r:.1f}\")\n",
|
||||
"\n",
|
||||
"print(f\"\\n结果: 平均 {np.mean(all_rewards):.2f} ± {np.std(all_rewards):.2f}\")\n",
|
||||
"print(f\"最佳: {max(all_rewards):.1f} | 最差: {min(all_rewards):.1f}\")\n",
|
||||
"print(f\"中位数: {np.median(all_rewards):.1f}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "my_env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.20"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -17,6 +17,10 @@ class GrayScaleWrapper(gym.ObservationWrapper):
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
old_shape = self.observation_space.shape
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=old_shape[:2], dtype=np.uint8
|
||||
)
|
||||
|
||||
def observation(self, obs):
|
||||
# RGB转灰度:加权平均
|
||||
@@ -30,6 +34,15 @@ class ResizeWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, size=(84, 84)):
|
||||
super().__init__(env)
|
||||
self.size = size
|
||||
obs_shape = self.observation_space.shape
|
||||
if len(obs_shape) == 3:
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=(*size, obs_shape[-1]), dtype=np.uint8
|
||||
)
|
||||
else:
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=size, dtype=np.uint8
|
||||
)
|
||||
|
||||
def observation(self, obs):
|
||||
import cv2
|
||||
|
||||
@@ -0,0 +1,371 @@
|
||||
"""并行环境 DQN 训练脚本 - 使用 AsyncVectorEnv 加速数据收集.
|
||||
|
||||
每个训练迭代并行采集 N 个环境的转移,批量 GPU 推理,显著提升 FPS。
|
||||
适合在 AutoDL 等多核服务器+GPU 环境下运行。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from collections import deque
|
||||
|
||||
from src.network import QNetwork, DuelingQNetwork
|
||||
from src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
from src.utils import make_env, get_device
|
||||
|
||||
|
||||
# ── 环境工厂函数(供 AsyncVectorEnv 子进程使用)──
|
||||
|
||||
def _make_env_fn(env_id):
|
||||
"""环境工厂 - 必须在模块级别以便 multiprocessing pickle."""
|
||||
# AsyncVectorEnv 子进程需要独立注册 ALE
|
||||
try:
|
||||
import ale_py
|
||||
import gymnasium as gym
|
||||
gym.register_envs(ale_py)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def _make():
|
||||
return make_env(env_id, gray_scale=True, resize=True, frame_stack=4)
|
||||
return _make
|
||||
|
||||
|
||||
# ── 并行训练器 ──
|
||||
|
||||
class ParallelTrainer:
|
||||
"""并行环境 DQN 训练器.
|
||||
|
||||
使用 AsyncVectorEnv 并行运行 N 个环境,
|
||||
同时收集转移 + 批量推理,大幅提升训练速度。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent,
|
||||
envs,
|
||||
eval_env,
|
||||
num_envs,
|
||||
save_dir="models",
|
||||
eval_freq=10000,
|
||||
save_freq=50000,
|
||||
num_eval_episodes=10,
|
||||
warmup_steps=10000,
|
||||
n_steps_per_env=1,
|
||||
):
|
||||
self.agent = agent
|
||||
self.envs = envs
|
||||
self.eval_env = eval_env
|
||||
self.num_envs = num_envs
|
||||
self.save_dir = save_dir
|
||||
self.eval_freq = eval_freq
|
||||
self.save_freq = save_freq
|
||||
self.num_eval_episodes = num_eval_episodes
|
||||
self.warmup_steps = warmup_steps
|
||||
self.n_steps_per_env = n_steps_per_env
|
||||
|
||||
self.episode_rewards = deque(maxlen=100)
|
||||
self.eval_rewards = []
|
||||
self.best_eval_reward = -float("inf")
|
||||
|
||||
def train(self, total_steps):
|
||||
"""主并行训练循环.
|
||||
|
||||
Args:
|
||||
total_steps: 总环境交互步数
|
||||
"""
|
||||
num_envs = self.num_envs
|
||||
device = self.agent.device
|
||||
envs = self.envs
|
||||
|
||||
print(f"开始并行训练,总步数: {total_steps:,}")
|
||||
print(f"并行环境数: {num_envs}")
|
||||
print(f"预热步数: {self.warmup_steps:,}")
|
||||
print("=" * 60)
|
||||
|
||||
# 重置所有环境
|
||||
states, _ = envs.reset()
|
||||
episode_rewards = np.zeros(num_envs, dtype=np.float32)
|
||||
episode_lengths = np.zeros(num_envs, dtype=np.int32)
|
||||
episode_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
step = 0
|
||||
while step < total_steps:
|
||||
# ── 动作选择 ──
|
||||
if step < self.warmup_steps:
|
||||
actions = np.array([envs.single_action_space.sample() for _ in range(num_envs)])
|
||||
else:
|
||||
actions = self._batch_select_actions(states)
|
||||
|
||||
# ── 环境步进(N 个环境并行)──
|
||||
next_states, rewards, terminateds, truncateds, _ = envs.step(actions)
|
||||
dones = np.logical_or(terminateds, truncateds)
|
||||
|
||||
# ── 存储转移 ──
|
||||
for i in range(num_envs):
|
||||
self.agent.replay_buffer.add(
|
||||
states[i], actions[i], rewards[i], next_states[i], dones[i]
|
||||
)
|
||||
|
||||
# ── 统计 ──
|
||||
episode_rewards += rewards
|
||||
episode_lengths += 1
|
||||
|
||||
# 处理结束的 episode
|
||||
for i in range(num_envs):
|
||||
if dones[i]:
|
||||
self.episode_rewards.append(episode_rewards[i])
|
||||
episode_count += 1
|
||||
episode_rewards[i] = 0
|
||||
episode_lengths[i] = 0
|
||||
|
||||
step += num_envs
|
||||
states = next_states
|
||||
|
||||
# ── 训练(环境每步一个 mini-batch)──
|
||||
if step >= self.warmup_steps:
|
||||
self.agent.train_step()
|
||||
|
||||
# ── 进度打印 ──
|
||||
if episode_count > 0 and episode_count % 10 == 0:
|
||||
avg_reward = np.mean(self.episode_rewards) if self.episode_rewards else 0
|
||||
elapsed = time.time() - start_time
|
||||
fps = step / elapsed
|
||||
|
||||
current_lr = self.agent.optimizer.param_groups[0]["lr"]
|
||||
print(
|
||||
f"Step: {step:>10,} | "
|
||||
f"Ep: {episode_count:>5} | "
|
||||
f"AvgReward: {avg_reward:>7.1f} | "
|
||||
f"Epsilon: {self.agent.epsilon:.3f} | "
|
||||
f"LR: {current_lr:.2e} | "
|
||||
f"FPS: {fps:.0f}"
|
||||
)
|
||||
|
||||
# ── 定期评估 ──
|
||||
if step % self.eval_freq == 0 and step > 0:
|
||||
eval_reward = self.evaluate()
|
||||
self.eval_rewards.append((step, eval_reward))
|
||||
print(f"\n[Eval] Step: {step:>10,} | AvgReward: {eval_reward:.1f}\n" + "-" * 60)
|
||||
|
||||
if eval_reward > self.best_eval_reward:
|
||||
self.best_eval_reward = eval_reward
|
||||
self.agent.save(f"{self.save_dir}/dqn_best.pt")
|
||||
|
||||
# ── 定期保存 ──
|
||||
if step % self.save_freq == 0:
|
||||
self.agent.save(f"{self.save_dir}/dqn_step_{step}.pt")
|
||||
|
||||
# 训练结束
|
||||
total_time = time.time() - start_time
|
||||
print("\n" + "=" * 60)
|
||||
print(f"训练完成!总时间: {total_time:.1f} 秒")
|
||||
print(f"平均 FPS: {total_steps / total_time:.0f}")
|
||||
print(f"最佳评估回报: {self.best_eval_reward:.1f}")
|
||||
|
||||
self.agent.save(f"{self.save_dir}/dqn_final.pt")
|
||||
|
||||
def _batch_select_actions(self, states):
|
||||
"""批量选择动作(使用 GPU 批量推理)."""
|
||||
epsilon = self.agent.epsilon
|
||||
num_envs = len(states)
|
||||
|
||||
# 随机探索
|
||||
random_mask = np.random.random(num_envs) < epsilon
|
||||
|
||||
actions = np.zeros(num_envs, dtype=np.int64)
|
||||
|
||||
# 对非随机的环境做批量推理
|
||||
non_random = ~random_mask
|
||||
if non_random.any():
|
||||
state_tensor = (
|
||||
torch.from_numpy(states[non_random]).float().to(self.agent.device)
|
||||
)
|
||||
with torch.no_grad():
|
||||
q_values = self.agent.q_network(state_tensor)
|
||||
actions[non_random] = q_values.argmax(dim=1).cpu().numpy()
|
||||
|
||||
# 随机的环境
|
||||
if random_mask.any():
|
||||
actions[random_mask] = np.random.randint(
|
||||
0, self.agent.num_actions, size=random_mask.sum()
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def evaluate(self):
|
||||
"""评估智能体."""
|
||||
rewards = []
|
||||
for _ in range(self.num_eval_episodes):
|
||||
state, _ = self.eval_env.reset()
|
||||
episode_reward = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
action = self.agent.select_action(state, evaluate=True)
|
||||
state, reward, terminated, truncated, _ = self.eval_env.step(action)
|
||||
done = terminated or truncated
|
||||
episode_reward += reward
|
||||
|
||||
rewards.append(episode_reward)
|
||||
|
||||
return np.mean(rewards)
|
||||
|
||||
|
||||
# ── 主入口 ──
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Parallel DQN for Space Invaders")
|
||||
|
||||
# 并行参数
|
||||
parser.add_argument("--n-envs", type=int, default=8, help="并行环境数")
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument("--env", type=str, default="ALE/SpaceInvaders-v5")
|
||||
parser.add_argument("--steps", type=int, default=10_000_000, help="总训练步数")
|
||||
parser.add_argument("--lr", type=float, default=5e-5, help="学习率")
|
||||
parser.add_argument("--gamma", type=float, default=0.99, help="折扣因子")
|
||||
parser.add_argument("--batch-size", type=int, default=64, help="批次大小")
|
||||
parser.add_argument("--buffer-size", type=int, default=500_000, help="回放缓冲区大小")
|
||||
|
||||
# ε-greedy
|
||||
parser.add_argument("--epsilon-start", type=float, default=1.0)
|
||||
parser.add_argument("--epsilon-end", type=float, default=0.01)
|
||||
parser.add_argument("--epsilon-decay", type=int, default=2_000_000)
|
||||
|
||||
# 网络
|
||||
parser.add_argument("--target-update", type=int, default=1000)
|
||||
parser.add_argument("--double-dqn", action="store_true", default=True)
|
||||
parser.add_argument("--dueling", action="store_true", default=True)
|
||||
|
||||
# 学习率
|
||||
parser.add_argument("--lr-decay-steps", type=int, default=5_000_000)
|
||||
parser.add_argument("--lr-decay-factor", type=float, default=0.5)
|
||||
parser.add_argument("--warmup-steps", type=int, default=10_000)
|
||||
|
||||
# 评估
|
||||
parser.add_argument("--eval-freq", type=int, default=50000)
|
||||
parser.add_argument("--eval-episodes", type=int, default=10)
|
||||
parser.add_argument("--save-freq", type=int, default=100000)
|
||||
|
||||
# 优先回放
|
||||
parser.add_argument("--prioritized", action="store_true", default=True)
|
||||
|
||||
# 其他
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--save-dir", type=str, default="models")
|
||||
parser.add_argument("--log-dir", type=str, default="logs")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# 设备
|
||||
device = get_device()
|
||||
|
||||
# 创建并行训练环境
|
||||
print(f"创建 {args.n_envs} 个并行训练环境...")
|
||||
try:
|
||||
from gymnasium.vector import AsyncVectorEnv
|
||||
env_fns = [_make_env_fn(args.env) for _ in range(args.n_envs)]
|
||||
envs = AsyncVectorEnv(env_fns, shared_memory=True)
|
||||
except ImportError:
|
||||
print("AsyncVectorEnv 不可用,回退到 SyncVectorEnv")
|
||||
from gymnasium.vector import SyncVectorEnv
|
||||
env_fns = [_make_env_fn(args.env) for _ in range(args.n_envs)]
|
||||
envs = SyncVectorEnv(env_fns)
|
||||
|
||||
# 创建评估环境(单环境)
|
||||
eval_env = make_env(args.env, gray_scale=True, resize=True, frame_stack=4)
|
||||
|
||||
num_actions = envs.single_action_space.n
|
||||
print(f"动作空间: {num_actions}")
|
||||
print(f"实际环境数: {envs.num_envs}")
|
||||
|
||||
state_shape = (4, 84, 84)
|
||||
|
||||
# 创建网络
|
||||
if args.dueling:
|
||||
print("使用 Dueling Double DQN")
|
||||
q_network = DuelingQNetwork(state_shape, num_actions).to(device)
|
||||
target_network = DuelingQNetwork(state_shape, num_actions).to(device)
|
||||
else:
|
||||
print("使用标准 DQN")
|
||||
q_network = QNetwork(state_shape, num_actions).to(device)
|
||||
target_network = QNetwork(state_shape, num_actions).to(device)
|
||||
|
||||
target_network.load_state_dict(q_network.state_dict())
|
||||
target_network.eval()
|
||||
|
||||
print(f"网络参数量: {sum(p.numel() for p in q_network.parameters()):,}")
|
||||
|
||||
# 回放缓冲区
|
||||
if args.prioritized:
|
||||
print("使用优先经验回放")
|
||||
replay_buffer = PrioritizedReplayBuffer(args.buffer_size, state_shape, device)
|
||||
else:
|
||||
print("使用标准经验回放")
|
||||
replay_buffer = ReplayBuffer(args.buffer_size, state_shape, device)
|
||||
|
||||
# 创建 Agent
|
||||
from src.agent import DQNAgent
|
||||
|
||||
agent = DQNAgent(
|
||||
q_network=q_network,
|
||||
target_network=target_network,
|
||||
replay_buffer=replay_buffer,
|
||||
device=device,
|
||||
num_actions=num_actions,
|
||||
gamma=args.gamma,
|
||||
lr=args.lr,
|
||||
epsilon_start=args.epsilon_start,
|
||||
epsilon_end=args.epsilon_end,
|
||||
epsilon_decay_steps=args.epsilon_decay,
|
||||
target_update_freq=args.target_update,
|
||||
batch_size=args.batch_size,
|
||||
double_dqn=args.double_dqn,
|
||||
lr_decay_steps=args.lr_decay_steps,
|
||||
lr_decay_factor=args.lr_decay_factor,
|
||||
warmup_steps=args.warmup_steps,
|
||||
)
|
||||
|
||||
# 创建训练器
|
||||
trainer = ParallelTrainer(
|
||||
agent=agent,
|
||||
envs=envs,
|
||||
eval_env=eval_env,
|
||||
num_envs=args.n_envs,
|
||||
save_dir=args.save_dir,
|
||||
eval_freq=args.eval_freq,
|
||||
save_freq=args.save_freq,
|
||||
num_eval_episodes=args.eval_episodes,
|
||||
warmup_steps=args.warmup_steps,
|
||||
)
|
||||
|
||||
# 打印配置
|
||||
print("\n训练配置:")
|
||||
print(f" 并行环境数: {args.n_envs}")
|
||||
print(f" 总步数: {args.steps:,}")
|
||||
print(f" 学习率: {args.lr} (Warmup: {args.warmup_steps:,} 步)")
|
||||
print(f" ε衰减: {args.epsilon_start} -> {args.epsilon_end} ({args.epsilon_decay:,} 步)")
|
||||
print(f" 批次大小: {args.batch_size}")
|
||||
print(f" 缓冲区大小: {args.buffer_size:,}")
|
||||
print(f" Double DQN: {args.double_dqn}")
|
||||
print(f" Dueling: {args.dueling}")
|
||||
print("=" * 60)
|
||||
|
||||
trainer.train(args.steps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user