feat: 优化DQN训练配置并支持Dueling网络评估
- 将学习率调整为5e-5,批次大小增加到64,经验回放缓冲区扩大到500,000 - 启用优先经验回放,调整目标网络更新频率为1000步 - 评估时使用Dueling网络架构,训练时评估模式的ε设为0 - 更新评估结果以反映配置改进后的性能变化
This commit is contained in:
@@ -7,13 +7,16 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from src.network import QNetwork
|
from src.network import QNetwork, DuelingQNetwork
|
||||||
from src.utils import make_env, get_device
|
from src.utils import make_env, get_device
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(model_path, env, device, num_actions, num_episodes=10):
|
def evaluate_model(model_path, env, device, num_actions, num_episodes=10, use_dueling=True):
|
||||||
"""评估单个模型"""
|
"""评估单个模型"""
|
||||||
state_shape = (4, 84, 84)
|
state_shape = (4, 84, 84)
|
||||||
|
if use_dueling:
|
||||||
|
q_network = DuelingQNetwork(state_shape, num_actions).to(device)
|
||||||
|
else:
|
||||||
q_network = QNetwork(state_shape, num_actions).to(device)
|
q_network = QNetwork(state_shape, num_actions).to(device)
|
||||||
|
|
||||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||||
@@ -75,7 +78,7 @@ def main():
|
|||||||
|
|
||||||
print(f"\n评估 {model_path}...")
|
print(f"\n评估 {model_path}...")
|
||||||
avg_reward, std_reward = evaluate_model(
|
avg_reward, std_reward = evaluate_model(
|
||||||
model_path, env, device, num_actions, num_episodes=10
|
model_path, env, device, num_actions, num_episodes=10, use_dueling=True
|
||||||
)
|
)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
|
|||||||
@@ -2,79 +2,79 @@
|
|||||||
{
|
{
|
||||||
"model": "models/dqn_step_100000.pt",
|
"model": "models/dqn_step_100000.pt",
|
||||||
"step": 100000,
|
"step": 100000,
|
||||||
"avg_reward": 17.8,
|
"avg_reward": 20.9,
|
||||||
"std_reward": 5.2306787322488075
|
"std_reward": 11.235657524150511
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_step_200000.pt",
|
"model": "models/dqn_step_200000.pt",
|
||||||
"step": 200000,
|
"step": 200000,
|
||||||
"avg_reward": 14.0,
|
"avg_reward": 23.05,
|
||||||
"std_reward": 6.603029607687671
|
"std_reward": 8.361967471833408
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_step_400000.pt",
|
"model": "models/dqn_step_400000.pt",
|
||||||
"step": 400000,
|
"step": 400000,
|
||||||
"avg_reward": 16.4,
|
"avg_reward": 14.5,
|
||||||
"std_reward": 4.24735211631906
|
"std_reward": 9.418067742376884
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_step_600000.pt",
|
"model": "models/dqn_step_600000.pt",
|
||||||
"step": 600000,
|
"step": 600000,
|
||||||
"avg_reward": 19.0,
|
"avg_reward": 22.0,
|
||||||
"std_reward": 4.123105625617661
|
"std_reward": 11.218288639538564
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_step_800000.pt",
|
"model": "models/dqn_step_800000.pt",
|
||||||
"step": 800000,
|
"step": 800000,
|
||||||
"avg_reward": 13.2,
|
"avg_reward": 24.95,
|
||||||
"std_reward": 3.944616584663204
|
"std_reward": 11.617766566771772
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_step_1000000.pt",
|
"model": "models/dqn_step_1000000.pt",
|
||||||
"step": 1000000,
|
"step": 1000000,
|
||||||
"avg_reward": 15.7,
|
"avg_reward": 32.65,
|
||||||
"std_reward": 4.960846701924985
|
"std_reward": 14.44134689009304
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_step_1200000.pt",
|
"model": "models/dqn_step_1200000.pt",
|
||||||
"step": 1200000,
|
"step": 1200000,
|
||||||
"avg_reward": 18.4,
|
"avg_reward": 21.5,
|
||||||
"std_reward": 6.216108107168021
|
"std_reward": 12.188108959145385
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_step_1400000.pt",
|
"model": "models/dqn_step_1400000.pt",
|
||||||
"step": 1400000,
|
"step": 1400000,
|
||||||
"avg_reward": 14.2,
|
"avg_reward": 16.15,
|
||||||
"std_reward": 4.467661580737736
|
"std_reward": 13.950000000000001
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_step_1600000.pt",
|
"model": "models/dqn_step_1600000.pt",
|
||||||
"step": 1600000,
|
"step": 1600000,
|
||||||
"avg_reward": 17.8,
|
"avg_reward": 30.5,
|
||||||
"std_reward": 4.686149805543993
|
"std_reward": 15.55795616396961
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_step_1800000.pt",
|
"model": "models/dqn_step_1800000.pt",
|
||||||
"step": 1800000,
|
"step": 1800000,
|
||||||
"avg_reward": 21.5,
|
"avg_reward": 34.25,
|
||||||
"std_reward": 4.984977432245807
|
"std_reward": 16.40464873138099
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_step_2000000.pt",
|
"model": "models/dqn_step_2000000.pt",
|
||||||
"step": 2000000,
|
"step": 2000000,
|
||||||
"avg_reward": 14.6,
|
"avg_reward": 23.65,
|
||||||
"std_reward": 5.276362383309167
|
"std_reward": 14.120995007434852
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_best.pt",
|
"model": "models/dqn_best.pt",
|
||||||
"step": -1,
|
"step": -1,
|
||||||
"avg_reward": 19.9,
|
"avg_reward": 16.6,
|
||||||
"std_reward": 6.920260110718383
|
"std_reward": 9.606768447297977
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "models/dqn_final.pt",
|
"model": "models/dqn_final.pt",
|
||||||
"step": -2,
|
"step": -2,
|
||||||
"avg_reward": 11.3,
|
"avg_reward": 20.2,
|
||||||
"std_reward": 3.3778691508109073
|
"std_reward": 11.185258155268478
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -89,7 +89,7 @@ class DQNAgent:
|
|||||||
"""
|
"""
|
||||||
if evaluate:
|
if evaluate:
|
||||||
# 评估模式:纯贪心
|
# 评估模式:纯贪心
|
||||||
epsilon = 0.01
|
epsilon = 0.0
|
||||||
else:
|
else:
|
||||||
# 训练模式:ε-greedy
|
# 训练模式:ε-greedy
|
||||||
epsilon = self.epsilon
|
epsilon = self.epsilon
|
||||||
@@ -139,10 +139,13 @@ class DQNAgent:
|
|||||||
if len(self.replay_buffer) < self.batch_size:
|
if len(self.replay_buffer) < self.batch_size:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# 采样
|
# 采样(兼容标准和优先经验回放)
|
||||||
states, actions, rewards, next_states, dones = self.replay_buffer.sample(
|
sample_result = self.replay_buffer.sample(self.batch_size)
|
||||||
self.batch_size
|
if len(sample_result) == 7:
|
||||||
)
|
states, actions, rewards, next_states, dones, indices, weights = sample_result
|
||||||
|
else:
|
||||||
|
states, actions, rewards, next_states, dones = sample_result
|
||||||
|
indices, weights = None, None
|
||||||
|
|
||||||
# 计算当前Q值
|
# 计算当前Q值
|
||||||
q_values = self.q_network(states)
|
q_values = self.q_network(states)
|
||||||
@@ -151,29 +154,35 @@ class DQNAgent:
|
|||||||
# 计算目标Q值
|
# 计算目标Q值
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.double_dqn:
|
if self.double_dqn:
|
||||||
# Double DQN: 用Q网络选择动作,用目标网络评估
|
|
||||||
next_actions = self.q_network(next_states).argmax(dim=1)
|
next_actions = self.q_network(next_states).argmax(dim=1)
|
||||||
next_q_values = self.target_network(next_states)
|
next_q_values = self.target_network(next_states)
|
||||||
next_q_values = next_q_values.gather(
|
next_q_values = next_q_values.gather(
|
||||||
1, next_actions.unsqueeze(1)
|
1, next_actions.unsqueeze(1)
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
else:
|
else:
|
||||||
# 标准DQN: 直接用目标网络的最大Q值
|
|
||||||
next_q_values = self.target_network(next_states).max(dim=1)[0]
|
next_q_values = self.target_network(next_states).max(dim=1)[0]
|
||||||
|
|
||||||
# 计算目标
|
|
||||||
target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
|
target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
|
||||||
|
|
||||||
# 计算损失
|
# 计算TD误差
|
||||||
|
td_errors = (q_values - target_q_values).detach()
|
||||||
|
|
||||||
|
# 计算损失(优先经验回放使用重要性采样权重)
|
||||||
|
if weights is not None:
|
||||||
|
loss = (weights * F.mse_loss(q_values, target_q_values, reduction='none')).mean()
|
||||||
|
else:
|
||||||
loss = F.mse_loss(q_values, target_q_values)
|
loss = F.mse_loss(q_values, target_q_values)
|
||||||
|
|
||||||
# 反向传播
|
# 反向传播
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
# 梯度裁剪
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10)
|
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10)
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
# 更新优先级
|
||||||
|
if indices is not None and hasattr(self.replay_buffer, 'update_priorities'):
|
||||||
|
self.replay_buffer.update_priorities(indices, td_errors.cpu().numpy())
|
||||||
|
|
||||||
# 更新目标网络
|
# 更新目标网络
|
||||||
self.step_count += 1
|
self.step_count += 1
|
||||||
if self.step_count % self.target_update_freq == 0:
|
if self.step_count % self.target_update_freq == 0:
|
||||||
|
|||||||
@@ -26,11 +26,11 @@ def main():
|
|||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
parser.add_argument("--steps", type=int, default=2_000_000, help="总训练步数")
|
parser.add_argument("--steps", type=int, default=2_000_000, help="总训练步数")
|
||||||
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
|
parser.add_argument("--lr", type=float, default=5e-5, help="学习率")
|
||||||
parser.add_argument("--gamma", type=float, default=0.99, help="折扣因子")
|
parser.add_argument("--gamma", type=float, default=0.99, help="折扣因子")
|
||||||
parser.add_argument("--batch-size", type=int, default=32, help="批次大小")
|
parser.add_argument("--batch-size", type=int, default=64, help="批次大小")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--buffer-size", type=int, default=200_000, help="经验回放缓冲区大小"
|
"--buffer-size", type=int, default=500_000, help="经验回放缓冲区大小"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ε-greedy参数
|
# ε-greedy参数
|
||||||
@@ -42,7 +42,7 @@ def main():
|
|||||||
|
|
||||||
# 网络参数
|
# 网络参数
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--target-update", type=int, default=500, help="目标网络更新频率"
|
"--target-update", type=int, default=1000, help="目标网络更新频率"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--double-dqn", action="store_true", default=True, help="使用Double DQN"
|
"--double-dqn", action="store_true", default=True, help="使用Double DQN"
|
||||||
@@ -68,7 +68,7 @@ def main():
|
|||||||
|
|
||||||
# 优先经验回放
|
# 优先经验回放
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prioritized", action="store_true", default=False, help="使用优先经验回放"
|
"--prioritized", action="store_true", default=True, help="使用优先经验回放"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 其他
|
# 其他
|
||||||
|
|||||||
Reference in New Issue
Block a user