fix(ppo): 修正日志概率维度与状态张量格式
修复 replay buffer 中 log_probs 的维度错误,从 (buffer_size, action_dim) 改为 buffer_size 修正训练时状态张量格式,从 (N, H, W, C) 转换为 (N, C, H, W) 更新 collect_rollout 返回观测值并修正 log_prob 计算 添加项目配置文件和训练曲线生成脚本
This commit is contained in:
+14
-4
@@ -30,7 +30,7 @@ def collect_rollout(actor, critic, env, buffer, device, rollout_steps):
|
||||
value = critic(obs_t).squeeze(0).item()
|
||||
|
||||
action_np = action.squeeze(0).cpu().numpy()
|
||||
log_prob_np = log_prob.squeeze(0).cpu().numpy()
|
||||
log_prob_np = log_prob.squeeze(0).cpu().numpy().sum()
|
||||
|
||||
next_obs, reward, terminated, truncated, _ = env.step(action_np)
|
||||
done = terminated or truncated
|
||||
@@ -46,6 +46,8 @@ def collect_rollout(actor, critic, env, buffer, device, rollout_steps):
|
||||
obs, _ = env.reset()
|
||||
obs = np.transpose(obs, (1, 2, 0))
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def train(
|
||||
total_steps=500000,
|
||||
@@ -102,10 +104,8 @@ def train(
|
||||
recent_rewards = []
|
||||
|
||||
while total_timesteps < total_steps:
|
||||
# Collect rollout
|
||||
collect_rollout(actor, critic, env, buffer, device, rollout_steps)
|
||||
obs = collect_rollout(actor, critic, env, buffer, device, rollout_steps)
|
||||
|
||||
# Get last value for GAE
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device)
|
||||
last_value = critic(obs_t).squeeze(0).item()
|
||||
@@ -190,3 +190,13 @@ if __name__ == "__main__":
|
||||
|
||||
device = get_device()
|
||||
train(total_steps=args.steps, rollout_steps=args.rollout, device=device)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--steps", type=int, default=500000, help="Total training steps")
|
||||
parser.add_argument("--rollout", type=int, default=2048, help="Rollout buffer size")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = get_device()
|
||||
train(total_steps=args.steps, rollout_steps=args.rollout, device=device)
|
||||
|
||||
Reference in New Issue
Block a user