Files
Serendipity fb09e66d09 feat: 重构项目结构并添加向量化PPO训练与评估脚本
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率
- 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计
- 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py)
- 提供详细的文档和开发日志,包含问题解决记录和实验分析
- 移除旧版项目文件,统一项目结构到CW1_id_name目录下
2026-05-02 13:44:08 +08:00

148 lines
4.7 KiB
Python

"""Evaluate a trained PPO checkpoint.
Usage:
python evaluate.py --ckpt models/vec_main/final.pt
python evaluate.py --ckpt models/vec_main/final.pt --episodes 50 --video
Outputs go to docs/:
fig_eval_bar.png bar chart of per-episode returns
fig_training_curves.png 6-panel training curves (vec_main only)
demo.mp4 one demo episode (only if --video)
"""
import argparse
import json
from pathlib import Path
import numpy as np
from src.eval_utils import (
evaluate_agent,
plot_eval_bar,
plot_training_curves,
record_demo_video,
)
from src.ppo_agent import PPOAgent
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", type=str, default="models/vec_main/final.pt")
p.add_argument("--episodes", type=int, default=20)
p.add_argument("--seed-start", type=int, default=1000)
p.add_argument("--video", action="store_true",
help="Record one demo mp4 to docs/demo.mp4")
p.add_argument("--video-seed", type=int, default=42)
p.add_argument("--deterministic", action="store_true",
help="Use argmax action instead of sampling")
p.add_argument("--out-dir", type=str, default="docs",
help="Where to save plots / video / json summary")
p.add_argument("--baseline", type=float, default=-54.19,
help="Random-policy baseline mean for the comparison line")
p.add_argument("--main-run", type=str, default="vec_main",
help="TensorBoard run-name to plot in the curves figure")
p.add_argument("--baseline-run", type=str, default="main_v1_baseline",
help="Optional second run-name to overlay (or empty)")
return p.parse_args()
def main():
args = parse_args()
project_root = Path(__file__).resolve().parent
ckpt_path = (project_root / args.ckpt).resolve()
out_dir = (project_root / args.out_dir).resolve()
out_dir.mkdir(parents=True, exist_ok=True)
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
print("=" * 60)
print(f"Checkpoint: {ckpt_path}")
print(f"Episodes: {args.episodes}")
print(f"Out dir: {out_dir}")
print("=" * 60)
# Load agent
agent = PPOAgent(n_actions=5)
agent.load(str(ckpt_path))
agent.net.eval()
# 1) Numerical evaluation
returns = evaluate_agent(
agent,
n_episodes=args.episodes,
seed_start=args.seed_start,
deterministic=args.deterministic,
)
mean_r = float(np.mean(returns))
std_r = float(np.std(returns))
min_r = float(np.min(returns))
max_r = float(np.max(returns))
print("\nPer-episode returns:")
for i, r in enumerate(returns):
print(f" ep {i:>2d} (seed={args.seed_start + i}): {r:7.2f}")
print(f"\n=== Summary over {args.episodes} unseen seeds ===")
print(f" Mean: {mean_r:.2f}")
print(f" Std : {std_r:.2f}")
print(f" Min : {min_r:.2f}")
print(f" Max : {max_r:.2f}")
# Save JSON summary
summary = {
"checkpoint": str(ckpt_path),
"n_episodes": args.episodes,
"seed_start": args.seed_start,
"deterministic": args.deterministic,
"mean": mean_r,
"std": std_r,
"min": min_r,
"max": max_r,
"returns": returns,
}
summary_path = out_dir / "eval_summary.json"
with open(summary_path, "w") as f:
json.dump(summary, f, indent=2)
print(f"\nSaved {summary_path}")
# 2) Bar chart
bar_path = plot_eval_bar(
returns,
baseline=args.baseline,
save_path=out_dir / "fig_eval_bar.png",
title=f"PPO evaluation returns over {args.episodes} unseen seeds",
)
print(f"Saved {bar_path}")
# 3) Training curves
runs_root = project_root / "runs"
main_run_dir = runs_root / args.main_run
if main_run_dir.exists():
run_dirs = [main_run_dir]
labels = [args.main_run]
if args.baseline_run:
baseline_run_dir = runs_root / args.baseline_run
if baseline_run_dir.exists():
run_dirs.append(baseline_run_dir)
labels.append(args.baseline_run)
curves_path = plot_training_curves(
run_dirs, labels, save_path=out_dir / "fig_training_curves.png"
)
print(f"Saved {curves_path}")
else:
print(f"Skipping training curves: {main_run_dir} not found")
# 4) Optional demo video
if args.video:
n_frames, video_path = record_demo_video(
agent,
out_path=out_dir / "demo.mp4",
seed=args.video_seed,
)
print(f"Saved {video_path} ({n_frames} frames)")
if __name__ == "__main__":
main()