fb09e66d09
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率 - 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计 - 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py) - 提供详细的文档和开发日志,包含问题解决记录和实验分析 - 移除旧版项目文件,统一项目结构到CW1_id_name目录下
59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
"""Shared-CNN Actor-Critic network for discrete CarRacing-v3 PPO.
|
|
|
|
Input : uint8 tensor (B, 4, 84, 84), values in [0, 255]
|
|
Output :
|
|
- logits (B, n_actions) for a Categorical policy
|
|
- value (B,) scalar state-value V(s)
|
|
"""
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributions import Categorical
|
|
|
|
|
|
def layer_init(layer, std=math.sqrt(2), bias=0.0):
|
|
"""Orthogonal init with configurable gain (PPO best practice)."""
|
|
nn.init.orthogonal_(layer.weight, std)
|
|
nn.init.constant_(layer.bias, bias)
|
|
return layer
|
|
|
|
|
|
class ActorCritic(nn.Module):
|
|
"""Shared-CNN actor-critic for discrete visual control."""
|
|
|
|
def __init__(self, n_actions=5):
|
|
super().__init__()
|
|
self.cnn = nn.Sequential(
|
|
layer_init(nn.Conv2d(4, 32, kernel_size=8, stride=4)),
|
|
nn.ReLU(),
|
|
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
|
|
nn.ReLU(),
|
|
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
|
|
nn.ReLU(),
|
|
nn.Flatten(),
|
|
layer_init(nn.Linear(64 * 7 * 7, 512)),
|
|
nn.ReLU(),
|
|
)
|
|
# Small std on the actor head -> initial policy is nearly uniform
|
|
self.actor = layer_init(nn.Linear(512, n_actions), std=0.01)
|
|
# Standard std on the critic head
|
|
self.critic = layer_init(nn.Linear(512, 1), std=1.0)
|
|
|
|
def forward(self, x):
|
|
# uint8 [0, 255] -> float32 [0, 1]
|
|
x = x.float() / 255.0
|
|
feat = self.cnn(x)
|
|
logits = self.actor(feat)
|
|
value = self.critic(feat).squeeze(-1)
|
|
return logits, value
|
|
|
|
def get_action_and_value(self, x, action=None):
|
|
logits, value = self(x)
|
|
dist = Categorical(logits=logits)
|
|
if action is None:
|
|
action = dist.sample()
|
|
log_prob = dist.log_prob(action)
|
|
entropy = dist.entropy()
|
|
return action, log_prob, entropy, value |