79ffb90823
- 将学习率调整为5e-5,批次大小增加到64,经验回放缓冲区扩大到500,000 - 启用优先经验回放,调整目标网络更新频率为1000步 - 评估时使用Dueling网络架构,训练时评估模式的ε设为0 - 更新评估结果以反映配置改进后的性能变化
111 lines
3.4 KiB
Python
111 lines
3.4 KiB
Python
"""评估所有检查点并生成评估曲线数据"""
|
|
import sys
|
|
import os
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
import torch
|
|
import numpy as np
|
|
import json
|
|
|
|
from src.network import QNetwork, DuelingQNetwork
|
|
from src.utils import make_env, get_device
|
|
|
|
|
|
def evaluate_model(model_path, env, device, num_actions, num_episodes=10, use_dueling=True):
|
|
"""评估单个模型"""
|
|
state_shape = (4, 84, 84)
|
|
if use_dueling:
|
|
q_network = DuelingQNetwork(state_shape, num_actions).to(device)
|
|
else:
|
|
q_network = QNetwork(state_shape, num_actions).to(device)
|
|
|
|
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
|
q_network.load_state_dict(checkpoint['q_network'])
|
|
q_network.eval()
|
|
|
|
rewards = []
|
|
for _ in range(num_episodes):
|
|
state, _ = env.reset()
|
|
episode_reward = 0
|
|
done = False
|
|
|
|
while not done:
|
|
with torch.no_grad():
|
|
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device)
|
|
q_values = q_network(state_tensor)
|
|
action = q_values.argmax(dim=1).item()
|
|
|
|
state, reward, terminated, truncated, _ = env.step(action)
|
|
done = terminated or truncated
|
|
episode_reward += reward
|
|
|
|
rewards.append(episode_reward)
|
|
|
|
return np.mean(rewards), np.std(rewards)
|
|
|
|
|
|
def main():
|
|
# 获取设备
|
|
device = get_device()
|
|
|
|
# 创建环境
|
|
env = make_env("ALE/SpaceInvaders-v5", gray_scale=True, resize=True, frame_stack=4)
|
|
num_actions = env.action_space.n
|
|
|
|
# 检查点列表
|
|
checkpoints = [
|
|
("models/dqn_step_100000.pt", 100000),
|
|
("models/dqn_step_200000.pt", 200000),
|
|
("models/dqn_step_400000.pt", 400000),
|
|
("models/dqn_step_600000.pt", 600000),
|
|
("models/dqn_step_800000.pt", 800000),
|
|
("models/dqn_step_1000000.pt", 1000000),
|
|
("models/dqn_step_1200000.pt", 1200000),
|
|
("models/dqn_step_1400000.pt", 1400000),
|
|
("models/dqn_step_1600000.pt", 1600000),
|
|
("models/dqn_step_1800000.pt", 1800000),
|
|
("models/dqn_step_2000000.pt", 2000000),
|
|
("models/dqn_best.pt", -1), # 最佳模型
|
|
("models/dqn_final.pt", -2), # 最终模型
|
|
]
|
|
|
|
results = []
|
|
|
|
for model_path, step in checkpoints:
|
|
if not os.path.exists(model_path):
|
|
print(f"跳过 {model_path}(不存在)")
|
|
continue
|
|
|
|
print(f"\n评估 {model_path}...")
|
|
avg_reward, std_reward = evaluate_model(
|
|
model_path, env, device, num_actions, num_episodes=10, use_dueling=True
|
|
)
|
|
|
|
results.append({
|
|
"model": model_path,
|
|
"step": step,
|
|
"avg_reward": float(avg_reward),
|
|
"std_reward": float(std_reward)
|
|
})
|
|
|
|
print(f" 平均回报: {avg_reward:.2f} ± {std_reward:.2f}")
|
|
|
|
# 保存结果
|
|
output_file = "evaluation_results.json"
|
|
with open(output_file, "w", encoding="utf-8") as f:
|
|
json.dump(results, f, indent=2, ensure_ascii=False)
|
|
|
|
print(f"\n结果已保存到 {output_file}")
|
|
|
|
# 打印汇总
|
|
print("\n" + "=" * 60)
|
|
print("评估汇总:")
|
|
print("=" * 60)
|
|
for r in results:
|
|
step_str = f"Step {r['step']:>10,}" if r['step'] > 0 else f"{'Best' if r['step'] == -1 else 'Final':>15}"
|
|
print(f"{step_str}: 平均回报 = {r['avg_reward']:.2f} ± {r['std_reward']:.2f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|