Files
rl-atari/CW1_id_name/notebooks/03_evaluate.ipynb
T
Serendipity fb09e66d09 feat: 重构项目结构并添加向量化PPO训练与评估脚本
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率
- 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计
- 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py)
- 提供详细的文档和开发日志,包含问题解决记录和实验分析
- 移除旧版项目文件,统一项目结构到CW1_id_name目录下
2026-05-02 13:44:08 +08:00

119 lines
4.0 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 05 - Evaluate the trained PPO agent\n",
"\n",
"This notebook is a thin wrapper around `src/eval_utils.py`.\n",
"All real logic lives in `src/eval_utils.py` and `evaluate.py` so that\n",
"the same code runs from the command line and from Jupyter.\n",
"\n",
"Steps:\n",
"1. Load the trained checkpoint\n",
"2. Roll 20 unseen-seed episodes -> mean / std / per-ep returns\n",
"3. Plot evaluation bar chart (saved to `docs/fig_eval_bar.png`)\n",
"4. Plot multi-run training curves (saved to `docs/fig_training_curves.png`)\n",
"5. Optionally record a demo video (saved to `docs/demo.mp4`)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"project_root = Path.cwd().parent\n",
"if str(project_root) not in sys.path:\n",
" sys.path.insert(0, str(project_root))\n",
"\n",
"import numpy as np\n",
"from src.eval_utils import (\n",
" evaluate_agent,\n",
" plot_eval_bar,\n",
" plot_training_curves,\n",
" record_demo_video,\n",
")\n",
"from src.ppo_agent import PPOAgent\n",
"\n",
"print('Project root:', project_root)\n",
"print('Available checkpoints:')\n",
"for d in sorted((project_root / 'models').iterdir()):\n",
" print(' ', d.name)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# 1) Load the submitted model checkpoint\nCKPT = project_root / 'models' / 'ppo_final.pt'\nassert CKPT.exists(), f'Not found: {CKPT}'\n\nagent = PPOAgent(n_actions=5)\nagent.load(str(CKPT))\nagent.net.eval()\nprint(f'Loaded {CKPT}')\nprint(f'Device: {agent.device}')\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 2) Numerical evaluation: 20 unseen seeds\n",
"returns = evaluate_agent(agent, n_episodes=20, seed_start=1000)\n",
"for i, r in enumerate(returns):\n",
" print(f' ep {i:>2d}: return = {r:7.2f}')\n",
"\n",
"mean_r = float(np.mean(returns))\n",
"std_r = float(np.std(returns))\n",
"print(f'\\nMean: {mean_r:.2f} Std: {std_r:.2f}')\n",
"print(f'Min : {min(returns):.2f} Max: {max(returns):.2f}')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 3) Bar chart\n",
"out = plot_eval_bar(\n",
" returns,\n",
" baseline=-54.19,\n",
" save_path=project_root / 'docs' / 'fig_eval_bar.png',\n",
")\n",
"print(f'Saved {out}')\n",
"\n",
"from IPython.display import Image, display\n",
"display(Image(str(out)))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# 4) Training curves: vec_main_v3 (our run) overlaid with sb3_baseline (reference)\nruns_dir = project_root / 'runs'\nrun_dirs = []\nlabels = []\nfor name in ['vec_main_v3', 'sb3_baseline']:\n d = runs_dir / name\n if d.exists():\n run_dirs.append(d)\n labels.append(name)\n\nout = plot_training_curves(\n run_dirs, labels,\n save_path=project_root / 'docs' / 'fig_training_curves.png',\n)\nprint(f'Saved {out}')\ndisplay(Image(str(out)))\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# 5) Optional: record one demo video using the cleanest seed\nn_frames, video_path = record_demo_video(\n agent,\n out_path=project_root / 'docs' / 'demo.mp4',\n seed=117, # an unseen seed where the agent achieves ~925 with early completion\n)\nprint(f'Saved {video_path} with {n_frames} frames')\n"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 5
}