Files
Serendipity fb09e66d09 feat: 重构项目结构并添加向量化PPO训练与评估脚本
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率
- 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计
- 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py)
- 提供详细的文档和开发日志,包含问题解决记录和实验分析
- 移除旧版项目文件,统一项目结构到CW1_id_name目录下
2026-05-02 13:44:08 +08:00

164 lines
5.7 KiB
Python

"""Scan all checkpoints in a run directory and evaluate each one.
Usage:
python scan_checkpoints.py --run-name vec_main --episodes 10
For each .pt file, evaluates with both stochastic and deterministic
policies and prints a comparison table. Helps identify the best
checkpoint to submit when the final one over-fits / over-anneals.
"""
import argparse
import json
from pathlib import Path
import numpy as np
from src.eval_utils import evaluate_agent
from src.ppo_agent import PPOAgent
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--run-name", type=str, default="vec_main")
p.add_argument("--episodes", type=int, default=10,
help="Episodes per checkpoint per mode (10 is enough for ranking)")
p.add_argument("--seed-start", type=int, default=2000,
help="Use seeds different from final-evaluation 1000-1019")
p.add_argument("--out-dir", type=str, default="docs")
p.add_argument("--from-iter", type=int, default=0,
help="Skip checkpoints whose iter number is below this")
p.add_argument("--every-k", type=int, default=1,
help="Subsample: only evaluate every k-th checkpoint")
p.add_argument("--mode", type=str, default="both",
choices=["both", "stochastic", "deterministic"])
return p.parse_args()
def main():
args = parse_args()
project_root = Path(__file__).resolve().parent
ckpt_dir = project_root / "models" / args.run_name
out_dir = project_root / args.out_dir
out_dir.mkdir(parents=True, exist_ok=True)
if not ckpt_dir.exists():
raise FileNotFoundError(f"No such run dir: {ckpt_dir}")
# Collect checkpoints (sorted by iter number, final last)
all_ckpts = sorted(
ckpt_dir.glob("*.pt"),
key=lambda p: (
0 if p.stem == "final" else 1,
p.stem,
),
)
if not all_ckpts:
raise FileNotFoundError(f"No .pt files in {ckpt_dir}")
# Apply --from-iter and --every-k filters
ckpts = []
iter_pts = [p for p in all_ckpts if p.stem.startswith("iter_")]
final_pt = [p for p in all_ckpts if p.stem == "final"]
for i, p in enumerate(iter_pts):
try:
iter_num = int(p.stem.replace("iter_", ""))
except ValueError:
continue
if iter_num < args.from_iter:
continue
if (i % args.every_k) != 0:
continue
ckpts.append(p)
ckpts.extend(final_pt)
if not ckpts:
raise RuntimeError("No checkpoints survived filtering")
print("=" * 80)
print(f"Scanning {len(ckpts)} checkpoints in {ckpt_dir}")
print(f"Episodes per ckpt per mode: {args.episodes}")
print(f"Seeds: {args.seed_start} to {args.seed_start + args.episodes - 1}")
print("=" * 80)
results = []
do_sto = args.mode in ("both", "stochastic")
do_det = args.mode in ("both", "deterministic")
for ckpt in ckpts:
agent = PPOAgent(n_actions=5)
agent.load(str(ckpt))
agent.net.eval()
sto_returns = []
det_returns = []
if do_sto:
sto_returns = evaluate_agent(
agent, n_episodes=args.episodes,
seed_start=args.seed_start, deterministic=False,
)
if do_det:
det_returns = evaluate_agent(
agent, n_episodes=args.episodes,
seed_start=args.seed_start, deterministic=True,
)
sto_mean = float(np.mean(sto_returns)) if sto_returns else float("nan")
sto_std = float(np.std(sto_returns)) if sto_returns else float("nan")
sto_min = float(np.min(sto_returns)) if sto_returns else float("nan")
det_mean = float(np.mean(det_returns)) if det_returns else float("nan")
det_std = float(np.std(det_returns)) if det_returns else float("nan")
det_min = float(np.min(det_returns)) if det_returns else float("nan")
print(
f"{ckpt.stem:>14s} | "
f"sto: {sto_mean:7.1f} +/- {sto_std:6.1f} (min {sto_min:6.1f}) | "
f"det: {det_mean:7.1f} +/- {det_std:6.1f} (min {det_min:6.1f})"
)
results.append({
"ckpt": ckpt.name,
"stochastic_mean": sto_mean,
"stochastic_std": sto_std,
"stochastic_min": sto_min,
"stochastic_returns": sto_returns,
"deterministic_mean": det_mean,
"deterministic_std": det_std,
"deterministic_min": det_min,
"deterministic_returns": det_returns,
})
# Save scan summary
out_path = out_dir / f"checkpoint_scan_{args.run_name}.json"
with open(out_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved scan summary to {out_path}")
# Print best by each criterion (NaN-safe)
import math
def safe_max(items, key):
items = [r for r in items if not math.isnan(key(r))]
return max(items, key=key) if items else None
print("\n" + "=" * 80)
print("BEST BY EACH CRITERION")
print("=" * 80)
best_sto_mean = safe_max(results, key=lambda r: r["stochastic_mean"])
best_det_mean = safe_max(results, key=lambda r: r["deterministic_mean"])
best_robust = safe_max(results, key=lambda r: r["stochastic_min"])
if best_sto_mean:
print(f"Highest stochastic mean : {best_sto_mean['ckpt']:>14s} "
f"({best_sto_mean['stochastic_mean']:.1f})")
if best_det_mean:
print(f"Highest deterministic : {best_det_mean['ckpt']:>14s} "
f"({best_det_mean['deterministic_mean']:.1f})")
if best_robust:
print(f"Most robust (high min) : {best_robust['ckpt']:>14s} "
f"(stochastic min {best_robust['stochastic_min']:.1f})")
if __name__ == "__main__":
main()