Files
rl-atari/强化学习个人项目报告(Atari 游戏方向)/src/utils.py
T
Serendipity ed0822966b feat(训练): 添加并行环境DQN训练脚本和Jupyter笔记本
- 新增 train_parallel.py 脚本,使用 AsyncVectorEnv 并行运行多个Atari环境
- 添加配套的 Jupyter 笔记本 train_parallel.ipynb 用于交互式训练
- 在 utils.py 的 wrapper 中修复 observation_space 定义,确保与预处理后的观测形状一致
- 删除旧的压缩文件 CW2_DQN_SpaceInvaders.zip
- 新增图片文件 image.png

并行训练器通过批量GPU推理和异步环境步进显著提升数据收集速度,适合在多核服务器环境下运行。包含完整的超参数配置、进度监控和模型保存功能。
2026-05-03 16:29:14 +08:00

215 lines
6.0 KiB
Python

"""Environment wrappers and utility functions."""
import gymnasium as gym
import numpy as np
import torch
from collections import deque
# 注册ALE环境
try:
import ale_py
gym.register_envs(ale_py)
except ImportError:
pass
class GrayScaleWrapper(gym.ObservationWrapper):
"""将RGB观测转换为灰度图"""
def __init__(self, env):
super().__init__(env)
old_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=old_shape[:2], dtype=np.uint8
)
def observation(self, obs):
# RGB转灰度:加权平均
gray = 0.299 * obs[:, :, 0] + 0.587 * obs[:, :, 1] + 0.114 * obs[:, :, 2]
return gray.astype(np.uint8)
class ResizeWrapper(gym.ObservationWrapper):
"""调整观测大小"""
def __init__(self, env, size=(84, 84)):
super().__init__(env)
self.size = size
obs_shape = self.observation_space.shape
if len(obs_shape) == 3:
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(*size, obs_shape[-1]), dtype=np.uint8
)
else:
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=size, dtype=np.uint8
)
def observation(self, obs):
import cv2
# 如果是灰度图,需要扩展维度
if len(obs.shape) == 2:
obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
else:
obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
return obs
class FrameStackWrapper(gym.ObservationWrapper):
"""堆叠N帧观测"""
def __init__(self, env, num_stack=4):
super().__init__(env)
self.num_stack = num_stack
self.frames = deque(maxlen=num_stack)
obs_shape = env.observation_space.shape
# 更新观测空间
if len(obs_shape) == 2:
# 灰度图
self.observation_space = gym.spaces.Box(
low=0, high=255,
shape=(num_stack, *obs_shape),
dtype=np.uint8
)
else:
# RGB图
self.observation_space = gym.spaces.Box(
low=0, high=255,
shape=(num_stack, *obs_shape[-2:]),
dtype=np.uint8
)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
for _ in range(self.num_stack):
self.frames.append(obs)
return self._get_observation(), info
def observation(self, obs):
self.frames.append(obs)
return self._get_observation()
def _get_observation(self):
return np.stack(list(self.frames), axis=0)
class RewardScaleWrapper(gym.RewardWrapper):
"""缩放奖励以稳定训练,同时保留奖励大小信号"""
def __init__(self, env, scale=10.0):
super().__init__(env)
self.scale = scale
def reward(self, reward):
return reward / self.scale
class NoopResetWrapper(gym.Wrapper):
"""在reset时随机执行noop动作,增加初始状态随机性"""
def __init__(self, env, noop_max=30):
super().__init__(env)
self.noop_max = noop_max
self.noop_action = 0
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
# 随机执行noop动作
noop_times = np.random.randint(1, self.noop_max + 1)
for _ in range(noop_times):
obs, reward, terminated, truncated, info = self.env.step(self.noop_action)
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
return obs, info
class MaxAndSkipWrapper(gym.Wrapper):
"""跳帧并取最大值,减少计算量"""
def __init__(self, env, skip=4):
super().__init__(env)
self.skip = skip
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
def step(self, action):
total_reward = 0.0
terminated = False
truncated = False
for i in range(self.skip):
obs, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
if i == self.skip - 2:
self._obs_buffer[0] = obs
if i == self.skip - 1:
self._obs_buffer[1] = obs
if terminated or truncated:
break
# 取最近两帧的最大值
max_frame = self._obs_buffer.max(axis=0)
return max_frame, total_reward, terminated, truncated, info
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def make_env(env_id="ALE/SpaceInvaders-v5", gray_scale=True, resize=True,
frame_stack=4, reward_clip=True, noop_reset=True, max_skip=4):
"""创建预处理后的Atari环境
Args:
env_id: 环境ID
gray_scale: 是否灰度化
resize: 是否调整大小
frame_stack: 堆叠帧数
reward_clip: 是否裁剪奖励
noop_reset: 是否使用noop reset
max_skip: 跳帧数
Returns:
env: 预处理后的环境
"""
env = gym.make(env_id, render_mode="rgb_array")
if noop_reset:
env = NoopResetWrapper(env, noop_max=30)
if max_skip > 1:
env = MaxAndSkipWrapper(env, skip=max_skip)
if resize:
env = ResizeWrapper(env, size=(84, 84))
if gray_scale:
env = GrayScaleWrapper(env)
if reward_clip:
env = RewardScaleWrapper(env, scale=10.0)
if frame_stack > 1:
env = FrameStackWrapper(env, num_stack=frame_stack)
return env
def get_device():
"""检测并返回可用设备"""
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"使用GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("使用CPU")
return device
def preprocess_obs(obs):
"""确保观测格式正确"""
if len(obs.shape) == 2:
obs = np.expand_dims(obs, axis=0)
return obs