fb09e66d09
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率 - 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计 - 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py) - 提供详细的文档和开发日志,包含问题解决记录和实验分析 - 移除旧版项目文件,统一项目结构到CW1_id_name目录下
148 lines
4.7 KiB
Python
148 lines
4.7 KiB
Python
"""Evaluate a trained PPO checkpoint.
|
|
|
|
Usage:
|
|
python evaluate.py --ckpt models/vec_main/final.pt
|
|
python evaluate.py --ckpt models/vec_main/final.pt --episodes 50 --video
|
|
|
|
Outputs go to docs/:
|
|
fig_eval_bar.png bar chart of per-episode returns
|
|
fig_training_curves.png 6-panel training curves (vec_main only)
|
|
demo.mp4 one demo episode (only if --video)
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
from src.eval_utils import (
|
|
evaluate_agent,
|
|
plot_eval_bar,
|
|
plot_training_curves,
|
|
record_demo_video,
|
|
)
|
|
from src.ppo_agent import PPOAgent
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--ckpt", type=str, default="models/vec_main/final.pt")
|
|
p.add_argument("--episodes", type=int, default=20)
|
|
p.add_argument("--seed-start", type=int, default=1000)
|
|
p.add_argument("--video", action="store_true",
|
|
help="Record one demo mp4 to docs/demo.mp4")
|
|
p.add_argument("--video-seed", type=int, default=42)
|
|
p.add_argument("--deterministic", action="store_true",
|
|
help="Use argmax action instead of sampling")
|
|
p.add_argument("--out-dir", type=str, default="docs",
|
|
help="Where to save plots / video / json summary")
|
|
p.add_argument("--baseline", type=float, default=-54.19,
|
|
help="Random-policy baseline mean for the comparison line")
|
|
p.add_argument("--main-run", type=str, default="vec_main",
|
|
help="TensorBoard run-name to plot in the curves figure")
|
|
p.add_argument("--baseline-run", type=str, default="main_v1_baseline",
|
|
help="Optional second run-name to overlay (or empty)")
|
|
return p.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
project_root = Path(__file__).resolve().parent
|
|
ckpt_path = (project_root / args.ckpt).resolve()
|
|
out_dir = (project_root / args.out_dir).resolve()
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not ckpt_path.exists():
|
|
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
|
|
|
print("=" * 60)
|
|
print(f"Checkpoint: {ckpt_path}")
|
|
print(f"Episodes: {args.episodes}")
|
|
print(f"Out dir: {out_dir}")
|
|
print("=" * 60)
|
|
|
|
# Load agent
|
|
agent = PPOAgent(n_actions=5)
|
|
agent.load(str(ckpt_path))
|
|
agent.net.eval()
|
|
|
|
# 1) Numerical evaluation
|
|
returns = evaluate_agent(
|
|
agent,
|
|
n_episodes=args.episodes,
|
|
seed_start=args.seed_start,
|
|
deterministic=args.deterministic,
|
|
)
|
|
mean_r = float(np.mean(returns))
|
|
std_r = float(np.std(returns))
|
|
min_r = float(np.min(returns))
|
|
max_r = float(np.max(returns))
|
|
|
|
print("\nPer-episode returns:")
|
|
for i, r in enumerate(returns):
|
|
print(f" ep {i:>2d} (seed={args.seed_start + i}): {r:7.2f}")
|
|
|
|
print(f"\n=== Summary over {args.episodes} unseen seeds ===")
|
|
print(f" Mean: {mean_r:.2f}")
|
|
print(f" Std : {std_r:.2f}")
|
|
print(f" Min : {min_r:.2f}")
|
|
print(f" Max : {max_r:.2f}")
|
|
|
|
# Save JSON summary
|
|
summary = {
|
|
"checkpoint": str(ckpt_path),
|
|
"n_episodes": args.episodes,
|
|
"seed_start": args.seed_start,
|
|
"deterministic": args.deterministic,
|
|
"mean": mean_r,
|
|
"std": std_r,
|
|
"min": min_r,
|
|
"max": max_r,
|
|
"returns": returns,
|
|
}
|
|
summary_path = out_dir / "eval_summary.json"
|
|
with open(summary_path, "w") as f:
|
|
json.dump(summary, f, indent=2)
|
|
print(f"\nSaved {summary_path}")
|
|
|
|
# 2) Bar chart
|
|
bar_path = plot_eval_bar(
|
|
returns,
|
|
baseline=args.baseline,
|
|
save_path=out_dir / "fig_eval_bar.png",
|
|
title=f"PPO evaluation returns over {args.episodes} unseen seeds",
|
|
)
|
|
print(f"Saved {bar_path}")
|
|
|
|
# 3) Training curves
|
|
runs_root = project_root / "runs"
|
|
main_run_dir = runs_root / args.main_run
|
|
if main_run_dir.exists():
|
|
run_dirs = [main_run_dir]
|
|
labels = [args.main_run]
|
|
if args.baseline_run:
|
|
baseline_run_dir = runs_root / args.baseline_run
|
|
if baseline_run_dir.exists():
|
|
run_dirs.append(baseline_run_dir)
|
|
labels.append(args.baseline_run)
|
|
curves_path = plot_training_curves(
|
|
run_dirs, labels, save_path=out_dir / "fig_training_curves.png"
|
|
)
|
|
print(f"Saved {curves_path}")
|
|
else:
|
|
print(f"Skipping training curves: {main_run_dir} not found")
|
|
|
|
# 4) Optional demo video
|
|
if args.video:
|
|
n_frames, video_path = record_demo_video(
|
|
agent,
|
|
out_path=out_dir / "demo.mp4",
|
|
seed=args.video_seed,
|
|
)
|
|
print(f"Saved {video_path} ({n_frames} frames)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|