fb09e66d09
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率 - 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计 - 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py) - 提供详细的文档和开发日志,包含问题解决记录和实验分析 - 移除旧版项目文件,统一项目结构到CW1_id_name目录下
164 lines
5.7 KiB
Python
164 lines
5.7 KiB
Python
"""Scan all checkpoints in a run directory and evaluate each one.
|
|
|
|
Usage:
|
|
python scan_checkpoints.py --run-name vec_main --episodes 10
|
|
|
|
For each .pt file, evaluates with both stochastic and deterministic
|
|
policies and prints a comparison table. Helps identify the best
|
|
checkpoint to submit when the final one over-fits / over-anneals.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
from src.eval_utils import evaluate_agent
|
|
from src.ppo_agent import PPOAgent
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--run-name", type=str, default="vec_main")
|
|
p.add_argument("--episodes", type=int, default=10,
|
|
help="Episodes per checkpoint per mode (10 is enough for ranking)")
|
|
p.add_argument("--seed-start", type=int, default=2000,
|
|
help="Use seeds different from final-evaluation 1000-1019")
|
|
p.add_argument("--out-dir", type=str, default="docs")
|
|
p.add_argument("--from-iter", type=int, default=0,
|
|
help="Skip checkpoints whose iter number is below this")
|
|
p.add_argument("--every-k", type=int, default=1,
|
|
help="Subsample: only evaluate every k-th checkpoint")
|
|
p.add_argument("--mode", type=str, default="both",
|
|
choices=["both", "stochastic", "deterministic"])
|
|
return p.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
project_root = Path(__file__).resolve().parent
|
|
ckpt_dir = project_root / "models" / args.run_name
|
|
out_dir = project_root / args.out_dir
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not ckpt_dir.exists():
|
|
raise FileNotFoundError(f"No such run dir: {ckpt_dir}")
|
|
|
|
# Collect checkpoints (sorted by iter number, final last)
|
|
all_ckpts = sorted(
|
|
ckpt_dir.glob("*.pt"),
|
|
key=lambda p: (
|
|
0 if p.stem == "final" else 1,
|
|
p.stem,
|
|
),
|
|
)
|
|
if not all_ckpts:
|
|
raise FileNotFoundError(f"No .pt files in {ckpt_dir}")
|
|
|
|
# Apply --from-iter and --every-k filters
|
|
ckpts = []
|
|
iter_pts = [p for p in all_ckpts if p.stem.startswith("iter_")]
|
|
final_pt = [p for p in all_ckpts if p.stem == "final"]
|
|
for i, p in enumerate(iter_pts):
|
|
try:
|
|
iter_num = int(p.stem.replace("iter_", ""))
|
|
except ValueError:
|
|
continue
|
|
if iter_num < args.from_iter:
|
|
continue
|
|
if (i % args.every_k) != 0:
|
|
continue
|
|
ckpts.append(p)
|
|
ckpts.extend(final_pt)
|
|
if not ckpts:
|
|
raise RuntimeError("No checkpoints survived filtering")
|
|
|
|
print("=" * 80)
|
|
print(f"Scanning {len(ckpts)} checkpoints in {ckpt_dir}")
|
|
print(f"Episodes per ckpt per mode: {args.episodes}")
|
|
print(f"Seeds: {args.seed_start} to {args.seed_start + args.episodes - 1}")
|
|
print("=" * 80)
|
|
|
|
results = []
|
|
|
|
do_sto = args.mode in ("both", "stochastic")
|
|
do_det = args.mode in ("both", "deterministic")
|
|
|
|
for ckpt in ckpts:
|
|
agent = PPOAgent(n_actions=5)
|
|
agent.load(str(ckpt))
|
|
agent.net.eval()
|
|
|
|
sto_returns = []
|
|
det_returns = []
|
|
if do_sto:
|
|
sto_returns = evaluate_agent(
|
|
agent, n_episodes=args.episodes,
|
|
seed_start=args.seed_start, deterministic=False,
|
|
)
|
|
if do_det:
|
|
det_returns = evaluate_agent(
|
|
agent, n_episodes=args.episodes,
|
|
seed_start=args.seed_start, deterministic=True,
|
|
)
|
|
|
|
sto_mean = float(np.mean(sto_returns)) if sto_returns else float("nan")
|
|
sto_std = float(np.std(sto_returns)) if sto_returns else float("nan")
|
|
sto_min = float(np.min(sto_returns)) if sto_returns else float("nan")
|
|
det_mean = float(np.mean(det_returns)) if det_returns else float("nan")
|
|
det_std = float(np.std(det_returns)) if det_returns else float("nan")
|
|
det_min = float(np.min(det_returns)) if det_returns else float("nan")
|
|
|
|
print(
|
|
f"{ckpt.stem:>14s} | "
|
|
f"sto: {sto_mean:7.1f} +/- {sto_std:6.1f} (min {sto_min:6.1f}) | "
|
|
f"det: {det_mean:7.1f} +/- {det_std:6.1f} (min {det_min:6.1f})"
|
|
)
|
|
|
|
results.append({
|
|
"ckpt": ckpt.name,
|
|
"stochastic_mean": sto_mean,
|
|
"stochastic_std": sto_std,
|
|
"stochastic_min": sto_min,
|
|
"stochastic_returns": sto_returns,
|
|
"deterministic_mean": det_mean,
|
|
"deterministic_std": det_std,
|
|
"deterministic_min": det_min,
|
|
"deterministic_returns": det_returns,
|
|
})
|
|
|
|
# Save scan summary
|
|
out_path = out_dir / f"checkpoint_scan_{args.run_name}.json"
|
|
with open(out_path, "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
print(f"\nSaved scan summary to {out_path}")
|
|
|
|
# Print best by each criterion (NaN-safe)
|
|
import math
|
|
|
|
def safe_max(items, key):
|
|
items = [r for r in items if not math.isnan(key(r))]
|
|
return max(items, key=key) if items else None
|
|
|
|
print("\n" + "=" * 80)
|
|
print("BEST BY EACH CRITERION")
|
|
print("=" * 80)
|
|
best_sto_mean = safe_max(results, key=lambda r: r["stochastic_mean"])
|
|
best_det_mean = safe_max(results, key=lambda r: r["deterministic_mean"])
|
|
best_robust = safe_max(results, key=lambda r: r["stochastic_min"])
|
|
|
|
if best_sto_mean:
|
|
print(f"Highest stochastic mean : {best_sto_mean['ckpt']:>14s} "
|
|
f"({best_sto_mean['stochastic_mean']:.1f})")
|
|
if best_det_mean:
|
|
print(f"Highest deterministic : {best_det_mean['ckpt']:>14s} "
|
|
f"({best_det_mean['deterministic_mean']:.1f})")
|
|
if best_robust:
|
|
print(f"Most robust (high min) : {best_robust['ckpt']:>14s} "
|
|
f"(stochastic min {best_robust['stochastic_min']:.1f})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|