feat: 添加DQN强化学习项目框架和核心实现
实现完整的DQN算法框架,用于Atari Space Invaders游戏训练。包括: - QNetwork和DuelingQNetwork神经网络架构 - 经验回放缓冲区(标准和优先级版本) - DQN智能体实现ε-greedy策略和Double DQN - 环境包装器(灰度化、调整大小、帧堆叠等) - 训练器、评估脚本和图表生成工具 - 详细的项目文档和依赖配置
This commit is contained in:
@@ -0,0 +1,264 @@
|
||||
"""Generate training plots for the report."""
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def load_training_logs(log_file):
|
||||
"""加载训练日志
|
||||
|
||||
Args:
|
||||
log_file: 日志文件路径
|
||||
|
||||
Returns:
|
||||
logs: 训练日志字典
|
||||
"""
|
||||
logs = defaultdict(list)
|
||||
|
||||
if os.path.exists(log_file):
|
||||
with open(log_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
for key, values in data.items():
|
||||
logs[key] = values
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def smooth_data(data, window=100):
|
||||
"""平滑数据
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
window: 平滑窗口大小
|
||||
|
||||
Returns:
|
||||
smoothed: 平滑后的数据
|
||||
"""
|
||||
if len(data) < window:
|
||||
return data
|
||||
|
||||
smoothed = []
|
||||
for i in range(len(data)):
|
||||
start = max(0, i - window + 1)
|
||||
smoothed.append(np.mean(data[start:i+1]))
|
||||
|
||||
return smoothed
|
||||
|
||||
|
||||
def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
"""绘制训练曲线
|
||||
|
||||
Args:
|
||||
rewards: 回报列表
|
||||
losses: 损失列表
|
||||
q_values: Q值列表
|
||||
save_dir: 保存目录
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 创建2x2子图
|
||||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||||
|
||||
# 1. 训练回报曲线
|
||||
ax1 = axes[0, 0]
|
||||
if rewards:
|
||||
episodes = range(1, len(rewards) + 1)
|
||||
ax1.plot(episodes, rewards, alpha=0.3, color='blue', label='原始数据')
|
||||
smoothed_rewards = smooth_data(rewards, window=100)
|
||||
ax1.plot(episodes, smoothed_rewards, color='red', linewidth=2, label='平滑曲线 (window=100)')
|
||||
ax1.set_xlabel('Episode', fontsize=12)
|
||||
ax1.set_ylabel('回报', fontsize=12)
|
||||
ax1.set_title('训练回报曲线', fontsize=14)
|
||||
ax1.legend(fontsize=10)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# 2. 损失曲线
|
||||
ax2 = axes[0, 1]
|
||||
if losses:
|
||||
steps = range(1, len(losses) + 1)
|
||||
ax2.plot(steps, losses, alpha=0.3, color='green', label='原始数据')
|
||||
smoothed_losses = smooth_data(losses, window=100)
|
||||
ax2.plot(steps, smoothed_losses, color='red', linewidth=2, label='平滑曲线')
|
||||
ax2.set_xlabel('训练步数', fontsize=12)
|
||||
ax2.set_ylabel('损失', fontsize=12)
|
||||
ax2.set_title('训练损失曲线', fontsize=14)
|
||||
ax2.legend(fontsize=10)
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# 3. Q值曲线
|
||||
ax3 = axes[1, 0]
|
||||
if q_values:
|
||||
steps = range(1, len(q_values) + 1)
|
||||
ax3.plot(steps, q_values, alpha=0.3, color='purple', label='原始数据')
|
||||
smoothed_q = smooth_data(q_values, window=100)
|
||||
ax3.plot(steps, smoothed_q, color='red', linewidth=2, label='平滑曲线')
|
||||
ax3.set_xlabel('训练步数', fontsize=12)
|
||||
ax3.set_ylabel('平均Q值', fontsize=12)
|
||||
ax3.set_title('平均Q值变化', fontsize=14)
|
||||
ax3.legend(fontsize=10)
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# 4. 回报分布直方图
|
||||
ax4 = axes[1, 1]
|
||||
if rewards:
|
||||
ax4.hist(rewards, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
|
||||
ax4.axvline(np.mean(rewards), color='red', linestyle='--', linewidth=2,
|
||||
label=f'均值: {np.mean(rewards):.1f}')
|
||||
ax4.axvline(np.median(rewards), color='green', linestyle='--', linewidth=2,
|
||||
label=f'中位数: {np.median(rewards):.1f}')
|
||||
ax4.set_xlabel('回报', fontsize=12)
|
||||
ax4.set_ylabel('频次', fontsize=12)
|
||||
ax4.set_title('回报分布', fontsize=14)
|
||||
ax4.legend(fontsize=10)
|
||||
ax4.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
save_path = os.path.join(save_dir, 'training_curves.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"训练曲线已保存到: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_epsilon_decay(epsilon_start, epsilon_end, decay_steps, save_dir="plots"):
|
||||
"""绘制ε衰减曲线
|
||||
|
||||
Args:
|
||||
epsilon_start: ε初始值
|
||||
epsilon_end: ε最终值
|
||||
decay_steps: 衰减步数
|
||||
save_dir: 保存目录
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
steps = np.linspace(0, decay_steps, 1000)
|
||||
epsilons = epsilon_start - (epsilon_start - epsilon_end) * (steps / decay_steps)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
ax.plot(steps / 1e6, epsilons, color='blue', linewidth=2)
|
||||
ax.set_xlabel('训练步数 (百万)', fontsize=12)
|
||||
ax.set_ylabel('Epsilon (ε)', fontsize=12)
|
||||
ax.set_title('Epsilon衰减曲线', fontsize=14)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_ylim(0, 1.1)
|
||||
|
||||
# 标注关键点
|
||||
ax.axhline(y=epsilon_end, color='red', linestyle='--', alpha=0.5)
|
||||
ax.text(decay_steps * 0.8 / 1e6, epsilon_end + 0.05,
|
||||
f'最终值: {epsilon_end}', fontsize=10, color='red')
|
||||
|
||||
save_path = os.path.join(save_dir, 'epsilon_decay.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"ε衰减曲线已保存到: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_evaluation_results(eval_rewards, save_dir="plots"):
|
||||
"""绘制评估结果
|
||||
|
||||
Args:
|
||||
eval_rewards: 评估回报列表 [(step, reward), ...]
|
||||
save_dir: 保存目录
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if not eval_rewards:
|
||||
print("没有评估数据")
|
||||
return
|
||||
|
||||
steps, rewards = zip(*eval_rewards)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
ax.plot(np.array(steps) / 1e6, rewards, 'o-', color='blue',
|
||||
linewidth=2, markersize=8, label='评估回报')
|
||||
|
||||
# 添加趋势线
|
||||
if len(rewards) > 1:
|
||||
z = np.polyfit(steps, rewards, 1)
|
||||
p = np.poly1d(z)
|
||||
ax.plot(np.array(steps) / 1e6, p(steps), '--', color='red',
|
||||
linewidth=2, alpha=0.7, label='趋势线')
|
||||
|
||||
ax.set_xlabel('训练步数 (百万)', fontsize=12)
|
||||
ax.set_ylabel('平均回报', fontsize=12)
|
||||
ax.set_title('评估回报变化', fontsize=14)
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
save_path = os.path.join(save_dir, 'evaluation_results.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"评估结果已保存到: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def generate_sample_plots(save_dir="plots"):
|
||||
"""生成示例图表(用于报告)
|
||||
|
||||
Args:
|
||||
save_dir: 保存目录
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 模拟训练数据
|
||||
np.random.seed(42)
|
||||
num_episodes = 500
|
||||
|
||||
# 模拟回报(逐渐上升)
|
||||
base_rewards = np.linspace(-20, 200, num_episodes)
|
||||
noise = np.random.normal(0, 30, num_episodes)
|
||||
rewards = base_rewards + noise
|
||||
|
||||
# 模拟损失(逐渐下降)
|
||||
num_steps = 1000
|
||||
base_loss = np.exp(-np.linspace(0, 3, num_steps)) * 10
|
||||
loss_noise = np.random.normal(0, 0.5, num_steps)
|
||||
losses = base_loss + loss_noise
|
||||
|
||||
# 模拟Q值(逐渐上升)
|
||||
base_q = np.linspace(0, 50, num_steps)
|
||||
q_noise = np.random.normal(0, 5, num_steps)
|
||||
q_values = base_q + q_noise
|
||||
|
||||
# 绘制图表
|
||||
plot_training_curves(rewards, losses, q_values, save_dir)
|
||||
plot_epsilon_decay(1.0, 0.01, 1_000_000, save_dir)
|
||||
|
||||
# 模拟评估数据
|
||||
eval_steps = [100000, 200000, 500000, 1000000, 1500000, 2000000]
|
||||
eval_rewards = [50, 100, 150, 180, 190, 195]
|
||||
plot_evaluation_results(list(zip(eval_steps, eval_rewards)), save_dir)
|
||||
|
||||
print(f"\n示例图表已生成到: {save_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="生成训练图表")
|
||||
parser.add_argument("--log-file", type=str, default=None,
|
||||
help="训练日志文件路径")
|
||||
parser.add_argument("--save-dir", type=str, default="plots",
|
||||
help="图表保存目录")
|
||||
parser.add_argument("--sample", action="store_true",
|
||||
help="生成示例图表")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sample:
|
||||
generate_sample_plots(args.save_dir)
|
||||
elif args.log_file:
|
||||
logs = load_training_logs(args.log_file)
|
||||
plot_training_curves(
|
||||
logs.get('rewards', []),
|
||||
logs.get('losses', []),
|
||||
logs.get('q_values', []),
|
||||
args.save_dir
|
||||
)
|
||||
else:
|
||||
print("请指定 --log-file 或 --sample 参数")
|
||||
Reference in New Issue
Block a user