feat: 改进DQN训练稳定性和性能
- 将奖励裁剪替换为奖励缩放,保留奖励大小信号 - 添加学习率调度器,支持warmup和步进衰减 - 增加经验回放缓冲区大小至200,000 - 默认启用Dueling DQN架构 - 优化代码格式和参数传递 - 添加更多训练中间模型保存点
This commit is contained in:
@@ -21,3 +21,7 @@ __pycache__/
|
|||||||
*.o
|
*.o
|
||||||
*.exe
|
*.exe
|
||||||
*.out
|
*.out
|
||||||
|
|
||||||
|
# 模型文件
|
||||||
|
*.pth
|
||||||
|
*.pt
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,4 +1,5 @@
|
|||||||
"""DQN Agent implementation."""
|
"""DQN Agent implementation."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -26,6 +27,9 @@ class DQNAgent:
|
|||||||
target_update_freq=1000,
|
target_update_freq=1000,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
double_dqn=True,
|
double_dqn=True,
|
||||||
|
lr_decay_steps=1_000_000,
|
||||||
|
lr_decay_factor=0.5,
|
||||||
|
warmup_steps=10_000,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -42,6 +46,9 @@ class DQNAgent:
|
|||||||
target_update_freq: 目标网络更新频率
|
target_update_freq: 目标网络更新频率
|
||||||
batch_size: 批次大小
|
batch_size: 批次大小
|
||||||
double_dqn: 是否使用Double DQN
|
double_dqn: 是否使用Double DQN
|
||||||
|
lr_decay_steps: 学习率衰减步数
|
||||||
|
lr_decay_factor: 学习率衰减因子
|
||||||
|
warmup_steps: 预热步数
|
||||||
"""
|
"""
|
||||||
self.q_network = q_network
|
self.q_network = q_network
|
||||||
self.target_network = target_network
|
self.target_network = target_network
|
||||||
@@ -53,19 +60,20 @@ class DQNAgent:
|
|||||||
self.target_update_freq = target_update_freq
|
self.target_update_freq = target_update_freq
|
||||||
self.double_dqn = double_dqn
|
self.double_dqn = double_dqn
|
||||||
|
|
||||||
# ε-greedy参数
|
|
||||||
self.epsilon_start = epsilon_start
|
self.epsilon_start = epsilon_start
|
||||||
self.epsilon_end = epsilon_end
|
self.epsilon_end = epsilon_end
|
||||||
self.epsilon_decay_steps = epsilon_decay_steps
|
self.epsilon_decay_steps = epsilon_decay_steps
|
||||||
self.epsilon = epsilon_start
|
self.epsilon = epsilon_start
|
||||||
|
|
||||||
# 优化器
|
self.base_lr = lr
|
||||||
|
self.lr_decay_steps = lr_decay_steps
|
||||||
|
self.lr_decay_factor = lr_decay_factor
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
|
||||||
self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr)
|
self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr)
|
||||||
|
|
||||||
# 训练步数
|
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
|
|
||||||
# 训练统计
|
|
||||||
self.loss_history = []
|
self.loss_history = []
|
||||||
self.q_value_history = []
|
self.q_value_history = []
|
||||||
|
|
||||||
@@ -92,18 +100,34 @@ class DQNAgent:
|
|||||||
else:
|
else:
|
||||||
# 贪心选择
|
# 贪心选择
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
|
state_tensor = (
|
||||||
|
torch.from_numpy(state).float().unsqueeze(0).to(self.device)
|
||||||
|
)
|
||||||
q_values = self.q_network(state_tensor)
|
q_values = self.q_network(state_tensor)
|
||||||
return q_values.argmax(dim=1).item()
|
return q_values.argmax(dim=1).item()
|
||||||
|
|
||||||
def update_epsilon(self):
|
def update_epsilon(self):
|
||||||
"""更新ε值(线性衰减)"""
|
"""更新ε值(线性衰减)"""
|
||||||
if self.step_count < self.epsilon_decay_steps:
|
if self.step_count < self.epsilon_decay_steps:
|
||||||
self.epsilon = self.epsilon_start - (self.epsilon_start - self.epsilon_end) * \
|
self.epsilon = self.epsilon_start - (
|
||||||
(self.step_count / self.epsilon_decay_steps)
|
self.epsilon_start - self.epsilon_end
|
||||||
|
) * (self.step_count / self.epsilon_decay_steps)
|
||||||
else:
|
else:
|
||||||
self.epsilon = self.epsilon_end
|
self.epsilon = self.epsilon_end
|
||||||
|
|
||||||
|
def update_learning_rate(self):
|
||||||
|
"""更新学习率:warmup + 步进衰减"""
|
||||||
|
if self.step_count < self.warmup_steps:
|
||||||
|
current_lr = self.base_lr * (self.step_count / self.warmup_steps)
|
||||||
|
for param_group in self.optimizer.param_groups:
|
||||||
|
param_group["lr"] = current_lr
|
||||||
|
elif (
|
||||||
|
self.step_count % self.lr_decay_steps == 0
|
||||||
|
and self.step_count > self.warmup_steps
|
||||||
|
):
|
||||||
|
for param_group in self.optimizer.param_groups:
|
||||||
|
param_group["lr"] *= self.lr_decay_factor
|
||||||
|
|
||||||
def train_step(self):
|
def train_step(self):
|
||||||
"""执行一步训练
|
"""执行一步训练
|
||||||
|
|
||||||
@@ -116,7 +140,9 @@ class DQNAgent:
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# 采样
|
# 采样
|
||||||
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
|
states, actions, rewards, next_states, dones = self.replay_buffer.sample(
|
||||||
|
self.batch_size
|
||||||
|
)
|
||||||
|
|
||||||
# 计算当前Q值
|
# 计算当前Q值
|
||||||
q_values = self.q_network(states)
|
q_values = self.q_network(states)
|
||||||
@@ -128,7 +154,9 @@ class DQNAgent:
|
|||||||
# Double DQN: 用Q网络选择动作,用目标网络评估
|
# Double DQN: 用Q网络选择动作,用目标网络评估
|
||||||
next_actions = self.q_network(next_states).argmax(dim=1)
|
next_actions = self.q_network(next_states).argmax(dim=1)
|
||||||
next_q_values = self.target_network(next_states)
|
next_q_values = self.target_network(next_states)
|
||||||
next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
next_q_values = next_q_values.gather(
|
||||||
|
1, next_actions.unsqueeze(1)
|
||||||
|
).squeeze(1)
|
||||||
else:
|
else:
|
||||||
# 标准DQN: 直接用目标网络的最大Q值
|
# 标准DQN: 直接用目标网络的最大Q值
|
||||||
next_q_values = self.target_network(next_states).max(dim=1)[0]
|
next_q_values = self.target_network(next_states).max(dim=1)[0]
|
||||||
@@ -151,8 +179,9 @@ class DQNAgent:
|
|||||||
if self.step_count % self.target_update_freq == 0:
|
if self.step_count % self.target_update_freq == 0:
|
||||||
self.target_network.load_state_dict(self.q_network.state_dict())
|
self.target_network.load_state_dict(self.q_network.state_dict())
|
||||||
|
|
||||||
# 更新ε
|
# 更新ε和学习率
|
||||||
self.update_epsilon()
|
self.update_epsilon()
|
||||||
|
self.update_learning_rate()
|
||||||
|
|
||||||
# 记录统计
|
# 记录统计
|
||||||
avg_q = q_values.mean().item()
|
avg_q = q_values.mean().item()
|
||||||
@@ -164,21 +193,24 @@ class DQNAgent:
|
|||||||
def save(self, path):
|
def save(self, path):
|
||||||
"""保存模型"""
|
"""保存模型"""
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
torch.save({
|
torch.save(
|
||||||
'q_network': self.q_network.state_dict(),
|
{
|
||||||
'target_network': self.target_network.state_dict(),
|
"q_network": self.q_network.state_dict(),
|
||||||
'optimizer': self.optimizer.state_dict(),
|
"target_network": self.target_network.state_dict(),
|
||||||
'step_count': self.step_count,
|
"optimizer": self.optimizer.state_dict(),
|
||||||
'epsilon': self.epsilon,
|
"step_count": self.step_count,
|
||||||
}, path)
|
"epsilon": self.epsilon,
|
||||||
|
},
|
||||||
|
path,
|
||||||
|
)
|
||||||
print(f"模型已保存到: {path}")
|
print(f"模型已保存到: {path}")
|
||||||
|
|
||||||
def load(self, path):
|
def load(self, path):
|
||||||
"""加载模型"""
|
"""加载模型"""
|
||||||
checkpoint = torch.load(path, map_location=self.device)
|
checkpoint = torch.load(path, map_location=self.device)
|
||||||
self.q_network.load_state_dict(checkpoint['q_network'])
|
self.q_network.load_state_dict(checkpoint["q_network"])
|
||||||
self.target_network.load_state_dict(checkpoint['target_network'])
|
self.target_network.load_state_dict(checkpoint["target_network"])
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
||||||
self.step_count = checkpoint['step_count']
|
self.step_count = checkpoint["step_count"]
|
||||||
self.epsilon = checkpoint['epsilon']
|
self.epsilon = checkpoint["epsilon"]
|
||||||
print(f"模型已从 {path} 加载")
|
print(f"模型已从 {path} 加载")
|
||||||
|
|||||||
@@ -80,14 +80,15 @@ class FrameStackWrapper(gym.ObservationWrapper):
|
|||||||
return np.stack(list(self.frames), axis=0)
|
return np.stack(list(self.frames), axis=0)
|
||||||
|
|
||||||
|
|
||||||
class RewardClipWrapper(gym.RewardWrapper):
|
class RewardScaleWrapper(gym.RewardWrapper):
|
||||||
"""裁剪奖励到[-1, 1]"""
|
"""缩放奖励以稳定训练,同时保留奖励大小信号"""
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env, scale=10.0):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
def reward(self, reward):
|
def reward(self, reward):
|
||||||
return np.clip(reward, -1, 1)
|
return reward / self.scale
|
||||||
|
|
||||||
|
|
||||||
class NoopResetWrapper(gym.Wrapper):
|
class NoopResetWrapper(gym.Wrapper):
|
||||||
@@ -174,7 +175,7 @@ def make_env(env_id="ALE/SpaceInvaders-v5", gray_scale=True, resize=True,
|
|||||||
env = GrayScaleWrapper(env)
|
env = GrayScaleWrapper(env)
|
||||||
|
|
||||||
if reward_clip:
|
if reward_clip:
|
||||||
env = RewardClipWrapper(env)
|
env = RewardScaleWrapper(env, scale=10.0)
|
||||||
|
|
||||||
if frame_stack > 1:
|
if frame_stack > 1:
|
||||||
env = FrameStackWrapper(env, num_stack=frame_stack)
|
env = FrameStackWrapper(env, num_stack=frame_stack)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Main training script for DQN on Space Invaders."""
|
"""Main training script for DQN on Space Invaders."""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -18,58 +20,61 @@ def main():
|
|||||||
parser = argparse.ArgumentParser(description="DQN for Space Invaders")
|
parser = argparse.ArgumentParser(description="DQN for Space Invaders")
|
||||||
|
|
||||||
# 环境参数
|
# 环境参数
|
||||||
parser.add_argument("--env", type=str, default="ALE/SpaceInvaders-v5",
|
parser.add_argument(
|
||||||
help="Atari环境ID")
|
"--env", type=str, default="ALE/SpaceInvaders-v5", help="Atari环境ID"
|
||||||
|
)
|
||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
parser.add_argument("--steps", type=int, default=2_000_000,
|
parser.add_argument("--steps", type=int, default=2_000_000, help="总训练步数")
|
||||||
help="总训练步数")
|
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
|
||||||
parser.add_argument("--lr", type=float, default=1e-4,
|
parser.add_argument("--gamma", type=float, default=0.99, help="折扣因子")
|
||||||
help="学习率")
|
parser.add_argument("--batch-size", type=int, default=32, help="批次大小")
|
||||||
parser.add_argument("--gamma", type=float, default=0.99,
|
parser.add_argument(
|
||||||
help="折扣因子")
|
"--buffer-size", type=int, default=200_000, help="经验回放缓冲区大小"
|
||||||
parser.add_argument("--batch-size", type=int, default=32,
|
)
|
||||||
help="批次大小")
|
|
||||||
parser.add_argument("--buffer-size", type=int, default=100_000,
|
|
||||||
help="经验回放缓冲区大小")
|
|
||||||
|
|
||||||
# ε-greedy参数
|
# ε-greedy参数
|
||||||
parser.add_argument("--epsilon-start", type=float, default=1.0,
|
parser.add_argument("--epsilon-start", type=float, default=1.0, help="ε初始值")
|
||||||
help="ε初始值")
|
parser.add_argument("--epsilon-end", type=float, default=0.01, help="ε最终值")
|
||||||
parser.add_argument("--epsilon-end", type=float, default=0.01,
|
parser.add_argument(
|
||||||
help="ε最终值")
|
"--epsilon-decay", type=int, default=1_000_000, help="ε衰减步数"
|
||||||
parser.add_argument("--epsilon-decay", type=int, default=1_000_000,
|
)
|
||||||
help="ε衰减步数")
|
|
||||||
|
|
||||||
# 网络参数
|
# 网络参数
|
||||||
parser.add_argument("--target-update", type=int, default=1000,
|
parser.add_argument(
|
||||||
help="目标网络更新频率")
|
"--target-update", type=int, default=500, help="目标网络更新频率"
|
||||||
parser.add_argument("--double-dqn", action="store_true", default=True,
|
)
|
||||||
help="使用Double DQN")
|
parser.add_argument(
|
||||||
parser.add_argument("--dueling", action="store_true", default=False,
|
"--double-dqn", action="store_true", default=True, help="使用Double DQN"
|
||||||
help="使用Dueling DQN架构")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dueling", action="store_true", default=True, help="使用Dueling DQN架构"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 学习率参数
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr-decay-steps", type=int, default=1_000_000, help="学习率衰减步数"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr-decay-factor", type=float, default=0.5, help="学习率衰减因子"
|
||||||
|
)
|
||||||
|
parser.add_argument("--warmup-steps", type=int, default=10_000, help="预热步数")
|
||||||
|
|
||||||
# 评估参数
|
# 评估参数
|
||||||
parser.add_argument("--eval-freq", type=int, default=10000,
|
parser.add_argument("--eval-freq", type=int, default=10000, help="评估频率")
|
||||||
help="评估频率")
|
parser.add_argument("--eval-episodes", type=int, default=10, help="评估episode数")
|
||||||
parser.add_argument("--eval-episodes", type=int, default=10,
|
parser.add_argument("--save-freq", type=int, default=50000, help="模型保存频率")
|
||||||
help="评估episode数")
|
parser.add_argument("--warmup", type=int, default=10000, help="预热步数")
|
||||||
parser.add_argument("--save-freq", type=int, default=50000,
|
|
||||||
help="模型保存频率")
|
|
||||||
parser.add_argument("--warmup", type=int, default=10000,
|
|
||||||
help="预热步数")
|
|
||||||
|
|
||||||
# 优先经验回放
|
# 优先经验回放
|
||||||
parser.add_argument("--prioritized", action="store_true", default=False,
|
parser.add_argument(
|
||||||
help="使用优先经验回放")
|
"--prioritized", action="store_true", default=False, help="使用优先经验回放"
|
||||||
|
)
|
||||||
|
|
||||||
# 其他
|
# 其他
|
||||||
parser.add_argument("--seed", type=int, default=42,
|
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||||
help="随机种子")
|
parser.add_argument("--save-dir", type=str, default="models", help="模型保存目录")
|
||||||
parser.add_argument("--save-dir", type=str, default="models",
|
parser.add_argument("--log-dir", type=str, default="logs", help="日志目录")
|
||||||
help="模型保存目录")
|
|
||||||
parser.add_argument("--log-dir", type=str, default="logs",
|
|
||||||
help="日志目录")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -110,14 +115,10 @@ def main():
|
|||||||
# 创建经验回放缓冲区
|
# 创建经验回放缓冲区
|
||||||
if args.prioritized:
|
if args.prioritized:
|
||||||
print("使用优先经验回放")
|
print("使用优先经验回放")
|
||||||
replay_buffer = PrioritizedReplayBuffer(
|
replay_buffer = PrioritizedReplayBuffer(args.buffer_size, state_shape, device)
|
||||||
args.buffer_size, state_shape, device
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print("使用标准经验回放")
|
print("使用标准经验回放")
|
||||||
replay_buffer = ReplayBuffer(
|
replay_buffer = ReplayBuffer(args.buffer_size, state_shape, device)
|
||||||
args.buffer_size, state_shape, device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建智能体
|
# 创建智能体
|
||||||
agent = DQNAgent(
|
agent = DQNAgent(
|
||||||
@@ -134,6 +135,9 @@ def main():
|
|||||||
target_update_freq=args.target_update,
|
target_update_freq=args.target_update,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
double_dqn=args.double_dqn,
|
double_dqn=args.double_dqn,
|
||||||
|
lr_decay_steps=args.lr_decay_steps,
|
||||||
|
lr_decay_factor=args.lr_decay_factor,
|
||||||
|
warmup_steps=args.warmup_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建训练器
|
# 创建训练器
|
||||||
@@ -146,20 +150,24 @@ def main():
|
|||||||
eval_freq=args.eval_freq,
|
eval_freq=args.eval_freq,
|
||||||
save_freq=args.save_freq,
|
save_freq=args.save_freq,
|
||||||
num_eval_episodes=args.eval_episodes,
|
num_eval_episodes=args.eval_episodes,
|
||||||
warmup_steps=args.warmup,
|
warmup_steps=args.warmup_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 打印配置
|
# 打印配置
|
||||||
print("\n训练配置:")
|
print("\n训练配置:")
|
||||||
print(f" 总步数: {args.steps:,}")
|
print(f" 总步数: {args.steps:,}")
|
||||||
print(f" 学习率: {args.lr}")
|
print(f" 学习率: {args.lr}")
|
||||||
|
print(f" 学习率衰减: 每{args.lr_decay_steps:,}步衰减{args.lr_decay_factor}倍")
|
||||||
|
print(f" Warmup步数: {args.warmup_steps:,}")
|
||||||
print(f" 折扣因子: {args.gamma}")
|
print(f" 折扣因子: {args.gamma}")
|
||||||
print(f" 批次大小: {args.batch_size}")
|
print(f" 批次大小: {args.batch_size}")
|
||||||
print(f" 缓冲区大小: {args.buffer_size:,}")
|
print(f" 缓冲区大小: {args.buffer_size:,}")
|
||||||
print(f" ε衰减: {args.epsilon_start} -> {args.epsilon_end} ({args.epsilon_decay:,}步)")
|
print(
|
||||||
|
f" ε衰减: {args.epsilon_start} -> {args.epsilon_end} ({args.epsilon_decay:,}步)"
|
||||||
|
)
|
||||||
print(f" 目标网络更新: 每{args.target_update}步")
|
print(f" 目标网络更新: 每{args.target_update}步")
|
||||||
print(f" Double DQN: {args.double_dqn}")
|
print(f" Double DQN: {args.double_dqn}")
|
||||||
print(f" 预热步数: {args.warmup:,}")
|
print(f" Dueling DQN: {args.dueling}")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
@@ -167,4 +175,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user