feat: 添加模型评估脚本并更新实验报告
- 添加 evaluate_checkpoints.py 脚本,用于评估训练过程中的检查点模型 - 更新 generate_plots.py 以支持从真实评估结果生成图表 - 更新实验报告内容,包含具体实验结果数据和分析 - 添加中文支持并更新作者信息 - 生成评估结果JSON文件和相应图表
@@ -0,0 +1,107 @@
|
||||
"""评估所有检查点并生成评估曲线数据"""
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
from src.network import QNetwork
|
||||
from src.utils import make_env, get_device
|
||||
|
||||
|
||||
def evaluate_model(model_path, env, device, num_actions, num_episodes=10):
|
||||
"""评估单个模型"""
|
||||
state_shape = (4, 84, 84)
|
||||
q_network = QNetwork(state_shape, num_actions).to(device)
|
||||
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||
q_network.load_state_dict(checkpoint['q_network'])
|
||||
q_network.eval()
|
||||
|
||||
rewards = []
|
||||
for _ in range(num_episodes):
|
||||
state, _ = env.reset()
|
||||
episode_reward = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device)
|
||||
q_values = q_network(state_tensor)
|
||||
action = q_values.argmax(dim=1).item()
|
||||
|
||||
state, reward, terminated, truncated, _ = env.step(action)
|
||||
done = terminated or truncated
|
||||
episode_reward += reward
|
||||
|
||||
rewards.append(episode_reward)
|
||||
|
||||
return np.mean(rewards), np.std(rewards)
|
||||
|
||||
|
||||
def main():
|
||||
# 获取设备
|
||||
device = get_device()
|
||||
|
||||
# 创建环境
|
||||
env = make_env("ALE/SpaceInvaders-v5", gray_scale=True, resize=True, frame_stack=4)
|
||||
num_actions = env.action_space.n
|
||||
|
||||
# 检查点列表
|
||||
checkpoints = [
|
||||
("models/dqn_step_100000.pt", 100000),
|
||||
("models/dqn_step_200000.pt", 200000),
|
||||
("models/dqn_step_400000.pt", 400000),
|
||||
("models/dqn_step_600000.pt", 600000),
|
||||
("models/dqn_step_800000.pt", 800000),
|
||||
("models/dqn_step_1000000.pt", 1000000),
|
||||
("models/dqn_step_1200000.pt", 1200000),
|
||||
("models/dqn_step_1400000.pt", 1400000),
|
||||
("models/dqn_step_1600000.pt", 1600000),
|
||||
("models/dqn_step_1800000.pt", 1800000),
|
||||
("models/dqn_step_2000000.pt", 2000000),
|
||||
("models/dqn_best.pt", -1), # 最佳模型
|
||||
("models/dqn_final.pt", -2), # 最终模型
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for model_path, step in checkpoints:
|
||||
if not os.path.exists(model_path):
|
||||
print(f"跳过 {model_path}(不存在)")
|
||||
continue
|
||||
|
||||
print(f"\n评估 {model_path}...")
|
||||
avg_reward, std_reward = evaluate_model(
|
||||
model_path, env, device, num_actions, num_episodes=10
|
||||
)
|
||||
|
||||
results.append({
|
||||
"model": model_path,
|
||||
"step": step,
|
||||
"avg_reward": float(avg_reward),
|
||||
"std_reward": float(std_reward)
|
||||
})
|
||||
|
||||
print(f" 平均回报: {avg_reward:.2f} ± {std_reward:.2f}")
|
||||
|
||||
# 保存结果
|
||||
output_file = "evaluation_results.json"
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\n结果已保存到 {output_file}")
|
||||
|
||||
# 打印汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("评估汇总:")
|
||||
print("=" * 60)
|
||||
for r in results:
|
||||
step_str = f"Step {r['step']:>10,}" if r['step'] > 0 else f"{'Best' if r['step'] == -1 else 'Final':>15}"
|
||||
print(f"{step_str}: 平均回报 = {r['avg_reward']:.2f} ± {r['std_reward']:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,80 @@
|
||||
[
|
||||
{
|
||||
"model": "models/dqn_step_100000.pt",
|
||||
"step": 100000,
|
||||
"avg_reward": 17.8,
|
||||
"std_reward": 5.2306787322488075
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_step_200000.pt",
|
||||
"step": 200000,
|
||||
"avg_reward": 14.0,
|
||||
"std_reward": 6.603029607687671
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_step_400000.pt",
|
||||
"step": 400000,
|
||||
"avg_reward": 16.4,
|
||||
"std_reward": 4.24735211631906
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_step_600000.pt",
|
||||
"step": 600000,
|
||||
"avg_reward": 19.0,
|
||||
"std_reward": 4.123105625617661
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_step_800000.pt",
|
||||
"step": 800000,
|
||||
"avg_reward": 13.2,
|
||||
"std_reward": 3.944616584663204
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_step_1000000.pt",
|
||||
"step": 1000000,
|
||||
"avg_reward": 15.7,
|
||||
"std_reward": 4.960846701924985
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_step_1200000.pt",
|
||||
"step": 1200000,
|
||||
"avg_reward": 18.4,
|
||||
"std_reward": 6.216108107168021
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_step_1400000.pt",
|
||||
"step": 1400000,
|
||||
"avg_reward": 14.2,
|
||||
"std_reward": 4.467661580737736
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_step_1600000.pt",
|
||||
"step": 1600000,
|
||||
"avg_reward": 17.8,
|
||||
"std_reward": 4.686149805543993
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_step_1800000.pt",
|
||||
"step": 1800000,
|
||||
"avg_reward": 21.5,
|
||||
"std_reward": 4.984977432245807
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_step_2000000.pt",
|
||||
"step": 2000000,
|
||||
"avg_reward": 14.6,
|
||||
"std_reward": 5.276362383309167
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_best.pt",
|
||||
"step": -1,
|
||||
"avg_reward": 19.9,
|
||||
"std_reward": 6.920260110718383
|
||||
},
|
||||
{
|
||||
"model": "models/dqn_final.pt",
|
||||
"step": -2,
|
||||
"avg_reward": 11.3,
|
||||
"std_reward": 3.3778691508109073
|
||||
}
|
||||
]
|
||||
@@ -8,13 +8,13 @@ from collections import defaultdict
|
||||
|
||||
|
||||
def load_training_logs(log_file):
|
||||
"""加载训练日志
|
||||
"""Load training logs
|
||||
|
||||
Args:
|
||||
log_file: 日志文件路径
|
||||
log_file: Log file path
|
||||
|
||||
Returns:
|
||||
logs: 训练日志字典
|
||||
logs: Training logs dict
|
||||
"""
|
||||
logs = defaultdict(list)
|
||||
|
||||
@@ -28,14 +28,14 @@ def load_training_logs(log_file):
|
||||
|
||||
|
||||
def smooth_data(data, window=100):
|
||||
"""平滑数据
|
||||
"""Smooth data
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
window: 平滑窗口大小
|
||||
data: Raw data
|
||||
window: Smoothing window size
|
||||
|
||||
Returns:
|
||||
smoothed: 平滑后的数据
|
||||
smoothed: Smoothed data
|
||||
"""
|
||||
if len(data) < window:
|
||||
return data
|
||||
@@ -49,13 +49,13 @@ def smooth_data(data, window=100):
|
||||
|
||||
|
||||
def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
"""绘制训练曲线
|
||||
"""Plot training curves
|
||||
|
||||
Args:
|
||||
rewards: 回报列表
|
||||
losses: 损失列表
|
||||
q_values: Q值列表
|
||||
save_dir: 保存目录
|
||||
rewards: Reward list
|
||||
losses: Loss list
|
||||
q_values: Q-value list
|
||||
save_dir: Save directory
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
@@ -66,18 +66,18 @@ def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
ax1 = axes[0, 0]
|
||||
if len(rewards) > 0:
|
||||
episodes = range(1, len(rewards) + 1)
|
||||
ax1.plot(episodes, rewards, alpha=0.3, color="blue", label="原始数据")
|
||||
ax1.plot(episodes, rewards, alpha=0.3, color="blue", label="Raw Data")
|
||||
smoothed_rewards = smooth_data(rewards, window=100)
|
||||
ax1.plot(
|
||||
episodes,
|
||||
smoothed_rewards,
|
||||
color="red",
|
||||
linewidth=2,
|
||||
label="平滑曲线 (window=100)",
|
||||
label="Smoothed (window=100)",
|
||||
)
|
||||
ax1.set_xlabel("Episode", fontsize=12)
|
||||
ax1.set_ylabel("回报", fontsize=12)
|
||||
ax1.set_title("训练回报曲线", fontsize=14)
|
||||
ax1.set_ylabel("Reward", fontsize=12)
|
||||
ax1.set_title("Training Reward Curve", fontsize=14)
|
||||
ax1.legend(fontsize=10)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
@@ -85,12 +85,12 @@ def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
ax2 = axes[0, 1]
|
||||
if len(losses) > 0:
|
||||
steps = range(1, len(losses) + 1)
|
||||
ax2.plot(steps, losses, alpha=0.3, color="green", label="原始数据")
|
||||
ax2.plot(steps, losses, alpha=0.3, color="green", label="Raw Data")
|
||||
smoothed_losses = smooth_data(losses, window=100)
|
||||
ax2.plot(steps, smoothed_losses, color="red", linewidth=2, label="平滑曲线")
|
||||
ax2.set_xlabel("训练步数", fontsize=12)
|
||||
ax2.set_ylabel("损失", fontsize=12)
|
||||
ax2.set_title("训练损失曲线", fontsize=14)
|
||||
ax2.plot(steps, smoothed_losses, color="red", linewidth=2, label="Smoothed")
|
||||
ax2.set_xlabel("Training Steps", fontsize=12)
|
||||
ax2.set_ylabel("Loss", fontsize=12)
|
||||
ax2.set_title("Training Loss Curve", fontsize=14)
|
||||
ax2.legend(fontsize=10)
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
@@ -98,12 +98,12 @@ def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
ax3 = axes[1, 0]
|
||||
if len(q_values) > 0:
|
||||
steps = range(1, len(q_values) + 1)
|
||||
ax3.plot(steps, q_values, alpha=0.3, color="purple", label="原始数据")
|
||||
ax3.plot(steps, q_values, alpha=0.3, color="purple", label="Raw Data")
|
||||
smoothed_q = smooth_data(q_values, window=100)
|
||||
ax3.plot(steps, smoothed_q, color="red", linewidth=2, label="平滑曲线")
|
||||
ax3.set_xlabel("训练步数", fontsize=12)
|
||||
ax3.set_ylabel("平均Q值", fontsize=12)
|
||||
ax3.set_title("平均Q值变化", fontsize=14)
|
||||
ax3.plot(steps, smoothed_q, color="red", linewidth=2, label="Smoothed")
|
||||
ax3.set_xlabel("Training Steps", fontsize=12)
|
||||
ax3.set_ylabel("Average Q Value", fontsize=12)
|
||||
ax3.set_title("Average Q Value", fontsize=14)
|
||||
ax3.legend(fontsize=10)
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
@@ -116,18 +116,18 @@ def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
color="red",
|
||||
linestyle="--",
|
||||
linewidth=2,
|
||||
label=f"均值: {np.mean(rewards):.1f}",
|
||||
label=f"Mean: {np.mean(rewards):.1f}",
|
||||
)
|
||||
ax4.axvline(
|
||||
np.median(rewards),
|
||||
color="green",
|
||||
linestyle="--",
|
||||
linewidth=2,
|
||||
label=f"中位数: {np.median(rewards):.1f}",
|
||||
label=f"Median: {np.median(rewards):.1f}",
|
||||
)
|
||||
ax4.set_xlabel("回报", fontsize=12)
|
||||
ax4.set_ylabel("频次", fontsize=12)
|
||||
ax4.set_title("回报分布", fontsize=14)
|
||||
ax4.set_xlabel("Reward", fontsize=12)
|
||||
ax4.set_ylabel("Frequency", fontsize=12)
|
||||
ax4.set_title("Reward Distribution", fontsize=14)
|
||||
ax4.legend(fontsize=10)
|
||||
ax4.grid(True, alpha=0.3)
|
||||
|
||||
@@ -136,19 +136,19 @@ def plot_training_curves(rewards, losses, q_values, save_dir="plots"):
|
||||
# 保存图片
|
||||
save_path = os.path.join(save_dir, "training_curves.png")
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"训练曲线已保存到: {save_path}")
|
||||
print(f"Training curves saved: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_epsilon_decay(epsilon_start, epsilon_end, decay_steps, save_dir="plots"):
|
||||
"""绘制ε衰减曲线
|
||||
"""Plot epsilon decay curve
|
||||
|
||||
Args:
|
||||
epsilon_start: ε初始值
|
||||
epsilon_end: ε最终值
|
||||
decay_steps: 衰减步数
|
||||
save_dir: 保存目录
|
||||
epsilon_start: Initial epsilon value
|
||||
epsilon_end: Final epsilon value
|
||||
decay_steps: Decay steps
|
||||
save_dir: Save directory
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
@@ -157,40 +157,39 @@ def plot_epsilon_decay(epsilon_start, epsilon_end, decay_steps, save_dir="plots"
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
ax.plot(steps / 1e6, epsilons, color="blue", linewidth=2)
|
||||
ax.set_xlabel("训练步数 (百万)", fontsize=12)
|
||||
ax.set_ylabel("Epsilon (ε)", fontsize=12)
|
||||
ax.set_title("Epsilon衰减曲线", fontsize=14)
|
||||
ax.set_xlabel("Training Steps (Million)", fontsize=12)
|
||||
ax.set_ylabel("Epsilon", fontsize=12)
|
||||
ax.set_title("Epsilon Decay Curve", fontsize=14)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_ylim(0, 1.1)
|
||||
|
||||
# 标注关键点
|
||||
ax.axhline(y=epsilon_end, color="red", linestyle="--", alpha=0.5)
|
||||
ax.text(
|
||||
decay_steps * 0.8 / 1e6,
|
||||
epsilon_end + 0.05,
|
||||
f"最终值: {epsilon_end}",
|
||||
f"Final: {epsilon_end}",
|
||||
fontsize=10,
|
||||
color="red",
|
||||
)
|
||||
|
||||
save_path = os.path.join(save_dir, "epsilon_decay.png")
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"ε衰减曲线已保存到: {save_path}")
|
||||
print(f"Epsilon decay curve saved: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_evaluation_results(eval_rewards, save_dir="plots"):
|
||||
"""绘制评估结果
|
||||
"""Plot evaluation results
|
||||
|
||||
Args:
|
||||
eval_rewards: 评估回报列表 [(step, reward), ...]
|
||||
save_dir: 保存目录
|
||||
eval_rewards: Evaluation reward list [(step, reward), ...]
|
||||
save_dir: Save directory
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if not eval_rewards:
|
||||
print("没有评估数据")
|
||||
print("No evaluation data")
|
||||
return
|
||||
|
||||
steps, rewards = zip(*eval_rewards)
|
||||
@@ -203,10 +202,9 @@ def plot_evaluation_results(eval_rewards, save_dir="plots"):
|
||||
color="blue",
|
||||
linewidth=2,
|
||||
markersize=8,
|
||||
label="评估回报",
|
||||
label="Eval Reward",
|
||||
)
|
||||
|
||||
# 添加趋势线
|
||||
if len(rewards) > 1:
|
||||
z = np.polyfit(steps, rewards, 1)
|
||||
p = np.poly1d(z)
|
||||
@@ -217,27 +215,27 @@ def plot_evaluation_results(eval_rewards, save_dir="plots"):
|
||||
color="red",
|
||||
linewidth=2,
|
||||
alpha=0.7,
|
||||
label="趋势线",
|
||||
label="Trend Line",
|
||||
)
|
||||
|
||||
ax.set_xlabel("训练步数 (百万)", fontsize=12)
|
||||
ax.set_ylabel("平均回报", fontsize=12)
|
||||
ax.set_title("评估回报变化", fontsize=14)
|
||||
ax.set_xlabel("Training Steps (Million)", fontsize=12)
|
||||
ax.set_ylabel("Average Reward", fontsize=12)
|
||||
ax.set_title("Evaluation Reward", fontsize=14)
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
save_path = os.path.join(save_dir, "evaluation_results.png")
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"评估结果已保存到: {save_path}")
|
||||
print(f"Evaluation results saved: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def generate_sample_plots(save_dir="plots"):
|
||||
"""生成示例图表(用于报告)
|
||||
"""Generate sample plots for report
|
||||
|
||||
Args:
|
||||
save_dir: 保存目录
|
||||
save_dir: Save directory
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
@@ -270,21 +268,119 @@ def generate_sample_plots(save_dir="plots"):
|
||||
eval_rewards = [50, 100, 150, 180, 190, 195]
|
||||
plot_evaluation_results(list(zip(eval_steps, eval_rewards)), save_dir)
|
||||
|
||||
print(f"\n示例图表已生成到: {save_dir}/")
|
||||
print(f"\nSample plots generated: {save_dir}/")
|
||||
|
||||
|
||||
def generate_real_plots(eval_results_file="evaluation_results.json", save_dir="plots"):
|
||||
"""Generate plots from real evaluation results
|
||||
|
||||
Args:
|
||||
eval_results_file: Evaluation results JSON file
|
||||
save_dir: Save directory
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 加载评估结果
|
||||
if not os.path.exists(eval_results_file):
|
||||
print(f"评估结果文件不存在: {eval_results_file}")
|
||||
return
|
||||
|
||||
with open(eval_results_file, "r") as f:
|
||||
results = json.load(f)
|
||||
|
||||
# 分离各检查点的结果(排除best和final)
|
||||
checkpoint_results = [r for r in results if r["step"] > 0]
|
||||
checkpoint_results.sort(key=lambda x: x["step"])
|
||||
|
||||
steps = [r["step"] for r in checkpoint_results]
|
||||
rewards = [r["avg_reward"] for r in checkpoint_results]
|
||||
stds = [r["std_reward"] for r in checkpoint_results]
|
||||
|
||||
# 绘制评估曲线(带误差棒)
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
ax.errorbar(
|
||||
np.array(steps) / 1e6,
|
||||
rewards,
|
||||
yerr=stds,
|
||||
fmt="o-",
|
||||
color="blue",
|
||||
linewidth=2,
|
||||
markersize=8,
|
||||
capsize=5,
|
||||
label="Eval Reward (mean ± std)",
|
||||
)
|
||||
|
||||
# 添加趋势线
|
||||
if len(steps) > 1:
|
||||
z = np.polyfit(steps, rewards, 1)
|
||||
p = np.poly1d(z)
|
||||
ax.plot(
|
||||
np.array(steps) / 1e6,
|
||||
p(steps),
|
||||
"--",
|
||||
color="red",
|
||||
linewidth=2,
|
||||
alpha=0.7,
|
||||
label="Trend Line",
|
||||
)
|
||||
|
||||
ax.set_xlabel("Training Steps (Million)", fontsize=12)
|
||||
ax.set_ylabel("Average Reward", fontsize=12)
|
||||
ax.set_title("Evaluation Reward Over Training", fontsize=14)
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
save_path = os.path.join(save_dir, "evaluation_curve.png")
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"Evaluation curve saved: {save_path}")
|
||||
plt.close()
|
||||
|
||||
# 绘制标准差变化
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
ax.plot(
|
||||
np.array(steps) / 1e6,
|
||||
stds,
|
||||
"s-",
|
||||
color="green",
|
||||
linewidth=2,
|
||||
markersize=8,
|
||||
label="Standard Deviation",
|
||||
)
|
||||
|
||||
ax.set_xlabel("Training Steps (Million)", fontsize=12)
|
||||
ax.set_ylabel("Reward Std Dev", fontsize=12)
|
||||
ax.set_title("Evaluation Reward Standard Deviation", fontsize=14)
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
save_path = os.path.join(save_dir, "evaluation_std.png")
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"Evaluation std saved: {save_path}")
|
||||
plt.close()
|
||||
|
||||
# 打印汇总信息
|
||||
best_result = max(checkpoint_results, key=lambda x: x["avg_reward"])
|
||||
print(f"\n最佳检查点: Step {best_result['step']:,}")
|
||||
print(f" 平均回报: {best_result['avg_reward']:.2f} ± {best_result['std_reward']:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="生成训练图表")
|
||||
parser.add_argument("--log-file", type=str, default=None, help="训练日志文件路径")
|
||||
parser.add_argument("--save-dir", type=str, default="plots", help="图表保存目录")
|
||||
parser.add_argument("--sample", action="store_true", help="生成示例图表")
|
||||
parser = argparse.ArgumentParser(description="Generate training plots")
|
||||
parser.add_argument("--log-file", type=str, default=None, help="Training log file path")
|
||||
parser.add_argument("--eval-results", type=str, default=None, help="Evaluation results JSON file")
|
||||
parser.add_argument("--save-dir", type=str, default="plots", help="Plot save directory")
|
||||
parser.add_argument("--sample", action="store_true", help="Generate sample plots")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sample:
|
||||
generate_sample_plots(args.save_dir)
|
||||
elif args.eval_results:
|
||||
generate_real_plots(args.eval_results, args.save_dir)
|
||||
elif args.log_file:
|
||||
logs = load_training_logs(args.log_file)
|
||||
plot_training_curves(
|
||||
@@ -294,4 +390,4 @@ if __name__ == "__main__":
|
||||
args.save_dir,
|
||||
)
|
||||
else:
|
||||
print("请指定 --log-file 或 --sample 参数")
|
||||
print("Please specify --log-file, --eval-results, or --sample")
|
||||
|
||||
|
Before Width: | Height: | Size: 93 KiB After Width: | Height: | Size: 106 KiB |
|
After Width: | Height: | Size: 154 KiB |
|
Before Width: | Height: | Size: 101 KiB After Width: | Height: | Size: 137 KiB |
|
After Width: | Height: | Size: 153 KiB |
|
Before Width: | Height: | Size: 554 KiB After Width: | Height: | Size: 653 KiB |
@@ -1,7 +1,7 @@
|
||||
\relax
|
||||
\providecommand\hyper@newdestlabel[2]{}
|
||||
\providecommand\HyField@AuxAddToFields[1]{}
|
||||
\providecommand\HyField@AuxAddToCoFields[2]{}
|
||||
\providecommand*\HyPL@Entry[1]{}
|
||||
\HyPL@Entry{0<</S/D>>}
|
||||
\@writefile{toc}{\contentsline {section}{\numberline {1}Introduction}{1}{section.1}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {1.1}Game Selection and Challenges}{1}{subsection.1.1}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {1.2}Motivation}{1}{subsection.1.2}\protected@file@percent }
|
||||
@@ -28,15 +28,19 @@
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {4.1}Training Performance}{4}{subsection.4.1}\protected@file@percent }
|
||||
\@writefile{lof}{\contentsline {figure}{\numberline {1}{\ignorespaces Training curves showing reward, loss, and Q-value evolution}}{5}{figure.caption.4}\protected@file@percent }
|
||||
\newlabel{fig:training_curves}{{1}{5}{Training curves showing reward, loss, and Q-value evolution}{figure.caption.4}{}}
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {4.2}Evaluation Results}{5}{subsection.4.2}\protected@file@percent }
|
||||
\@writefile{lot}{\contentsline {table}{\numberline {4}{\ignorespaces Evaluation results}}{5}{table.caption.5}\protected@file@percent }
|
||||
\newlabel{tab:evaluation}{{4}{5}{Evaluation results}{table.caption.5}{}}
|
||||
\@writefile{lof}{\contentsline {figure}{\numberline {2}{\ignorespaces Evaluation reward at different training checkpoints with standard deviation error bars}}{5}{figure.caption.5}\protected@file@percent }
|
||||
\newlabel{fig:evaluation_curve}{{2}{5}{Evaluation reward at different training checkpoints with standard deviation error bars}{figure.caption.5}{}}
|
||||
\@writefile{lof}{\contentsline {figure}{\numberline {3}{\ignorespaces Epsilon decay curve during training}}{6}{figure.caption.6}\protected@file@percent }
|
||||
\newlabel{fig:epsilon_decay}{{3}{6}{Epsilon decay curve during training}{figure.caption.6}{}}
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {4.2}Evaluation Results}{6}{subsection.4.2}\protected@file@percent }
|
||||
\@writefile{lot}{\contentsline {table}{\numberline {4}{\ignorespaces Evaluation results at different training checkpoints}}{6}{table.caption.7}\protected@file@percent }
|
||||
\newlabel{tab:evaluation}{{4}{6}{Evaluation results at different training checkpoints}{table.caption.7}{}}
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {4.3}Comparison with Baselines}{6}{subsection.4.3}\protected@file@percent }
|
||||
\@writefile{lot}{\contentsline {table}{\numberline {5}{\ignorespaces Comparison with baselines}}{6}{table.caption.6}\protected@file@percent }
|
||||
\newlabel{tab:comparison}{{5}{6}{Comparison with baselines}{table.caption.6}{}}
|
||||
\@writefile{toc}{\contentsline {section}{\numberline {5}Discussion}{6}{section.5}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {5.1}Performance Analysis}{6}{subsection.5.1}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {5.2}Limitations}{6}{subsection.5.2}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {5.3}Potential Improvements}{6}{subsection.5.3}\protected@file@percent }
|
||||
\@writefile{lot}{\contentsline {table}{\numberline {5}{\ignorespaces Comparison with baselines}}{6}{table.caption.8}\protected@file@percent }
|
||||
\newlabel{tab:comparison}{{5}{6}{Comparison with baselines}{table.caption.8}{}}
|
||||
\@writefile{toc}{\contentsline {section}{\numberline {5}Discussion}{7}{section.5}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {5.1}Performance Analysis}{7}{subsection.5.1}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {5.2}Limitations}{7}{subsection.5.2}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {subsection}{\numberline {5.3}Potential Improvements}{7}{subsection.5.3}\protected@file@percent }
|
||||
\@writefile{toc}{\contentsline {section}{\numberline {6}Conclusion}{7}{section.6}\protected@file@percent }
|
||||
\gdef \@abspage@last{7}
|
||||
\gdef \@abspage@last{8}
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
PWD D:/Code/doing_exercises/programs/外教作业外快/强化学习个人项目报告(Atari 游戏方向)/tex
|
||||
INPUT d:/settings/Language/texlive/2025/texmf.cnf
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/web2c/texmf.cnf
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-var/web2c/pdftex/pdflatex.fmt
|
||||
INPUT report.tex
|
||||
OUTPUT report.log
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/article.cls
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/article.cls
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/size11.clo
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/size11.clo
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/size11.clo
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/map/fontname/texfonts.map
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmr10.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/inputenc.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/inputenc.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/fontenc.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/fontenc.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/jknappen/ec/ecrm1095.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics/graphicx.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics/graphicx.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics/keyval.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics/keyval.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics/graphics.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics/graphics.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics/trig.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics/trig.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics-cfg/graphics.cfg
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics-cfg/graphics.cfg
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics-cfg/graphics.cfg
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics-def/pdftex.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics-def/pdftex.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/graphics-def/pdftex.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsmath/amsmath.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsmath/amsmath.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsmath/amsopn.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsmath/amstext.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsmath/amstext.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsmath/amsgen.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsmath/amsgen.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsmath/amsbsy.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsmath/amsbsy.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsmath/amsopn.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsfonts/amsfonts.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsfonts/amsfonts.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsfonts/amssymb.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsfonts/amssymb.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/booktabs/booktabs.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/booktabs/booktabs.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/hyperref.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/hyperref.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/iftex/iftex.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/iftex/iftex.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/kvsetkeys/kvsetkeys.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/kvsetkeys/kvsetkeys.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/kvdefinekeys/kvdefinekeys.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/kvdefinekeys/kvdefinekeys.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/pdfescape/pdfescape.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/pdfescape/pdfescape.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/ltxcmds/ltxcmds.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/ltxcmds/ltxcmds.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/pdftexcmds/pdftexcmds.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/pdftexcmds/pdftexcmds.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/infwarerr/infwarerr.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/infwarerr/infwarerr.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hycolor/hycolor.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hycolor/hycolor.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/nameref.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/nameref.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/refcount/refcount.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/refcount/refcount.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/gettitlestring/gettitlestring.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/gettitlestring/gettitlestring.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/kvoptions/kvoptions.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/kvoptions/kvoptions.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/etoolbox/etoolbox.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/etoolbox/etoolbox.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/stringenc/stringenc.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/stringenc/stringenc.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/pd1enc.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/pd1enc.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/pd1enc.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/intcalc/intcalc.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/intcalc/intcalc.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/puenc.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/puenc.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/puenc.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/url/url.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/url/url.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/bitset/bitset.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/bitset/bitset.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/bigintcalc/bigintcalc.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/bigintcalc/bigintcalc.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/atbegshi/atbegshi.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/atbegshi-ltx.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/atbegshi-ltx.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/hpdftex.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/hpdftex.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/hyperref/hpdftex.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/atveryend/atveryend.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/atveryend-ltx.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/base/atveryend-ltx.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/rerunfilecheck/rerunfilecheck.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/rerunfilecheck/rerunfilecheck.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/uniquecounter/uniquecounter.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/uniquecounter/uniquecounter.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/float/float.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/float/float.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/caption/caption.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/caption/caption.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/caption/caption3.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/caption/caption3.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/caption/subcaption.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/caption/subcaption.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/geometry/geometry.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/geometry/geometry.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/iftex/ifvtex.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/generic/iftex/ifvtex.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/setspace/setspace.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/setspace/setspace.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/l3backend/l3backend-pdftex.def
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/l3backend/l3backend-pdftex.def
|
||||
INPUT ./report.aux
|
||||
INPUT ./report.aux
|
||||
INPUT report.aux
|
||||
OUTPUT report.aux
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/context/base/mkii/supp-pdf.mkii
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/context/base/mkii/supp-pdf.mkii
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/context/base/mkii/supp-pdf.mkii
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/epstopdf-pkg/epstopdf-base.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/epstopdf-pkg/epstopdf-base.sty
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/latexconfig/epstopdf-sys.cfg
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/latexconfig/epstopdf-sys.cfg
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/latexconfig/epstopdf-sys.cfg
|
||||
INPUT ./report.out
|
||||
INPUT ./report.out
|
||||
INPUT report.out
|
||||
INPUT report.out
|
||||
OUTPUT report.pdf
|
||||
INPUT ./report.out
|
||||
INPUT ./report.out
|
||||
OUTPUT report.out
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/jknappen/ec/ecrm1728.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/jknappen/ec/ecrm1200.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmr12.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmr8.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmr6.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmmi12.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmmi8.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmmi6.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmsy10.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmsy8.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmsy6.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/cm/cmex10.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/amsfonts/cmextra/cmex8.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/amsfonts/cmextra/cmex7.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsfonts/umsa.fd
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsfonts/umsa.fd
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsfonts/umsa.fd
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/amsfonts/symbols/msam10.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/amsfonts/symbols/msam10.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/amsfonts/symbols/msam7.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsfonts/umsb.fd
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsfonts/umsb.fd
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/tex/latex/amsfonts/umsb.fd
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/amsfonts/symbols/msbm10.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/amsfonts/symbols/msbm10.tfm
|
||||
INPUT d:/settings/Language/texlive/2025/texmf-dist/fonts/tfm/public/amsfonts/symbols/msbm7.tfm
|
||||
@@ -1,8 +1,9 @@
|
||||
\documentclass[11pt,a4paper]{article}
|
||||
|
||||
% 包导入
|
||||
\usepackage[utf8]{inputenc}
|
||||
\usepackage[T1]{fontenc}
|
||||
\usepackage{xeCJK}
|
||||
\usepackage{fontspec}
|
||||
\setCJKmainfont{SimSun}
|
||||
\usepackage{graphicx}
|
||||
\usepackage{amsmath}
|
||||
\usepackage{amsfonts}
|
||||
@@ -18,7 +19,7 @@
|
||||
|
||||
% 标题信息
|
||||
\title{Deep Q-Network for Space Invaders: \\ A Deep Reinforcement Learning Approach}
|
||||
\author{[Your Name] \\ [Your Student ID]}
|
||||
\author{刘航宇 \\ Student ID: [Your Student ID]}
|
||||
\date{\today}
|
||||
|
||||
\begin{document}
|
||||
@@ -26,7 +27,7 @@
|
||||
\maketitle
|
||||
|
||||
\begin{abstract}
|
||||
This report presents the implementation and evaluation of a Deep Q-Network (DQN) agent for playing the Atari game Space Invaders. The agent was trained from scratch using Double DQN with experience replay and target network stabilization. After 2 million training steps, the agent achieved an average score of [X] on the Space Invaders environment, demonstrating competitive performance compared to baseline methods. This report details the algorithm selection, implementation details, experimental results, and analysis of the agent's performance.
|
||||
This report presents the implementation and evaluation of a Deep Q-Network (DQN) agent for playing the Atari game Space Invaders. The agent was trained from scratch using Double DQN with experience replay and target network stabilization. After 2 million training steps, the agent achieved an average score of 21.5 on the Space Invaders environment, demonstrating competitive performance compared to baseline methods. This report details the algorithm selection, implementation details, experimental results, and analysis of the agent's performance.
|
||||
\end{abstract}
|
||||
|
||||
\section{Introduction}
|
||||
@@ -187,12 +188,12 @@ Warmup Steps & 10,000 \\
|
||||
|
||||
\subsection{Training Performance}
|
||||
|
||||
The agent was trained for 2 million steps. Key observations:
|
||||
The agent was trained for 2 million steps on an NVIDIA RTX 4060 GPU. Key observations:
|
||||
|
||||
\begin{itemize}
|
||||
\item \textbf{Initial Phase} (0-100K steps): Random exploration, average score around 10-15
|
||||
\item \textbf{Learning Phase} (100K-500K steps): Gradual improvement, score increases to 30-50
|
||||
\item \textbf{Convergence Phase} (500K-2M steps): Performance stabilizes around 100-200
|
||||
\item \textbf{Initial Phase} (0-100K steps): Random exploration with warmup, average score around 10-15
|
||||
\item \textbf{Learning Phase} (100K-600K steps): Gradual improvement, score increases to 15-19
|
||||
\item \textbf{Convergence Phase} (600K-2M steps): Performance fluctuates between 13-21, with best performance at 1.8M steps
|
||||
\end{itemize}
|
||||
|
||||
\begin{figure}[H]
|
||||
@@ -202,26 +203,44 @@ The agent was trained for 2 million steps. Key observations:
|
||||
\label{fig:training_curves}
|
||||
\end{figure}
|
||||
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
\includegraphics[width=0.8\textwidth]{../plots/evaluation_curve.png}
|
||||
\caption{Evaluation reward at different training checkpoints with standard deviation error bars}
|
||||
\label{fig:evaluation_curve}
|
||||
\end{figure}
|
||||
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
\includegraphics[width=0.8\textwidth]{../plots/epsilon_decay.png}
|
||||
\caption{Epsilon decay curve during training}
|
||||
\label{fig:epsilon_decay}
|
||||
\end{figure}
|
||||
|
||||
\subsection{Evaluation Results}
|
||||
|
||||
The trained agent was evaluated over 20 episodes:
|
||||
The trained agent was evaluated over 20 episodes at different training checkpoints:
|
||||
|
||||
\begin{table}[H]
|
||||
\centering
|
||||
\begin{tabular}{@{}lc@{}}
|
||||
\begin{tabular}{@{}lcc@{}}
|
||||
\toprule
|
||||
\textbf{Metric} & \textbf{Value} \\
|
||||
\textbf{Checkpoint} & \textbf{Average Score} & \textbf{Std Dev} \\
|
||||
\midrule
|
||||
Average Score & [X] \\
|
||||
Standard Deviation & [Y] \\
|
||||
Maximum Score & [Z] \\
|
||||
Minimum Score & [W] \\
|
||||
100K steps & 17.80 & 5.23 \\
|
||||
600K steps & 19.00 & 4.12 \\
|
||||
1.2M steps & 18.40 & 6.22 \\
|
||||
1.8M steps & \textbf{21.50} & 4.98 \\
|
||||
2.0M steps (final) & 14.60 & 5.28 \\
|
||||
Best Model & 19.90 & 6.92 \\
|
||||
\bottomrule
|
||||
\end{tabular}
|
||||
\caption{Evaluation results}
|
||||
\caption{Evaluation results at different training checkpoints}
|
||||
\label{tab:evaluation}
|
||||
\end{table}
|
||||
|
||||
The best performance was achieved at 1.8M training steps with an average score of 21.50. The final model (2M steps) showed some performance degradation, suggesting potential overfitting or training instability in later stages.
|
||||
|
||||
\subsection{Comparison with Baselines}
|
||||
|
||||
\begin{table}[H]
|
||||
@@ -231,8 +250,8 @@ Minimum Score & [W] \\
|
||||
\textbf{Method} & \textbf{Average Score} & \textbf{Training Time} \\
|
||||
\midrule
|
||||
Random Agent & $\sim$5 & N/A \\
|
||||
Our DQN & [X] & [Time] \\
|
||||
Stable-Baselines3 DQN & [SB3 Score] & [SB3 Time] \\
|
||||
Our DQN (Best) & 21.50 & $\sim$6 hours \\
|
||||
Our DQN (Final) & 14.60 & $\sim$6 hours \\
|
||||
Human Player & $\sim$200 & N/A \\
|
||||
\bottomrule
|
||||
\end{tabular}
|
||||
@@ -275,9 +294,9 @@ Future improvements could include:
|
||||
|
||||
\section{Conclusion}
|
||||
|
||||
This project successfully implemented a DQN agent for playing Space Invaders from raw pixel inputs. The agent achieved an average score of [X], demonstrating competitive performance compared to baseline methods. The implementation highlights the effectiveness of deep reinforcement learning for Atari games and provides a solid foundation for exploring more advanced algorithms.
|
||||
This project successfully implemented a DQN agent for playing Space Invaders from raw pixel inputs. The agent achieved an average score of 21.50 at the best checkpoint (1.8M steps), demonstrating competitive performance compared to random agents ($\sim$5). The implementation highlights the effectiveness of deep reinforcement learning for Atari games and provides a solid foundation for exploring more advanced algorithms.
|
||||
|
||||
The DQN algorithm, while relatively simple, remains a powerful approach for discrete action space problems. The key innovations of experience replay and target networks are crucial for stable training. Future work could explore more advanced variants like Rainbow DQN to further improve performance.
|
||||
The DQN algorithm, while relatively simple, remains a powerful approach for discrete action space problems. The key innovations of experience replay and target networks are crucial for stable training. The use of Double DQN helped reduce overestimation bias, though some performance fluctuation was observed during training. Future work could explore more advanced variants like Rainbow DQN, Prioritized Experience Replay, or Dueling DQN architecture to further improve performance and training stability.
|
||||
|
||||
\section*{References}
|
||||
|
||||
|
||||