Files
rl-atari/强化学习个人项目报告(Atari 游戏方向)/generate_plots.py
T
Serendipity e8b51240f9 feat: 添加DQN强化学习项目框架和核心实现
实现完整的DQN算法框架,用于Atari Space Invaders游戏训练。包括:
- QNetwork和DuelingQNetwork神经网络架构
- 经验回放缓冲区(标准和优先级版本)
- DQN智能体实现ε-greedy策略和Double DQN
- 环境包装器(灰度化、调整大小、帧堆叠等)
- 训练器、评估脚本和图表生成工具
- 详细的项目文档和依赖配置
2026-05-01 10:01:12 +08:00

264 lines
8.0 KiB
Python

"""Generate training plots for the report."""
import os
import numpy as np
import matplotlib.pyplot as plt
import json
from collections import defaultdict
def load_training_logs(log_file):
"""加载训练日志
Args:
log_file: 日志文件路径
Returns:
logs: 训练日志字典
"""
logs = defaultdict(list)
if os.path.exists(log_file):
with open(log_file, 'r') as f:
data = json.load(f)
for key, values in data.items():
logs[key] = values
return logs
def smooth_data(data, window=100):
"""平滑数据
Args:
data: 原始数据
window: 平滑窗口大小
Returns:
smoothed: 平滑后的数据
"""
if len(data) < window:
return data
smoothed = []
for i in range(len(data)):
start = max(0, i - window + 1)
smoothed.append(np.mean(data[start:i+1]))
return smoothed
def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
"""绘制训练曲线
Args:
rewards: 回报列表
losses: 损失列表
q_values: Q值列表
save_dir: 保存目录
"""
os.makedirs(save_dir, exist_ok=True)
# 创建2x2子图
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# 1. 训练回报曲线
ax1 = axes[0, 0]
if rewards:
episodes = range(1, len(rewards) + 1)
ax1.plot(episodes, rewards, alpha=0.3, color='blue', label='原始数据')
smoothed_rewards = smooth_data(rewards, window=100)
ax1.plot(episodes, smoothed_rewards, color='red', linewidth=2, label='平滑曲线 (window=100)')
ax1.set_xlabel('Episode', fontsize=12)
ax1.set_ylabel('回报', fontsize=12)
ax1.set_title('训练回报曲线', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)
# 2. 损失曲线
ax2 = axes[0, 1]
if losses:
steps = range(1, len(losses) + 1)
ax2.plot(steps, losses, alpha=0.3, color='green', label='原始数据')
smoothed_losses = smooth_data(losses, window=100)
ax2.plot(steps, smoothed_losses, color='red', linewidth=2, label='平滑曲线')
ax2.set_xlabel('训练步数', fontsize=12)
ax2.set_ylabel('损失', fontsize=12)
ax2.set_title('训练损失曲线', fontsize=14)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
# 3. Q值曲线
ax3 = axes[1, 0]
if q_values:
steps = range(1, len(q_values) + 1)
ax3.plot(steps, q_values, alpha=0.3, color='purple', label='原始数据')
smoothed_q = smooth_data(q_values, window=100)
ax3.plot(steps, smoothed_q, color='red', linewidth=2, label='平滑曲线')
ax3.set_xlabel('训练步数', fontsize=12)
ax3.set_ylabel('平均Q值', fontsize=12)
ax3.set_title('平均Q值变化', fontsize=14)
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)
# 4. 回报分布直方图
ax4 = axes[1, 1]
if rewards:
ax4.hist(rewards, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
ax4.axvline(np.mean(rewards), color='red', linestyle='--', linewidth=2,
label=f'均值: {np.mean(rewards):.1f}')
ax4.axvline(np.median(rewards), color='green', linestyle='--', linewidth=2,
label=f'中位数: {np.median(rewards):.1f}')
ax4.set_xlabel('回报', fontsize=12)
ax4.set_ylabel('频次', fontsize=12)
ax4.set_title('回报分布', fontsize=14)
ax4.legend(fontsize=10)
ax4.grid(True, alpha=0.3)
plt.tight_layout()
# 保存图片
save_path = os.path.join(save_dir, 'training_curves.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"训练曲线已保存到: {save_path}")
plt.close()
def plot_epsilon_decay(epsilon_start, epsilon_end, decay_steps, save_dir="plots"):
"""绘制ε衰减曲线
Args:
epsilon_start: ε初始值
epsilon_end: ε最终值
decay_steps: 衰减步数
save_dir: 保存目录
"""
os.makedirs(save_dir, exist_ok=True)
steps = np.linspace(0, decay_steps, 1000)
epsilons = epsilon_start - (epsilon_start - epsilon_end) * (steps / decay_steps)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(steps / 1e6, epsilons, color='blue', linewidth=2)
ax.set_xlabel('训练步数 (百万)', fontsize=12)
ax.set_ylabel('Epsilon (ε)', fontsize=12)
ax.set_title('Epsilon衰减曲线', fontsize=14)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1.1)
# 标注关键点
ax.axhline(y=epsilon_end, color='red', linestyle='--', alpha=0.5)
ax.text(decay_steps * 0.8 / 1e6, epsilon_end + 0.05,
f'最终值: {epsilon_end}', fontsize=10, color='red')
save_path = os.path.join(save_dir, 'epsilon_decay.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"ε衰减曲线已保存到: {save_path}")
plt.close()
def plot_evaluation_results(eval_rewards, save_dir="plots"):
"""绘制评估结果
Args:
eval_rewards: 评估回报列表 [(step, reward), ...]
save_dir: 保存目录
"""
os.makedirs(save_dir, exist_ok=True)
if not eval_rewards:
print("没有评估数据")
return
steps, rewards = zip(*eval_rewards)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(np.array(steps) / 1e6, rewards, 'o-', color='blue',
linewidth=2, markersize=8, label='评估回报')
# 添加趋势线
if len(rewards) > 1:
z = np.polyfit(steps, rewards, 1)
p = np.poly1d(z)
ax.plot(np.array(steps) / 1e6, p(steps), '--', color='red',
linewidth=2, alpha=0.7, label='趋势线')
ax.set_xlabel('训练步数 (百万)', fontsize=12)
ax.set_ylabel('平均回报', fontsize=12)
ax.set_title('评估回报变化', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
save_path = os.path.join(save_dir, 'evaluation_results.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"评估结果已保存到: {save_path}")
plt.close()
def generate_sample_plots(save_dir="plots"):
"""生成示例图表(用于报告)
Args:
save_dir: 保存目录
"""
os.makedirs(save_dir, exist_ok=True)
# 模拟训练数据
np.random.seed(42)
num_episodes = 500
# 模拟回报(逐渐上升)
base_rewards = np.linspace(-20, 200, num_episodes)
noise = np.random.normal(0, 30, num_episodes)
rewards = base_rewards + noise
# 模拟损失(逐渐下降)
num_steps = 1000
base_loss = np.exp(-np.linspace(0, 3, num_steps)) * 10
loss_noise = np.random.normal(0, 0.5, num_steps)
losses = base_loss + loss_noise
# 模拟Q值(逐渐上升)
base_q = np.linspace(0, 50, num_steps)
q_noise = np.random.normal(0, 5, num_steps)
q_values = base_q + q_noise
# 绘制图表
plot_training_curves(rewards, losses, q_values, save_dir)
plot_epsilon_decay(1.0, 0.01, 1_000_000, save_dir)
# 模拟评估数据
eval_steps = [100000, 200000, 500000, 1000000, 1500000, 2000000]
eval_rewards = [50, 100, 150, 180, 190, 195]
plot_evaluation_results(list(zip(eval_steps, eval_rewards)), save_dir)
print(f"\n示例图表已生成到: {save_dir}/")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="生成训练图表")
parser.add_argument("--log-file", type=str, default=None,
help="训练日志文件路径")
parser.add_argument("--save-dir", type=str, default="plots",
help="图表保存目录")
parser.add_argument("--sample", action="store_true",
help="生成示例图表")
args = parser.parse_args()
if args.sample:
generate_sample_plots(args.save_dir)
elif args.log_file:
logs = load_training_logs(args.log_file)
plot_training_curves(
logs.get('rewards', []),
logs.get('losses', []),
logs.get('q_values', []),
args.save_dir
)
else:
print("请指定 --log-file 或 --sample 参数")