Files
Serendipity 79ffb90823 feat: 优化DQN训练配置并支持Dueling网络评估
- 将学习率调整为5e-5,批次大小增加到64,经验回放缓冲区扩大到500,000
- 启用优先经验回放,调整目标网络更新频率为1000步
- 评估时使用Dueling网络架构,训练时评估模式的ε设为0
- 更新评估结果以反映配置改进后的性能变化
2026-05-02 11:36:12 +08:00

111 lines
3.4 KiB
Python

"""评估所有检查点并生成评估曲线数据"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
import numpy as np
import json
from src.network import QNetwork, DuelingQNetwork
from src.utils import make_env, get_device
def evaluate_model(model_path, env, device, num_actions, num_episodes=10, use_dueling=True):
"""评估单个模型"""
state_shape = (4, 84, 84)
if use_dueling:
q_network = DuelingQNetwork(state_shape, num_actions).to(device)
else:
q_network = QNetwork(state_shape, num_actions).to(device)
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
q_network.load_state_dict(checkpoint['q_network'])
q_network.eval()
rewards = []
for _ in range(num_episodes):
state, _ = env.reset()
episode_reward = 0
done = False
while not done:
with torch.no_grad():
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device)
q_values = q_network(state_tensor)
action = q_values.argmax(dim=1).item()
state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
episode_reward += reward
rewards.append(episode_reward)
return np.mean(rewards), np.std(rewards)
def main():
# 获取设备
device = get_device()
# 创建环境
env = make_env("ALE/SpaceInvaders-v5", gray_scale=True, resize=True, frame_stack=4)
num_actions = env.action_space.n
# 检查点列表
checkpoints = [
("models/dqn_step_100000.pt", 100000),
("models/dqn_step_200000.pt", 200000),
("models/dqn_step_400000.pt", 400000),
("models/dqn_step_600000.pt", 600000),
("models/dqn_step_800000.pt", 800000),
("models/dqn_step_1000000.pt", 1000000),
("models/dqn_step_1200000.pt", 1200000),
("models/dqn_step_1400000.pt", 1400000),
("models/dqn_step_1600000.pt", 1600000),
("models/dqn_step_1800000.pt", 1800000),
("models/dqn_step_2000000.pt", 2000000),
("models/dqn_best.pt", -1), # 最佳模型
("models/dqn_final.pt", -2), # 最终模型
]
results = []
for model_path, step in checkpoints:
if not os.path.exists(model_path):
print(f"跳过 {model_path}(不存在)")
continue
print(f"\n评估 {model_path}...")
avg_reward, std_reward = evaluate_model(
model_path, env, device, num_actions, num_episodes=10, use_dueling=True
)
results.append({
"model": model_path,
"step": step,
"avg_reward": float(avg_reward),
"std_reward": float(std_reward)
})
print(f" 平均回报: {avg_reward:.2f} ± {std_reward:.2f}")
# 保存结果
output_file = "evaluation_results.json"
with open(output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\n结果已保存到 {output_file}")
# 打印汇总
print("\n" + "=" * 60)
print("评估汇总:")
print("=" * 60)
for r in results:
step_str = f"Step {r['step']:>10,}" if r['step'] > 0 else f"{'Best' if r['step'] == -1 else 'Final':>15}"
print(f"{step_str}: 平均回报 = {r['avg_reward']:.2f} ± {r['std_reward']:.2f}")
if __name__ == "__main__":
main()