feat: 添加DQN强化学习项目框架和核心实现

实现完整的DQN算法框架,用于Atari Space Invaders游戏训练。包括:
- QNetwork和DuelingQNetwork神经网络架构
- 经验回放缓冲区(标准和优先级版本)
- DQN智能体实现ε-greedy策略和Double DQN
- 环境包装器(灰度化、调整大小、帧堆叠等)
- 训练器、评估脚本和图表生成工具
- 详细的项目文档和依赖配置
This commit is contained in:
2026-05-01 10:01:12 +08:00
parent cdec40a7c7
commit e8b51240f9
13 changed files with 1561 additions and 84 deletions
@@ -0,0 +1,200 @@
"""Environment wrappers and utility functions."""
import gymnasium as gym
import numpy as np
import torch
from collections import deque
# 注册ALE环境
try:
import ale_py
gym.register_envs(ale_py)
except ImportError:
pass
class GrayScaleWrapper(gym.ObservationWrapper):
"""将RGB观测转换为灰度图"""
def __init__(self, env):
super().__init__(env)
def observation(self, obs):
# RGB转灰度:加权平均
gray = 0.299 * obs[:, :, 0] + 0.587 * obs[:, :, 1] + 0.114 * obs[:, :, 2]
return gray.astype(np.uint8)
class ResizeWrapper(gym.ObservationWrapper):
"""调整观测大小"""
def __init__(self, env, size=(84, 84)):
super().__init__(env)
self.size = size
def observation(self, obs):
import cv2
# 如果是灰度图,需要扩展维度
if len(obs.shape) == 2:
obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
else:
obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
return obs
class FrameStackWrapper(gym.ObservationWrapper):
"""堆叠N帧观测"""
def __init__(self, env, num_stack=4):
super().__init__(env)
self.num_stack = num_stack
self.frames = deque(maxlen=num_stack)
obs_shape = env.observation_space.shape
# 更新观测空间
if len(obs_shape) == 2:
# 灰度图
self.observation_space = gym.spaces.Box(
low=0, high=255,
shape=(num_stack, *obs_shape),
dtype=np.uint8
)
else:
# RGB图
self.observation_space = gym.spaces.Box(
low=0, high=255,
shape=(num_stack, *obs_shape[-2:]),
dtype=np.uint8
)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
for _ in range(self.num_stack):
self.frames.append(obs)
return self._get_observation(), info
def observation(self, obs):
self.frames.append(obs)
return self._get_observation()
def _get_observation(self):
return np.stack(list(self.frames), axis=0)
class RewardClipWrapper(gym.RewardWrapper):
"""裁剪奖励到[-1, 1]"""
def __init__(self, env):
super().__init__(env)
def reward(self, reward):
return np.clip(reward, -1, 1)
class NoopResetWrapper(gym.Wrapper):
"""在reset时随机执行noop动作,增加初始状态随机性"""
def __init__(self, env, noop_max=30):
super().__init__(env)
self.noop_max = noop_max
self.noop_action = 0
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
# 随机执行noop动作
noop_times = np.random.randint(1, self.noop_max + 1)
for _ in range(noop_times):
obs, reward, terminated, truncated, info = self.env.step(self.noop_action)
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
return obs, info
class MaxAndSkipWrapper(gym.Wrapper):
"""跳帧并取最大值,减少计算量"""
def __init__(self, env, skip=4):
super().__init__(env)
self.skip = skip
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
def step(self, action):
total_reward = 0.0
terminated = False
truncated = False
for i in range(self.skip):
obs, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
if i == self.skip - 2:
self._obs_buffer[0] = obs
if i == self.skip - 1:
self._obs_buffer[1] = obs
if terminated or truncated:
break
# 取最近两帧的最大值
max_frame = self._obs_buffer.max(axis=0)
return max_frame, total_reward, terminated, truncated, info
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def make_env(env_id="ALE/SpaceInvaders-v5", gray_scale=True, resize=True,
frame_stack=4, reward_clip=True, noop_reset=True, max_skip=4):
"""创建预处理后的Atari环境
Args:
env_id: 环境ID
gray_scale: 是否灰度化
resize: 是否调整大小
frame_stack: 堆叠帧数
reward_clip: 是否裁剪奖励
noop_reset: 是否使用noop reset
max_skip: 跳帧数
Returns:
env: 预处理后的环境
"""
env = gym.make(env_id, render_mode="rgb_array")
if noop_reset:
env = NoopResetWrapper(env, noop_max=30)
if max_skip > 1:
env = MaxAndSkipWrapper(env, skip=max_skip)
if resize:
env = ResizeWrapper(env, size=(84, 84))
if gray_scale:
env = GrayScaleWrapper(env)
if reward_clip:
env = RewardClipWrapper(env)
if frame_stack > 1:
env = FrameStackWrapper(env, num_stack=frame_stack)
return env
def get_device():
"""检测并返回可用设备"""
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"使用GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("使用CPU")
return device
def preprocess_obs(obs):
"""确保观测格式正确"""
if len(obs.shape) == 2:
obs = np.expand_dims(obs, axis=0)
return obs