feat(训练): 添加并行环境DQN训练脚本和Jupyter笔记本
- 新增 train_parallel.py 脚本,使用 AsyncVectorEnv 并行运行多个Atari环境 - 添加配套的 Jupyter 笔记本 train_parallel.ipynb 用于交互式训练 - 在 utils.py 的 wrapper 中修复 observation_space 定义,确保与预处理后的观测形状一致 - 删除旧的压缩文件 CW2_DQN_SpaceInvaders.zip - 新增图片文件 image.png 并行训练器通过批量GPU推理和异步环境步进显著提升数据收集速度,适合在多核服务器环境下运行。包含完整的超参数配置、进度监控和模型保存功能。
This commit is contained in:
@@ -17,6 +17,10 @@ class GrayScaleWrapper(gym.ObservationWrapper):
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
old_shape = self.observation_space.shape
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=old_shape[:2], dtype=np.uint8
|
||||
)
|
||||
|
||||
def observation(self, obs):
|
||||
# RGB转灰度:加权平均
|
||||
@@ -30,6 +34,15 @@ class ResizeWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, size=(84, 84)):
|
||||
super().__init__(env)
|
||||
self.size = size
|
||||
obs_shape = self.observation_space.shape
|
||||
if len(obs_shape) == 3:
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=(*size, obs_shape[-1]), dtype=np.uint8
|
||||
)
|
||||
else:
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=size, dtype=np.uint8
|
||||
)
|
||||
|
||||
def observation(self, obs):
|
||||
import cv2
|
||||
|
||||
Reference in New Issue
Block a user