feat: 重构项目结构并添加向量化PPO训练与评估脚本

- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率
- 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计
- 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py)
- 提供详细的文档和开发日志,包含问题解决记录和实验分析
- 移除旧版项目文件,统一项目结构到CW1_id_name目录下
This commit is contained in:
2026-05-02 13:44:08 +08:00
parent 79ffb90823
commit fb09e66d09
80 changed files with 2971 additions and 4822 deletions
File diff suppressed because one or more lines are too long
+238
View File
@@ -0,0 +1,238 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "6df75d72-c1c1-40e0-a7f8-ef3da32e4592",
"metadata": {},
"source": [
"# 02 — Sanity-checking the Actor-Critic network\n",
"\n",
"### Verify that:\n",
"### - the network accepts uint8 (4, 84, 84) input\n",
"### - it runs on GPU\n",
"### - forward pass returns the expected shapes\n",
"### - get_action_and_value works for both sampling and scoring"
]
},
{
"cell_type": "markdown",
"id": "9c6c1d35-f17c-4fca-9cfb-b5b001b7a0c8",
"metadata": {},
"source": [
"## Cell 1 test env"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "9e09b2e5-c076-4599-8e98-1cb09c0a7cf5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device: cuda\n",
"ActorCritic(\n",
" (cnn): Sequential(\n",
" (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))\n",
" (1): ReLU()\n",
" (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
" (3): ReLU()\n",
" (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
" (5): ReLU()\n",
" (6): Flatten(start_dim=1, end_dim=-1)\n",
" (7): Linear(in_features=3136, out_features=512, bias=True)\n",
" (8): ReLU()\n",
" )\n",
" (actor): Linear(in_features=512, out_features=5, bias=True)\n",
" (critic): Linear(in_features=512, out_features=1, bias=True)\n",
")\n",
"\n",
"Total parameters: 1,687,206\n"
]
}
],
"source": [
"import sys\n",
"from pathlib import Path\n",
"project_root = Path.cwd().parent\n",
"if str(project_root) not in sys.path:\n",
" sys.path.insert(0, str(project_root))\n",
"\n",
"import torch\n",
"from src.networks import ActorCritic\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(\"Device:\", device)\n",
"\n",
"net = ActorCritic(n_actions=5).to(device)\n",
"print(net)\n",
"\n",
"# Count parameters\n",
"total_params = sum(p.numel() for p in net.parameters())\n",
"print(f\"\\nTotal parameters: {total_params:,}\")"
]
},
{
"cell_type": "markdown",
"id": "8334668c-8f1c-4460-9e1b-9cc2c8c938a1",
"metadata": {},
"source": [
"## Cell 2 test forward"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "affedbd5-d08b-441b-8cae-b46057be5c63",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape : torch.Size([8, 4, 84, 84]) torch.uint8\n",
"Logits shape: torch.Size([8, 5]) torch.float32\n",
"Value shape : torch.Size([8]) torch.float32\n",
"Sample logits: [-0.0013231671182438731, 0.0014129895716905594, 0.0010137694189324975, 0.0005002821562811732, -0.0012777929659932852]\n",
"Sample value : 0.9111840724945068\n"
]
}
],
"source": [
"# Fake batch of 8 observations, shape (8, 4, 84, 84) uint8\n",
"fake_obs = torch.randint(0, 255, (8, 4, 84, 84), dtype=torch.uint8, device=device)\n",
"\n",
"logits, value = net(fake_obs)\n",
"print(\"Input shape :\", fake_obs.shape, fake_obs.dtype)\n",
"print(\"Logits shape:\", logits.shape, logits.dtype)\n",
"print(\"Value shape :\", value.shape, value.dtype)\n",
"print(\"Sample logits:\", logits[0].detach().cpu().tolist())\n",
"print(\"Sample value :\", value[0].item())"
]
},
{
"cell_type": "markdown",
"id": "6eb6ee0d-adc7-4ac5-953b-91018599dd7f",
"metadata": {},
"source": [
"## Cell 3 test get_action_and_value"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "e5ad46a8-7f62-442a-96a2-2d2c3ef91d59",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mode 1 (sample):\n",
" action : torch.Size([8]) torch.int64, sample = [4, 2, 3]\n",
" log_prob: torch.Size([8]), sample = [-1.6107815504074097, -1.60985267162323, -1.6094183921813965]\n",
" entropy : torch.Size([8]), sample = [1.6094372272491455, 1.609436273574829, 1.6094352006912231]\n",
" value : torch.Size([8]), sample = [0.9111840724945068, 0.8728611469268799, 0.9081785678863525]\n",
"\n",
"Mode 2 (score given action):\n",
" log_prob shape: torch.Size([8])\n",
" entropy shape : torch.Size([8])\n",
"\n",
"Reference: ln(5) = 1.6094\n",
"Mean entropy at init: 1.6094\n"
]
}
],
"source": [
"# Mode 1: sample action\n",
"action, log_prob, entropy, value = net.get_action_and_value(fake_obs)\n",
"print(\"Mode 1 (sample):\")\n",
"print(f\" action : {action.shape} {action.dtype}, sample = {action[:3].tolist()}\")\n",
"print(f\" log_prob: {log_prob.shape}, sample = {log_prob[:3].detach().cpu().tolist()}\")\n",
"print(f\" entropy : {entropy.shape}, sample = {entropy[:3].detach().cpu().tolist()}\")\n",
"print(f\" value : {value.shape}, sample = {value[:3].detach().cpu().tolist()}\")\n",
"\n",
"# Mode 2: score given action (this is what PPO update uses)\n",
"provided = torch.tensor([0, 3, 2, 1, 4, 0, 3, 2], device=device)\n",
"_, log_prob2, entropy2, value2 = net.get_action_and_value(fake_obs, provided)\n",
"print(\"\\nMode 2 (score given action):\")\n",
"print(f\" log_prob shape: {log_prob2.shape}\")\n",
"print(f\" entropy shape : {entropy2.shape}\")\n",
"\n",
"# Sanity: entropy of a uniform 5-action distribution should be ln(5) ≈ 1.6094\n",
"import math\n",
"print(f\"\\nReference: ln(5) = {math.log(5):.4f}\")\n",
"print(f\"Mean entropy at init: {entropy.mean().item():.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "d4c00e27-2a64-46f9-99d0-aff04fe2e714",
"metadata": {},
"source": [
"## Cell 4 : Run it once with the obs of the real env"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "c362b84b-1b74-4f55-b7e9-d842414c4a9f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"obs_t shape: torch.Size([1, 4, 84, 84]) torch.uint8\n",
"Sampled action: 0\n",
"Log prob: -1.6102\n",
"Entropy: 1.6094\n",
"Value estimate: 0.3286\n"
]
}
],
"source": [
"from src.env_wrappers import make_env\n",
"import numpy as np\n",
"\n",
"env = make_env(seed=42)\n",
"obs, _ = env.reset(seed=42)\n",
"\n",
"# obs is a numpy uint8 array (4, 84, 84). Add a batch dim and move to device.\n",
"obs_t = torch.as_tensor(obs).unsqueeze(0).to(device)\n",
"print(\"obs_t shape:\", obs_t.shape, obs_t.dtype)\n",
"\n",
"action, log_prob, entropy, value = net.get_action_and_value(obs_t)\n",
"print(f\"Sampled action: {action.item()}\")\n",
"print(f\"Log prob: {log_prob.item():.4f}\")\n",
"print(f\"Entropy: {entropy.item():.4f}\")\n",
"print(f\"Value estimate: {value.item():.4f}\")\n",
"\n",
"env.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "PYTORCH",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+119
View File
@@ -0,0 +1,119 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 05 - Evaluate the trained PPO agent\n",
"\n",
"This notebook is a thin wrapper around `src/eval_utils.py`.\n",
"All real logic lives in `src/eval_utils.py` and `evaluate.py` so that\n",
"the same code runs from the command line and from Jupyter.\n",
"\n",
"Steps:\n",
"1. Load the trained checkpoint\n",
"2. Roll 20 unseen-seed episodes -> mean / std / per-ep returns\n",
"3. Plot evaluation bar chart (saved to `docs/fig_eval_bar.png`)\n",
"4. Plot multi-run training curves (saved to `docs/fig_training_curves.png`)\n",
"5. Optionally record a demo video (saved to `docs/demo.mp4`)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"project_root = Path.cwd().parent\n",
"if str(project_root) not in sys.path:\n",
" sys.path.insert(0, str(project_root))\n",
"\n",
"import numpy as np\n",
"from src.eval_utils import (\n",
" evaluate_agent,\n",
" plot_eval_bar,\n",
" plot_training_curves,\n",
" record_demo_video,\n",
")\n",
"from src.ppo_agent import PPOAgent\n",
"\n",
"print('Project root:', project_root)\n",
"print('Available checkpoints:')\n",
"for d in sorted((project_root / 'models').iterdir()):\n",
" print(' ', d.name)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# 1) Load the submitted model checkpoint\nCKPT = project_root / 'models' / 'ppo_final.pt'\nassert CKPT.exists(), f'Not found: {CKPT}'\n\nagent = PPOAgent(n_actions=5)\nagent.load(str(CKPT))\nagent.net.eval()\nprint(f'Loaded {CKPT}')\nprint(f'Device: {agent.device}')\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 2) Numerical evaluation: 20 unseen seeds\n",
"returns = evaluate_agent(agent, n_episodes=20, seed_start=1000)\n",
"for i, r in enumerate(returns):\n",
" print(f' ep {i:>2d}: return = {r:7.2f}')\n",
"\n",
"mean_r = float(np.mean(returns))\n",
"std_r = float(np.std(returns))\n",
"print(f'\\nMean: {mean_r:.2f} Std: {std_r:.2f}')\n",
"print(f'Min : {min(returns):.2f} Max: {max(returns):.2f}')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 3) Bar chart\n",
"out = plot_eval_bar(\n",
" returns,\n",
" baseline=-54.19,\n",
" save_path=project_root / 'docs' / 'fig_eval_bar.png',\n",
")\n",
"print(f'Saved {out}')\n",
"\n",
"from IPython.display import Image, display\n",
"display(Image(str(out)))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# 4) Training curves: vec_main_v3 (our run) overlaid with sb3_baseline (reference)\nruns_dir = project_root / 'runs'\nrun_dirs = []\nlabels = []\nfor name in ['vec_main_v3', 'sb3_baseline']:\n d = runs_dir / name\n if d.exists():\n run_dirs.append(d)\n labels.append(name)\n\nout = plot_training_curves(\n run_dirs, labels,\n save_path=project_root / 'docs' / 'fig_training_curves.png',\n)\nprint(f'Saved {out}')\ndisplay(Image(str(out)))\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# 5) Optional: record one demo video using the cleanest seed\nn_frames, video_path = record_demo_video(\n agent,\n out_path=project_root / 'docs' / 'demo.mp4',\n seed=117, # an unseen seed where the agent achieves ~925 with early completion\n)\nprint(f'Saved {video_path} with {n_frames} frames')\n"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 5
}