"""Scan all checkpoints in a run directory and evaluate each one. Usage: python scan_checkpoints.py --run-name vec_main --episodes 10 For each .pt file, evaluates with both stochastic and deterministic policies and prints a comparison table. Helps identify the best checkpoint to submit when the final one over-fits / over-anneals. """ import argparse import json from pathlib import Path import numpy as np from src.eval_utils import evaluate_agent from src.ppo_agent import PPOAgent def parse_args(): p = argparse.ArgumentParser() p.add_argument("--run-name", type=str, default="vec_main") p.add_argument("--episodes", type=int, default=10, help="Episodes per checkpoint per mode (10 is enough for ranking)") p.add_argument("--seed-start", type=int, default=2000, help="Use seeds different from final-evaluation 1000-1019") p.add_argument("--out-dir", type=str, default="docs") p.add_argument("--from-iter", type=int, default=0, help="Skip checkpoints whose iter number is below this") p.add_argument("--every-k", type=int, default=1, help="Subsample: only evaluate every k-th checkpoint") p.add_argument("--mode", type=str, default="both", choices=["both", "stochastic", "deterministic"]) return p.parse_args() def main(): args = parse_args() project_root = Path(__file__).resolve().parent ckpt_dir = project_root / "models" / args.run_name out_dir = project_root / args.out_dir out_dir.mkdir(parents=True, exist_ok=True) if not ckpt_dir.exists(): raise FileNotFoundError(f"No such run dir: {ckpt_dir}") # Collect checkpoints (sorted by iter number, final last) all_ckpts = sorted( ckpt_dir.glob("*.pt"), key=lambda p: ( 0 if p.stem == "final" else 1, p.stem, ), ) if not all_ckpts: raise FileNotFoundError(f"No .pt files in {ckpt_dir}") # Apply --from-iter and --every-k filters ckpts = [] iter_pts = [p for p in all_ckpts if p.stem.startswith("iter_")] final_pt = [p for p in all_ckpts if p.stem == "final"] for i, p in enumerate(iter_pts): try: iter_num = int(p.stem.replace("iter_", "")) except ValueError: continue if iter_num < args.from_iter: continue if (i % args.every_k) != 0: continue ckpts.append(p) ckpts.extend(final_pt) if not ckpts: raise RuntimeError("No checkpoints survived filtering") print("=" * 80) print(f"Scanning {len(ckpts)} checkpoints in {ckpt_dir}") print(f"Episodes per ckpt per mode: {args.episodes}") print(f"Seeds: {args.seed_start} to {args.seed_start + args.episodes - 1}") print("=" * 80) results = [] do_sto = args.mode in ("both", "stochastic") do_det = args.mode in ("both", "deterministic") for ckpt in ckpts: agent = PPOAgent(n_actions=5) agent.load(str(ckpt)) agent.net.eval() sto_returns = [] det_returns = [] if do_sto: sto_returns = evaluate_agent( agent, n_episodes=args.episodes, seed_start=args.seed_start, deterministic=False, ) if do_det: det_returns = evaluate_agent( agent, n_episodes=args.episodes, seed_start=args.seed_start, deterministic=True, ) sto_mean = float(np.mean(sto_returns)) if sto_returns else float("nan") sto_std = float(np.std(sto_returns)) if sto_returns else float("nan") sto_min = float(np.min(sto_returns)) if sto_returns else float("nan") det_mean = float(np.mean(det_returns)) if det_returns else float("nan") det_std = float(np.std(det_returns)) if det_returns else float("nan") det_min = float(np.min(det_returns)) if det_returns else float("nan") print( f"{ckpt.stem:>14s} | " f"sto: {sto_mean:7.1f} +/- {sto_std:6.1f} (min {sto_min:6.1f}) | " f"det: {det_mean:7.1f} +/- {det_std:6.1f} (min {det_min:6.1f})" ) results.append({ "ckpt": ckpt.name, "stochastic_mean": sto_mean, "stochastic_std": sto_std, "stochastic_min": sto_min, "stochastic_returns": sto_returns, "deterministic_mean": det_mean, "deterministic_std": det_std, "deterministic_min": det_min, "deterministic_returns": det_returns, }) # Save scan summary out_path = out_dir / f"checkpoint_scan_{args.run_name}.json" with open(out_path, "w") as f: json.dump(results, f, indent=2) print(f"\nSaved scan summary to {out_path}") # Print best by each criterion (NaN-safe) import math def safe_max(items, key): items = [r for r in items if not math.isnan(key(r))] return max(items, key=key) if items else None print("\n" + "=" * 80) print("BEST BY EACH CRITERION") print("=" * 80) best_sto_mean = safe_max(results, key=lambda r: r["stochastic_mean"]) best_det_mean = safe_max(results, key=lambda r: r["deterministic_mean"]) best_robust = safe_max(results, key=lambda r: r["stochastic_min"]) if best_sto_mean: print(f"Highest stochastic mean : {best_sto_mean['ckpt']:>14s} " f"({best_sto_mean['stochastic_mean']:.1f})") if best_det_mean: print(f"Highest deterministic : {best_det_mean['ckpt']:>14s} " f"({best_det_mean['deterministic_mean']:.1f})") if best_robust: print(f"Most robust (high min) : {best_robust['ckpt']:>14s} " f"(stochastic min {best_robust['stochastic_min']:.1f})") if __name__ == "__main__": main()