fix(ppo): 修正日志概率维度与状态张量格式

修复 replay buffer 中 log_probs 的维度错误,从 (buffer_size, action_dim) 改为 buffer_size
修正训练时状态张量格式,从 (N, H, W, C) 转换为 (N, C, H, W)
更新 collect_rollout 返回观测值并修正 log_prob 计算
添加项目配置文件和训练曲线生成脚本
This commit is contained in:
2026-04-30 20:30:40 +08:00
parent d353133b31
commit b32490ae03
19 changed files with 185 additions and 22 deletions
+14 -4
View File
@@ -30,7 +30,7 @@ def collect_rollout(actor, critic, env, buffer, device, rollout_steps):
value = critic(obs_t).squeeze(0).item()
action_np = action.squeeze(0).cpu().numpy()
log_prob_np = log_prob.squeeze(0).cpu().numpy()
log_prob_np = log_prob.squeeze(0).cpu().numpy().sum()
next_obs, reward, terminated, truncated, _ = env.step(action_np)
done = terminated or truncated
@@ -46,6 +46,8 @@ def collect_rollout(actor, critic, env, buffer, device, rollout_steps):
obs, _ = env.reset()
obs = np.transpose(obs, (1, 2, 0))
return obs
def train(
total_steps=500000,
@@ -102,10 +104,8 @@ def train(
recent_rewards = []
while total_timesteps < total_steps:
# Collect rollout
collect_rollout(actor, critic, env, buffer, device, rollout_steps)
obs = collect_rollout(actor, critic, env, buffer, device, rollout_steps)
# Get last value for GAE
with torch.no_grad():
obs_t = torch.from_numpy(obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device)
last_value = critic(obs_t).squeeze(0).item()
@@ -190,3 +190,13 @@ if __name__ == "__main__":
device = get_device()
train(total_steps=args.steps, rollout_steps=args.rollout, device=device)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=500000, help="Total training steps")
parser.add_argument("--rollout", type=int, default=2048, help="Rollout buffer size")
args = parser.parse_args()
device = get_device()
train(total_steps=args.steps, rollout_steps=args.rollout, device=device)