feat(训练): 添加并行环境DQN训练脚本和Jupyter笔记本

- 新增 train_parallel.py 脚本,使用 AsyncVectorEnv 并行运行多个Atari环境
- 添加配套的 Jupyter 笔记本 train_parallel.ipynb 用于交互式训练
- 在 utils.py 的 wrapper 中修复 observation_space 定义,确保与预处理后的观测形状一致
- 删除旧的压缩文件 CW2_DQN_SpaceInvaders.zip
- 新增图片文件 image.png

并行训练器通过批量GPU推理和异步环境步进显著提升数据收集速度,适合在多核服务器环境下运行。包含完整的超参数配置、进度监控和模型保存功能。
This commit is contained in:
2026-05-03 16:29:14 +08:00
parent b474e7976e
commit ed0822966b
5 changed files with 962 additions and 0 deletions
@@ -17,6 +17,10 @@ class GrayScaleWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
old_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=old_shape[:2], dtype=np.uint8
)
def observation(self, obs):
# RGB转灰度:加权平均
@@ -30,6 +34,15 @@ class ResizeWrapper(gym.ObservationWrapper):
def __init__(self, env, size=(84, 84)):
super().__init__(env)
self.size = size
obs_shape = self.observation_space.shape
if len(obs_shape) == 3:
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(*size, obs_shape[-1]), dtype=np.uint8
)
else:
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=size, dtype=np.uint8
)
def observation(self, obs):
import cv2