fix(ppo): 修正日志概率维度与状态张量格式
修复 replay buffer 中 log_probs 的维度错误,从 (buffer_size, action_dim) 改为 buffer_size 修正训练时状态张量格式,从 (N, H, W, C) 转换为 (N, C, H, W) 更新 collect_rollout 返回观测值并修正 log_prob 计算 添加项目配置文件和训练曲线生成脚本
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
[project]
|
||||
name = "ppo-carracing"
|
||||
version = "0.1.0"
|
||||
description = "PPO (Proximal Policy Optimization) for CarRacing-v3 environment"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.0.0",
|
||||
"gymnasium[box2d]>=0.29.0",
|
||||
"numpy>=1.24.0",
|
||||
"matplotlib>=3.7.0",
|
||||
"tensorboard>=2.14.0",
|
||||
"opencv-python>=4.8.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest>=7.4.0", "black>=23.0.0", "ruff>=0.1.0"]
|
||||
|
||||
[project.scripts]
|
||||
ppo-train = "train:main"
|
||||
ppo-evaluate = "src.evaluate:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W"]
|
||||
ignore = ["E501"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
target-version = ["py310"]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src"]
|
||||
Reference in New Issue
Block a user