b32490ae03
修复 replay buffer 中 log_probs 的维度错误,从 (buffer_size, action_dim) 改为 buffer_size 修正训练时状态张量格式,从 (N, H, W, C) 转换为 (N, C, H, W) 更新 collect_rollout 返回观测值并修正 log_prob 计算 添加项目配置文件和训练曲线生成脚本
40 lines
812 B
TOML
40 lines
812 B
TOML
[project]
|
|
name = "ppo-carracing"
|
|
version = "0.1.0"
|
|
description = "PPO (Proximal Policy Optimization) for CarRacing-v3 environment"
|
|
requires-python = ">=3.10"
|
|
dependencies = [
|
|
"torch>=2.0.0",
|
|
"gymnasium[box2d]>=0.29.0",
|
|
"numpy>=1.24.0",
|
|
"matplotlib>=3.7.0",
|
|
"tensorboard>=2.14.0",
|
|
"opencv-python>=4.8.0",
|
|
]
|
|
|
|
[project.optional-dependencies]
|
|
dev = ["pytest>=7.4.0", "black>=23.0.0", "ruff>=0.1.0"]
|
|
|
|
[project.scripts]
|
|
ppo-train = "train:main"
|
|
ppo-evaluate = "src.evaluate:main"
|
|
|
|
[build-system]
|
|
requires = ["hatchling"]
|
|
build-backend = "hatchling.build"
|
|
|
|
[tool.ruff]
|
|
line-length = 100
|
|
target-version = "py310"
|
|
|
|
[tool.ruff.lint]
|
|
select = ["E", "F", "I", "N", "W"]
|
|
ignore = ["E501"]
|
|
|
|
[tool.black]
|
|
line-length = 100
|
|
target-version = ["py310"]
|
|
|
|
[tool.hatch.build.targets.wheel]
|
|
packages = ["src"]
|