fix(ppo): 修正日志概率维度与状态张量格式
修复 replay buffer 中 log_probs 的维度错误,从 (buffer_size, action_dim) 改为 buffer_size 修正训练时状态张量格式,从 (N, H, W, C) 转换为 (N, C, H, W) 更新 collect_rollout 返回观测值并修正 log_prob 计算 添加项目配置文件和训练曲线生成脚本
This commit is contained in:
@@ -1,4 +1,8 @@
|
||||
"""Evaluation script for trained PPO agent."""
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import gymnasium as gym
|
||||
|
||||
+16
-10
@@ -1,4 +1,5 @@
|
||||
"""Neural network architectures for Actor and Critic."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -9,7 +10,11 @@ class Actor(nn.Module):
|
||||
|
||||
def __init__(self, state_shape=(84, 84, 4), action_dim=3):
|
||||
super().__init__()
|
||||
c, h, w = state_shape[2], state_shape[0], state_shape[1] # channels, height, width
|
||||
c, h, w = (
|
||||
state_shape[2],
|
||||
state_shape[0],
|
||||
state_shape[1],
|
||||
) # channels, height, width
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4),
|
||||
@@ -20,8 +25,10 @@ class Actor(nn.Module):
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Calculate feature map size: 84x84 -> 20x20 after conv layers
|
||||
feat_size = 64 * 20 * 20
|
||||
out_h = (h - 8) // 4 + 1
|
||||
out_h = (out_h - 4) // 2 + 1
|
||||
out_h = (out_h - 3) // 1 + 1
|
||||
feat_size = 64 * out_h * out_h
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_size, 512),
|
||||
@@ -62,17 +69,16 @@ class Critic(nn.Module):
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
feat_size = 64 * 20 * 20
|
||||
out_h = (h - 8) // 4 + 1
|
||||
out_h = (out_h - 4) // 2 + 1
|
||||
out_h = (out_h - 3) // 1 + 1
|
||||
feat_size = 64 * out_h * out_h
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_size, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 1)
|
||||
)
|
||||
self.fc = nn.Sequential(nn.Linear(feat_size, 512), nn.ReLU(), nn.Linear(512, 1))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass returning V(s)."""
|
||||
x = x / 255.0
|
||||
x = self.conv(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.fc(x)
|
||||
return self.fc(x)
|
||||
|
||||
@@ -15,7 +15,7 @@ class RolloutBuffer:
|
||||
self.rewards = np.zeros(buffer_size, dtype=np.float32)
|
||||
self.dones = np.zeros(buffer_size, dtype=np.bool_)
|
||||
self.values = np.zeros(buffer_size, dtype=np.float32)
|
||||
self.log_probs = np.zeros((buffer_size, action_dim), dtype=np.float32)
|
||||
self.log_probs = np.zeros(buffer_size, dtype=np.float32)
|
||||
|
||||
def add(self, state, action, reward, done, value, log_prob):
|
||||
"""Add a transition to the buffer."""
|
||||
|
||||
@@ -56,8 +56,8 @@ class PPOTrainer:
|
||||
# Normalize advantages
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# Convert to tensors
|
||||
states_t = torch.from_numpy(states).float().to(self.device)
|
||||
# Convert to tensors (states: N, H, W, C -> N, C, H, W)
|
||||
states_t = torch.from_numpy(states).float().permute(0, 3, 1, 2).to(self.device)
|
||||
actions_t = torch.from_numpy(actions).float().to(self.device)
|
||||
log_probs_old_t = torch.from_numpy(log_probs_old).float().to(self.device)
|
||||
returns_t = torch.from_numpy(returns).float().to(self.device)
|
||||
@@ -75,16 +75,13 @@ class PPOTrainer:
|
||||
for batch in loader:
|
||||
s, a, log_pi_old, ret, adv = batch
|
||||
|
||||
# Get current policy distribution
|
||||
mu, std = self.actor(s)
|
||||
dist = torch.distributions.Normal(mu, std)
|
||||
log_pi = dist.log_prob(a).sum(dim=-1, keepdim=True)
|
||||
entropy = dist.entropy().sum(dim=-1, keepdim=True)
|
||||
log_pi = dist.log_prob(a).sum(dim=-1)
|
||||
entropy = dist.entropy().sum(dim=-1)
|
||||
|
||||
# Probability ratio
|
||||
ratio = torch.exp(log_pi - log_pi_old)
|
||||
|
||||
# Clipped surrogate objective
|
||||
surr1 = ratio * adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * adv
|
||||
actor_loss = -torch.min(surr1, surr2).mean()
|
||||
|
||||
Reference in New Issue
Block a user