feat: 添加模型评估脚本并更新实验报告
- 添加 evaluate_checkpoints.py 脚本,用于评估训练过程中的检查点模型 - 更新 generate_plots.py 以支持从真实评估结果生成图表 - 更新实验报告内容,包含具体实验结果数据和分析 - 添加中文支持并更新作者信息 - 生成评估结果JSON文件和相应图表
This commit is contained in:
@@ -0,0 +1,107 @@
|
||||
"""评估所有检查点并生成评估曲线数据"""
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
from src.network import QNetwork
|
||||
from src.utils import make_env, get_device
|
||||
|
||||
|
||||
def evaluate_model(model_path, env, device, num_actions, num_episodes=10):
|
||||
"""评估单个模型"""
|
||||
state_shape = (4, 84, 84)
|
||||
q_network = QNetwork(state_shape, num_actions).to(device)
|
||||
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||
q_network.load_state_dict(checkpoint['q_network'])
|
||||
q_network.eval()
|
||||
|
||||
rewards = []
|
||||
for _ in range(num_episodes):
|
||||
state, _ = env.reset()
|
||||
episode_reward = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device)
|
||||
q_values = q_network(state_tensor)
|
||||
action = q_values.argmax(dim=1).item()
|
||||
|
||||
state, reward, terminated, truncated, _ = env.step(action)
|
||||
done = terminated or truncated
|
||||
episode_reward += reward
|
||||
|
||||
rewards.append(episode_reward)
|
||||
|
||||
return np.mean(rewards), np.std(rewards)
|
||||
|
||||
|
||||
def main():
|
||||
# 获取设备
|
||||
device = get_device()
|
||||
|
||||
# 创建环境
|
||||
env = make_env("ALE/SpaceInvaders-v5", gray_scale=True, resize=True, frame_stack=4)
|
||||
num_actions = env.action_space.n
|
||||
|
||||
# 检查点列表
|
||||
checkpoints = [
|
||||
("models/dqn_step_100000.pt", 100000),
|
||||
("models/dqn_step_200000.pt", 200000),
|
||||
("models/dqn_step_400000.pt", 400000),
|
||||
("models/dqn_step_600000.pt", 600000),
|
||||
("models/dqn_step_800000.pt", 800000),
|
||||
("models/dqn_step_1000000.pt", 1000000),
|
||||
("models/dqn_step_1200000.pt", 1200000),
|
||||
("models/dqn_step_1400000.pt", 1400000),
|
||||
("models/dqn_step_1600000.pt", 1600000),
|
||||
("models/dqn_step_1800000.pt", 1800000),
|
||||
("models/dqn_step_2000000.pt", 2000000),
|
||||
("models/dqn_best.pt", -1), # 最佳模型
|
||||
("models/dqn_final.pt", -2), # 最终模型
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for model_path, step in checkpoints:
|
||||
if not os.path.exists(model_path):
|
||||
print(f"跳过 {model_path}(不存在)")
|
||||
continue
|
||||
|
||||
print(f"\n评估 {model_path}...")
|
||||
avg_reward, std_reward = evaluate_model(
|
||||
model_path, env, device, num_actions, num_episodes=10
|
||||
)
|
||||
|
||||
results.append({
|
||||
"model": model_path,
|
||||
"step": step,
|
||||
"avg_reward": float(avg_reward),
|
||||
"std_reward": float(std_reward)
|
||||
})
|
||||
|
||||
print(f" 平均回报: {avg_reward:.2f} ± {std_reward:.2f}")
|
||||
|
||||
# 保存结果
|
||||
output_file = "evaluation_results.json"
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\n结果已保存到 {output_file}")
|
||||
|
||||
# 打印汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("评估汇总:")
|
||||
print("=" * 60)
|
||||
for r in results:
|
||||
step_str = f"Step {r['step']:>10,}" if r['step'] > 0 else f"{'Best' if r['step'] == -1 else 'Final':>15}"
|
||||
print(f"{step_str}: 平均回报 = {r['avg_reward']:.2f} ± {r['std_reward']:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user