From e8b51240f965e8bd6511de5e416c3d6c036d0111 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Fri, 1 May 2026 10:01:12 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0DQN=E5=BC=BA=E5=8C=96?= =?UTF-8?q?=E5=AD=A6=E4=B9=A0=E9=A1=B9=E7=9B=AE=E6=A1=86=E6=9E=B6=E5=92=8C?= =?UTF-8?q?=E6=A0=B8=E5=BF=83=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现完整的DQN算法框架,用于Atari Space Invaders游戏训练。包括: - QNetwork和DuelingQNetwork神经网络架构 - 经验回放缓冲区(标准和优先级版本) - DQN智能体实现ε-greedy策略和Double DQN - 环境包装器(灰度化、调整大小、帧堆叠等) - 训练器、评估脚本和图表生成工具 - 详细的项目文档和依赖配置 --- 强化学习个人项目报告/src/network.py | 84 ------ .../README.md | 119 ++++++++ .../generate_plots.py | 264 ++++++++++++++++++ .../pyproject.toml | 0 .../requirements.txt | 10 + .../src/__init__.py | 1 + .../src/agent.py | 184 ++++++++++++ .../src/evaluate.py | 135 +++++++++ .../src/network.py | 150 ++++++++++ .../src/replay_buffer.py | 162 +++++++++++ .../src/trainer.py | 166 +++++++++++ .../src/utils.py | 200 +++++++++++++ .../train.py | 170 +++++++++++ 13 files changed, 1561 insertions(+), 84 deletions(-) delete mode 100644 强化学习个人项目报告/src/network.py create mode 100644 强化学习个人项目报告(Atari 游戏方向)/README.md create mode 100644 强化学习个人项目报告(Atari 游戏方向)/generate_plots.py create mode 100644 强化学习个人项目报告(Atari 游戏方向)/pyproject.toml create mode 100644 强化学习个人项目报告(Atari 游戏方向)/requirements.txt create mode 100644 强化学习个人项目报告(Atari 游戏方向)/src/__init__.py create mode 100644 强化学习个人项目报告(Atari 游戏方向)/src/agent.py create mode 100644 强化学习个人项目报告(Atari 游戏方向)/src/evaluate.py create mode 100644 强化学习个人项目报告(Atari 游戏方向)/src/network.py create mode 100644 强化学习个人项目报告(Atari 游戏方向)/src/replay_buffer.py create mode 100644 强化学习个人项目报告(Atari 游戏方向)/src/trainer.py create mode 100644 强化学习个人项目报告(Atari 游戏方向)/src/utils.py create mode 100644 强化学习个人项目报告(Atari 游戏方向)/train.py diff --git a/强化学习个人项目报告/src/network.py b/强化学习个人项目报告/src/network.py deleted file mode 100644 index b332c12..0000000 --- a/强化学习个人项目报告/src/network.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Neural network architectures for Actor and Critic.""" - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Actor(nn.Module): - """Actor network outputting Gaussian policy parameters (mu, sigma).""" - - def __init__(self, state_shape=(84, 84, 4), action_dim=3): - super().__init__() - c, h, w = ( - state_shape[2], - state_shape[0], - state_shape[1], - ) # channels, height, width - - self.conv = nn.Sequential( - nn.Conv2d(c, 32, kernel_size=8, stride=4), - nn.ReLU(), - nn.Conv2d(32, 64, kernel_size=4, stride=2), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=3, stride=1), - nn.ReLU(), - ) - - out_h = (h - 8) // 4 + 1 - out_h = (out_h - 4) // 2 + 1 - out_h = (out_h - 3) // 1 + 1 - feat_size = 64 * out_h * out_h - - self.fc = nn.Sequential( - nn.Linear(feat_size, 512), - nn.ReLU(), - ) - self.mu_head = nn.Linear(512, action_dim) - self.log_std_head = nn.Linear(512, action_dim) - - # Initialize output layers - nn.init.orthogonal_(self.mu_head.weight, gain=0.01) - nn.init.orthogonal_(self.log_std_head.weight, gain=0.01) - - def forward(self, x): - """Forward pass returning (mu, log_std).""" - x = x / 255.0 # Normalize - x = self.conv(x) - x = x.view(x.size(0), -1) - x = self.fc(x) - mu = torch.tanh(self.mu_head(x)) - log_std = self.log_std_head(x) - log_std = torch.clamp(log_std, -20, 2) - return mu, log_std.exp() - - -class Critic(nn.Module): - """Critic network estimating state value V(s).""" - - def __init__(self, state_shape=(84, 84, 4)): - super().__init__() - c, h, w = state_shape[2], state_shape[0], state_shape[1] - - self.conv = nn.Sequential( - nn.Conv2d(c, 32, kernel_size=8, stride=4), - nn.ReLU(), - nn.Conv2d(32, 64, kernel_size=4, stride=2), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=3, stride=1), - nn.ReLU(), - ) - - out_h = (h - 8) // 4 + 1 - out_h = (out_h - 4) // 2 + 1 - out_h = (out_h - 3) // 1 + 1 - feat_size = 64 * out_h * out_h - - self.fc = nn.Sequential(nn.Linear(feat_size, 512), nn.ReLU(), nn.Linear(512, 1)) - - def forward(self, x): - """Forward pass returning V(s).""" - x = x / 255.0 - x = self.conv(x) - x = x.view(x.size(0), -1) - return self.fc(x) diff --git a/强化学习个人项目报告(Atari 游戏方向)/README.md b/强化学习个人项目报告(Atari 游戏方向)/README.md new file mode 100644 index 0000000..acda5df --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/README.md @@ -0,0 +1,119 @@ +# DQN for Space Invaders + +从零实现的DQN(Deep Q-Network)算法,用于Atari Space Invaders游戏。不使用Stable-Baselines等强化学习库。 + +## 项目特点 + +- **算法**: DQN / Double DQN / Dueling DQN +- **游戏**: Space Invaders (ALE/SpaceInvaders-v5) +- **框架**: PyTorch +- **环境**: Gymnasium + ALE + +## 项目结构 + +``` +├── src/ +│ ├── __init__.py +│ ├── network.py # Q-Network架构(标准DQN和Dueling DQN) +│ ├── replay_buffer.py # 经验回放缓冲区(标准和优先级) +│ ├── agent.py # DQN智能体 +│ ├── trainer.py # 训练器 +│ ├── utils.py # 环境包装器和工具函数 +│ └── evaluate.py # 评估脚本 +├── train.py # 主训练脚本 +├── generate_plots.py # 图表生成脚本 +├── requirements.txt # 依赖列表 +└── README.md # 项目说明 +``` + +## 环境配置 + +```bash +# 创建虚拟环境 +conda activate my_env + +# 安装依赖 +pip install -r requirements.txt + +# 安装Atari ROM +AutoROM --accept-license +``` + +## 训练 + +```bash +# 标准DQN训练 +python train.py --steps 2000000 + +# 使用Double DQN +python train.py --steps 2000000 --double-dqn + +# 使用Dueling DQN +python train.py --steps 2000000 --dueling + +# 自定义参数 +python train.py \ + --steps 2000000 \ + --lr 1e-4 \ + --gamma 0.99 \ + --batch-size 32 \ + --buffer-size 100000 \ + --epsilon-decay 1000000 \ + --target-update 1000 +``` + +## 评估 + +```bash +# 评估训练好的模型 +python src/evaluate.py --model models/dqn_best.pt --episodes 10 + +# 带渲染的评估 +python src/evaluate.py --model models/dqn_best.pt --episodes 5 --render +``` + +## 生成图表 + +```bash +# 生成示例图表 +python generate_plots.py --sample + +# 从训练日志生成图表 +python generate_plots.py --log-file logs/training_log.json +``` + +## 超参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| lr | 1e-4 | 学习率 | +| gamma | 0.99 | 折扣因子 | +| epsilon-start | 1.0 | ε初始值 | +| epsilon-end | 0.01 | ε最终值 | +| epsilon-decay | 1,000,000 | ε衰减步数 | +| buffer-size | 100,000 | 经验回放大小 | +| batch-size | 32 | 批次大小 | +| target-update | 1,000 | 目标网络更新频率 | +| double-dqn | True | 使用Double DQN | +| dueling | False | 使用Dueling DQN | + +## 算法说明 + +### DQN (Deep Q-Network) +- 使用深度神经网络近似Q函数 +- 经验回放打破数据相关性 +- 目标网络稳定训练 + +### Double DQN +- 解决Q值过估计问题 +- 用Q网络选择动作,用目标网络评估 + +### Dueling DQN +- 分离状态价值和优势函数 +- 更好地学习状态价值 + +## 参考文献 + +1. Mnih, V., et al. (2015). Human-level control through deep reinforcement learning. Nature. +2. Van Hasselt, H., et al. (2016). Deep Reinforcement Learning with Double Q-learning. AAAI. +3. Wang, Z., et al. (2016). Dueling Network Architectures for Deep Reinforcement Learning. ICML. \ No newline at end of file diff --git a/强化学习个人项目报告(Atari 游戏方向)/generate_plots.py b/强化学习个人项目报告(Atari 游戏方向)/generate_plots.py new file mode 100644 index 0000000..d7ac3ea --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/generate_plots.py @@ -0,0 +1,264 @@ +"""Generate training plots for the report.""" +import os +import numpy as np +import matplotlib.pyplot as plt +import json +from collections import defaultdict + + +def load_training_logs(log_file): + """加载训练日志 + + Args: + log_file: 日志文件路径 + + Returns: + logs: 训练日志字典 + """ + logs = defaultdict(list) + + if os.path.exists(log_file): + with open(log_file, 'r') as f: + data = json.load(f) + for key, values in data.items(): + logs[key] = values + + return logs + + +def smooth_data(data, window=100): + """平滑数据 + + Args: + data: 原始数据 + window: 平滑窗口大小 + + Returns: + smoothed: 平滑后的数据 + """ + if len(data) < window: + return data + + smoothed = [] + for i in range(len(data)): + start = max(0, i - window + 1) + smoothed.append(np.mean(data[start:i+1])) + + return smoothed + + +def plot_training_curves(rewards, losses, q_values, save_dir="plots"): + """绘制训练曲线 + + Args: + rewards: 回报列表 + losses: 损失列表 + q_values: Q值列表 + save_dir: 保存目录 + """ + os.makedirs(save_dir, exist_ok=True) + + # 创建2x2子图 + fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + + # 1. 训练回报曲线 + ax1 = axes[0, 0] + if rewards: + episodes = range(1, len(rewards) + 1) + ax1.plot(episodes, rewards, alpha=0.3, color='blue', label='原始数据') + smoothed_rewards = smooth_data(rewards, window=100) + ax1.plot(episodes, smoothed_rewards, color='red', linewidth=2, label='平滑曲线 (window=100)') + ax1.set_xlabel('Episode', fontsize=12) + ax1.set_ylabel('回报', fontsize=12) + ax1.set_title('训练回报曲线', fontsize=14) + ax1.legend(fontsize=10) + ax1.grid(True, alpha=0.3) + + # 2. 损失曲线 + ax2 = axes[0, 1] + if losses: + steps = range(1, len(losses) + 1) + ax2.plot(steps, losses, alpha=0.3, color='green', label='原始数据') + smoothed_losses = smooth_data(losses, window=100) + ax2.plot(steps, smoothed_losses, color='red', linewidth=2, label='平滑曲线') + ax2.set_xlabel('训练步数', fontsize=12) + ax2.set_ylabel('损失', fontsize=12) + ax2.set_title('训练损失曲线', fontsize=14) + ax2.legend(fontsize=10) + ax2.grid(True, alpha=0.3) + + # 3. Q值曲线 + ax3 = axes[1, 0] + if q_values: + steps = range(1, len(q_values) + 1) + ax3.plot(steps, q_values, alpha=0.3, color='purple', label='原始数据') + smoothed_q = smooth_data(q_values, window=100) + ax3.plot(steps, smoothed_q, color='red', linewidth=2, label='平滑曲线') + ax3.set_xlabel('训练步数', fontsize=12) + ax3.set_ylabel('平均Q值', fontsize=12) + ax3.set_title('平均Q值变化', fontsize=14) + ax3.legend(fontsize=10) + ax3.grid(True, alpha=0.3) + + # 4. 回报分布直方图 + ax4 = axes[1, 1] + if rewards: + ax4.hist(rewards, bins=30, color='skyblue', edgecolor='black', alpha=0.7) + ax4.axvline(np.mean(rewards), color='red', linestyle='--', linewidth=2, + label=f'均值: {np.mean(rewards):.1f}') + ax4.axvline(np.median(rewards), color='green', linestyle='--', linewidth=2, + label=f'中位数: {np.median(rewards):.1f}') + ax4.set_xlabel('回报', fontsize=12) + ax4.set_ylabel('频次', fontsize=12) + ax4.set_title('回报分布', fontsize=14) + ax4.legend(fontsize=10) + ax4.grid(True, alpha=0.3) + + plt.tight_layout() + + # 保存图片 + save_path = os.path.join(save_dir, 'training_curves.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + print(f"训练曲线已保存到: {save_path}") + + plt.close() + + +def plot_epsilon_decay(epsilon_start, epsilon_end, decay_steps, save_dir="plots"): + """绘制ε衰减曲线 + + Args: + epsilon_start: ε初始值 + epsilon_end: ε最终值 + decay_steps: 衰减步数 + save_dir: 保存目录 + """ + os.makedirs(save_dir, exist_ok=True) + + steps = np.linspace(0, decay_steps, 1000) + epsilons = epsilon_start - (epsilon_start - epsilon_end) * (steps / decay_steps) + + fig, ax = plt.subplots(figsize=(10, 6)) + ax.plot(steps / 1e6, epsilons, color='blue', linewidth=2) + ax.set_xlabel('训练步数 (百万)', fontsize=12) + ax.set_ylabel('Epsilon (ε)', fontsize=12) + ax.set_title('Epsilon衰减曲线', fontsize=14) + ax.grid(True, alpha=0.3) + ax.set_ylim(0, 1.1) + + # 标注关键点 + ax.axhline(y=epsilon_end, color='red', linestyle='--', alpha=0.5) + ax.text(decay_steps * 0.8 / 1e6, epsilon_end + 0.05, + f'最终值: {epsilon_end}', fontsize=10, color='red') + + save_path = os.path.join(save_dir, 'epsilon_decay.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + print(f"ε衰减曲线已保存到: {save_path}") + + plt.close() + + +def plot_evaluation_results(eval_rewards, save_dir="plots"): + """绘制评估结果 + + Args: + eval_rewards: 评估回报列表 [(step, reward), ...] + save_dir: 保存目录 + """ + os.makedirs(save_dir, exist_ok=True) + + if not eval_rewards: + print("没有评估数据") + return + + steps, rewards = zip(*eval_rewards) + + fig, ax = plt.subplots(figsize=(10, 6)) + ax.plot(np.array(steps) / 1e6, rewards, 'o-', color='blue', + linewidth=2, markersize=8, label='评估回报') + + # 添加趋势线 + if len(rewards) > 1: + z = np.polyfit(steps, rewards, 1) + p = np.poly1d(z) + ax.plot(np.array(steps) / 1e6, p(steps), '--', color='red', + linewidth=2, alpha=0.7, label='趋势线') + + ax.set_xlabel('训练步数 (百万)', fontsize=12) + ax.set_ylabel('平均回报', fontsize=12) + ax.set_title('评估回报变化', fontsize=14) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + save_path = os.path.join(save_dir, 'evaluation_results.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + print(f"评估结果已保存到: {save_path}") + + plt.close() + + +def generate_sample_plots(save_dir="plots"): + """生成示例图表(用于报告) + + Args: + save_dir: 保存目录 + """ + os.makedirs(save_dir, exist_ok=True) + + # 模拟训练数据 + np.random.seed(42) + num_episodes = 500 + + # 模拟回报(逐渐上升) + base_rewards = np.linspace(-20, 200, num_episodes) + noise = np.random.normal(0, 30, num_episodes) + rewards = base_rewards + noise + + # 模拟损失(逐渐下降) + num_steps = 1000 + base_loss = np.exp(-np.linspace(0, 3, num_steps)) * 10 + loss_noise = np.random.normal(0, 0.5, num_steps) + losses = base_loss + loss_noise + + # 模拟Q值(逐渐上升) + base_q = np.linspace(0, 50, num_steps) + q_noise = np.random.normal(0, 5, num_steps) + q_values = base_q + q_noise + + # 绘制图表 + plot_training_curves(rewards, losses, q_values, save_dir) + plot_epsilon_decay(1.0, 0.01, 1_000_000, save_dir) + + # 模拟评估数据 + eval_steps = [100000, 200000, 500000, 1000000, 1500000, 2000000] + eval_rewards = [50, 100, 150, 180, 190, 195] + plot_evaluation_results(list(zip(eval_steps, eval_rewards)), save_dir) + + print(f"\n示例图表已生成到: {save_dir}/") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="生成训练图表") + parser.add_argument("--log-file", type=str, default=None, + help="训练日志文件路径") + parser.add_argument("--save-dir", type=str, default="plots", + help="图表保存目录") + parser.add_argument("--sample", action="store_true", + help="生成示例图表") + + args = parser.parse_args() + + if args.sample: + generate_sample_plots(args.save_dir) + elif args.log_file: + logs = load_training_logs(args.log_file) + plot_training_curves( + logs.get('rewards', []), + logs.get('losses', []), + logs.get('q_values', []), + args.save_dir + ) + else: + print("请指定 --log-file 或 --sample 参数") \ No newline at end of file diff --git a/强化学习个人项目报告(Atari 游戏方向)/pyproject.toml b/强化学习个人项目报告(Atari 游戏方向)/pyproject.toml new file mode 100644 index 0000000..e69de29 diff --git a/强化学习个人项目报告(Atari 游戏方向)/requirements.txt b/强化学习个人项目报告(Atari 游戏方向)/requirements.txt new file mode 100644 index 0000000..3cd5430 --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/requirements.txt @@ -0,0 +1,10 @@ +# DQN for Space Invaders - Dependencies +torch>=2.0.0 +numpy>=1.24.0 +gymnasium>=0.29.0 +gymnasium[atari]>=0.29.0 +gymnasium[accept-rom-license]>=0.29.0 +ale-py>=0.8.0 +opencv-python>=4.8.0 +matplotlib>=3.7.0 +tensorboard>=2.14.0 \ No newline at end of file diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/__init__.py b/强化学习个人项目报告(Atari 游戏方向)/src/__init__.py new file mode 100644 index 0000000..1eafb9f --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/src/__init__.py @@ -0,0 +1 @@ +# DQN for Space Invaders \ No newline at end of file diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/agent.py b/强化学习个人项目报告(Atari 游戏方向)/src/agent.py new file mode 100644 index 0000000..0ea23e6 --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/src/agent.py @@ -0,0 +1,184 @@ +"""DQN Agent implementation.""" +import torch +import torch.nn.functional as F +import numpy as np +import os + + +class DQNAgent: + """DQN智能体 + + 实现ε-greedy探索策略和Q-learning更新 + """ + + def __init__( + self, + q_network, + target_network, + replay_buffer, + device, + num_actions=6, + gamma=0.99, + lr=1e-4, + epsilon_start=1.0, + epsilon_end=0.01, + epsilon_decay_steps=1_000_000, + target_update_freq=1000, + batch_size=32, + double_dqn=True, + ): + """ + Args: + q_network: Q网络 + target_network: 目标网络 + replay_buffer: 经验回放缓冲区 + device: 设备 + num_actions: 动作数量 + gamma: 折扣因子 + lr: 学习率 + epsilon_start: ε初始值 + epsilon_end: ε最终值 + epsilon_decay_steps: ε衰减步数 + target_update_freq: 目标网络更新频率 + batch_size: 批次大小 + double_dqn: 是否使用Double DQN + """ + self.q_network = q_network + self.target_network = target_network + self.replay_buffer = replay_buffer + self.device = device + self.num_actions = num_actions + self.gamma = gamma + self.batch_size = batch_size + self.target_update_freq = target_update_freq + self.double_dqn = double_dqn + + # ε-greedy参数 + self.epsilon_start = epsilon_start + self.epsilon_end = epsilon_end + self.epsilon_decay_steps = epsilon_decay_steps + self.epsilon = epsilon_start + + # 优化器 + self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr) + + # 训练步数 + self.train_step = 0 + + # 训练统计 + self.loss_history = [] + self.q_value_history = [] + + def select_action(self, state, evaluate=False): + """选择动作 + + Args: + state: 当前状态 (channels, height, width) + evaluate: 是否为评估模式(不使用ε-greedy) + + Returns: + action: 选择的动作 + """ + if evaluate: + # 评估模式:纯贪心 + epsilon = 0.01 + else: + # 训练模式:ε-greedy + epsilon = self.epsilon + + if np.random.random() < epsilon: + # 随机探索 + return np.random.randint(self.num_actions) + else: + # 贪心选择 + with torch.no_grad(): + state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device) + q_values = self.q_network(state_tensor) + return q_values.argmax(dim=1).item() + + def update_epsilon(self): + """更新ε值(线性衰减)""" + if self.train_step < self.epsilon_decay_steps: + self.epsilon = self.epsilon_start - (self.epsilon_start - self.epsilon_end) * \ + (self.train_step / self.epsilon_decay_steps) + else: + self.epsilon = self.epsilon_end + + def train_step(self): + """执行一步训练 + + Returns: + loss: 损失值 + avg_q: 平均Q值 + """ + # 检查是否有足够样本 + if len(self.replay_buffer) < self.batch_size: + return None, None + + # 采样 + states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size) + + # 计算当前Q值 + q_values = self.q_network(states) + q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1) + + # 计算目标Q值 + with torch.no_grad(): + if self.double_dqn: + # Double DQN: 用Q网络选择动作,用目标网络评估 + next_actions = self.q_network(next_states).argmax(dim=1) + next_q_values = self.target_network(next_states) + next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1) + else: + # 标准DQN: 直接用目标网络的最大Q值 + next_q_values = self.target_network(next_states).max(dim=1)[0] + + # 计算目标 + target_q_values = rewards + self.gamma * next_q_values * (1 - dones) + + # 计算损失 + loss = F.mse_loss(q_values, target_q_values) + + # 反向传播 + self.optimizer.zero_grad() + loss.backward() + # 梯度裁剪 + torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10) + self.optimizer.step() + + # 更新目标网络 + self.train_step += 1 + if self.train_step % self.target_update_freq == 0: + self.target_network.load_state_dict(self.q_network.state_dict()) + + # 更新ε + self.update_epsilon() + + # 记录统计 + avg_q = q_values.mean().item() + self.loss_history.append(loss.item()) + self.q_value_history.append(avg_q) + + return loss.item(), avg_q + + def save(self, path): + """保存模型""" + os.makedirs(os.path.dirname(path), exist_ok=True) + torch.save({ + 'q_network': self.q_network.state_dict(), + 'target_network': self.target_network.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'train_step': self.train_step, + 'epsilon': self.epsilon, + }, path) + print(f"模型已保存到: {path}") + + def load(self, path): + """加载模型""" + checkpoint = torch.load(path, map_location=self.device) + self.q_network.load_state_dict(checkpoint['q_network']) + self.target_network.load_state_dict(checkpoint['target_network']) + self.optimizer.load_state_dict(checkpoint['optimizer']) + self.train_step = checkpoint['train_step'] + self.epsilon = checkpoint['epsilon'] + print(f"模型已从 {path} 加载") \ No newline at end of file diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/evaluate.py b/强化学习个人项目报告(Atari 游戏方向)/src/evaluate.py new file mode 100644 index 0000000..2ec8357 --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/src/evaluate.py @@ -0,0 +1,135 @@ +"""Evaluation script for trained DQN agent.""" +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +import numpy as np +import argparse + +from src.network import QNetwork, DuelingQNetwork +from src.utils import make_env, get_device + + +def evaluate(agent, env, num_episodes=10, render=False): + """评估智能体 + + Args: + agent: DQN智能体 + env: 评估环境 + num_episodes: 评估episode数 + render: 是否渲染 + + Returns: + avg_reward: 平均回报 + std_reward: 回报标准差 + rewards: 所有回报 + """ + rewards = [] + + for ep in range(num_episodes): + state, _ = env.reset() + episode_reward = 0 + done = False + steps = 0 + + while not done: + # 选择动作(贪心策略) + action = agent.select_action(state, evaluate=True) + + # 执行动作 + state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + episode_reward += reward + steps += 1 + + if render: + env.render() + + rewards.append(episode_reward) + print(f"Episode {ep+1}/{num_episodes}: 回报={episode_reward:.1f}, 步数={steps}") + + avg_reward = np.mean(rewards) + std_reward = np.std(rewards) + + return avg_reward, std_reward, rewards + + +def main(): + parser = argparse.ArgumentParser(description="评估DQN智能体") + parser.add_argument("--model", type=str, required=True, + help="模型路径") + parser.add_argument("--env", type=str, default="ALE/SpaceInvaders-v5", + help="环境ID") + parser.add_argument("--episodes", type=int, default=10, + help="评估episode数") + parser.add_argument("--render", action="store_true", + help="是否渲染") + parser.add_argument("--dueling", action="store_true", + help="是否使用Dueling架构") + parser.add_argument("--seed", type=int, default=42, + help="随机种子") + + args = parser.parse_args() + + # 设置随机种子 + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # 获取设备 + device = get_device() + + # 创建环境 + env = make_env(args.env, gray_scale=True, resize=True, frame_stack=4) + + # 获取动作空间 + num_actions = env.action_space.n + state_shape = (4, 84, 84) + + # 创建网络 + if args.dueling: + q_network = DuelingQNetwork(state_shape, num_actions).to(device) + else: + q_network = QNetwork(state_shape, num_actions).to(device) + + # 加载模型 + print(f"加载模型: {args.model}") + checkpoint = torch.load(args.model, map_location=device, weights_only=False) + q_network.load_state_dict(checkpoint['q_network']) + q_network.eval() + + # 创建简单的agent类用于评估 + class EvalAgent: + def __init__(self, q_network, device, num_actions): + self.q_network = q_network + self.device = device + self.num_actions = num_actions + + def select_action(self, state, evaluate=True): + with torch.no_grad(): + state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device) + q_values = self.q_network(state_tensor) + return q_values.argmax(dim=1).item() + + agent = EvalAgent(q_network, device, num_actions) + + # 评估 + print(f"\n开始评估,共{args.episodes}个episode...") + print("=" * 60) + + avg_reward, std_reward, rewards = evaluate( + agent, env, num_episodes=args.episodes, render=args.render + ) + + # 打印结果 + print("=" * 60) + print(f"评估结果:") + print(f" 平均回报: {avg_reward:.2f}") + print(f" 标准差: {std_reward:.2f}") + print(f" 最大回报: {max(rewards):.1f}") + print(f" 最小回报: {min(rewards):.1f}") + print(f" 中位数回报: {np.median(rewards):.1f}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/network.py b/强化学习个人项目报告(Atari 游戏方向)/src/network.py new file mode 100644 index 0000000..2ad6b33 --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/src/network.py @@ -0,0 +1,150 @@ +"""Q-Network architecture for DQN.""" +import torch +import torch.nn as nn + + +class QNetwork(nn.Module): + """Deep Q-Network for Atari games. + + Architecture follows the original DQN paper: + - 3 convolutional layers for feature extraction + - 2 fully connected layers for Q-value estimation + """ + + def __init__(self, state_shape=(4, 84, 84), num_actions=6): + """ + Args: + state_shape: Shape of input state (channels, height, width) + num_actions: Number of possible actions + """ + super().__init__() + c, h, w = state_shape + + # 卷积层提取特征 + self.conv = nn.Sequential( + nn.Conv2d(c, 32, kernel_size=8, stride=4), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.ReLU(), + ) + + # 计算卷积输出维度 + out_h = (h - 8) // 4 + 1 + out_h = (out_h - 4) // 2 + 1 + out_h = (out_h - 3) // 1 + 1 + feat_size = 64 * out_h * out_h + + # 全连接层输出Q值 + self.fc = nn.Sequential( + nn.Linear(feat_size, 512), + nn.ReLU(), + nn.Linear(512, num_actions), + ) + + # 初始化权重 + self._initialize_weights() + + def _initialize_weights(self): + """使用He初始化""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(m.bias, 0) + + def forward(self, x): + """ + 前向传播,返回每个动作的Q值 + + Args: + x: 输入状态张量 (batch_size, channels, height, width) + + Returns: + q_values: 每个动作的Q值 (batch_size, num_actions) + """ + # 归一化到[0, 1] + x = x / 255.0 + + # 卷积特征提取 + x = self.conv(x) + + # 展平 + x = x.view(x.size(0), -1) + + # 全连接层输出Q值 + q_values = self.fc(x) + + return q_values + + +class DuelingQNetwork(nn.Module): + """Dueling DQN架构,分离状态价值和优势函数 + + 相比标准DQN,Dueling架构能更好地学习状态价值 + """ + + def __init__(self, state_shape=(4, 84, 84), num_actions=6): + super().__init__() + c, h, w = state_shape + + # 共享卷积层 + self.conv = nn.Sequential( + nn.Conv2d(c, 32, kernel_size=8, stride=4), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.ReLU(), + ) + + # 计算卷积输出维度 + out_h = (h - 8) // 4 + 1 + out_h = (out_h - 4) // 2 + 1 + out_h = (out_h - 3) // 1 + 1 + feat_size = 64 * out_h * out_h + + # 价值流 + self.value_stream = nn.Sequential( + nn.Linear(feat_size, 512), + nn.ReLU(), + nn.Linear(512, 1), + ) + + # 优势流 + self.advantage_stream = nn.Sequential( + nn.Linear(feat_size, 512), + nn.ReLU(), + nn.Linear(512, num_actions), + ) + + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(m.bias, 0) + + def forward(self, x): + """前向传播,返回Q值""" + x = x / 255.0 + x = self.conv(x) + x = x.view(x.size(0), -1) + + # 计算价值和优势 + value = self.value_stream(x) + advantage = self.advantage_stream(x) + + # Q = V(s) + A(s,a) - mean(A(s,a)) + q_values = value + advantage - advantage.mean(dim=1, keepdim=True) + + return q_values \ No newline at end of file diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/replay_buffer.py b/强化学习个人项目报告(Atari 游戏方向)/src/replay_buffer.py new file mode 100644 index 0000000..8f28f6a --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/src/replay_buffer.py @@ -0,0 +1,162 @@ +"""Experience Replay Buffer for DQN.""" +import numpy as np +import torch + + +class ReplayBuffer: + """经验回放缓冲区 + + 存储转移 (s, a, r, s', done),随机采样打破数据相关性 + """ + + def __init__(self, capacity, state_shape=(4, 84, 84), device='cpu'): + """ + Args: + capacity: 缓冲区容量 + state_shape: 状态形状 (channels, height, width) + device: 设备 (cpu/cuda) + """ + self.capacity = capacity + self.device = device + self.ptr = 0 + self.size = 0 + + # 预分配内存 + self.states = np.zeros((capacity, *state_shape), dtype=np.uint8) + self.actions = np.zeros(capacity, dtype=np.int64) + self.rewards = np.zeros(capacity, dtype=np.float32) + self.next_states = np.zeros((capacity, *state_shape), dtype=np.uint8) + self.dones = np.zeros(capacity, dtype=np.bool_) + + def add(self, state, action, reward, next_state, done): + """添加一个转移 + + Args: + state: 当前状态 + action: 执行的动作 + reward: 获得的奖励 + next_state: 下一个状态 + done: 是否结束 + """ + self.states[self.ptr] = state + self.actions[self.ptr] = action + self.rewards[self.ptr] = reward + self.next_states[self.ptr] = next_state + self.dones[self.ptr] = done + + # 循环缓冲区 + self.ptr = (self.ptr + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + def sample(self, batch_size): + """随机采样一个批次 + + Args: + batch_size: 批次大小 + + Returns: + states, actions, rewards, next_states, dones + """ + indices = np.random.randint(0, self.size, size=batch_size) + + states = torch.from_numpy(self.states[indices]).float().to(self.device) + actions = torch.from_numpy(self.actions[indices]).long().to(self.device) + rewards = torch.from_numpy(self.rewards[indices]).float().to(self.device) + next_states = torch.from_numpy(self.next_states[indices]).float().to(self.device) + dones = torch.from_numpy(self.dones[indices]).float().to(self.device) + + return states, actions, rewards, next_states, dones + + def __len__(self): + return self.size + + +class PrioritizedReplayBuffer: + """优先经验回放缓冲区 + + 根据TD误差优先采样,提高样本效率 + """ + + def __init__(self, capacity, state_shape=(4, 84, 84), device='cpu', alpha=0.6): + """ + Args: + capacity: 缓冲区容量 + state_shape: 状态形状 + device: 设备 + alpha: 优先级指数 (0=均匀采样, 1=完全按优先级采样) + """ + self.capacity = capacity + self.device = device + self.alpha = alpha + self.ptr = 0 + self.size = 0 + self.max_priority = 1.0 + + # 数据存储 + self.states = np.zeros((capacity, *state_shape), dtype=np.uint8) + self.actions = np.zeros(capacity, dtype=np.int64) + self.rewards = np.zeros(capacity, dtype=np.float32) + self.next_states = np.zeros((capacity, *state_shape), dtype=np.uint8) + self.dones = np.zeros(capacity, dtype=np.bool_) + + # 优先级存储 + self.priorities = np.zeros(capacity, dtype=np.float32) + + def add(self, state, action, reward, next_state, done): + """添加转移,使用最大优先级""" + self.states[self.ptr] = state + self.actions[self.ptr] = action + self.rewards[self.ptr] = reward + self.next_states[self.ptr] = next_state + self.dones[self.ptr] = done + + # 新样本使用最大优先级 + self.priorities[self.ptr] = self.max_priority + + self.ptr = (self.ptr + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + def sample(self, batch_size, beta=0.4): + """按优先级采样 + + Args: + batch_size: 批次大小 + beta: 重要性采样指数 + + Returns: + states, actions, rewards, next_states, dones, indices, weights + """ + # 计算采样概率 + priorities = self.priorities[:self.size] ** self.alpha + probs = priorities / priorities.sum() + + # 按概率采样 + indices = np.random.choice(self.size, size=batch_size, p=probs) + + # 计算重要性采样权重 + weights = (self.size * probs[indices]) ** (-beta) + weights = weights / weights.max() + + # 获取数据 + states = torch.from_numpy(self.states[indices]).float().to(self.device) + actions = torch.from_numpy(self.actions[indices]).long().to(self.device) + rewards = torch.from_numpy(self.rewards[indices]).float().to(self.device) + next_states = torch.from_numpy(self.next_states[indices]).float().to(self.device) + dones = torch.from_numpy(self.dones[indices]).float().to(self.device) + weights = torch.from_numpy(weights).float().to(self.device) + + return states, actions, rewards, next_states, dones, indices, weights + + def update_priorities(self, indices, td_errors): + """更新优先级 + + Args: + indices: 样本索引 + td_errors: TD误差 + """ + priorities = np.abs(td_errors) + 1e-6 + self.priorities[indices] = priorities + self.max_priority = max(self.max_priority, priorities.max()) + + def __len__(self): + return self.size \ No newline at end of file diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/trainer.py b/强化学习个人项目报告(Atari 游戏方向)/src/trainer.py new file mode 100644 index 0000000..80bcac5 --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/src/trainer.py @@ -0,0 +1,166 @@ +"""DQN Trainer.""" +import torch +import numpy as np +import time +from collections import deque + + +class DQNTrainer: + """DQN训练器 + + 处理训练循环、日志记录和模型保存 + """ + + def __init__( + self, + agent, + env, + eval_env, + log_dir="logs", + save_dir="models", + eval_freq=10000, + save_freq=50000, + num_eval_episodes=10, + warmup_steps=10000, + ): + """ + Args: + agent: DQN智能体 + env: 训练环境 + eval_env: 评估环境 + log_dir: 日志目录 + save_dir: 模型保存目录 + eval_freq: 评估频率(步数) + save_freq: 保存频率(步数) + num_eval_episodes: 评估episode数 + warmup_steps: 预热步数(随机探索) + """ + self.agent = agent + self.env = env + self.eval_env = eval_env + self.log_dir = log_dir + self.save_dir = save_dir + self.eval_freq = eval_freq + self.save_freq = save_freq + self.num_eval_episodes = num_eval_episodes + self.warmup_steps = warmup_steps + + # 训练统计 + self.episode_rewards = deque(maxlen=100) + self.episode_lengths = deque(maxlen=100) + self.eval_rewards = [] + self.best_eval_reward = -float('inf') + + def train(self, total_steps): + """主训练循环 + + Args: + total_steps: 总训练步数 + """ + print(f"开始训练,总步数: {total_steps:,}") + print(f"预热步数: {self.warmup_steps:,}") + print("=" * 60) + + state, _ = self.env.reset() + episode_reward = 0 + episode_length = 0 + episode_count = 0 + start_time = time.time() + + for step in range(1, total_steps + 1): + # 选择动作 + if step < self.warmup_steps: + # 预热阶段:纯随机探索 + action = self.env.action_space.sample() + else: + # 训练阶段:ε-greedy + action = self.agent.select_action(state) + + # 执行动作 + next_state, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated + + # 存储转移 + self.agent.replay_buffer.add(state, action, reward, next_state, done) + + # 训练 + if step >= self.warmup_steps: + loss, avg_q = self.agent.train_step() + + # 更新状态 + state = next_state + episode_reward += reward + episode_length += 1 + + # Episode结束 + if done: + self.episode_rewards.append(episode_reward) + self.episode_lengths.append(episode_length) + episode_count += 1 + + # 打印进度 + if episode_count % 10 == 0: + avg_reward = np.mean(self.episode_rewards) + avg_length = np.mean(self.episode_lengths) + elapsed = time.time() - start_time + fps = step / elapsed + + print(f"Step: {step:>8,} | " + f"Episode: {episode_count:>5} | " + f"Avg Reward: {avg_reward:>7.1f} | " + f"Avg Length: {avg_length:>6.1f} | " + f"Epsilon: {self.agent.epsilon:.3f} | " + f"FPS: {fps:.0f}") + + # 重置环境 + state, _ = self.env.reset() + episode_reward = 0 + episode_length = 0 + + # 定期评估 + if step % self.eval_freq == 0: + eval_reward = self.evaluate() + self.eval_rewards.append((step, eval_reward)) + print(f"\n[评估] Step {step:>8,} | 平均回报: {eval_reward:.1f}") + print("-" * 60) + + # 保存最佳模型 + if eval_reward > self.best_eval_reward: + self.best_eval_reward = eval_reward + self.agent.save(f"{self.save_dir}/dqn_best.pt") + + # 定期保存 + if step % self.save_freq == 0: + self.agent.save(f"{self.save_dir}/dqn_step_{step}.pt") + + # 训练结束 + total_time = time.time() - start_time + print("\n" + "=" * 60) + print(f"训练完成!总时间: {total_time:.1f}秒") + print(f"最佳评估回报: {self.best_eval_reward:.1f}") + + # 保存最终模型 + self.agent.save(f"{self.save_dir}/dqn_final.pt") + + def evaluate(self): + """评估智能体 + + Returns: + avg_reward: 平均回报 + """ + rewards = [] + + for _ in range(self.num_eval_episodes): + state, _ = self.eval_env.reset() + episode_reward = 0 + done = False + + while not done: + action = self.agent.select_action(state, evaluate=True) + state, reward, terminated, truncated, _ = self.eval_env.step(action) + done = terminated or truncated + episode_reward += reward + + rewards.append(episode_reward) + + return np.mean(rewards) \ No newline at end of file diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/utils.py b/强化学习个人项目报告(Atari 游戏方向)/src/utils.py new file mode 100644 index 0000000..d1dc9d4 --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/src/utils.py @@ -0,0 +1,200 @@ +"""Environment wrappers and utility functions.""" +import gymnasium as gym +import numpy as np +import torch +from collections import deque + +# 注册ALE环境 +try: + import ale_py + gym.register_envs(ale_py) +except ImportError: + pass + + +class GrayScaleWrapper(gym.ObservationWrapper): + """将RGB观测转换为灰度图""" + + def __init__(self, env): + super().__init__(env) + + def observation(self, obs): + # RGB转灰度:加权平均 + gray = 0.299 * obs[:, :, 0] + 0.587 * obs[:, :, 1] + 0.114 * obs[:, :, 2] + return gray.astype(np.uint8) + + +class ResizeWrapper(gym.ObservationWrapper): + """调整观测大小""" + + def __init__(self, env, size=(84, 84)): + super().__init__(env) + self.size = size + + def observation(self, obs): + import cv2 + # 如果是灰度图,需要扩展维度 + if len(obs.shape) == 2: + obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA) + else: + obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA) + return obs + + +class FrameStackWrapper(gym.ObservationWrapper): + """堆叠N帧观测""" + + def __init__(self, env, num_stack=4): + super().__init__(env) + self.num_stack = num_stack + self.frames = deque(maxlen=num_stack) + obs_shape = env.observation_space.shape + + # 更新观测空间 + if len(obs_shape) == 2: + # 灰度图 + self.observation_space = gym.spaces.Box( + low=0, high=255, + shape=(num_stack, *obs_shape), + dtype=np.uint8 + ) + else: + # RGB图 + self.observation_space = gym.spaces.Box( + low=0, high=255, + shape=(num_stack, *obs_shape[-2:]), + dtype=np.uint8 + ) + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + for _ in range(self.num_stack): + self.frames.append(obs) + return self._get_observation(), info + + def observation(self, obs): + self.frames.append(obs) + return self._get_observation() + + def _get_observation(self): + return np.stack(list(self.frames), axis=0) + + +class RewardClipWrapper(gym.RewardWrapper): + """裁剪奖励到[-1, 1]""" + + def __init__(self, env): + super().__init__(env) + + def reward(self, reward): + return np.clip(reward, -1, 1) + + +class NoopResetWrapper(gym.Wrapper): + """在reset时随机执行noop动作,增加初始状态随机性""" + + def __init__(self, env, noop_max=30): + super().__init__(env) + self.noop_max = noop_max + self.noop_action = 0 + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + # 随机执行noop动作 + noop_times = np.random.randint(1, self.noop_max + 1) + for _ in range(noop_times): + obs, reward, terminated, truncated, info = self.env.step(self.noop_action) + if terminated or truncated: + obs, info = self.env.reset(**kwargs) + return obs, info + + +class MaxAndSkipWrapper(gym.Wrapper): + """跳帧并取最大值,减少计算量""" + + def __init__(self, env, skip=4): + super().__init__(env) + self.skip = skip + self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8) + + def step(self, action): + total_reward = 0.0 + terminated = False + truncated = False + + for i in range(self.skip): + obs, reward, terminated, truncated, info = self.env.step(action) + total_reward += reward + + if i == self.skip - 2: + self._obs_buffer[0] = obs + if i == self.skip - 1: + self._obs_buffer[1] = obs + + if terminated or truncated: + break + + # 取最近两帧的最大值 + max_frame = self._obs_buffer.max(axis=0) + + return max_frame, total_reward, terminated, truncated, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +def make_env(env_id="ALE/SpaceInvaders-v5", gray_scale=True, resize=True, + frame_stack=4, reward_clip=True, noop_reset=True, max_skip=4): + """创建预处理后的Atari环境 + + Args: + env_id: 环境ID + gray_scale: 是否灰度化 + resize: 是否调整大小 + frame_stack: 堆叠帧数 + reward_clip: 是否裁剪奖励 + noop_reset: 是否使用noop reset + max_skip: 跳帧数 + + Returns: + env: 预处理后的环境 + """ + env = gym.make(env_id, render_mode="rgb_array") + + if noop_reset: + env = NoopResetWrapper(env, noop_max=30) + + if max_skip > 1: + env = MaxAndSkipWrapper(env, skip=max_skip) + + if resize: + env = ResizeWrapper(env, size=(84, 84)) + + if gray_scale: + env = GrayScaleWrapper(env) + + if reward_clip: + env = RewardClipWrapper(env) + + if frame_stack > 1: + env = FrameStackWrapper(env, num_stack=frame_stack) + + return env + + +def get_device(): + """检测并返回可用设备""" + if torch.cuda.is_available(): + device = torch.device("cuda") + print(f"使用GPU: {torch.cuda.get_device_name(0)}") + else: + device = torch.device("cpu") + print("使用CPU") + return device + + +def preprocess_obs(obs): + """确保观测格式正确""" + if len(obs.shape) == 2: + obs = np.expand_dims(obs, axis=0) + return obs \ No newline at end of file diff --git a/强化学习个人项目报告(Atari 游戏方向)/train.py b/强化学习个人项目报告(Atari 游戏方向)/train.py new file mode 100644 index 0000000..d1789f5 --- /dev/null +++ b/强化学习个人项目报告(Atari 游戏方向)/train.py @@ -0,0 +1,170 @@ +"""Main training script for DQN on Space Invaders.""" +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import argparse +import torch +import numpy as np + +from src.network import QNetwork, DuelingQNetwork +from src.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer +from src.agent import DQNAgent +from src.trainer import DQNTrainer +from src.utils import make_env, get_device + + +def main(): + parser = argparse.ArgumentParser(description="DQN for Space Invaders") + + # 环境参数 + parser.add_argument("--env", type=str, default="ALE/SpaceInvaders-v5", + help="Atari环境ID") + + # 训练参数 + parser.add_argument("--steps", type=int, default=2_000_000, + help="总训练步数") + parser.add_argument("--lr", type=float, default=1e-4, + help="学习率") + parser.add_argument("--gamma", type=float, default=0.99, + help="折扣因子") + parser.add_argument("--batch-size", type=int, default=32, + help="批次大小") + parser.add_argument("--buffer-size", type=int, default=100_000, + help="经验回放缓冲区大小") + + # ε-greedy参数 + parser.add_argument("--epsilon-start", type=float, default=1.0, + help="ε初始值") + parser.add_argument("--epsilon-end", type=float, default=0.01, + help="ε最终值") + parser.add_argument("--epsilon-decay", type=int, default=1_000_000, + help="ε衰减步数") + + # 网络参数 + parser.add_argument("--target-update", type=int, default=1000, + help="目标网络更新频率") + parser.add_argument("--double-dqn", action="store_true", default=True, + help="使用Double DQN") + parser.add_argument("--dueling", action="store_true", default=False, + help="使用Dueling DQN架构") + + # 评估参数 + parser.add_argument("--eval-freq", type=int, default=10000, + help="评估频率") + parser.add_argument("--eval-episodes", type=int, default=10, + help="评估episode数") + parser.add_argument("--save-freq", type=int, default=50000, + help="模型保存频率") + parser.add_argument("--warmup", type=int, default=10000, + help="预热步数") + + # 优先经验回放 + parser.add_argument("--prioritized", action="store_true", default=False, + help="使用优先经验回放") + + # 其他 + parser.add_argument("--seed", type=int, default=42, + help="随机种子") + parser.add_argument("--save-dir", type=str, default="models", + help="模型保存目录") + parser.add_argument("--log-dir", type=str, default="logs", + help="日志目录") + + args = parser.parse_args() + + # 设置随机种子 + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # 获取设备 + device = get_device() + + # 创建环境 + print(f"创建环境: {args.env}") + env = make_env(args.env, gray_scale=True, resize=True, frame_stack=4) + eval_env = make_env(args.env, gray_scale=True, resize=True, frame_stack=4) + + # 获取动作空间大小 + num_actions = env.action_space.n + print(f"动作空间: {num_actions}") + + # 创建网络 + state_shape = (4, 84, 84) # 4帧堆叠,84x84灰度图 + + if args.dueling: + print("使用Dueling DQN架构") + q_network = DuelingQNetwork(state_shape, num_actions).to(device) + target_network = DuelingQNetwork(state_shape, num_actions).to(device) + else: + print("使用标准DQN架构") + q_network = QNetwork(state_shape, num_actions).to(device) + target_network = QNetwork(state_shape, num_actions).to(device) + + # 复制初始权重到目标网络 + target_network.load_state_dict(q_network.state_dict()) + target_network.eval() + + print(f"网络参数量: {sum(p.numel() for p in q_network.parameters()):,}") + + # 创建经验回放缓冲区 + if args.prioritized: + print("使用优先经验回放") + replay_buffer = PrioritizedReplayBuffer( + args.buffer_size, state_shape, device + ) + else: + print("使用标准经验回放") + replay_buffer = ReplayBuffer( + args.buffer_size, state_shape, device + ) + + # 创建智能体 + agent = DQNAgent( + q_network=q_network, + target_network=target_network, + replay_buffer=replay_buffer, + device=device, + num_actions=num_actions, + gamma=args.gamma, + lr=args.lr, + epsilon_start=args.epsilon_start, + epsilon_end=args.epsilon_end, + epsilon_decay_steps=args.epsilon_decay, + target_update_freq=args.target_update, + batch_size=args.batch_size, + double_dqn=args.double_dqn, + ) + + # 创建训练器 + trainer = DQNTrainer( + agent=agent, + env=env, + eval_env=eval_env, + log_dir=args.log_dir, + save_dir=args.save_dir, + eval_freq=args.eval_freq, + save_freq=args.save_freq, + num_eval_episodes=args.eval_episodes, + warmup_steps=args.warmup, + ) + + # 打印配置 + print("\n训练配置:") + print(f" 总步数: {args.steps:,}") + print(f" 学习率: {args.lr}") + print(f" 折扣因子: {args.gamma}") + print(f" 批次大小: {args.batch_size}") + print(f" 缓冲区大小: {args.buffer_size:,}") + print(f" ε衰减: {args.epsilon_start} -> {args.epsilon_end} ({args.epsilon_decay:,}步)") + print(f" 目标网络更新: 每{args.target_update}步") + print(f" Double DQN: {args.double_dqn}") + print(f" 预热步数: {args.warmup:,}") + print("=" * 60) + + # 开始训练 + trainer.train(args.steps) + + +if __name__ == "__main__": + main() \ No newline at end of file