85f912483a
- 将 agent.py 中的 train_step 变量重命名为 step_count,使其含义更清晰 - 更新所有相关引用,包括 epsilon 衰减和目标网络更新逻辑 - 同步修改模型保存和加载时的键名 - 修复多个源文件末尾的换行符问题
185 lines
5.8 KiB
Python
185 lines
5.8 KiB
Python
"""DQN Agent implementation."""
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import os
|
|
|
|
|
|
class DQNAgent:
|
|
"""DQN智能体
|
|
|
|
实现ε-greedy探索策略和Q-learning更新
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
q_network,
|
|
target_network,
|
|
replay_buffer,
|
|
device,
|
|
num_actions=6,
|
|
gamma=0.99,
|
|
lr=1e-4,
|
|
epsilon_start=1.0,
|
|
epsilon_end=0.01,
|
|
epsilon_decay_steps=1_000_000,
|
|
target_update_freq=1000,
|
|
batch_size=32,
|
|
double_dqn=True,
|
|
):
|
|
"""
|
|
Args:
|
|
q_network: Q网络
|
|
target_network: 目标网络
|
|
replay_buffer: 经验回放缓冲区
|
|
device: 设备
|
|
num_actions: 动作数量
|
|
gamma: 折扣因子
|
|
lr: 学习率
|
|
epsilon_start: ε初始值
|
|
epsilon_end: ε最终值
|
|
epsilon_decay_steps: ε衰减步数
|
|
target_update_freq: 目标网络更新频率
|
|
batch_size: 批次大小
|
|
double_dqn: 是否使用Double DQN
|
|
"""
|
|
self.q_network = q_network
|
|
self.target_network = target_network
|
|
self.replay_buffer = replay_buffer
|
|
self.device = device
|
|
self.num_actions = num_actions
|
|
self.gamma = gamma
|
|
self.batch_size = batch_size
|
|
self.target_update_freq = target_update_freq
|
|
self.double_dqn = double_dqn
|
|
|
|
# ε-greedy参数
|
|
self.epsilon_start = epsilon_start
|
|
self.epsilon_end = epsilon_end
|
|
self.epsilon_decay_steps = epsilon_decay_steps
|
|
self.epsilon = epsilon_start
|
|
|
|
# 优化器
|
|
self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr)
|
|
|
|
# 训练步数
|
|
self.step_count = 0
|
|
|
|
# 训练统计
|
|
self.loss_history = []
|
|
self.q_value_history = []
|
|
|
|
def select_action(self, state, evaluate=False):
|
|
"""选择动作
|
|
|
|
Args:
|
|
state: 当前状态 (channels, height, width)
|
|
evaluate: 是否为评估模式(不使用ε-greedy)
|
|
|
|
Returns:
|
|
action: 选择的动作
|
|
"""
|
|
if evaluate:
|
|
# 评估模式:纯贪心
|
|
epsilon = 0.01
|
|
else:
|
|
# 训练模式:ε-greedy
|
|
epsilon = self.epsilon
|
|
|
|
if np.random.random() < epsilon:
|
|
# 随机探索
|
|
return np.random.randint(self.num_actions)
|
|
else:
|
|
# 贪心选择
|
|
with torch.no_grad():
|
|
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
|
|
q_values = self.q_network(state_tensor)
|
|
return q_values.argmax(dim=1).item()
|
|
|
|
def update_epsilon(self):
|
|
"""更新ε值(线性衰减)"""
|
|
if self.step_count < self.epsilon_decay_steps:
|
|
self.epsilon = self.epsilon_start - (self.epsilon_start - self.epsilon_end) * \
|
|
(self.step_count / self.epsilon_decay_steps)
|
|
else:
|
|
self.epsilon = self.epsilon_end
|
|
|
|
def train_step(self):
|
|
"""执行一步训练
|
|
|
|
Returns:
|
|
loss: 损失值
|
|
avg_q: 平均Q值
|
|
"""
|
|
# 检查是否有足够样本
|
|
if len(self.replay_buffer) < self.batch_size:
|
|
return None, None
|
|
|
|
# 采样
|
|
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
|
|
|
|
# 计算当前Q值
|
|
q_values = self.q_network(states)
|
|
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
|
|
# 计算目标Q值
|
|
with torch.no_grad():
|
|
if self.double_dqn:
|
|
# Double DQN: 用Q网络选择动作,用目标网络评估
|
|
next_actions = self.q_network(next_states).argmax(dim=1)
|
|
next_q_values = self.target_network(next_states)
|
|
next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
|
else:
|
|
# 标准DQN: 直接用目标网络的最大Q值
|
|
next_q_values = self.target_network(next_states).max(dim=1)[0]
|
|
|
|
# 计算目标
|
|
target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
|
|
|
|
# 计算损失
|
|
loss = F.mse_loss(q_values, target_q_values)
|
|
|
|
# 反向传播
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
# 梯度裁剪
|
|
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10)
|
|
self.optimizer.step()
|
|
|
|
# 更新目标网络
|
|
self.step_count += 1
|
|
if self.step_count % self.target_update_freq == 0:
|
|
self.target_network.load_state_dict(self.q_network.state_dict())
|
|
|
|
# 更新ε
|
|
self.update_epsilon()
|
|
|
|
# 记录统计
|
|
avg_q = q_values.mean().item()
|
|
self.loss_history.append(loss.item())
|
|
self.q_value_history.append(avg_q)
|
|
|
|
return loss.item(), avg_q
|
|
|
|
def save(self, path):
|
|
"""保存模型"""
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
torch.save({
|
|
'q_network': self.q_network.state_dict(),
|
|
'target_network': self.target_network.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
'step_count': self.step_count,
|
|
'epsilon': self.epsilon,
|
|
}, path)
|
|
print(f"模型已保存到: {path}")
|
|
|
|
def load(self, path):
|
|
"""加载模型"""
|
|
checkpoint = torch.load(path, map_location=self.device)
|
|
self.q_network.load_state_dict(checkpoint['q_network'])
|
|
self.target_network.load_state_dict(checkpoint['target_network'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
self.step_count = checkpoint['step_count']
|
|
self.epsilon = checkpoint['epsilon']
|
|
print(f"模型已从 {path} 加载")
|