ed0822966b
- 新增 train_parallel.py 脚本,使用 AsyncVectorEnv 并行运行多个Atari环境 - 添加配套的 Jupyter 笔记本 train_parallel.ipynb 用于交互式训练 - 在 utils.py 的 wrapper 中修复 observation_space 定义,确保与预处理后的观测形状一致 - 删除旧的压缩文件 CW2_DQN_SpaceInvaders.zip - 新增图片文件 image.png 并行训练器通过批量GPU推理和异步环境步进显著提升数据收集速度,适合在多核服务器环境下运行。包含完整的超参数配置、进度监控和模型保存功能。
579 lines
28 KiB
Plaintext
579 lines
28 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Dueling Double DQN - Space Invaders 并行训练\n",
|
|
"\n",
|
|
"使用 AsyncVectorEnv 并行运行多个 Atari 环境,GPU 批量推理加速。\n",
|
|
"适合在 AutoDL 等多核服务器环境运行。"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"导入完成\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import sys\n",
|
|
"import os\n",
|
|
"import time\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from collections import deque\n",
|
|
"\n",
|
|
"# notebooks/ 的上级目录即项目根目录\n",
|
|
"sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), \"..\")))\n",
|
|
"\n",
|
|
"from src.network import QNetwork, DuelingQNetwork\n",
|
|
"from src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer\n",
|
|
"from src.agent import DQNAgent\n",
|
|
"from src.utils import make_env, get_device\n",
|
|
"\n",
|
|
"print(\"导入完成\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ── 环境工厂 ──\n",
|
|
"def _make_env_fn(env_id):\n",
|
|
" \"\"\"环境工厂 - 必须在模块级别以便 multiprocessing pickle.\"\"\"\n",
|
|
" # AsyncVectorEnv 子进程需要独立注册 ALE\n",
|
|
" try:\n",
|
|
" import ale_py\n",
|
|
" import gymnasium as gym\n",
|
|
" gym.register_envs(ale_py)\n",
|
|
" except ImportError:\n",
|
|
" pass\n",
|
|
"\n",
|
|
" def _make():\n",
|
|
" return make_env(env_id, gray_scale=True, resize=True, frame_stack=4)\n",
|
|
" return _make\n",
|
|
"\n",
|
|
"print(\"环境工厂就绪\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"训练器就绪\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# ── 并行训练器 ──\n",
|
|
"class ParallelTrainer:\n",
|
|
" def __init__(\n",
|
|
" self, agent, envs, eval_env, num_envs,\n",
|
|
" save_dir=\"models\", eval_freq=10000, save_freq=50000,\n",
|
|
" num_eval_episodes=10, warmup_steps=10000,\n",
|
|
" ):\n",
|
|
" self.agent = agent\n",
|
|
" self.envs = envs\n",
|
|
" self.eval_env = eval_env\n",
|
|
" self.num_envs = num_envs\n",
|
|
" self.save_dir = save_dir\n",
|
|
" self.eval_freq = eval_freq\n",
|
|
" self.save_freq = save_freq\n",
|
|
" self.num_eval_episodes = num_eval_episodes\n",
|
|
" self.warmup_steps = warmup_steps\n",
|
|
" self.episode_rewards = deque(maxlen=100)\n",
|
|
" self.eval_rewards = []\n",
|
|
" self.best_eval_reward = -float(\"inf\")\n",
|
|
"\n",
|
|
" def _batch_select_actions(self, states):\n",
|
|
" epsilon = self.agent.epsilon\n",
|
|
" n = len(states)\n",
|
|
" random_mask = np.random.random(n) < epsilon\n",
|
|
" actions = np.zeros(n, dtype=np.int64)\n",
|
|
" non_random = ~random_mask\n",
|
|
" if non_random.any():\n",
|
|
" state_tensor = torch.from_numpy(states[non_random]).float().to(self.agent.device)\n",
|
|
" with torch.no_grad():\n",
|
|
" q_values = self.agent.q_network(state_tensor)\n",
|
|
" actions[non_random] = q_values.argmax(dim=1).cpu().numpy()\n",
|
|
" if random_mask.any():\n",
|
|
" actions[random_mask] = np.random.randint(0, self.agent.num_actions, size=random_mask.sum())\n",
|
|
" return actions\n",
|
|
"\n",
|
|
" def evaluate(self):\n",
|
|
" rewards = []\n",
|
|
" for _ in range(self.num_eval_episodes):\n",
|
|
" state, _ = self.eval_env.reset()\n",
|
|
" ep_reward = 0\n",
|
|
" done = False\n",
|
|
" while not done:\n",
|
|
" action = self.agent.select_action(state, evaluate=True)\n",
|
|
" state, reward, terminated, truncated, _ = self.eval_env.step(action)\n",
|
|
" done = terminated or truncated\n",
|
|
" ep_reward += reward\n",
|
|
" rewards.append(ep_reward)\n",
|
|
" return np.mean(rewards)\n",
|
|
"\n",
|
|
" def train(self, total_steps):\n",
|
|
" n = self.num_envs\n",
|
|
" device = self.agent.device\n",
|
|
" envs = self.envs\n",
|
|
"\n",
|
|
" print(f\"开始训练: {total_steps:,} 步, {n} 并行环境\")\n",
|
|
" print(\"=\" * 60)\n",
|
|
"\n",
|
|
" states, _ = envs.reset()\n",
|
|
" ep_rewards = np.zeros(n, dtype=np.float32)\n",
|
|
" ep_count = 0\n",
|
|
" start_time = time.time()\n",
|
|
" step = 0\n",
|
|
"\n",
|
|
" while step < total_steps:\n",
|
|
" if step < self.warmup_steps:\n",
|
|
" actions = np.array([envs.single_action_space.sample() for _ in range(n)])\n",
|
|
" else:\n",
|
|
" actions = self._batch_select_actions(states)\n",
|
|
"\n",
|
|
" next_states, rewards, terminateds, truncateds, _ = envs.step(actions)\n",
|
|
" dones = np.logical_or(terminateds, truncateds)\n",
|
|
"\n",
|
|
" for i in range(n):\n",
|
|
" self.agent.replay_buffer.add(states[i], actions[i], rewards[i], next_states[i], dones[i])\n",
|
|
"\n",
|
|
" ep_rewards += rewards\n",
|
|
"\n",
|
|
" for i in range(n):\n",
|
|
" if dones[i]:\n",
|
|
" self.episode_rewards.append(ep_rewards[i])\n",
|
|
" ep_count += 1\n",
|
|
" ep_rewards[i] = 0\n",
|
|
"\n",
|
|
" step += n\n",
|
|
" states = next_states\n",
|
|
"\n",
|
|
" if step >= self.warmup_steps:\n",
|
|
" self.agent.train_step()\n",
|
|
"\n",
|
|
" if ep_count > 0 and ep_count % 20 == 0:\n",
|
|
" avg_r = np.mean(self.episode_rewards) if self.episode_rewards else 0\n",
|
|
" elapsed = time.time() - start_time\n",
|
|
" fps = step / elapsed\n",
|
|
" lr = self.agent.optimizer.param_groups[0][\"lr\"]\n",
|
|
" print(f\"Step:{step:>10,} | Ep:{ep_count:>5} | AvgR:{avg_r:>7.1f} | \"\n",
|
|
" f\"Eps:{self.agent.epsilon:.3f} | LR:{lr:.2e} | FPS:{fps:.0f}\")\n",
|
|
"\n",
|
|
" if step % self.eval_freq == 0 and step > 0:\n",
|
|
" eval_r = self.evaluate()\n",
|
|
" self.eval_rewards.append((step, eval_r))\n",
|
|
" print(f\"\\n[评估] Step:{step:>10,} | 平均回报:{eval_r:.1f}\\n\")\n",
|
|
" if eval_r > self.best_eval_reward:\n",
|
|
" self.best_eval_reward = eval_r\n",
|
|
" self.agent.save(f\"{self.save_dir}/dqn_best.pt\")\n",
|
|
"\n",
|
|
" if step % self.save_freq == 0:\n",
|
|
" self.agent.save(f\"{self.save_dir}/dqn_step_{step}.pt\")\n",
|
|
"\n",
|
|
" total_time = time.time() - start_time\n",
|
|
" print(\"\\n\" + \"=\" * 60)\n",
|
|
" print(f\"训练完成!总时间: {total_time:.1f} 秒 | FPS: {total_steps/total_time:.0f}\")\n",
|
|
" print(f\"最佳评估回报: {self.best_eval_reward:.1f}\")\n",
|
|
" self.agent.save(f\"{self.save_dir}/dqn_final.pt\")\n",
|
|
"\n",
|
|
"print(\"训练器就绪\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 配置参数\n",
|
|
"\n",
|
|
"根据 GPU 和预期训练时间调整以下参数。"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ── 可修改的超参数 ──\n",
|
|
"\n",
|
|
"ENV_ID = \"ALE/SpaceInvaders-v5\"\n",
|
|
"N_ENVS = 24 # 25 核 CPU,留 1 核给主进程\n",
|
|
"TOTAL_STEPS = 10_000_000 # 总步数\n",
|
|
"LR = 3e-5 # 学习率(大 batch 配低 lr 更稳定)\n",
|
|
"GAMMA = 0.99 # 折扣因子\n",
|
|
"BATCH_SIZE = 512 # RTX 5090 跑大 batch 才不浪费\n",
|
|
"BUFFER_SIZE = 1_000_000 # 回放缓冲区\n",
|
|
"EPSILON_START = 1.0\n",
|
|
"EPSILON_END = 0.01\n",
|
|
"EPSILON_DECAY = 4_000_000 # ε衰减步数(24 环境探索效率高,延长探索期)\n",
|
|
"TARGET_UPDATE = 2000\n",
|
|
"LR_DECAY_STEPS = 5_000_000\n",
|
|
"LR_DECAY_FACTOR = 0.5\n",
|
|
"WARMUP_STEPS = 50_000\n",
|
|
"EVAL_FREQ = 50000\n",
|
|
"EVAL_EPISODES = 10\n",
|
|
"SAVE_FREQ = 200000\n",
|
|
"SEED = 42\n",
|
|
"SAVE_DIR = \"models\"\n",
|
|
"\n",
|
|
"USE_DUELING = True\n",
|
|
"USE_DOUBLE = True\n",
|
|
"USE_PER = True # 优先经验回放\n",
|
|
"\n",
|
|
"print(f\"配置: {TOTAL_STEPS/1e6:.0f}M 步, {N_ENVS} 并行环境\")\n",
|
|
"print(f\"预计环境交互: {TOTAL_STEPS * 4 / 1e6:.0f}M frames\")\n",
|
|
"print(f\"预计时间 (AutoDL 5090): ~{TOTAL_STEPS / 1000 / N_ENVS / 3600:.1f} 小时\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"使用GPU: NVIDIA GeForce RTX 4060 Laptop GPU\n",
|
|
"SyncVectorEnv (Windows): 16 个环境\n",
|
|
"动作空间: 6\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# ── 环境设置 ──\n",
|
|
"torch.manual_seed(SEED)\n",
|
|
"np.random.seed(SEED)\n",
|
|
"import platform\n",
|
|
"\n",
|
|
"device = get_device()\n",
|
|
"\n",
|
|
"# Windows Jupyter 不支持 AsyncVectorEnv 子进程,用 SyncVectorEnv 替代\n",
|
|
"if platform.system() == \"Linux\":\n",
|
|
" from gymnasium.vector import AsyncVectorEnv\n",
|
|
" env_fns = [_make_env_fn(ENV_ID) for _ in range(N_ENVS)]\n",
|
|
" envs = AsyncVectorEnv(env_fns, shared_memory=True)\n",
|
|
" print(f\"AsyncVectorEnv: {envs.num_envs} 个环境\")\n",
|
|
"else:\n",
|
|
" from gymnasium.vector import SyncVectorEnv\n",
|
|
" env_fns = [_make_env_fn(ENV_ID) for _ in range(N_ENVS)]\n",
|
|
" envs = SyncVectorEnv(env_fns)\n",
|
|
" print(f\"SyncVectorEnv (Windows): {envs.num_envs} 个环境\")\n",
|
|
"\n",
|
|
"# 评估环境\n",
|
|
"eval_env = make_env(ENV_ID, gray_scale=True, resize=True, frame_stack=4)\n",
|
|
"\n",
|
|
"num_actions = envs.single_action_space.n\n",
|
|
"print(f\"动作空间: {num_actions}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Dueling DQN: 3,293,863 参数\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DuelingQNetwork(\n",
|
|
" (conv): Sequential(\n",
|
|
" (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))\n",
|
|
" (1): ReLU()\n",
|
|
" (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
|
|
" (3): ReLU()\n",
|
|
" (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
|
|
" (5): ReLU()\n",
|
|
" )\n",
|
|
" (value_stream): Sequential(\n",
|
|
" (0): Linear(in_features=3136, out_features=512, bias=True)\n",
|
|
" (1): ReLU()\n",
|
|
" (2): Linear(in_features=512, out_features=1, bias=True)\n",
|
|
" )\n",
|
|
" (advantage_stream): Sequential(\n",
|
|
" (0): Linear(in_features=3136, out_features=512, bias=True)\n",
|
|
" (1): ReLU()\n",
|
|
" (2): Linear(in_features=512, out_features=6, bias=True)\n",
|
|
" )\n",
|
|
")"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# ── 网络 ──\n",
|
|
"state_shape = (4, 84, 84)\n",
|
|
"\n",
|
|
"if USE_DUELING:\n",
|
|
" q_network = DuelingQNetwork(state_shape, num_actions).to(device)\n",
|
|
" target_network = DuelingQNetwork(state_shape, num_actions).to(device)\n",
|
|
" print(f\"Dueling DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\n",
|
|
"else:\n",
|
|
" q_network = QNetwork(state_shape, num_actions).to(device)\n",
|
|
" target_network = QNetwork(state_shape, num_actions).to(device)\n",
|
|
" print(f\"标准 DQN: {sum(p.numel() for p in q_network.parameters()):,} 参数\")\n",
|
|
"\n",
|
|
"target_network.load_state_dict(q_network.state_dict())\n",
|
|
"target_network.eval()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"优先经验回放\n",
|
|
"Agent 创建完成\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# ── 回放缓冲区 + Agent ──\n",
|
|
"if USE_PER:\n",
|
|
" replay_buffer = PrioritizedReplayBuffer(BUFFER_SIZE, state_shape, device)\n",
|
|
" print(\"优先经验回放\")\n",
|
|
"else:\n",
|
|
" replay_buffer = ReplayBuffer(BUFFER_SIZE, state_shape, device)\n",
|
|
" print(\"标准经验回放\")\n",
|
|
"\n",
|
|
"agent = DQNAgent(\n",
|
|
" q_network=q_network,\n",
|
|
" target_network=target_network,\n",
|
|
" replay_buffer=replay_buffer,\n",
|
|
" device=device,\n",
|
|
" num_actions=num_actions,\n",
|
|
" gamma=GAMMA,\n",
|
|
" lr=LR,\n",
|
|
" epsilon_start=EPSILON_START,\n",
|
|
" epsilon_end=EPSILON_END,\n",
|
|
" epsilon_decay_steps=EPSILON_DECAY,\n",
|
|
" target_update_freq=TARGET_UPDATE,\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" double_dqn=USE_DOUBLE,\n",
|
|
" lr_decay_steps=LR_DECAY_STEPS,\n",
|
|
" lr_decay_factor=LR_DECAY_FACTOR,\n",
|
|
" warmup_steps=WARMUP_STEPS,\n",
|
|
")\n",
|
|
"print(\"Agent 创建完成\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"============================================================\n",
|
|
"开始 10M 步并行训练\n",
|
|
" GPU: cuda\n",
|
|
" 并行环境: 16\n",
|
|
" Dueling: True\n",
|
|
" Double DQN: True\n",
|
|
" PER: True\n",
|
|
"============================================================\n",
|
|
"\n",
|
|
"开始训练: 10,000,000 步, 16 并行环境\n",
|
|
"============================================================\n",
|
|
"Step: 3,232 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:651\n",
|
|
"Step: 3,248 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
|
|
"Step: 3,264 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
|
|
"Step: 3,280 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:649\n",
|
|
"Step: 3,296 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
|
|
"Step: 3,312 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
|
|
"Step: 3,328 | Ep: 20 | AvgR: 12.4 | Eps:1.000 | LR:5.00e-05 | FPS:648\n",
|
|
"Step: 5,776 | Ep: 40 | AvgR: 12.3 | Eps:1.000 | LR:5.00e-05 | FPS:596\n",
|
|
"Step: 5,792 | Ep: 40 | AvgR: 12.3 | Eps:1.000 | LR:5.00e-05 | FPS:596\n",
|
|
"Step: 8,144 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
|
"Step: 8,160 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 8,176 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 8,192 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 8,208 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
|
"Step: 8,224 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
|
"Step: 8,240 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
|
"Step: 8,256 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
|
"Step: 8,272 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
|
"Step: 8,288 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
|
"Step: 8,304 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:576\n",
|
|
"Step: 8,320 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 8,336 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 8,352 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 8,368 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 8,384 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 8,400 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 8,416 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 8,432 | Ep: 60 | AvgR: 13.5 | Eps:1.000 | LR:5.00e-05 | FPS:575\n",
|
|
"Step: 10,912 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:2.90e-07 | FPS:520\n",
|
|
"Step: 10,928 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:2.95e-07 | FPS:520\n",
|
|
"Step: 10,944 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:3.00e-07 | FPS:519\n",
|
|
"Step: 10,960 | Ep: 80 | AvgR: 13.2 | Eps:1.000 | LR:3.05e-07 | FPS:519\n",
|
|
"Step: 13,280 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.03e-06 | FPS:481\n",
|
|
"Step: 13,296 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
|
|
"Step: 13,312 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
|
|
"Step: 13,328 | Ep: 100 | AvgR: 13.0 | Eps:1.000 | LR:1.04e-06 | FPS:481\n",
|
|
"Step: 15,648 | Ep: 120 | AvgR: 12.8 | Eps:1.000 | LR:1.77e-06 | FPS:454\n",
|
|
"Step: 21,184 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.50e-06 | FPS:434\n",
|
|
"Step: 21,200 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.50e-06 | FPS:434\n",
|
|
"Step: 21,216 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.51e-06 | FPS:434\n",
|
|
"Step: 21,232 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
|
|
"Step: 21,248 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
|
|
"Step: 21,264 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.52e-06 | FPS:434\n",
|
|
"Step: 21,280 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.53e-06 | FPS:434\n",
|
|
"Step: 21,296 | Ep: 160 | AvgR: 13.9 | Eps:1.000 | LR:3.54e-06 | FPS:434\n",
|
|
"Step: 23,824 | Ep: 180 | AvgR: 14.4 | Eps:1.000 | LR:4.33e-06 | FPS:432\n",
|
|
"Step: 26,144 | Ep: 200 | AvgR: 14.3 | Eps:1.000 | LR:5.05e-06 | FPS:422\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# ── 开始训练 ──\n",
|
|
"trainer = ParallelTrainer(\n",
|
|
" agent=agent,\n",
|
|
" envs=envs,\n",
|
|
" eval_env=eval_env,\n",
|
|
" num_envs=N_ENVS,\n",
|
|
" save_dir=SAVE_DIR,\n",
|
|
" eval_freq=EVAL_FREQ,\n",
|
|
" save_freq=SAVE_FREQ,\n",
|
|
" num_eval_episodes=EVAL_EPISODES,\n",
|
|
" warmup_steps=WARMUP_STEPS,\n",
|
|
")\n",
|
|
"\n",
|
|
"print(\"\\n\" + \"=\" * 60)\n",
|
|
"print(f\"开始 10M 步并行训练\")\n",
|
|
"print(f\" GPU: {device}\")\n",
|
|
"print(f\" 并行环境: {N_ENVS}\")\n",
|
|
"print(f\" Dueling: {USE_DUELING}\")\n",
|
|
"print(f\" Double DQN: {USE_DOUBLE}\")\n",
|
|
"print(f\" PER: {USE_PER}\")\n",
|
|
"print(\"=\" * 60 + \"\\n\")\n",
|
|
"\n",
|
|
"trainer.train(TOTAL_STEPS)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 训练完成后:评估最佳模型"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"加载最佳模型...\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"d:\\Code\\doing_exercises\\programs\\外教作业外快\\强化学习个人项目报告(Atari 游戏方向)\\src\\agent.py:219: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
|
" checkpoint = torch.load(path, map_location=self.device)\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "FileNotFoundError",
|
|
"evalue": "[Errno 2] No such file or directory: 'models/dqn_best.pt'",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[1;32mIn[9], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# ── 评估最佳模型 ──\u001b[39;00m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m加载最佳模型...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 3\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mSAVE_DIR\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m/dqn_best.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m评估中...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 6\u001b[0m all_rewards \u001b[38;5;241m=\u001b[39m []\n",
|
|
"File \u001b[1;32md:\\Code\\doing_exercises\\programs\\外教作业外快\\强化学习个人项目报告(Atari 游戏方向)\\src\\agent.py:219\u001b[0m, in \u001b[0;36mDQNAgent.load\u001b[1;34m(self, path)\u001b[0m\n\u001b[0;32m 217\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mload\u001b[39m(\u001b[38;5;28mself\u001b[39m, path):\n\u001b[0;32m 218\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"加载模型\"\"\"\u001b[39;00m\n\u001b[1;32m--> 219\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 220\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mq_network\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mq_network\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[0;32m 221\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_network\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtarget_network\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
|
|
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:1065\u001b[0m, in \u001b[0;36mload\u001b[1;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[0;32m 1062\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m pickle_load_args\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m 1063\u001b[0m pickle_load_args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m-> 1065\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_file_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_file:\n\u001b[0;32m 1066\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[0;32m 1067\u001b[0m \u001b[38;5;66;03m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[0;32m 1068\u001b[0m \u001b[38;5;66;03m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[0;32m 1069\u001b[0m \u001b[38;5;66;03m# reset back to the original position.\u001b[39;00m\n\u001b[0;32m 1070\u001b[0m orig_position \u001b[38;5;241m=\u001b[39m opened_file\u001b[38;5;241m.\u001b[39mtell()\n",
|
|
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:468\u001b[0m, in \u001b[0;36m_open_file_like\u001b[1;34m(name_or_buffer, mode)\u001b[0m\n\u001b[0;32m 466\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[0;32m 467\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[1;32m--> 468\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_open_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 469\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 470\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m mode:\n",
|
|
"File \u001b[1;32md:\\ProgramData\\anaconda3\\envs\\my_env\\lib\\site-packages\\torch\\serialization.py:449\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[1;34m(self, name, mode)\u001b[0m\n\u001b[0;32m 448\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, mode):\n\u001b[1;32m--> 449\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m)\n",
|
|
"\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'models/dqn_best.pt'"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# ── 评估最佳模型 ──\n",
|
|
"print(\"加载最佳模型...\")\n",
|
|
"agent.load(f\"{SAVE_DIR}/dqn_best.pt\")\n",
|
|
"\n",
|
|
"print(\"\\n评估中...\")\n",
|
|
"all_rewards = []\n",
|
|
"for i in range(20):\n",
|
|
" state, _ = eval_env.reset()\n",
|
|
" ep_r = 0\n",
|
|
" done = False\n",
|
|
" while not done:\n",
|
|
" action = agent.select_action(state, evaluate=True)\n",
|
|
" state, reward, terminated, truncated, _ = eval_env.step(action)\n",
|
|
" done = terminated or truncated\n",
|
|
" ep_r += reward\n",
|
|
" all_rewards.append(ep_r)\n",
|
|
" print(f\" Episode {i+1:>2}: {ep_r:.1f}\")\n",
|
|
"\n",
|
|
"print(f\"\\n结果: 平均 {np.mean(all_rewards):.2f} ± {np.std(all_rewards):.2f}\")\n",
|
|
"print(f\"最佳: {max(all_rewards):.1f} | 最差: {min(all_rewards):.1f}\")\n",
|
|
"print(f\"中位数: {np.median(all_rewards):.1f}\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "my_env",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.20"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|