ed0822966b
- 新增 train_parallel.py 脚本,使用 AsyncVectorEnv 并行运行多个Atari环境 - 添加配套的 Jupyter 笔记本 train_parallel.ipynb 用于交互式训练 - 在 utils.py 的 wrapper 中修复 observation_space 定义,确保与预处理后的观测形状一致 - 删除旧的压缩文件 CW2_DQN_SpaceInvaders.zip - 新增图片文件 image.png 并行训练器通过批量GPU推理和异步环境步进显著提升数据收集速度,适合在多核服务器环境下运行。包含完整的超参数配置、进度监控和模型保存功能。
215 lines
6.0 KiB
Python
215 lines
6.0 KiB
Python
"""Environment wrappers and utility functions."""
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import torch
|
|
from collections import deque
|
|
|
|
# 注册ALE环境
|
|
try:
|
|
import ale_py
|
|
gym.register_envs(ale_py)
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
class GrayScaleWrapper(gym.ObservationWrapper):
|
|
"""将RGB观测转换为灰度图"""
|
|
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
old_shape = self.observation_space.shape
|
|
self.observation_space = gym.spaces.Box(
|
|
low=0, high=255, shape=old_shape[:2], dtype=np.uint8
|
|
)
|
|
|
|
def observation(self, obs):
|
|
# RGB转灰度:加权平均
|
|
gray = 0.299 * obs[:, :, 0] + 0.587 * obs[:, :, 1] + 0.114 * obs[:, :, 2]
|
|
return gray.astype(np.uint8)
|
|
|
|
|
|
class ResizeWrapper(gym.ObservationWrapper):
|
|
"""调整观测大小"""
|
|
|
|
def __init__(self, env, size=(84, 84)):
|
|
super().__init__(env)
|
|
self.size = size
|
|
obs_shape = self.observation_space.shape
|
|
if len(obs_shape) == 3:
|
|
self.observation_space = gym.spaces.Box(
|
|
low=0, high=255, shape=(*size, obs_shape[-1]), dtype=np.uint8
|
|
)
|
|
else:
|
|
self.observation_space = gym.spaces.Box(
|
|
low=0, high=255, shape=size, dtype=np.uint8
|
|
)
|
|
|
|
def observation(self, obs):
|
|
import cv2
|
|
# 如果是灰度图,需要扩展维度
|
|
if len(obs.shape) == 2:
|
|
obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
|
|
else:
|
|
obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
|
|
return obs
|
|
|
|
|
|
class FrameStackWrapper(gym.ObservationWrapper):
|
|
"""堆叠N帧观测"""
|
|
|
|
def __init__(self, env, num_stack=4):
|
|
super().__init__(env)
|
|
self.num_stack = num_stack
|
|
self.frames = deque(maxlen=num_stack)
|
|
obs_shape = env.observation_space.shape
|
|
|
|
# 更新观测空间
|
|
if len(obs_shape) == 2:
|
|
# 灰度图
|
|
self.observation_space = gym.spaces.Box(
|
|
low=0, high=255,
|
|
shape=(num_stack, *obs_shape),
|
|
dtype=np.uint8
|
|
)
|
|
else:
|
|
# RGB图
|
|
self.observation_space = gym.spaces.Box(
|
|
low=0, high=255,
|
|
shape=(num_stack, *obs_shape[-2:]),
|
|
dtype=np.uint8
|
|
)
|
|
|
|
def reset(self, **kwargs):
|
|
obs, info = self.env.reset(**kwargs)
|
|
for _ in range(self.num_stack):
|
|
self.frames.append(obs)
|
|
return self._get_observation(), info
|
|
|
|
def observation(self, obs):
|
|
self.frames.append(obs)
|
|
return self._get_observation()
|
|
|
|
def _get_observation(self):
|
|
return np.stack(list(self.frames), axis=0)
|
|
|
|
|
|
class RewardScaleWrapper(gym.RewardWrapper):
|
|
"""缩放奖励以稳定训练,同时保留奖励大小信号"""
|
|
|
|
def __init__(self, env, scale=10.0):
|
|
super().__init__(env)
|
|
self.scale = scale
|
|
|
|
def reward(self, reward):
|
|
return reward / self.scale
|
|
|
|
|
|
class NoopResetWrapper(gym.Wrapper):
|
|
"""在reset时随机执行noop动作,增加初始状态随机性"""
|
|
|
|
def __init__(self, env, noop_max=30):
|
|
super().__init__(env)
|
|
self.noop_max = noop_max
|
|
self.noop_action = 0
|
|
|
|
def reset(self, **kwargs):
|
|
obs, info = self.env.reset(**kwargs)
|
|
# 随机执行noop动作
|
|
noop_times = np.random.randint(1, self.noop_max + 1)
|
|
for _ in range(noop_times):
|
|
obs, reward, terminated, truncated, info = self.env.step(self.noop_action)
|
|
if terminated or truncated:
|
|
obs, info = self.env.reset(**kwargs)
|
|
return obs, info
|
|
|
|
|
|
class MaxAndSkipWrapper(gym.Wrapper):
|
|
"""跳帧并取最大值,减少计算量"""
|
|
|
|
def __init__(self, env, skip=4):
|
|
super().__init__(env)
|
|
self.skip = skip
|
|
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
|
|
|
|
def step(self, action):
|
|
total_reward = 0.0
|
|
terminated = False
|
|
truncated = False
|
|
|
|
for i in range(self.skip):
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
total_reward += reward
|
|
|
|
if i == self.skip - 2:
|
|
self._obs_buffer[0] = obs
|
|
if i == self.skip - 1:
|
|
self._obs_buffer[1] = obs
|
|
|
|
if terminated or truncated:
|
|
break
|
|
|
|
# 取最近两帧的最大值
|
|
max_frame = self._obs_buffer.max(axis=0)
|
|
|
|
return max_frame, total_reward, terminated, truncated, info
|
|
|
|
def reset(self, **kwargs):
|
|
return self.env.reset(**kwargs)
|
|
|
|
|
|
def make_env(env_id="ALE/SpaceInvaders-v5", gray_scale=True, resize=True,
|
|
frame_stack=4, reward_clip=True, noop_reset=True, max_skip=4):
|
|
"""创建预处理后的Atari环境
|
|
|
|
Args:
|
|
env_id: 环境ID
|
|
gray_scale: 是否灰度化
|
|
resize: 是否调整大小
|
|
frame_stack: 堆叠帧数
|
|
reward_clip: 是否裁剪奖励
|
|
noop_reset: 是否使用noop reset
|
|
max_skip: 跳帧数
|
|
|
|
Returns:
|
|
env: 预处理后的环境
|
|
"""
|
|
env = gym.make(env_id, render_mode="rgb_array")
|
|
|
|
if noop_reset:
|
|
env = NoopResetWrapper(env, noop_max=30)
|
|
|
|
if max_skip > 1:
|
|
env = MaxAndSkipWrapper(env, skip=max_skip)
|
|
|
|
if resize:
|
|
env = ResizeWrapper(env, size=(84, 84))
|
|
|
|
if gray_scale:
|
|
env = GrayScaleWrapper(env)
|
|
|
|
if reward_clip:
|
|
env = RewardScaleWrapper(env, scale=10.0)
|
|
|
|
if frame_stack > 1:
|
|
env = FrameStackWrapper(env, num_stack=frame_stack)
|
|
|
|
return env
|
|
|
|
|
|
def get_device():
|
|
"""检测并返回可用设备"""
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
print(f"使用GPU: {torch.cuda.get_device_name(0)}")
|
|
else:
|
|
device = torch.device("cpu")
|
|
print("使用CPU")
|
|
return device
|
|
|
|
|
|
def preprocess_obs(obs):
|
|
"""确保观测格式正确"""
|
|
if len(obs.shape) == 2:
|
|
obs = np.expand_dims(obs, axis=0)
|
|
return obs
|