"""Environment wrappers and utility functions.""" import gymnasium as gym import numpy as np import torch from collections import deque # 注册ALE环境 try: import ale_py gym.register_envs(ale_py) except ImportError: pass class GrayScaleWrapper(gym.ObservationWrapper): """将RGB观测转换为灰度图""" def __init__(self, env): super().__init__(env) old_shape = self.observation_space.shape self.observation_space = gym.spaces.Box( low=0, high=255, shape=old_shape[:2], dtype=np.uint8 ) def observation(self, obs): # RGB转灰度:加权平均 gray = 0.299 * obs[:, :, 0] + 0.587 * obs[:, :, 1] + 0.114 * obs[:, :, 2] return gray.astype(np.uint8) class ResizeWrapper(gym.ObservationWrapper): """调整观测大小""" def __init__(self, env, size=(84, 84)): super().__init__(env) self.size = size obs_shape = self.observation_space.shape if len(obs_shape) == 3: self.observation_space = gym.spaces.Box( low=0, high=255, shape=(*size, obs_shape[-1]), dtype=np.uint8 ) else: self.observation_space = gym.spaces.Box( low=0, high=255, shape=size, dtype=np.uint8 ) def observation(self, obs): import cv2 # 如果是灰度图,需要扩展维度 if len(obs.shape) == 2: obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA) else: obs = cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA) return obs class FrameStackWrapper(gym.ObservationWrapper): """堆叠N帧观测""" def __init__(self, env, num_stack=4): super().__init__(env) self.num_stack = num_stack self.frames = deque(maxlen=num_stack) obs_shape = env.observation_space.shape # 更新观测空间 if len(obs_shape) == 2: # 灰度图 self.observation_space = gym.spaces.Box( low=0, high=255, shape=(num_stack, *obs_shape), dtype=np.uint8 ) else: # RGB图 self.observation_space = gym.spaces.Box( low=0, high=255, shape=(num_stack, *obs_shape[-2:]), dtype=np.uint8 ) def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) for _ in range(self.num_stack): self.frames.append(obs) return self._get_observation(), info def observation(self, obs): self.frames.append(obs) return self._get_observation() def _get_observation(self): return np.stack(list(self.frames), axis=0) class RewardScaleWrapper(gym.RewardWrapper): """缩放奖励以稳定训练,同时保留奖励大小信号""" def __init__(self, env, scale=10.0): super().__init__(env) self.scale = scale def reward(self, reward): return reward / self.scale class NoopResetWrapper(gym.Wrapper): """在reset时随机执行noop动作,增加初始状态随机性""" def __init__(self, env, noop_max=30): super().__init__(env) self.noop_max = noop_max self.noop_action = 0 def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) # 随机执行noop动作 noop_times = np.random.randint(1, self.noop_max + 1) for _ in range(noop_times): obs, reward, terminated, truncated, info = self.env.step(self.noop_action) if terminated or truncated: obs, info = self.env.reset(**kwargs) return obs, info class MaxAndSkipWrapper(gym.Wrapper): """跳帧并取最大值,减少计算量""" def __init__(self, env, skip=4): super().__init__(env) self.skip = skip self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8) def step(self, action): total_reward = 0.0 terminated = False truncated = False for i in range(self.skip): obs, reward, terminated, truncated, info = self.env.step(action) total_reward += reward if i == self.skip - 2: self._obs_buffer[0] = obs if i == self.skip - 1: self._obs_buffer[1] = obs if terminated or truncated: break # 取最近两帧的最大值 max_frame = self._obs_buffer.max(axis=0) return max_frame, total_reward, terminated, truncated, info def reset(self, **kwargs): return self.env.reset(**kwargs) def make_env(env_id="ALE/SpaceInvaders-v5", gray_scale=True, resize=True, frame_stack=4, reward_clip=True, noop_reset=True, max_skip=4): """创建预处理后的Atari环境 Args: env_id: 环境ID gray_scale: 是否灰度化 resize: 是否调整大小 frame_stack: 堆叠帧数 reward_clip: 是否裁剪奖励 noop_reset: 是否使用noop reset max_skip: 跳帧数 Returns: env: 预处理后的环境 """ env = gym.make(env_id, render_mode="rgb_array") if noop_reset: env = NoopResetWrapper(env, noop_max=30) if max_skip > 1: env = MaxAndSkipWrapper(env, skip=max_skip) if resize: env = ResizeWrapper(env, size=(84, 84)) if gray_scale: env = GrayScaleWrapper(env) if reward_clip: env = RewardScaleWrapper(env, scale=10.0) if frame_stack > 1: env = FrameStackWrapper(env, num_stack=frame_stack) return env def get_device(): """检测并返回可用设备""" if torch.cuda.is_available(): device = torch.device("cuda") print(f"使用GPU: {torch.cuda.get_device_name(0)}") else: device = torch.device("cpu") print("使用CPU") return device def preprocess_obs(obs): """确保观测格式正确""" if len(obs.shape) == 2: obs = np.expand_dims(obs, axis=0) return obs