feat: 优化DQN训练配置并支持Dueling网络评估
- 将学习率调整为5e-5,批次大小增加到64,经验回放缓冲区扩大到500,000 - 启用优先经验回放,调整目标网络更新频率为1000步 - 评估时使用Dueling网络架构,训练时评估模式的ε设为0 - 更新评估结果以反映配置改进后的性能变化
This commit is contained in:
@@ -7,14 +7,17 @@ import torch
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
from src.network import QNetwork
|
||||
from src.network import QNetwork, DuelingQNetwork
|
||||
from src.utils import make_env, get_device
|
||||
|
||||
|
||||
def evaluate_model(model_path, env, device, num_actions, num_episodes=10):
|
||||
def evaluate_model(model_path, env, device, num_actions, num_episodes=10, use_dueling=True):
|
||||
"""评估单个模型"""
|
||||
state_shape = (4, 84, 84)
|
||||
q_network = QNetwork(state_shape, num_actions).to(device)
|
||||
if use_dueling:
|
||||
q_network = DuelingQNetwork(state_shape, num_actions).to(device)
|
||||
else:
|
||||
q_network = QNetwork(state_shape, num_actions).to(device)
|
||||
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||
q_network.load_state_dict(checkpoint['q_network'])
|
||||
@@ -75,7 +78,7 @@ def main():
|
||||
|
||||
print(f"\n评估 {model_path}...")
|
||||
avg_reward, std_reward = evaluate_model(
|
||||
model_path, env, device, num_actions, num_episodes=10
|
||||
model_path, env, device, num_actions, num_episodes=10, use_dueling=True
|
||||
)
|
||||
|
||||
results.append({
|
||||
|
||||
Reference in New Issue
Block a user