fix(ppo): 修正日志概率维度与状态张量格式

修复 replay buffer 中 log_probs 的维度错误,从 (buffer_size, action_dim) 改为 buffer_size
修正训练时状态张量格式,从 (N, H, W, C) 转换为 (N, C, H, W)
更新 collect_rollout 返回观测值并修正 log_prob 计算
添加项目配置文件和训练曲线生成脚本
This commit is contained in:
2026-04-30 20:30:40 +08:00
parent d353133b31
commit b32490ae03
19 changed files with 185 additions and 22 deletions
@@ -0,0 +1,39 @@
[project]
name = "ppo-carracing"
version = "0.1.0"
description = "PPO (Proximal Policy Optimization) for CarRacing-v3 environment"
requires-python = ">=3.10"
dependencies = [
"torch>=2.0.0",
"gymnasium[box2d]>=0.29.0",
"numpy>=1.24.0",
"matplotlib>=3.7.0",
"tensorboard>=2.14.0",
"opencv-python>=4.8.0",
]
[project.optional-dependencies]
dev = ["pytest>=7.4.0", "black>=23.0.0", "ruff>=0.1.0"]
[project.scripts]
ppo-train = "train:main"
ppo-evaluate = "src.evaluate:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.ruff]
line-length = 100
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W"]
ignore = ["E501"]
[tool.black]
line-length = 100
target-version = ["py310"]
[tool.hatch.build.targets.wheel]
packages = ["src"]