"""Generate training plots for the report.""" import os import numpy as np import matplotlib.pyplot as plt import json from collections import defaultdict def load_training_logs(log_file): """加载训练日志 Args: log_file: 日志文件路径 Returns: logs: 训练日志字典 """ logs = defaultdict(list) if os.path.exists(log_file): with open(log_file, 'r') as f: data = json.load(f) for key, values in data.items(): logs[key] = values return logs def smooth_data(data, window=100): """平滑数据 Args: data: 原始数据 window: 平滑窗口大小 Returns: smoothed: 平滑后的数据 """ if len(data) < window: return data smoothed = [] for i in range(len(data)): start = max(0, i - window + 1) smoothed.append(np.mean(data[start:i+1])) return smoothed def plot_training_curves(rewards, losses, q_values, save_dir="plots"): """绘制训练曲线 Args: rewards: 回报列表 losses: 损失列表 q_values: Q值列表 save_dir: 保存目录 """ os.makedirs(save_dir, exist_ok=True) # 创建2x2子图 fig, axes = plt.subplots(2, 2, figsize=(14, 10)) # 1. 训练回报曲线 ax1 = axes[0, 0] if rewards: episodes = range(1, len(rewards) + 1) ax1.plot(episodes, rewards, alpha=0.3, color='blue', label='原始数据') smoothed_rewards = smooth_data(rewards, window=100) ax1.plot(episodes, smoothed_rewards, color='red', linewidth=2, label='平滑曲线 (window=100)') ax1.set_xlabel('Episode', fontsize=12) ax1.set_ylabel('回报', fontsize=12) ax1.set_title('训练回报曲线', fontsize=14) ax1.legend(fontsize=10) ax1.grid(True, alpha=0.3) # 2. 损失曲线 ax2 = axes[0, 1] if losses: steps = range(1, len(losses) + 1) ax2.plot(steps, losses, alpha=0.3, color='green', label='原始数据') smoothed_losses = smooth_data(losses, window=100) ax2.plot(steps, smoothed_losses, color='red', linewidth=2, label='平滑曲线') ax2.set_xlabel('训练步数', fontsize=12) ax2.set_ylabel('损失', fontsize=12) ax2.set_title('训练损失曲线', fontsize=14) ax2.legend(fontsize=10) ax2.grid(True, alpha=0.3) # 3. Q值曲线 ax3 = axes[1, 0] if q_values: steps = range(1, len(q_values) + 1) ax3.plot(steps, q_values, alpha=0.3, color='purple', label='原始数据') smoothed_q = smooth_data(q_values, window=100) ax3.plot(steps, smoothed_q, color='red', linewidth=2, label='平滑曲线') ax3.set_xlabel('训练步数', fontsize=12) ax3.set_ylabel('平均Q值', fontsize=12) ax3.set_title('平均Q值变化', fontsize=14) ax3.legend(fontsize=10) ax3.grid(True, alpha=0.3) # 4. 回报分布直方图 ax4 = axes[1, 1] if rewards: ax4.hist(rewards, bins=30, color='skyblue', edgecolor='black', alpha=0.7) ax4.axvline(np.mean(rewards), color='red', linestyle='--', linewidth=2, label=f'均值: {np.mean(rewards):.1f}') ax4.axvline(np.median(rewards), color='green', linestyle='--', linewidth=2, label=f'中位数: {np.median(rewards):.1f}') ax4.set_xlabel('回报', fontsize=12) ax4.set_ylabel('频次', fontsize=12) ax4.set_title('回报分布', fontsize=14) ax4.legend(fontsize=10) ax4.grid(True, alpha=0.3) plt.tight_layout() # 保存图片 save_path = os.path.join(save_dir, 'training_curves.png') plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"训练曲线已保存到: {save_path}") plt.close() def plot_epsilon_decay(epsilon_start, epsilon_end, decay_steps, save_dir="plots"): """绘制ε衰减曲线 Args: epsilon_start: ε初始值 epsilon_end: ε最终值 decay_steps: 衰减步数 save_dir: 保存目录 """ os.makedirs(save_dir, exist_ok=True) steps = np.linspace(0, decay_steps, 1000) epsilons = epsilon_start - (epsilon_start - epsilon_end) * (steps / decay_steps) fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(steps / 1e6, epsilons, color='blue', linewidth=2) ax.set_xlabel('训练步数 (百万)', fontsize=12) ax.set_ylabel('Epsilon (ε)', fontsize=12) ax.set_title('Epsilon衰减曲线', fontsize=14) ax.grid(True, alpha=0.3) ax.set_ylim(0, 1.1) # 标注关键点 ax.axhline(y=epsilon_end, color='red', linestyle='--', alpha=0.5) ax.text(decay_steps * 0.8 / 1e6, epsilon_end + 0.05, f'最终值: {epsilon_end}', fontsize=10, color='red') save_path = os.path.join(save_dir, 'epsilon_decay.png') plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"ε衰减曲线已保存到: {save_path}") plt.close() def plot_evaluation_results(eval_rewards, save_dir="plots"): """绘制评估结果 Args: eval_rewards: 评估回报列表 [(step, reward), ...] save_dir: 保存目录 """ os.makedirs(save_dir, exist_ok=True) if not eval_rewards: print("没有评估数据") return steps, rewards = zip(*eval_rewards) fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(np.array(steps) / 1e6, rewards, 'o-', color='blue', linewidth=2, markersize=8, label='评估回报') # 添加趋势线 if len(rewards) > 1: z = np.polyfit(steps, rewards, 1) p = np.poly1d(z) ax.plot(np.array(steps) / 1e6, p(steps), '--', color='red', linewidth=2, alpha=0.7, label='趋势线') ax.set_xlabel('训练步数 (百万)', fontsize=12) ax.set_ylabel('平均回报', fontsize=12) ax.set_title('评估回报变化', fontsize=14) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) save_path = os.path.join(save_dir, 'evaluation_results.png') plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"评估结果已保存到: {save_path}") plt.close() def generate_sample_plots(save_dir="plots"): """生成示例图表(用于报告) Args: save_dir: 保存目录 """ os.makedirs(save_dir, exist_ok=True) # 模拟训练数据 np.random.seed(42) num_episodes = 500 # 模拟回报(逐渐上升) base_rewards = np.linspace(-20, 200, num_episodes) noise = np.random.normal(0, 30, num_episodes) rewards = base_rewards + noise # 模拟损失(逐渐下降) num_steps = 1000 base_loss = np.exp(-np.linspace(0, 3, num_steps)) * 10 loss_noise = np.random.normal(0, 0.5, num_steps) losses = base_loss + loss_noise # 模拟Q值(逐渐上升) base_q = np.linspace(0, 50, num_steps) q_noise = np.random.normal(0, 5, num_steps) q_values = base_q + q_noise # 绘制图表 plot_training_curves(rewards, losses, q_values, save_dir) plot_epsilon_decay(1.0, 0.01, 1_000_000, save_dir) # 模拟评估数据 eval_steps = [100000, 200000, 500000, 1000000, 1500000, 2000000] eval_rewards = [50, 100, 150, 180, 190, 195] plot_evaluation_results(list(zip(eval_steps, eval_rewards)), save_dir) print(f"\n示例图表已生成到: {save_dir}/") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="生成训练图表") parser.add_argument("--log-file", type=str, default=None, help="训练日志文件路径") parser.add_argument("--save-dir", type=str, default="plots", help="图表保存目录") parser.add_argument("--sample", action="store_true", help="生成示例图表") args = parser.parse_args() if args.sample: generate_sample_plots(args.save_dir) elif args.log_file: logs = load_training_logs(args.log_file) plot_training_curves( logs.get('rewards', []), logs.get('losses', []), logs.get('q_values', []), args.save_dir ) else: print("请指定 --log-file 或 --sample 参数")