feat: 添加DQN强化学习项目框架和核心实现
实现完整的DQN算法框架,用于Atari Space Invaders游戏训练。包括: - QNetwork和DuelingQNetwork神经网络架构 - 经验回放缓冲区(标准和优先级版本) - DQN智能体实现ε-greedy策略和Double DQN - 环境包装器(灰度化、调整大小、帧堆叠等) - 训练器、评估脚本和图表生成工具 - 详细的项目文档和依赖配置
This commit is contained in:
@@ -1,84 +0,0 @@
|
||||
"""Neural network architectures for Actor and Critic."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
"""Actor network outputting Gaussian policy parameters (mu, sigma)."""
|
||||
|
||||
def __init__(self, state_shape=(84, 84, 4), action_dim=3):
|
||||
super().__init__()
|
||||
c, h, w = (
|
||||
state_shape[2],
|
||||
state_shape[0],
|
||||
state_shape[1],
|
||||
) # channels, height, width
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
out_h = (h - 8) // 4 + 1
|
||||
out_h = (out_h - 4) // 2 + 1
|
||||
out_h = (out_h - 3) // 1 + 1
|
||||
feat_size = 64 * out_h * out_h
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_size, 512),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.mu_head = nn.Linear(512, action_dim)
|
||||
self.log_std_head = nn.Linear(512, action_dim)
|
||||
|
||||
# Initialize output layers
|
||||
nn.init.orthogonal_(self.mu_head.weight, gain=0.01)
|
||||
nn.init.orthogonal_(self.log_std_head.weight, gain=0.01)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass returning (mu, log_std)."""
|
||||
x = x / 255.0 # Normalize
|
||||
x = self.conv(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
mu = torch.tanh(self.mu_head(x))
|
||||
log_std = self.log_std_head(x)
|
||||
log_std = torch.clamp(log_std, -20, 2)
|
||||
return mu, log_std.exp()
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
"""Critic network estimating state value V(s)."""
|
||||
|
||||
def __init__(self, state_shape=(84, 84, 4)):
|
||||
super().__init__()
|
||||
c, h, w = state_shape[2], state_shape[0], state_shape[1]
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
out_h = (h - 8) // 4 + 1
|
||||
out_h = (out_h - 4) // 2 + 1
|
||||
out_h = (out_h - 3) // 1 + 1
|
||||
feat_size = 64 * out_h * out_h
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(feat_size, 512), nn.ReLU(), nn.Linear(512, 1))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass returning V(s)."""
|
||||
x = x / 255.0
|
||||
x = self.conv(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.fc(x)
|
||||
@@ -0,0 +1,119 @@
|
||||
# DQN for Space Invaders
|
||||
|
||||
从零实现的DQN(Deep Q-Network)算法,用于Atari Space Invaders游戏。不使用Stable-Baselines等强化学习库。
|
||||
|
||||
## 项目特点
|
||||
|
||||
- **算法**: DQN / Double DQN / Dueling DQN
|
||||
- **游戏**: Space Invaders (ALE/SpaceInvaders-v5)
|
||||
- **框架**: PyTorch
|
||||
- **环境**: Gymnasium + ALE
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
├── src/
|
||||
│ ├── __init__.py
|
||||
│ ├── network.py # Q-Network架构(标准DQN和Dueling DQN)
|
||||
│ ├── replay_buffer.py # 经验回放缓冲区(标准和优先级)
|
||||
│ ├── agent.py # DQN智能体
|
||||
│ ├── trainer.py # 训练器
|
||||
│ ├── utils.py # 环境包装器和工具函数
|
||||
│ └── evaluate.py # 评估脚本
|
||||
├── train.py # 主训练脚本
|
||||
├── generate_plots.py # 图表生成脚本
|
||||
├── requirements.txt # 依赖列表
|
||||
└── README.md # 项目说明
|
||||
```
|
||||
|
||||
## 环境配置
|
||||
|
||||
```bash
|
||||
# 创建虚拟环境
|
||||
conda activate my_env
|
||||
|
||||
# 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 安装Atari ROM
|
||||
AutoROM --accept-license
|
||||
```
|
||||
|
||||
## 训练
|
||||
|
||||
```bash
|
||||
# 标准DQN训练
|
||||
python train.py --steps 2000000
|
||||
|
||||
# 使用Double DQN
|
||||
python train.py --steps 2000000 --double-dqn
|
||||
|
||||
# 使用Dueling DQN
|
||||
python train.py --steps 2000000 --dueling
|
||||
|
||||
# 自定义参数
|
||||
python train.py \
|
||||
--steps 2000000 \
|
||||
--lr 1e-4 \
|
||||
--gamma 0.99 \
|
||||
--batch-size 32 \
|
||||
--buffer-size 100000 \
|
||||
--epsilon-decay 1000000 \
|
||||
--target-update 1000
|
||||
```
|
||||
|
||||
## 评估
|
||||
|
||||
```bash
|
||||
# 评估训练好的模型
|
||||
python src/evaluate.py --model models/dqn_best.pt --episodes 10
|
||||
|
||||
# 带渲染的评估
|
||||
python src/evaluate.py --model models/dqn_best.pt --episodes 5 --render
|
||||
```
|
||||
|
||||
## 生成图表
|
||||
|
||||
```bash
|
||||
# 生成示例图表
|
||||
python generate_plots.py --sample
|
||||
|
||||
# 从训练日志生成图表
|
||||
python generate_plots.py --log-file logs/training_log.json
|
||||
```
|
||||
|
||||
## 超参数
|
||||
|
||||
| 参数 | 默认值 | 说明 |
|
||||
|------|--------|------|
|
||||
| lr | 1e-4 | 学习率 |
|
||||
| gamma | 0.99 | 折扣因子 |
|
||||
| epsilon-start | 1.0 | ε初始值 |
|
||||
| epsilon-end | 0.01 | ε最终值 |
|
||||
| epsilon-decay | 1,000,000 | ε衰减步数 |
|
||||
| buffer-size | 100,000 | 经验回放大小 |
|
||||
| batch-size | 32 | 批次大小 |
|
||||
| target-update | 1,000 | 目标网络更新频率 |
|
||||
| double-dqn | True | 使用Double DQN |
|
||||
| dueling | False | 使用Dueling DQN |
|
||||
|
||||
## 算法说明
|
||||
|
||||
### DQN (Deep Q-Network)
|
||||
- 使用深度神经网络近似Q函数
|
||||
- 经验回放打破数据相关性
|
||||
- 目标网络稳定训练
|
||||
|
||||
### Double DQN
|
||||
- 解决Q值过估计问题
|
||||
- 用Q网络选择动作,用目标网络评估
|
||||
|
||||
### Dueling DQN
|
||||
- 分离状态价值和优势函数
|
||||
- 更好地学习状态价值
|
||||
|
||||
## 参考文献
|
||||
|
||||
1. Mnih, V., et al. (2015). Human-level control through deep reinforcement learning. Nature.
|
||||
2. Van Hasselt, H., et al. (2016). Deep Reinforcement Learning with Double Q-learning. AAAI.
|
||||
3. Wang, Z., et al. (2016). Dueling Network Architectures for Deep Reinforcement Learning. ICML.
|
||||
@@ -0,0 +1,264 @@
|
||||
"""Generate training plots for the report."""
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def load_training_logs(log_file):
|
||||
"""加载训练日志
|
||||
|
||||
Args:
|
||||
log_file: 日志文件路径
|
||||
|
||||
Returns:
|
||||
logs: 训练日志字典
|
||||
"""
|
||||
logs = defaultdict(list)
|
||||
|
||||
if os.path.exists(log_file):
|
||||
with open(log_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
for key, values in data.items():
|
||||
logs[key] = values
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def smooth_data(data, window=100):
|
||||
"""平滑数据
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
window: 平滑窗口大小
|
||||
|
||||
Returns:
|
||||
smoothed: 平滑后的数据
|
||||
"""
|
||||
if len(data) < window:
|
||||
return data
|
||||
|
||||
smoothed = []
|
||||
for i in range(len(data)):
|
||||
start = max(0, i - window + 1)
|
||||
smoothed.append(np.mean(data[start:i+1]))
|
||||
|
||||
return smoothed
|
||||
|
||||
|
||||
def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
"""绘制训练曲线
|
||||
|
||||
Args:
|
||||
rewards: 回报列表
|
||||
losses: 损失列表
|
||||
q_values: Q值列表
|
||||
save_dir: 保存目录
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 创建2x2子图
|
||||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||||
|
||||
# 1. 训练回报曲线
|
||||
ax1 = axes[0, 0]
|
||||
if rewards:
|
||||
episodes = range(1, len(rewards) + 1)
|
||||
ax1.plot(episodes, rewards, alpha=0.3, color='blue', label='原始数据')
|
||||
smoothed_rewards = smooth_data(rewards, window=100)
|
||||
ax1.plot(episodes, smoothed_rewards, color='red', linewidth=2, label='平滑曲线 (window=100)')
|
||||
ax1.set_xlabel('Episode', fontsize=12)
|
||||
ax1.set_ylabel('回报', fontsize=12)
|
||||
ax1.set_title('训练回报曲线', fontsize=14)
|
||||
ax1.legend(fontsize=10)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# 2. 损失曲线
|
||||
ax2 = axes[0, 1]
|
||||
if losses:
|
||||
steps = range(1, len(losses) + 1)
|
||||
ax2.plot(steps, losses, alpha=0.3, color='green', label='原始数据')
|
||||
smoothed_losses = smooth_data(losses, window=100)
|
||||
ax2.plot(steps, smoothed_losses, color='red', linewidth=2, label='平滑曲线')
|
||||
ax2.set_xlabel('训练步数', fontsize=12)
|
||||
ax2.set_ylabel('损失', fontsize=12)
|
||||
ax2.set_title('训练损失曲线', fontsize=14)
|
||||
ax2.legend(fontsize=10)
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# 3. Q值曲线
|
||||
ax3 = axes[1, 0]
|
||||
if q_values:
|
||||
steps = range(1, len(q_values) + 1)
|
||||
ax3.plot(steps, q_values, alpha=0.3, color='purple', label='原始数据')
|
||||
smoothed_q = smooth_data(q_values, window=100)
|
||||
ax3.plot(steps, smoothed_q, color='red', linewidth=2, label='平滑曲线')
|
||||
ax3.set_xlabel('训练步数', fontsize=12)
|
||||
ax3.set_ylabel('平均Q值', fontsize=12)
|
||||
ax3.set_title('平均Q值变化', fontsize=14)
|
||||
ax3.legend(fontsize=10)
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# 4. 回报分布直方图
|
||||
ax4 = axes[1, 1]
|
||||
if rewards:
|
||||
ax4.hist(rewards, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
|
||||
ax4.axvline(np.mean(rewards), color='red', linestyle='--', linewidth=2,
|
||||
label=f'均值: {np.mean(rewards):.1f}')
|
||||
ax4.axvline(np.median(rewards), color='green', linestyle='--', linewidth=2,
|
||||
label=f'中位数: {np.median(rewards):.1f}')
|
||||
ax4.set_xlabel('回报', fontsize=12)
|
||||
ax4.set_ylabel('频次', fontsize=12)
|
||||
ax4.set_title('回报分布', fontsize=14)
|
||||
ax4.legend(fontsize=10)
|
||||
ax4.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
save_path = os.path.join(save_dir, 'training_curves.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"训练曲线已保存到: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_epsilon_decay(epsilon_start, epsilon_end, decay_steps, save_dir="plots"):
|
||||
"""绘制ε衰减曲线
|
||||
|
||||
Args:
|
||||
epsilon_start: ε初始值
|
||||
epsilon_end: ε最终值
|
||||
decay_steps: 衰减步数
|
||||
save_dir: 保存目录
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
steps = np.linspace(0, decay_steps, 1000)
|
||||
epsilons = epsilon_start - (epsilon_start - epsilon_end) * (steps / decay_steps)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
ax.plot(steps / 1e6, epsilons, color='blue', linewidth=2)
|
||||
ax.set_xlabel('训练步数 (百万)', fontsize=12)
|
||||
ax.set_ylabel('Epsilon (ε)', fontsize=12)
|
||||
ax.set_title('Epsilon衰减曲线', fontsize=14)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_ylim(0, 1.1)
|
||||
|
||||
# 标注关键点
|
||||
ax.axhline(y=epsilon_end, color='red', linestyle='--', alpha=0.5)
|
||||
ax.text(decay_steps * 0.8 / 1e6, epsilon_end + 0.05,
|
||||
f'最终值: {epsilon_end}', fontsize=10, color='red')
|
||||
|
||||
save_path = os.path.join(save_dir, 'epsilon_decay.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"ε衰减曲线已保存到: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_evaluation_results(eval_rewards, save_dir="plots"):
|
||||
"""绘制评估结果
|
||||
|
||||
Args:
|
||||
eval_rewards: 评估回报列表 [(step, reward), ...]
|
||||
save_dir: 保存目录
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if not eval_rewards:
|
||||
print("没有评估数据")
|
||||
return
|
||||
|
||||
steps, rewards = zip(*eval_rewards)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
ax.plot(np.array(steps) / 1e6, rewards, 'o-', color='blue',
|
||||
linewidth=2, markersize=8, label='评估回报')
|
||||
|
||||
# 添加趋势线
|
||||
if len(rewards) > 1:
|
||||
z = np.polyfit(steps, rewards, 1)
|
||||
p = np.poly1d(z)
|
||||
ax.plot(np.array(steps) / 1e6, p(steps), '--', color='red',
|
||||
linewidth=2, alpha=0.7, label='趋势线')
|
||||
|
||||
ax.set_xlabel('训练步数 (百万)', fontsize=12)
|
||||
ax.set_ylabel('平均回报', fontsize=12)
|
||||
ax.set_title('评估回报变化', fontsize=14)
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
save_path = os.path.join(save_dir, 'evaluation_results.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"评估结果已保存到: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def generate_sample_plots(save_dir="plots"):
|
||||
"""生成示例图表(用于报告)
|
||||
|
||||
Args:
|
||||
save_dir: 保存目录
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 模拟训练数据
|
||||
np.random.seed(42)
|
||||
num_episodes = 500
|
||||
|
||||
# 模拟回报(逐渐上升)
|
||||
base_rewards = np.linspace(-20, 200, num_episodes)
|
||||
noise = np.random.normal(0, 30, num_episodes)
|
||||
rewards = base_rewards + noise
|
||||
|
||||
# 模拟损失(逐渐下降)
|
||||
num_steps = 1000
|
||||
base_loss = np.exp(-np.linspace(0, 3, num_steps)) * 10
|
||||
loss_noise = np.random.normal(0, 0.5, num_steps)
|
||||
losses = base_loss + loss_noise
|
||||
|
||||
# 模拟Q值(逐渐上升)
|
||||
base_q = np.linspace(0, 50, num_steps)
|
||||
q_noise = np.random.normal(0, 5, num_steps)
|
||||
q_values = base_q + q_noise
|
||||
|
||||
# 绘制图表
|
||||
plot_training_curves(rewards, losses, q_values, save_dir)
|
||||
plot_epsilon_decay(1.0, 0.01, 1_000_000, save_dir)
|
||||
|
||||
# 模拟评估数据
|
||||
eval_steps = [100000, 200000, 500000, 1000000, 1500000, 2000000]
|
||||
eval_rewards = [50, 100, 150, 180, 190, 195]
|
||||
plot_evaluation_results(list(zip(eval_steps, eval_rewards)), save_dir)
|
||||
|
||||
print(f"\n示例图表已生成到: {save_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="生成训练图表")
|
||||
parser.add_argument("--log-file", type=str, default=None,
|
||||
help="训练日志文件路径")
|
||||
parser.add_argument("--save-dir", type=str, default="plots",
|
||||
help="图表保存目录")
|
||||
parser.add_argument("--sample", action="store_true",
|
||||
help="生成示例图表")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sample:
|
||||
generate_sample_plots(args.save_dir)
|
||||
elif args.log_file:
|
||||
logs = load_training_logs(args.log_file)
|
||||
plot_training_curves(
|
||||
logs.get('rewards', []),
|
||||
logs.get('losses', []),
|
||||
logs.get('q_values', []),
|
||||
args.save_dir
|
||||
)
|
||||
else:
|
||||
print("请指定 --log-file 或 --sample 参数")
|
||||
@@ -0,0 +1,10 @@
|
||||
# DQN for Space Invaders - Dependencies
|
||||
torch>=2.0.0
|
||||
numpy>=1.24.0
|
||||
gymnasium>=0.29.0
|
||||
gymnasium[atari]>=0.29.0
|
||||
gymnasium[accept-rom-license]>=0.29.0
|
||||
ale-py>=0.8.0
|
||||
opencv-python>=4.8.0
|
||||
matplotlib>=3.7.0
|
||||
tensorboard>=2.14.0
|
||||
@@ -0,0 +1 @@
|
||||
# DQN for Space Invaders
|
||||
@@ -0,0 +1,184 @@
|
||||
"""DQN Agent implementation."""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class DQNAgent:
|
||||
"""DQN智能体
|
||||
|
||||
实现ε-greedy探索策略和Q-learning更新
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
q_network,
|
||||
target_network,
|
||||
replay_buffer,
|
||||
device,
|
||||
num_actions=6,
|
||||
gamma=0.99,
|
||||
lr=1e-4,
|
||||
epsilon_start=1.0,
|
||||
epsilon_end=0.01,
|
||||
epsilon_decay_steps=1_000_000,
|
||||
target_update_freq=1000,
|
||||
batch_size=32,
|
||||
double_dqn=True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
q_network: Q网络
|
||||
target_network: 目标网络
|
||||
replay_buffer: 经验回放缓冲区
|
||||
device: 设备
|
||||
num_actions: 动作数量
|
||||
gamma: 折扣因子
|
||||
lr: 学习率
|
||||
epsilon_start: ε初始值
|
||||
epsilon_end: ε最终值
|
||||
epsilon_decay_steps: ε衰减步数
|
||||
target_update_freq: 目标网络更新频率
|
||||
batch_size: 批次大小
|
||||
double_dqn: 是否使用Double DQN
|
||||
"""
|
||||
self.q_network = q_network
|
||||
self.target_network = target_network
|
||||
self.replay_buffer = replay_buffer
|
||||
self.device = device
|
||||
self.num_actions = num_actions
|
||||
self.gamma = gamma
|
||||
self.batch_size = batch_size
|
||||
self.target_update_freq = target_update_freq
|
||||
self.double_dqn = double_dqn
|
||||
|
||||
# ε-greedy参数
|
||||
self.epsilon_start = epsilon_start
|
||||
self.epsilon_end = epsilon_end
|
||||
self.epsilon_decay_steps = epsilon_decay_steps
|
||||
self.epsilon = epsilon_start
|
||||
|
||||
# 优化器
|
||||
self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr)
|
||||
|
||||
# 训练步数
|
||||
self.train_step = 0
|
||||
|
||||
# 训练统计
|
||||
self.loss_history = []
|
||||
self.q_value_history = []
|
||||
|
||||
def select_action(self, state, evaluate=False):
|
||||
"""选择动作
|
||||
|
||||
Args:
|
||||
state: 当前状态 (channels, height, width)
|
||||
evaluate: 是否为评估模式(不使用ε-greedy)
|
||||
|
||||
Returns:
|
||||
action: 选择的动作
|
||||
"""
|
||||
if evaluate:
|
||||
# 评估模式:纯贪心
|
||||
epsilon = 0.01
|
||||
else:
|
||||
# 训练模式:ε-greedy
|
||||
epsilon = self.epsilon
|
||||
|
||||
if np.random.random() < epsilon:
|
||||
# 随机探索
|
||||
return np.random.randint(self.num_actions)
|
||||
else:
|
||||
# 贪心选择
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
|
||||
q_values = self.q_network(state_tensor)
|
||||
return q_values.argmax(dim=1).item()
|
||||
|
||||
def update_epsilon(self):
|
||||
"""更新ε值(线性衰减)"""
|
||||
if self.train_step < self.epsilon_decay_steps:
|
||||
self.epsilon = self.epsilon_start - (self.epsilon_start - self.epsilon_end) * \
|
||||
(self.train_step / self.epsilon_decay_steps)
|
||||
else:
|
||||
self.epsilon = self.epsilon_end
|
||||
|
||||
def train_step(self):
|
||||
"""执行一步训练
|
||||
|
||||
Returns:
|
||||
loss: 损失值
|
||||
avg_q: 平均Q值
|
||||
"""
|
||||
# 检查是否有足够样本
|
||||
if len(self.replay_buffer) < self.batch_size:
|
||||
return None, None
|
||||
|
||||
# 采样
|
||||
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
|
||||
|
||||
# 计算当前Q值
|
||||
q_values = self.q_network(states)
|
||||
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# 计算目标Q值
|
||||
with torch.no_grad():
|
||||
if self.double_dqn:
|
||||
# Double DQN: 用Q网络选择动作,用目标网络评估
|
||||
next_actions = self.q_network(next_states).argmax(dim=1)
|
||||
next_q_values = self.target_network(next_states)
|
||||
next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
# 标准DQN: 直接用目标网络的最大Q值
|
||||
next_q_values = self.target_network(next_states).max(dim=1)[0]
|
||||
|
||||
# 计算目标
|
||||
target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
|
||||
|
||||
# 计算损失
|
||||
loss = F.mse_loss(q_values, target_q_values)
|
||||
|
||||
# 反向传播
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# 梯度裁剪
|
||||
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10)
|
||||
self.optimizer.step()
|
||||
|
||||
# 更新目标网络
|
||||
self.train_step += 1
|
||||
if self.train_step % self.target_update_freq == 0:
|
||||
self.target_network.load_state_dict(self.q_network.state_dict())
|
||||
|
||||
# 更新ε
|
||||
self.update_epsilon()
|
||||
|
||||
# 记录统计
|
||||
avg_q = q_values.mean().item()
|
||||
self.loss_history.append(loss.item())
|
||||
self.q_value_history.append(avg_q)
|
||||
|
||||
return loss.item(), avg_q
|
||||
|
||||
def save(self, path):
|
||||
"""保存模型"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save({
|
||||
'q_network': self.q_network.state_dict(),
|
||||
'target_network': self.target_network.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'train_step': self.train_step,
|
||||
'epsilon': self.epsilon,
|
||||
}, path)
|
||||
print(f"模型已保存到: {path}")
|
||||
|
||||
def load(self, path):
|
||||
"""加载模型"""
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
self.q_network.load_state_dict(checkpoint['q_network'])
|
||||
self.target_network.load_state_dict(checkpoint['target_network'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
self.train_step = checkpoint['train_step']
|
||||
self.epsilon = checkpoint['epsilon']
|
||||
print(f"模型已从 {path} 加载")
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Evaluation script for trained DQN agent."""
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
from src.network import QNetwork, DuelingQNetwork
|
||||
from src.utils import make_env, get_device
|
||||
|
||||
|
||||
def evaluate(agent, env, num_episodes=10, render=False):
|
||||
"""评估智能体
|
||||
|
||||
Args:
|
||||
agent: DQN智能体
|
||||
env: 评估环境
|
||||
num_episodes: 评估episode数
|
||||
render: 是否渲染
|
||||
|
||||
Returns:
|
||||
avg_reward: 平均回报
|
||||
std_reward: 回报标准差
|
||||
rewards: 所有回报
|
||||
"""
|
||||
rewards = []
|
||||
|
||||
for ep in range(num_episodes):
|
||||
state, _ = env.reset()
|
||||
episode_reward = 0
|
||||
done = False
|
||||
steps = 0
|
||||
|
||||
while not done:
|
||||
# 选择动作(贪心策略)
|
||||
action = agent.select_action(state, evaluate=True)
|
||||
|
||||
# 执行动作
|
||||
state, reward, terminated, truncated, _ = env.step(action)
|
||||
done = terminated or truncated
|
||||
episode_reward += reward
|
||||
steps += 1
|
||||
|
||||
if render:
|
||||
env.render()
|
||||
|
||||
rewards.append(episode_reward)
|
||||
print(f"Episode {ep+1}/{num_episodes}: 回报={episode_reward:.1f}, 步数={steps}")
|
||||
|
||||
avg_reward = np.mean(rewards)
|
||||
std_reward = np.std(rewards)
|
||||
|
||||
return avg_reward, std_reward, rewards
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="评估DQN智能体")
|
||||
parser.add_argument("--model", type=str, required=True,
|
||||
help="模型路径")
|
||||
parser.add_argument("--env", type=str, default="ALE/SpaceInvaders-v5",
|
||||
help="环境ID")
|
||||
parser.add_argument("--episodes", type=int, default=10,
|
||||
help="评估episode数")
|
||||
parser.add_argument("--render", action="store_true",
|
||||
help="是否渲染")
|
||||
parser.add_argument("--dueling", action="store_true",
|
||||
help="是否使用Dueling架构")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="随机种子")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# 获取设备
|
||||
device = get_device()
|
||||
|
||||
# 创建环境
|
||||
env = make_env(args.env, gray_scale=True, resize=True, frame_stack=4)
|
||||
|
||||
# 获取动作空间
|
||||
num_actions = env.action_space.n
|
||||
state_shape = (4, 84, 84)
|
||||
|
||||
# 创建网络
|
||||
if args.dueling:
|
||||
q_network = DuelingQNetwork(state_shape, num_actions).to(device)
|
||||
else:
|
||||
q_network = QNetwork(state_shape, num_actions).to(device)
|
||||
|
||||
# 加载模型
|
||||
print(f"加载模型: {args.model}")
|
||||
checkpoint = torch.load(args.model, map_location=device, weights_only=False)
|
||||
q_network.load_state_dict(checkpoint['q_network'])
|
||||
q_network.eval()
|
||||
|
||||
# 创建简单的agent类用于评估
|
||||
class EvalAgent:
|
||||
def __init__(self, q_network, device, num_actions):
|
||||
self.q_network = q_network
|
||||
self.device = device
|
||||
self.num_actions = num_actions
|
||||
|
||||
def select_action(self, state, evaluate=True):
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
|
||||
q_values = self.q_network(state_tensor)
|
||||
return q_values.argmax(dim=1).item()
|
||||
|
||||
agent = EvalAgent(q_network, device, num_actions)
|
||||
|
||||
# 评估
|
||||
print(f"\n开始评估,共{args.episodes}个episode...")
|
||||
print("=" * 60)
|
||||
|
||||
avg_reward, std_reward, rewards = evaluate(
|
||||
agent, env, num_episodes=args.episodes, render=args.render
|
||||
)
|
||||
|
||||
# 打印结果
|
||||
print("=" * 60)
|
||||
print(f"评估结果:")
|
||||
print(f" 平均回报: {avg_reward:.2f}")
|
||||
print(f" 标准差: {std_reward:.2f}")
|
||||
print(f" 最大回报: {max(rewards):.1f}")
|
||||
print(f" 最小回报: {min(rewards):.1f}")
|
||||
print(f" 中位数回报: {np.median(rewards):.1f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Q-Network architecture for DQN."""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class QNetwork(nn.Module):
|
||||
"""Deep Q-Network for Atari games.
|
||||
|
||||
Architecture follows the original DQN paper:
|
||||
- 3 convolutional layers for feature extraction
|
||||
- 2 fully connected layers for Q-value estimation
|
||||
"""
|
||||
|
||||
def __init__(self, state_shape=(4, 84, 84), num_actions=6):
|
||||
"""
|
||||
Args:
|
||||
state_shape: Shape of input state (channels, height, width)
|
||||
num_actions: Number of possible actions
|
||||
"""
|
||||
super().__init__()
|
||||
c, h, w = state_shape
|
||||
|
||||
# 卷积层提取特征
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# 计算卷积输出维度
|
||||
out_h = (h - 8) // 4 + 1
|
||||
out_h = (out_h - 4) // 2 + 1
|
||||
out_h = (out_h - 3) // 1 + 1
|
||||
feat_size = 64 * out_h * out_h
|
||||
|
||||
# 全连接层输出Q值
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_size, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, num_actions),
|
||||
)
|
||||
|
||||
# 初始化权重
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""使用He初始化"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
前向传播,返回每个动作的Q值
|
||||
|
||||
Args:
|
||||
x: 输入状态张量 (batch_size, channels, height, width)
|
||||
|
||||
Returns:
|
||||
q_values: 每个动作的Q值 (batch_size, num_actions)
|
||||
"""
|
||||
# 归一化到[0, 1]
|
||||
x = x / 255.0
|
||||
|
||||
# 卷积特征提取
|
||||
x = self.conv(x)
|
||||
|
||||
# 展平
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
# 全连接层输出Q值
|
||||
q_values = self.fc(x)
|
||||
|
||||
return q_values
|
||||
|
||||
|
||||
class DuelingQNetwork(nn.Module):
|
||||
"""Dueling DQN架构,分离状态价值和优势函数
|
||||
|
||||
相比标准DQN,Dueling架构能更好地学习状态价值
|
||||
"""
|
||||
|
||||
def __init__(self, state_shape=(4, 84, 84), num_actions=6):
|
||||
super().__init__()
|
||||
c, h, w = state_shape
|
||||
|
||||
# 共享卷积层
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# 计算卷积输出维度
|
||||
out_h = (h - 8) // 4 + 1
|
||||
out_h = (out_h - 4) // 2 + 1
|
||||
out_h = (out_h - 3) // 1 + 1
|
||||
feat_size = 64 * out_h * out_h
|
||||
|
||||
# 价值流
|
||||
self.value_stream = nn.Sequential(
|
||||
nn.Linear(feat_size, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 1),
|
||||
)
|
||||
|
||||
# 优势流
|
||||
self.advantage_stream = nn.Sequential(
|
||||
nn.Linear(feat_size, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, num_actions),
|
||||
)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""前向传播,返回Q值"""
|
||||
x = x / 255.0
|
||||
x = self.conv(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
# 计算价值和优势
|
||||
value = self.value_stream(x)
|
||||
advantage = self.advantage_stream(x)
|
||||
|
||||
# Q = V(s) + A(s,a) - mean(A(s,a))
|
||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
return q_values
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Experience Replay Buffer for DQN."""
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
"""经验回放缓冲区
|
||||
|
||||
存储转移 (s, a, r, s', done),随机采样打破数据相关性
|
||||
"""
|
||||
|
||||
def __init__(self, capacity, state_shape=(4, 84, 84), device='cpu'):
|
||||
"""
|
||||
Args:
|
||||
capacity: 缓冲区容量
|
||||
state_shape: 状态形状 (channels, height, width)
|
||||
device: 设备 (cpu/cuda)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.ptr = 0
|
||||
self.size = 0
|
||||
|
||||
# 预分配内存
|
||||
self.states = np.zeros((capacity, *state_shape), dtype=np.uint8)
|
||||
self.actions = np.zeros(capacity, dtype=np.int64)
|
||||
self.rewards = np.zeros(capacity, dtype=np.float32)
|
||||
self.next_states = np.zeros((capacity, *state_shape), dtype=np.uint8)
|
||||
self.dones = np.zeros(capacity, dtype=np.bool_)
|
||||
|
||||
def add(self, state, action, reward, next_state, done):
|
||||
"""添加一个转移
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
action: 执行的动作
|
||||
reward: 获得的奖励
|
||||
next_state: 下一个状态
|
||||
done: 是否结束
|
||||
"""
|
||||
self.states[self.ptr] = state
|
||||
self.actions[self.ptr] = action
|
||||
self.rewards[self.ptr] = reward
|
||||
self.next_states[self.ptr] = next_state
|
||||
self.dones[self.ptr] = done
|
||||
|
||||
# 循环缓冲区
|
||||
self.ptr = (self.ptr + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""随机采样一个批次
|
||||
|
||||
Args:
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
states, actions, rewards, next_states, dones
|
||||
"""
|
||||
indices = np.random.randint(0, self.size, size=batch_size)
|
||||
|
||||
states = torch.from_numpy(self.states[indices]).float().to(self.device)
|
||||
actions = torch.from_numpy(self.actions[indices]).long().to(self.device)
|
||||
rewards = torch.from_numpy(self.rewards[indices]).float().to(self.device)
|
||||
next_states = torch.from_numpy(self.next_states[indices]).float().to(self.device)
|
||||
dones = torch.from_numpy(self.dones[indices]).float().to(self.device)
|
||||
|
||||
return states, actions, rewards, next_states, dones
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
|
||||
class PrioritizedReplayBuffer:
|
||||
"""优先经验回放缓冲区
|
||||
|
||||
根据TD误差优先采样,提高样本效率
|
||||
"""
|
||||
|
||||
def __init__(self, capacity, state_shape=(4, 84, 84), device='cpu', alpha=0.6):
|
||||
"""
|
||||
Args:
|
||||
capacity: 缓冲区容量
|
||||
state_shape: 状态形状
|
||||
device: 设备
|
||||
alpha: 优先级指数 (0=均匀采样, 1=完全按优先级采样)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.alpha = alpha
|
||||
self.ptr = 0
|
||||
self.size = 0
|
||||
self.max_priority = 1.0
|
||||
|
||||
# 数据存储
|
||||
self.states = np.zeros((capacity, *state_shape), dtype=np.uint8)
|
||||
self.actions = np.zeros(capacity, dtype=np.int64)
|
||||
self.rewards = np.zeros(capacity, dtype=np.float32)
|
||||
self.next_states = np.zeros((capacity, *state_shape), dtype=np.uint8)
|
||||
self.dones = np.zeros(capacity, dtype=np.bool_)
|
||||
|
||||
# 优先级存储
|
||||
self.priorities = np.zeros(capacity, dtype=np.float32)
|
||||
|
||||
def add(self, state, action, reward, next_state, done):
|
||||
"""添加转移,使用最大优先级"""
|
||||
self.states[self.ptr] = state
|
||||
self.actions[self.ptr] = action
|
||||
self.rewards[self.ptr] = reward
|
||||
self.next_states[self.ptr] = next_state
|
||||
self.dones[self.ptr] = done
|
||||
|
||||
# 新样本使用最大优先级
|
||||
self.priorities[self.ptr] = self.max_priority
|
||||
|
||||
self.ptr = (self.ptr + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
def sample(self, batch_size, beta=0.4):
|
||||
"""按优先级采样
|
||||
|
||||
Args:
|
||||
batch_size: 批次大小
|
||||
beta: 重要性采样指数
|
||||
|
||||
Returns:
|
||||
states, actions, rewards, next_states, dones, indices, weights
|
||||
"""
|
||||
# 计算采样概率
|
||||
priorities = self.priorities[:self.size] ** self.alpha
|
||||
probs = priorities / priorities.sum()
|
||||
|
||||
# 按概率采样
|
||||
indices = np.random.choice(self.size, size=batch_size, p=probs)
|
||||
|
||||
# 计算重要性采样权重
|
||||
weights = (self.size * probs[indices]) ** (-beta)
|
||||
weights = weights / weights.max()
|
||||
|
||||
# 获取数据
|
||||
states = torch.from_numpy(self.states[indices]).float().to(self.device)
|
||||
actions = torch.from_numpy(self.actions[indices]).long().to(self.device)
|
||||
rewards = torch.from_numpy(self.rewards[indices]).float().to(self.device)
|
||||
next_states = torch.from_numpy(self.next_states[indices]).float().to(self.device)
|
||||
dones = torch.from_numpy(self.dones[indices]).float().to(self.device)
|
||||
weights = torch.from_numpy(weights).float().to(self.device)
|
||||
|
||||
return states, actions, rewards, next_states, dones, indices, weights
|
||||
|
||||
def update_priorities(self, indices, td_errors):
|
||||
"""更新优先级
|
||||
|
||||
Args:
|
||||
indices: 样本索引
|
||||
td_errors: TD误差
|
||||
"""
|
||||
priorities = np.abs(td_errors) + 1e-6
|
||||
self.priorities[indices] = priorities
|
||||
self.max_priority = max(self.max_priority, priorities.max())
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
@@ -0,0 +1,166 @@
|
||||
"""DQN Trainer."""
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
|
||||
class DQNTrainer:
|
||||
"""DQN训练器
|
||||
|
||||
处理训练循环、日志记录和模型保存
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent,
|
||||
env,
|
||||
eval_env,
|
||||
log_dir="logs",
|
||||
save_dir="models",
|
||||
eval_freq=10000,
|
||||
save_freq=50000,
|
||||
num_eval_episodes=10,
|
||||
warmup_steps=10000,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
agent: DQN智能体
|
||||
env: 训练环境
|
||||
eval_env: 评估环境
|
||||
log_dir: 日志目录
|
||||
save_dir: 模型保存目录
|
||||
eval_freq: 评估频率(步数)
|
||||
save_freq: 保存频率(步数)
|
||||
num_eval_episodes: 评估episode数
|
||||
warmup_steps: 预热步数(随机探索)
|
||||
"""
|
||||
self.agent = agent
|
||||
self.env = env
|
||||
self.eval_env = eval_env
|
||||
self.log_dir = log_dir
|
||||
self.save_dir = save_dir
|
||||
self.eval_freq = eval_freq
|
||||
self.save_freq = save_freq
|
||||
self.num_eval_episodes = num_eval_episodes
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
# 训练统计
|
||||
self.episode_rewards = deque(maxlen=100)
|
||||
self.episode_lengths = deque(maxlen=100)
|
||||
self.eval_rewards = []
|
||||
self.best_eval_reward = -float('inf')
|
||||
|
||||
def train(self, total_steps):
|
||||
"""主训练循环
|
||||
|
||||
Args:
|
||||
total_steps: 总训练步数
|
||||
"""
|
||||
print(f"开始训练,总步数: {total_steps:,}")
|
||||
print(f"预热步数: {self.warmup_steps:,}")
|
||||
print("=" * 60)
|
||||
|
||||
state, _ = self.env.reset()
|
||||
episode_reward = 0
|
||||
episode_length = 0
|
||||
episode_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
for step in range(1, total_steps + 1):
|
||||
# 选择动作
|
||||
if step < self.warmup_steps:
|
||||
# 预热阶段:纯随机探索
|
||||
action = self.env.action_space.sample()
|
||||
else:
|
||||
# 训练阶段:ε-greedy
|
||||
action = self.agent.select_action(state)
|
||||
|
||||
# 执行动作
|
||||
next_state, reward, terminated, truncated, _ = self.env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
# 存储转移
|
||||
self.agent.replay_buffer.add(state, action, reward, next_state, done)
|
||||
|
||||
# 训练
|
||||
if step >= self.warmup_steps:
|
||||
loss, avg_q = self.agent.train_step()
|
||||
|
||||
# 更新状态
|
||||
state = next_state
|
||||
episode_reward += reward
|
||||
episode_length += 1
|
||||
|
||||
# Episode结束
|
||||
if done:
|
||||
self.episode_rewards.append(episode_reward)
|
||||
self.episode_lengths.append(episode_length)
|
||||
episode_count += 1
|
||||
|
||||
# 打印进度
|
||||
if episode_count % 10 == 0:
|
||||
avg_reward = np.mean(self.episode_rewards)
|
||||
avg_length = np.mean(self.episode_lengths)
|
||||
elapsed = time.time() - start_time
|
||||
fps = step / elapsed
|
||||
|
||||
print(f"Step: {step:>8,} | "
|
||||
f"Episode: {episode_count:>5} | "
|
||||
f"Avg Reward: {avg_reward:>7.1f} | "
|
||||
f"Avg Length: {avg_length:>6.1f} | "
|
||||
f"Epsilon: {self.agent.epsilon:.3f} | "
|
||||
f"FPS: {fps:.0f}")
|
||||
|
||||
# 重置环境
|
||||
state, _ = self.env.reset()
|
||||
episode_reward = 0
|
||||
episode_length = 0
|
||||
|
||||
# 定期评估
|
||||
if step % self.eval_freq == 0:
|
||||
eval_reward = self.evaluate()
|
||||
self.eval_rewards.append((step, eval_reward))
|
||||
print(f"\n[评估] Step {step:>8,} | 平均回报: {eval_reward:.1f}")
|
||||
print("-" * 60)
|
||||
|
||||
# 保存最佳模型
|
||||
if eval_reward > self.best_eval_reward:
|
||||
self.best_eval_reward = eval_reward
|
||||
self.agent.save(f"{self.save_dir}/dqn_best.pt")
|
||||
|
||||
# 定期保存
|
||||
if step % self.save_freq == 0:
|
||||
self.agent.save(f"{self.save_dir}/dqn_step_{step}.pt")
|
||||
|
||||
# 训练结束
|
||||
total_time = time.time() - start_time
|
||||
print("\n" + "=" * 60)
|
||||
print(f"训练完成!总时间: {total_time:.1f}秒")
|
||||
print(f"最佳评估回报: {self.best_eval_reward:.1f}")
|
||||
|
||||
# 保存最终模型
|
||||
self.agent.save(f"{self.save_dir}/dqn_final.pt")
|
||||
|
||||
def evaluate(self):
|
||||
"""评估智能体
|
||||
|
||||
Returns:
|
||||
avg_reward: 平均回报
|
||||
"""
|
||||
rewards = []
|
||||
|
||||
for _ in range(self.num_eval_episodes):
|
||||
state, _ = self.eval_env.reset()
|
||||
episode_reward = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
action = self.agent.select_action(state, evaluate=True)
|
||||
state, reward, terminated, truncated, _ = self.eval_env.step(action)
|
||||
done = terminated or truncated
|
||||
episode_reward += reward
|
||||
|
||||
rewards.append(episode_reward)
|
||||
|
||||
return np.mean(rewards)
|
||||
@@ -0,0 +1,200 @@
|
||||
"""Environment wrappers and utility functions."""
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from collections import deque
|
||||
|
||||
# 注册ALE环境
|
||||
try:
|
||||
import ale_py
|
||||
gym.register_envs(ale_py)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class GrayScaleWrapper(gym.ObservationWrapper):
|
||||
"""将RGB观测转换为灰度图"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, obs):
|
||||
# RGB转灰度:加权平均
|
||||
gray = 0.299 * obs[:, :, 0] + 0.587 * obs[:, :, 1] + 0.114 * obs[:, :, 2]
|
||||
return gray.astype(np.uint8)
|
||||
|
||||
|
||||
class ResizeWrapper(gym.ObservationWrapper):
|
||||
"""调整观测大小"""
|
||||
|
||||
def __init__(self, env, size=(84, 84)):
|
||||
super().__init__(env)
|
||||
self.size = size
|
||||
|
||||
def observation(self, obs):
|
||||
import cv2
|
||||
# 如果是灰度图,需要扩展维度
|
||||
if len(obs.shape) == 2:
|
||||
obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
|
||||
return obs
|
||||
|
||||
|
||||
class FrameStackWrapper(gym.ObservationWrapper):
|
||||
"""堆叠N帧观测"""
|
||||
|
||||
def __init__(self, env, num_stack=4):
|
||||
super().__init__(env)
|
||||
self.num_stack = num_stack
|
||||
self.frames = deque(maxlen=num_stack)
|
||||
obs_shape = env.observation_space.shape
|
||||
|
||||
# 更新观测空间
|
||||
if len(obs_shape) == 2:
|
||||
# 灰度图
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255,
|
||||
shape=(num_stack, *obs_shape),
|
||||
dtype=np.uint8
|
||||
)
|
||||
else:
|
||||
# RGB图
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255,
|
||||
shape=(num_stack, *obs_shape[-2:]),
|
||||
dtype=np.uint8
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
for _ in range(self.num_stack):
|
||||
self.frames.append(obs)
|
||||
return self._get_observation(), info
|
||||
|
||||
def observation(self, obs):
|
||||
self.frames.append(obs)
|
||||
return self._get_observation()
|
||||
|
||||
def _get_observation(self):
|
||||
return np.stack(list(self.frames), axis=0)
|
||||
|
||||
|
||||
class RewardClipWrapper(gym.RewardWrapper):
|
||||
"""裁剪奖励到[-1, 1]"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def reward(self, reward):
|
||||
return np.clip(reward, -1, 1)
|
||||
|
||||
|
||||
class NoopResetWrapper(gym.Wrapper):
|
||||
"""在reset时随机执行noop动作,增加初始状态随机性"""
|
||||
|
||||
def __init__(self, env, noop_max=30):
|
||||
super().__init__(env)
|
||||
self.noop_max = noop_max
|
||||
self.noop_action = 0
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
# 随机执行noop动作
|
||||
noop_times = np.random.randint(1, self.noop_max + 1)
|
||||
for _ in range(noop_times):
|
||||
obs, reward, terminated, truncated, info = self.env.step(self.noop_action)
|
||||
if terminated or truncated:
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
return obs, info
|
||||
|
||||
|
||||
class MaxAndSkipWrapper(gym.Wrapper):
|
||||
"""跳帧并取最大值,减少计算量"""
|
||||
|
||||
def __init__(self, env, skip=4):
|
||||
super().__init__(env)
|
||||
self.skip = skip
|
||||
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
|
||||
|
||||
def step(self, action):
|
||||
total_reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
for i in range(self.skip):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
total_reward += reward
|
||||
|
||||
if i == self.skip - 2:
|
||||
self._obs_buffer[0] = obs
|
||||
if i == self.skip - 1:
|
||||
self._obs_buffer[1] = obs
|
||||
|
||||
if terminated or truncated:
|
||||
break
|
||||
|
||||
# 取最近两帧的最大值
|
||||
max_frame = self._obs_buffer.max(axis=0)
|
||||
|
||||
return max_frame, total_reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
|
||||
def make_env(env_id="ALE/SpaceInvaders-v5", gray_scale=True, resize=True,
|
||||
frame_stack=4, reward_clip=True, noop_reset=True, max_skip=4):
|
||||
"""创建预处理后的Atari环境
|
||||
|
||||
Args:
|
||||
env_id: 环境ID
|
||||
gray_scale: 是否灰度化
|
||||
resize: 是否调整大小
|
||||
frame_stack: 堆叠帧数
|
||||
reward_clip: 是否裁剪奖励
|
||||
noop_reset: 是否使用noop reset
|
||||
max_skip: 跳帧数
|
||||
|
||||
Returns:
|
||||
env: 预处理后的环境
|
||||
"""
|
||||
env = gym.make(env_id, render_mode="rgb_array")
|
||||
|
||||
if noop_reset:
|
||||
env = NoopResetWrapper(env, noop_max=30)
|
||||
|
||||
if max_skip > 1:
|
||||
env = MaxAndSkipWrapper(env, skip=max_skip)
|
||||
|
||||
if resize:
|
||||
env = ResizeWrapper(env, size=(84, 84))
|
||||
|
||||
if gray_scale:
|
||||
env = GrayScaleWrapper(env)
|
||||
|
||||
if reward_clip:
|
||||
env = RewardClipWrapper(env)
|
||||
|
||||
if frame_stack > 1:
|
||||
env = FrameStackWrapper(env, num_stack=frame_stack)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def get_device():
|
||||
"""检测并返回可用设备"""
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
print(f"使用GPU: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("使用CPU")
|
||||
return device
|
||||
|
||||
|
||||
def preprocess_obs(obs):
|
||||
"""确保观测格式正确"""
|
||||
if len(obs.shape) == 2:
|
||||
obs = np.expand_dims(obs, axis=0)
|
||||
return obs
|
||||
@@ -0,0 +1,170 @@
|
||||
"""Main training script for DQN on Space Invaders."""
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from src.network import QNetwork, DuelingQNetwork
|
||||
from src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
from src.agent import DQNAgent
|
||||
from src.trainer import DQNTrainer
|
||||
from src.utils import make_env, get_device
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="DQN for Space Invaders")
|
||||
|
||||
# 环境参数
|
||||
parser.add_argument("--env", type=str, default="ALE/SpaceInvaders-v5",
|
||||
help="Atari环境ID")
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument("--steps", type=int, default=2_000_000,
|
||||
help="总训练步数")
|
||||
parser.add_argument("--lr", type=float, default=1e-4,
|
||||
help="学习率")
|
||||
parser.add_argument("--gamma", type=float, default=0.99,
|
||||
help="折扣因子")
|
||||
parser.add_argument("--batch-size", type=int, default=32,
|
||||
help="批次大小")
|
||||
parser.add_argument("--buffer-size", type=int, default=100_000,
|
||||
help="经验回放缓冲区大小")
|
||||
|
||||
# ε-greedy参数
|
||||
parser.add_argument("--epsilon-start", type=float, default=1.0,
|
||||
help="ε初始值")
|
||||
parser.add_argument("--epsilon-end", type=float, default=0.01,
|
||||
help="ε最终值")
|
||||
parser.add_argument("--epsilon-decay", type=int, default=1_000_000,
|
||||
help="ε衰减步数")
|
||||
|
||||
# 网络参数
|
||||
parser.add_argument("--target-update", type=int, default=1000,
|
||||
help="目标网络更新频率")
|
||||
parser.add_argument("--double-dqn", action="store_true", default=True,
|
||||
help="使用Double DQN")
|
||||
parser.add_argument("--dueling", action="store_true", default=False,
|
||||
help="使用Dueling DQN架构")
|
||||
|
||||
# 评估参数
|
||||
parser.add_argument("--eval-freq", type=int, default=10000,
|
||||
help="评估频率")
|
||||
parser.add_argument("--eval-episodes", type=int, default=10,
|
||||
help="评估episode数")
|
||||
parser.add_argument("--save-freq", type=int, default=50000,
|
||||
help="模型保存频率")
|
||||
parser.add_argument("--warmup", type=int, default=10000,
|
||||
help="预热步数")
|
||||
|
||||
# 优先经验回放
|
||||
parser.add_argument("--prioritized", action="store_true", default=False,
|
||||
help="使用优先经验回放")
|
||||
|
||||
# 其他
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="随机种子")
|
||||
parser.add_argument("--save-dir", type=str, default="models",
|
||||
help="模型保存目录")
|
||||
parser.add_argument("--log-dir", type=str, default="logs",
|
||||
help="日志目录")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# 获取设备
|
||||
device = get_device()
|
||||
|
||||
# 创建环境
|
||||
print(f"创建环境: {args.env}")
|
||||
env = make_env(args.env, gray_scale=True, resize=True, frame_stack=4)
|
||||
eval_env = make_env(args.env, gray_scale=True, resize=True, frame_stack=4)
|
||||
|
||||
# 获取动作空间大小
|
||||
num_actions = env.action_space.n
|
||||
print(f"动作空间: {num_actions}")
|
||||
|
||||
# 创建网络
|
||||
state_shape = (4, 84, 84) # 4帧堆叠,84x84灰度图
|
||||
|
||||
if args.dueling:
|
||||
print("使用Dueling DQN架构")
|
||||
q_network = DuelingQNetwork(state_shape, num_actions).to(device)
|
||||
target_network = DuelingQNetwork(state_shape, num_actions).to(device)
|
||||
else:
|
||||
print("使用标准DQN架构")
|
||||
q_network = QNetwork(state_shape, num_actions).to(device)
|
||||
target_network = QNetwork(state_shape, num_actions).to(device)
|
||||
|
||||
# 复制初始权重到目标网络
|
||||
target_network.load_state_dict(q_network.state_dict())
|
||||
target_network.eval()
|
||||
|
||||
print(f"网络参数量: {sum(p.numel() for p in q_network.parameters()):,}")
|
||||
|
||||
# 创建经验回放缓冲区
|
||||
if args.prioritized:
|
||||
print("使用优先经验回放")
|
||||
replay_buffer = PrioritizedReplayBuffer(
|
||||
args.buffer_size, state_shape, device
|
||||
)
|
||||
else:
|
||||
print("使用标准经验回放")
|
||||
replay_buffer = ReplayBuffer(
|
||||
args.buffer_size, state_shape, device
|
||||
)
|
||||
|
||||
# 创建智能体
|
||||
agent = DQNAgent(
|
||||
q_network=q_network,
|
||||
target_network=target_network,
|
||||
replay_buffer=replay_buffer,
|
||||
device=device,
|
||||
num_actions=num_actions,
|
||||
gamma=args.gamma,
|
||||
lr=args.lr,
|
||||
epsilon_start=args.epsilon_start,
|
||||
epsilon_end=args.epsilon_end,
|
||||
epsilon_decay_steps=args.epsilon_decay,
|
||||
target_update_freq=args.target_update,
|
||||
batch_size=args.batch_size,
|
||||
double_dqn=args.double_dqn,
|
||||
)
|
||||
|
||||
# 创建训练器
|
||||
trainer = DQNTrainer(
|
||||
agent=agent,
|
||||
env=env,
|
||||
eval_env=eval_env,
|
||||
log_dir=args.log_dir,
|
||||
save_dir=args.save_dir,
|
||||
eval_freq=args.eval_freq,
|
||||
save_freq=args.save_freq,
|
||||
num_eval_episodes=args.eval_episodes,
|
||||
warmup_steps=args.warmup,
|
||||
)
|
||||
|
||||
# 打印配置
|
||||
print("\n训练配置:")
|
||||
print(f" 总步数: {args.steps:,}")
|
||||
print(f" 学习率: {args.lr}")
|
||||
print(f" 折扣因子: {args.gamma}")
|
||||
print(f" 批次大小: {args.batch_size}")
|
||||
print(f" 缓冲区大小: {args.buffer_size:,}")
|
||||
print(f" ε衰减: {args.epsilon_start} -> {args.epsilon_end} ({args.epsilon_decay:,}步)")
|
||||
print(f" 目标网络更新: 每{args.target_update}步")
|
||||
print(f" Double DQN: {args.double_dqn}")
|
||||
print(f" 预热步数: {args.warmup:,}")
|
||||
print("=" * 60)
|
||||
|
||||
# 开始训练
|
||||
trainer.train(args.steps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user