feat: 添加模型评估脚本并更新实验报告
- 添加 evaluate_checkpoints.py 脚本,用于评估训练过程中的检查点模型 - 更新 generate_plots.py 以支持从真实评估结果生成图表 - 更新实验报告内容,包含具体实验结果数据和分析 - 添加中文支持并更新作者信息 - 生成评估结果JSON文件和相应图表
This commit is contained in:
@@ -8,13 +8,13 @@ from collections import defaultdict
|
||||
|
||||
|
||||
def load_training_logs(log_file):
|
||||
"""加载训练日志
|
||||
"""Load training logs
|
||||
|
||||
Args:
|
||||
log_file: 日志文件路径
|
||||
log_file: Log file path
|
||||
|
||||
Returns:
|
||||
logs: 训练日志字典
|
||||
logs: Training logs dict
|
||||
"""
|
||||
logs = defaultdict(list)
|
||||
|
||||
@@ -28,14 +28,14 @@ def load_training_logs(log_file):
|
||||
|
||||
|
||||
def smooth_data(data, window=100):
|
||||
"""平滑数据
|
||||
"""Smooth data
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
window: 平滑窗口大小
|
||||
data: Raw data
|
||||
window: Smoothing window size
|
||||
|
||||
Returns:
|
||||
smoothed: 平滑后的数据
|
||||
smoothed: Smoothed data
|
||||
"""
|
||||
if len(data) < window:
|
||||
return data
|
||||
@@ -49,13 +49,13 @@ def smooth_data(data, window=100):
|
||||
|
||||
|
||||
def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
"""绘制训练曲线
|
||||
"""Plot training curves
|
||||
|
||||
Args:
|
||||
rewards: 回报列表
|
||||
losses: 损失列表
|
||||
q_values: Q值列表
|
||||
save_dir: 保存目录
|
||||
rewards: Reward list
|
||||
losses: Loss list
|
||||
q_values: Q-value list
|
||||
save_dir: Save directory
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
@@ -66,18 +66,18 @@ def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
ax1 = axes[0, 0]
|
||||
if len(rewards) > 0:
|
||||
episodes = range(1, len(rewards) + 1)
|
||||
ax1.plot(episodes, rewards, alpha=0.3, color="blue", label="原始数据")
|
||||
ax1.plot(episodes, rewards, alpha=0.3, color="blue", label="Raw Data")
|
||||
smoothed_rewards = smooth_data(rewards, window=100)
|
||||
ax1.plot(
|
||||
episodes,
|
||||
smoothed_rewards,
|
||||
color="red",
|
||||
linewidth=2,
|
||||
label="平滑曲线 (window=100)",
|
||||
label="Smoothed (window=100)",
|
||||
)
|
||||
ax1.set_xlabel("Episode", fontsize=12)
|
||||
ax1.set_ylabel("回报", fontsize=12)
|
||||
ax1.set_title("训练回报曲线", fontsize=14)
|
||||
ax1.set_ylabel("Reward", fontsize=12)
|
||||
ax1.set_title("Training Reward Curve", fontsize=14)
|
||||
ax1.legend(fontsize=10)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
@@ -85,12 +85,12 @@ def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
ax2 = axes[0, 1]
|
||||
if len(losses) > 0:
|
||||
steps = range(1, len(losses) + 1)
|
||||
ax2.plot(steps, losses, alpha=0.3, color="green", label="原始数据")
|
||||
ax2.plot(steps, losses, alpha=0.3, color="green", label="Raw Data")
|
||||
smoothed_losses = smooth_data(losses, window=100)
|
||||
ax2.plot(steps, smoothed_losses, color="red", linewidth=2, label="平滑曲线")
|
||||
ax2.set_xlabel("训练步数", fontsize=12)
|
||||
ax2.set_ylabel("损失", fontsize=12)
|
||||
ax2.set_title("训练损失曲线", fontsize=14)
|
||||
ax2.plot(steps, smoothed_losses, color="red", linewidth=2, label="Smoothed")
|
||||
ax2.set_xlabel("Training Steps", fontsize=12)
|
||||
ax2.set_ylabel("Loss", fontsize=12)
|
||||
ax2.set_title("Training Loss Curve", fontsize=14)
|
||||
ax2.legend(fontsize=10)
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
@@ -98,12 +98,12 @@ def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
ax3 = axes[1, 0]
|
||||
if len(q_values) > 0:
|
||||
steps = range(1, len(q_values) + 1)
|
||||
ax3.plot(steps, q_values, alpha=0.3, color="purple", label="原始数据")
|
||||
ax3.plot(steps, q_values, alpha=0.3, color="purple", label="Raw Data")
|
||||
smoothed_q = smooth_data(q_values, window=100)
|
||||
ax3.plot(steps, smoothed_q, color="red", linewidth=2, label="平滑曲线")
|
||||
ax3.set_xlabel("训练步数", fontsize=12)
|
||||
ax3.set_ylabel("平均Q值", fontsize=12)
|
||||
ax3.set_title("平均Q值变化", fontsize=14)
|
||||
ax3.plot(steps, smoothed_q, color="red", linewidth=2, label="Smoothed")
|
||||
ax3.set_xlabel("Training Steps", fontsize=12)
|
||||
ax3.set_ylabel("Average Q Value", fontsize=12)
|
||||
ax3.set_title("Average Q Value", fontsize=14)
|
||||
ax3.legend(fontsize=10)
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
@@ -116,18 +116,18 @@ def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
color="red",
|
||||
linestyle="--",
|
||||
linewidth=2,
|
||||
label=f"均值: {np.mean(rewards):.1f}",
|
||||
label=f"Mean: {np.mean(rewards):.1f}",
|
||||
)
|
||||
ax4.axvline(
|
||||
np.median(rewards),
|
||||
color="green",
|
||||
linestyle="--",
|
||||
linewidth=2,
|
||||
label=f"中位数: {np.median(rewards):.1f}",
|
||||
label=f"Median: {np.median(rewards):.1f}",
|
||||
)
|
||||
ax4.set_xlabel("回报", fontsize=12)
|
||||
ax4.set_ylabel("频次", fontsize=12)
|
||||
ax4.set_title("回报分布", fontsize=14)
|
||||
ax4.set_xlabel("Reward", fontsize=12)
|
||||
ax4.set_ylabel("Frequency", fontsize=12)
|
||||
ax4.set_title("Reward Distribution", fontsize=14)
|
||||
ax4.legend(fontsize=10)
|
||||
ax4.grid(True, alpha=0.3)
|
||||
|
||||
@@ -136,19 +136,19 @@ def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
# 保存图片
|
||||
save_path = os.path.join(save_dir, "training_curves.png")
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"训练曲线已保存到: {save_path}")
|
||||
print(f"Training curves saved: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_epsilon_decay(epsilon_start, epsilon_end, decay_steps, save_dir="plots"):
|
||||
"""绘制ε衰减曲线
|
||||
"""Plot epsilon decay curve
|
||||
|
||||
Args:
|
||||
epsilon_start: ε初始值
|
||||
epsilon_end: ε最终值
|
||||
decay_steps: 衰减步数
|
||||
save_dir: 保存目录
|
||||
epsilon_start: Initial epsilon value
|
||||
epsilon_end: Final epsilon value
|
||||
decay_steps: Decay steps
|
||||
save_dir: Save directory
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
@@ -157,40 +157,39 @@ def plot_epsilon_decay(epsilon_start, epsilon_end, decay_steps, save_dir="plots"
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
ax.plot(steps / 1e6, epsilons, color="blue", linewidth=2)
|
||||
ax.set_xlabel("训练步数 (百万)", fontsize=12)
|
||||
ax.set_ylabel("Epsilon (ε)", fontsize=12)
|
||||
ax.set_title("Epsilon衰减曲线", fontsize=14)
|
||||
ax.set_xlabel("Training Steps (Million)", fontsize=12)
|
||||
ax.set_ylabel("Epsilon", fontsize=12)
|
||||
ax.set_title("Epsilon Decay Curve", fontsize=14)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_ylim(0, 1.1)
|
||||
|
||||
# 标注关键点
|
||||
ax.axhline(y=epsilon_end, color="red", linestyle="--", alpha=0.5)
|
||||
ax.text(
|
||||
decay_steps * 0.8 / 1e6,
|
||||
epsilon_end + 0.05,
|
||||
f"最终值: {epsilon_end}",
|
||||
f"Final: {epsilon_end}",
|
||||
fontsize=10,
|
||||
color="red",
|
||||
)
|
||||
|
||||
save_path = os.path.join(save_dir, "epsilon_decay.png")
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"ε衰减曲线已保存到: {save_path}")
|
||||
print(f"Epsilon decay curve saved: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_evaluation_results(eval_rewards, save_dir="plots"):
|
||||
"""绘制评估结果
|
||||
"""Plot evaluation results
|
||||
|
||||
Args:
|
||||
eval_rewards: 评估回报列表 [(step, reward), ...]
|
||||
save_dir: 保存目录
|
||||
eval_rewards: Evaluation reward list [(step, reward), ...]
|
||||
save_dir: Save directory
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if not eval_rewards:
|
||||
print("没有评估数据")
|
||||
print("No evaluation data")
|
||||
return
|
||||
|
||||
steps, rewards = zip(*eval_rewards)
|
||||
@@ -203,10 +202,9 @@ def plot_evaluation_results(eval_rewards, save_dir="plots"):
|
||||
color="blue",
|
||||
linewidth=2,
|
||||
markersize=8,
|
||||
label="评估回报",
|
||||
label="Eval Reward",
|
||||
)
|
||||
|
||||
# 添加趋势线
|
||||
if len(rewards) > 1:
|
||||
z = np.polyfit(steps, rewards, 1)
|
||||
p = np.poly1d(z)
|
||||
@@ -217,27 +215,27 @@ def plot_evaluation_results(eval_rewards, save_dir="plots"):
|
||||
color="red",
|
||||
linewidth=2,
|
||||
alpha=0.7,
|
||||
label="趋势线",
|
||||
label="Trend Line",
|
||||
)
|
||||
|
||||
ax.set_xlabel("训练步数 (百万)", fontsize=12)
|
||||
ax.set_ylabel("平均回报", fontsize=12)
|
||||
ax.set_title("评估回报变化", fontsize=14)
|
||||
ax.set_xlabel("Training Steps (Million)", fontsize=12)
|
||||
ax.set_ylabel("Average Reward", fontsize=12)
|
||||
ax.set_title("Evaluation Reward", fontsize=14)
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
save_path = os.path.join(save_dir, "evaluation_results.png")
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"评估结果已保存到: {save_path}")
|
||||
print(f"Evaluation results saved: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def generate_sample_plots(save_dir="plots"):
|
||||
"""生成示例图表(用于报告)
|
||||
"""Generate sample plots for report
|
||||
|
||||
Args:
|
||||
save_dir: 保存目录
|
||||
save_dir: Save directory
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
@@ -270,21 +268,119 @@ def generate_sample_plots(save_dir="plots"):
|
||||
eval_rewards = [50, 100, 150, 180, 190, 195]
|
||||
plot_evaluation_results(list(zip(eval_steps, eval_rewards)), save_dir)
|
||||
|
||||
print(f"\n示例图表已生成到: {save_dir}/")
|
||||
print(f"\nSample plots generated: {save_dir}/")
|
||||
|
||||
|
||||
def generate_real_plots(eval_results_file="evaluation_results.json", save_dir="plots"):
|
||||
"""Generate plots from real evaluation results
|
||||
|
||||
Args:
|
||||
eval_results_file: Evaluation results JSON file
|
||||
save_dir: Save directory
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 加载评估结果
|
||||
if not os.path.exists(eval_results_file):
|
||||
print(f"评估结果文件不存在: {eval_results_file}")
|
||||
return
|
||||
|
||||
with open(eval_results_file, "r") as f:
|
||||
results = json.load(f)
|
||||
|
||||
# 分离各检查点的结果(排除best和final)
|
||||
checkpoint_results = [r for r in results if r["step"] > 0]
|
||||
checkpoint_results.sort(key=lambda x: x["step"])
|
||||
|
||||
steps = [r["step"] for r in checkpoint_results]
|
||||
rewards = [r["avg_reward"] for r in checkpoint_results]
|
||||
stds = [r["std_reward"] for r in checkpoint_results]
|
||||
|
||||
# 绘制评估曲线(带误差棒)
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
ax.errorbar(
|
||||
np.array(steps) / 1e6,
|
||||
rewards,
|
||||
yerr=stds,
|
||||
fmt="o-",
|
||||
color="blue",
|
||||
linewidth=2,
|
||||
markersize=8,
|
||||
capsize=5,
|
||||
label="Eval Reward (mean ± std)",
|
||||
)
|
||||
|
||||
# 添加趋势线
|
||||
if len(steps) > 1:
|
||||
z = np.polyfit(steps, rewards, 1)
|
||||
p = np.poly1d(z)
|
||||
ax.plot(
|
||||
np.array(steps) / 1e6,
|
||||
p(steps),
|
||||
"--",
|
||||
color="red",
|
||||
linewidth=2,
|
||||
alpha=0.7,
|
||||
label="Trend Line",
|
||||
)
|
||||
|
||||
ax.set_xlabel("Training Steps (Million)", fontsize=12)
|
||||
ax.set_ylabel("Average Reward", fontsize=12)
|
||||
ax.set_title("Evaluation Reward Over Training", fontsize=14)
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
save_path = os.path.join(save_dir, "evaluation_curve.png")
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"Evaluation curve saved: {save_path}")
|
||||
plt.close()
|
||||
|
||||
# 绘制标准差变化
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
ax.plot(
|
||||
np.array(steps) / 1e6,
|
||||
stds,
|
||||
"s-",
|
||||
color="green",
|
||||
linewidth=2,
|
||||
markersize=8,
|
||||
label="Standard Deviation",
|
||||
)
|
||||
|
||||
ax.set_xlabel("Training Steps (Million)", fontsize=12)
|
||||
ax.set_ylabel("Reward Std Dev", fontsize=12)
|
||||
ax.set_title("Evaluation Reward Standard Deviation", fontsize=14)
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
save_path = os.path.join(save_dir, "evaluation_std.png")
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"Evaluation std saved: {save_path}")
|
||||
plt.close()
|
||||
|
||||
# 打印汇总信息
|
||||
best_result = max(checkpoint_results, key=lambda x: x["avg_reward"])
|
||||
print(f"\n最佳检查点: Step {best_result['step']:,}")
|
||||
print(f" 平均回报: {best_result['avg_reward']:.2f} ± {best_result['std_reward']:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="生成训练图表")
|
||||
parser.add_argument("--log-file", type=str, default=None, help="训练日志文件路径")
|
||||
parser.add_argument("--save-dir", type=str, default="plots", help="图表保存目录")
|
||||
parser.add_argument("--sample", action="store_true", help="生成示例图表")
|
||||
parser = argparse.ArgumentParser(description="Generate training plots")
|
||||
parser.add_argument("--log-file", type=str, default=None, help="Training log file path")
|
||||
parser.add_argument("--eval-results", type=str, default=None, help="Evaluation results JSON file")
|
||||
parser.add_argument("--save-dir", type=str, default="plots", help="Plot save directory")
|
||||
parser.add_argument("--sample", action="store_true", help="Generate sample plots")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sample:
|
||||
generate_sample_plots(args.save_dir)
|
||||
elif args.eval_results:
|
||||
generate_real_plots(args.eval_results, args.save_dir)
|
||||
elif args.log_file:
|
||||
logs = load_training_logs(args.log_file)
|
||||
plot_training_curves(
|
||||
@@ -294,4 +390,4 @@ if __name__ == "__main__":
|
||||
args.save_dir,
|
||||
)
|
||||
else:
|
||||
print("请指定 --log-file 或 --sample 参数")
|
||||
print("Please specify --log-file, --eval-results, or --sample")
|
||||
|
||||
Reference in New Issue
Block a user