feat: 改进DQN训练稳定性和性能
- 将奖励裁剪替换为奖励缩放,保留奖励大小信号 - 添加学习率调度器,支持warmup和步进衰减 - 增加经验回放缓冲区大小至200,000 - 默认启用Dueling DQN架构 - 优化代码格式和参数传递 - 添加更多训练中间模型保存点
This commit is contained in:
@@ -21,3 +21,7 @@ __pycache__/
|
||||
*.o
|
||||
*.exe
|
||||
*.out
|
||||
|
||||
# 模型文件
|
||||
*.pth
|
||||
*.pt
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,4 +1,5 @@
|
||||
"""DQN Agent implementation."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
@@ -26,6 +27,9 @@ class DQNAgent:
|
||||
target_update_freq=1000,
|
||||
batch_size=32,
|
||||
double_dqn=True,
|
||||
lr_decay_steps=1_000_000,
|
||||
lr_decay_factor=0.5,
|
||||
warmup_steps=10_000,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -42,6 +46,9 @@ class DQNAgent:
|
||||
target_update_freq: 目标网络更新频率
|
||||
batch_size: 批次大小
|
||||
double_dqn: 是否使用Double DQN
|
||||
lr_decay_steps: 学习率衰减步数
|
||||
lr_decay_factor: 学习率衰减因子
|
||||
warmup_steps: 预热步数
|
||||
"""
|
||||
self.q_network = q_network
|
||||
self.target_network = target_network
|
||||
@@ -53,19 +60,20 @@ class DQNAgent:
|
||||
self.target_update_freq = target_update_freq
|
||||
self.double_dqn = double_dqn
|
||||
|
||||
# ε-greedy参数
|
||||
self.epsilon_start = epsilon_start
|
||||
self.epsilon_end = epsilon_end
|
||||
self.epsilon_decay_steps = epsilon_decay_steps
|
||||
self.epsilon = epsilon_start
|
||||
|
||||
# 优化器
|
||||
self.base_lr = lr
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.lr_decay_factor = lr_decay_factor
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr)
|
||||
|
||||
# 训练步数
|
||||
self.step_count = 0
|
||||
|
||||
# 训练统计
|
||||
self.loss_history = []
|
||||
self.q_value_history = []
|
||||
|
||||
@@ -92,18 +100,34 @@ class DQNAgent:
|
||||
else:
|
||||
# 贪心选择
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
|
||||
state_tensor = (
|
||||
torch.from_numpy(state).float().unsqueeze(0).to(self.device)
|
||||
)
|
||||
q_values = self.q_network(state_tensor)
|
||||
return q_values.argmax(dim=1).item()
|
||||
|
||||
def update_epsilon(self):
|
||||
"""更新ε值(线性衰减)"""
|
||||
if self.step_count < self.epsilon_decay_steps:
|
||||
self.epsilon = self.epsilon_start - (self.epsilon_start - self.epsilon_end) * \
|
||||
(self.step_count / self.epsilon_decay_steps)
|
||||
self.epsilon = self.epsilon_start - (
|
||||
self.epsilon_start - self.epsilon_end
|
||||
) * (self.step_count / self.epsilon_decay_steps)
|
||||
else:
|
||||
self.epsilon = self.epsilon_end
|
||||
|
||||
def update_learning_rate(self):
|
||||
"""更新学习率:warmup + 步进衰减"""
|
||||
if self.step_count < self.warmup_steps:
|
||||
current_lr = self.base_lr * (self.step_count / self.warmup_steps)
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group["lr"] = current_lr
|
||||
elif (
|
||||
self.step_count % self.lr_decay_steps == 0
|
||||
and self.step_count > self.warmup_steps
|
||||
):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group["lr"] *= self.lr_decay_factor
|
||||
|
||||
def train_step(self):
|
||||
"""执行一步训练
|
||||
|
||||
@@ -116,7 +140,9 @@ class DQNAgent:
|
||||
return None, None
|
||||
|
||||
# 采样
|
||||
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
|
||||
states, actions, rewards, next_states, dones = self.replay_buffer.sample(
|
||||
self.batch_size
|
||||
)
|
||||
|
||||
# 计算当前Q值
|
||||
q_values = self.q_network(states)
|
||||
@@ -128,7 +154,9 @@ class DQNAgent:
|
||||
# Double DQN: 用Q网络选择动作,用目标网络评估
|
||||
next_actions = self.q_network(next_states).argmax(dim=1)
|
||||
next_q_values = self.target_network(next_states)
|
||||
next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||
next_q_values = next_q_values.gather(
|
||||
1, next_actions.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
else:
|
||||
# 标准DQN: 直接用目标网络的最大Q值
|
||||
next_q_values = self.target_network(next_states).max(dim=1)[0]
|
||||
@@ -151,8 +179,9 @@ class DQNAgent:
|
||||
if self.step_count % self.target_update_freq == 0:
|
||||
self.target_network.load_state_dict(self.q_network.state_dict())
|
||||
|
||||
# 更新ε
|
||||
# 更新ε和学习率
|
||||
self.update_epsilon()
|
||||
self.update_learning_rate()
|
||||
|
||||
# 记录统计
|
||||
avg_q = q_values.mean().item()
|
||||
@@ -164,21 +193,24 @@ class DQNAgent:
|
||||
def save(self, path):
|
||||
"""保存模型"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save({
|
||||
'q_network': self.q_network.state_dict(),
|
||||
'target_network': self.target_network.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'step_count': self.step_count,
|
||||
'epsilon': self.epsilon,
|
||||
}, path)
|
||||
torch.save(
|
||||
{
|
||||
"q_network": self.q_network.state_dict(),
|
||||
"target_network": self.target_network.state_dict(),
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
"step_count": self.step_count,
|
||||
"epsilon": self.epsilon,
|
||||
},
|
||||
path,
|
||||
)
|
||||
print(f"模型已保存到: {path}")
|
||||
|
||||
def load(self, path):
|
||||
"""加载模型"""
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
self.q_network.load_state_dict(checkpoint['q_network'])
|
||||
self.target_network.load_state_dict(checkpoint['target_network'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
self.step_count = checkpoint['step_count']
|
||||
self.epsilon = checkpoint['epsilon']
|
||||
self.q_network.load_state_dict(checkpoint["q_network"])
|
||||
self.target_network.load_state_dict(checkpoint["target_network"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
self.step_count = checkpoint["step_count"]
|
||||
self.epsilon = checkpoint["epsilon"]
|
||||
print(f"模型已从 {path} 加载")
|
||||
|
||||
@@ -80,14 +80,15 @@ class FrameStackWrapper(gym.ObservationWrapper):
|
||||
return np.stack(list(self.frames), axis=0)
|
||||
|
||||
|
||||
class RewardClipWrapper(gym.RewardWrapper):
|
||||
"""裁剪奖励到[-1, 1]"""
|
||||
class RewardScaleWrapper(gym.RewardWrapper):
|
||||
"""缩放奖励以稳定训练,同时保留奖励大小信号"""
|
||||
|
||||
def __init__(self, env):
|
||||
def __init__(self, env, scale=10.0):
|
||||
super().__init__(env)
|
||||
self.scale = scale
|
||||
|
||||
def reward(self, reward):
|
||||
return np.clip(reward, -1, 1)
|
||||
return reward / self.scale
|
||||
|
||||
|
||||
class NoopResetWrapper(gym.Wrapper):
|
||||
@@ -174,7 +175,7 @@ def make_env(env_id="ALE/SpaceInvaders-v5", gray_scale=True, resize=True,
|
||||
env = GrayScaleWrapper(env)
|
||||
|
||||
if reward_clip:
|
||||
env = RewardClipWrapper(env)
|
||||
env = RewardScaleWrapper(env, scale=10.0)
|
||||
|
||||
if frame_stack > 1:
|
||||
env = FrameStackWrapper(env, num_stack=frame_stack)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Main training script for DQN on Space Invaders."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import argparse
|
||||
@@ -18,58 +20,61 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="DQN for Space Invaders")
|
||||
|
||||
# 环境参数
|
||||
parser.add_argument("--env", type=str, default="ALE/SpaceInvaders-v5",
|
||||
help="Atari环境ID")
|
||||
parser.add_argument(
|
||||
"--env", type=str, default="ALE/SpaceInvaders-v5", help="Atari环境ID"
|
||||
)
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument("--steps", type=int, default=2_000_000,
|
||||
help="总训练步数")
|
||||
parser.add_argument("--lr", type=float, default=1e-4,
|
||||
help="学习率")
|
||||
parser.add_argument("--gamma", type=float, default=0.99,
|
||||
help="折扣因子")
|
||||
parser.add_argument("--batch-size", type=int, default=32,
|
||||
help="批次大小")
|
||||
parser.add_argument("--buffer-size", type=int, default=100_000,
|
||||
help="经验回放缓冲区大小")
|
||||
parser.add_argument("--steps", type=int, default=2_000_000, help="总训练步数")
|
||||
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
|
||||
parser.add_argument("--gamma", type=float, default=0.99, help="折扣因子")
|
||||
parser.add_argument("--batch-size", type=int, default=32, help="批次大小")
|
||||
parser.add_argument(
|
||||
"--buffer-size", type=int, default=200_000, help="经验回放缓冲区大小"
|
||||
)
|
||||
|
||||
# ε-greedy参数
|
||||
parser.add_argument("--epsilon-start", type=float, default=1.0,
|
||||
help="ε初始值")
|
||||
parser.add_argument("--epsilon-end", type=float, default=0.01,
|
||||
help="ε最终值")
|
||||
parser.add_argument("--epsilon-decay", type=int, default=1_000_000,
|
||||
help="ε衰减步数")
|
||||
parser.add_argument("--epsilon-start", type=float, default=1.0, help="ε初始值")
|
||||
parser.add_argument("--epsilon-end", type=float, default=0.01, help="ε最终值")
|
||||
parser.add_argument(
|
||||
"--epsilon-decay", type=int, default=1_000_000, help="ε衰减步数"
|
||||
)
|
||||
|
||||
# 网络参数
|
||||
parser.add_argument("--target-update", type=int, default=1000,
|
||||
help="目标网络更新频率")
|
||||
parser.add_argument("--double-dqn", action="store_true", default=True,
|
||||
help="使用Double DQN")
|
||||
parser.add_argument("--dueling", action="store_true", default=False,
|
||||
help="使用Dueling DQN架构")
|
||||
parser.add_argument(
|
||||
"--target-update", type=int, default=500, help="目标网络更新频率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--double-dqn", action="store_true", default=True, help="使用Double DQN"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dueling", action="store_true", default=True, help="使用Dueling DQN架构"
|
||||
)
|
||||
|
||||
# 学习率参数
|
||||
parser.add_argument(
|
||||
"--lr-decay-steps", type=int, default=1_000_000, help="学习率衰减步数"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr-decay-factor", type=float, default=0.5, help="学习率衰减因子"
|
||||
)
|
||||
parser.add_argument("--warmup-steps", type=int, default=10_000, help="预热步数")
|
||||
|
||||
# 评估参数
|
||||
parser.add_argument("--eval-freq", type=int, default=10000,
|
||||
help="评估频率")
|
||||
parser.add_argument("--eval-episodes", type=int, default=10,
|
||||
help="评估episode数")
|
||||
parser.add_argument("--save-freq", type=int, default=50000,
|
||||
help="模型保存频率")
|
||||
parser.add_argument("--warmup", type=int, default=10000,
|
||||
help="预热步数")
|
||||
parser.add_argument("--eval-freq", type=int, default=10000, help="评估频率")
|
||||
parser.add_argument("--eval-episodes", type=int, default=10, help="评估episode数")
|
||||
parser.add_argument("--save-freq", type=int, default=50000, help="模型保存频率")
|
||||
parser.add_argument("--warmup", type=int, default=10000, help="预热步数")
|
||||
|
||||
# 优先经验回放
|
||||
parser.add_argument("--prioritized", action="store_true", default=False,
|
||||
help="使用优先经验回放")
|
||||
parser.add_argument(
|
||||
"--prioritized", action="store_true", default=False, help="使用优先经验回放"
|
||||
)
|
||||
|
||||
# 其他
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="随机种子")
|
||||
parser.add_argument("--save-dir", type=str, default="models",
|
||||
help="模型保存目录")
|
||||
parser.add_argument("--log-dir", type=str, default="logs",
|
||||
help="日志目录")
|
||||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||
parser.add_argument("--save-dir", type=str, default="models", help="模型保存目录")
|
||||
parser.add_argument("--log-dir", type=str, default="logs", help="日志目录")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -110,14 +115,10 @@ def main():
|
||||
# 创建经验回放缓冲区
|
||||
if args.prioritized:
|
||||
print("使用优先经验回放")
|
||||
replay_buffer = PrioritizedReplayBuffer(
|
||||
args.buffer_size, state_shape, device
|
||||
)
|
||||
replay_buffer = PrioritizedReplayBuffer(args.buffer_size, state_shape, device)
|
||||
else:
|
||||
print("使用标准经验回放")
|
||||
replay_buffer = ReplayBuffer(
|
||||
args.buffer_size, state_shape, device
|
||||
)
|
||||
replay_buffer = ReplayBuffer(args.buffer_size, state_shape, device)
|
||||
|
||||
# 创建智能体
|
||||
agent = DQNAgent(
|
||||
@@ -134,6 +135,9 @@ def main():
|
||||
target_update_freq=args.target_update,
|
||||
batch_size=args.batch_size,
|
||||
double_dqn=args.double_dqn,
|
||||
lr_decay_steps=args.lr_decay_steps,
|
||||
lr_decay_factor=args.lr_decay_factor,
|
||||
warmup_steps=args.warmup_steps,
|
||||
)
|
||||
|
||||
# 创建训练器
|
||||
@@ -146,20 +150,24 @@ def main():
|
||||
eval_freq=args.eval_freq,
|
||||
save_freq=args.save_freq,
|
||||
num_eval_episodes=args.eval_episodes,
|
||||
warmup_steps=args.warmup,
|
||||
warmup_steps=args.warmup_steps,
|
||||
)
|
||||
|
||||
# 打印配置
|
||||
print("\n训练配置:")
|
||||
print(f" 总步数: {args.steps:,}")
|
||||
print(f" 学习率: {args.lr}")
|
||||
print(f" 学习率衰减: 每{args.lr_decay_steps:,}步衰减{args.lr_decay_factor}倍")
|
||||
print(f" Warmup步数: {args.warmup_steps:,}")
|
||||
print(f" 折扣因子: {args.gamma}")
|
||||
print(f" 批次大小: {args.batch_size}")
|
||||
print(f" 缓冲区大小: {args.buffer_size:,}")
|
||||
print(f" ε衰减: {args.epsilon_start} -> {args.epsilon_end} ({args.epsilon_decay:,}步)")
|
||||
print(
|
||||
f" ε衰减: {args.epsilon_start} -> {args.epsilon_end} ({args.epsilon_decay:,}步)"
|
||||
)
|
||||
print(f" 目标网络更新: 每{args.target_update}步")
|
||||
print(f" Double DQN: {args.double_dqn}")
|
||||
print(f" 预热步数: {args.warmup:,}")
|
||||
print(f" Dueling DQN: {args.dueling}")
|
||||
print("=" * 60)
|
||||
|
||||
# 开始训练
|
||||
|
||||
Reference in New Issue
Block a user