Files
Serendipity fb09e66d09 feat: 重构项目结构并添加向量化PPO训练与评估脚本
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率
- 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计
- 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py)
- 提供详细的文档和开发日志,包含问题解决记录和实验分析
- 移除旧版项目文件,统一项目结构到CW1_id_name目录下
2026-05-02 13:44:08 +08:00

239 lines
7.4 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "6df75d72-c1c1-40e0-a7f8-ef3da32e4592",
"metadata": {},
"source": [
"# 02 — Sanity-checking the Actor-Critic network\n",
"\n",
"### Verify that:\n",
"### - the network accepts uint8 (4, 84, 84) input\n",
"### - it runs on GPU\n",
"### - forward pass returns the expected shapes\n",
"### - get_action_and_value works for both sampling and scoring"
]
},
{
"cell_type": "markdown",
"id": "9c6c1d35-f17c-4fca-9cfb-b5b001b7a0c8",
"metadata": {},
"source": [
"## Cell 1 test env"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "9e09b2e5-c076-4599-8e98-1cb09c0a7cf5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device: cuda\n",
"ActorCritic(\n",
" (cnn): Sequential(\n",
" (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))\n",
" (1): ReLU()\n",
" (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
" (3): ReLU()\n",
" (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
" (5): ReLU()\n",
" (6): Flatten(start_dim=1, end_dim=-1)\n",
" (7): Linear(in_features=3136, out_features=512, bias=True)\n",
" (8): ReLU()\n",
" )\n",
" (actor): Linear(in_features=512, out_features=5, bias=True)\n",
" (critic): Linear(in_features=512, out_features=1, bias=True)\n",
")\n",
"\n",
"Total parameters: 1,687,206\n"
]
}
],
"source": [
"import sys\n",
"from pathlib import Path\n",
"project_root = Path.cwd().parent\n",
"if str(project_root) not in sys.path:\n",
" sys.path.insert(0, str(project_root))\n",
"\n",
"import torch\n",
"from src.networks import ActorCritic\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(\"Device:\", device)\n",
"\n",
"net = ActorCritic(n_actions=5).to(device)\n",
"print(net)\n",
"\n",
"# Count parameters\n",
"total_params = sum(p.numel() for p in net.parameters())\n",
"print(f\"\\nTotal parameters: {total_params:,}\")"
]
},
{
"cell_type": "markdown",
"id": "8334668c-8f1c-4460-9e1b-9cc2c8c938a1",
"metadata": {},
"source": [
"## Cell 2 test forward"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "affedbd5-d08b-441b-8cae-b46057be5c63",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape : torch.Size([8, 4, 84, 84]) torch.uint8\n",
"Logits shape: torch.Size([8, 5]) torch.float32\n",
"Value shape : torch.Size([8]) torch.float32\n",
"Sample logits: [-0.0013231671182438731, 0.0014129895716905594, 0.0010137694189324975, 0.0005002821562811732, -0.0012777929659932852]\n",
"Sample value : 0.9111840724945068\n"
]
}
],
"source": [
"# Fake batch of 8 observations, shape (8, 4, 84, 84) uint8\n",
"fake_obs = torch.randint(0, 255, (8, 4, 84, 84), dtype=torch.uint8, device=device)\n",
"\n",
"logits, value = net(fake_obs)\n",
"print(\"Input shape :\", fake_obs.shape, fake_obs.dtype)\n",
"print(\"Logits shape:\", logits.shape, logits.dtype)\n",
"print(\"Value shape :\", value.shape, value.dtype)\n",
"print(\"Sample logits:\", logits[0].detach().cpu().tolist())\n",
"print(\"Sample value :\", value[0].item())"
]
},
{
"cell_type": "markdown",
"id": "6eb6ee0d-adc7-4ac5-953b-91018599dd7f",
"metadata": {},
"source": [
"## Cell 3 test get_action_and_value"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "e5ad46a8-7f62-442a-96a2-2d2c3ef91d59",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mode 1 (sample):\n",
" action : torch.Size([8]) torch.int64, sample = [4, 2, 3]\n",
" log_prob: torch.Size([8]), sample = [-1.6107815504074097, -1.60985267162323, -1.6094183921813965]\n",
" entropy : torch.Size([8]), sample = [1.6094372272491455, 1.609436273574829, 1.6094352006912231]\n",
" value : torch.Size([8]), sample = [0.9111840724945068, 0.8728611469268799, 0.9081785678863525]\n",
"\n",
"Mode 2 (score given action):\n",
" log_prob shape: torch.Size([8])\n",
" entropy shape : torch.Size([8])\n",
"\n",
"Reference: ln(5) = 1.6094\n",
"Mean entropy at init: 1.6094\n"
]
}
],
"source": [
"# Mode 1: sample action\n",
"action, log_prob, entropy, value = net.get_action_and_value(fake_obs)\n",
"print(\"Mode 1 (sample):\")\n",
"print(f\" action : {action.shape} {action.dtype}, sample = {action[:3].tolist()}\")\n",
"print(f\" log_prob: {log_prob.shape}, sample = {log_prob[:3].detach().cpu().tolist()}\")\n",
"print(f\" entropy : {entropy.shape}, sample = {entropy[:3].detach().cpu().tolist()}\")\n",
"print(f\" value : {value.shape}, sample = {value[:3].detach().cpu().tolist()}\")\n",
"\n",
"# Mode 2: score given action (this is what PPO update uses)\n",
"provided = torch.tensor([0, 3, 2, 1, 4, 0, 3, 2], device=device)\n",
"_, log_prob2, entropy2, value2 = net.get_action_and_value(fake_obs, provided)\n",
"print(\"\\nMode 2 (score given action):\")\n",
"print(f\" log_prob shape: {log_prob2.shape}\")\n",
"print(f\" entropy shape : {entropy2.shape}\")\n",
"\n",
"# Sanity: entropy of a uniform 5-action distribution should be ln(5) ≈ 1.6094\n",
"import math\n",
"print(f\"\\nReference: ln(5) = {math.log(5):.4f}\")\n",
"print(f\"Mean entropy at init: {entropy.mean().item():.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "d4c00e27-2a64-46f9-99d0-aff04fe2e714",
"metadata": {},
"source": [
"## Cell 4 : Run it once with the obs of the real env"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "c362b84b-1b74-4f55-b7e9-d842414c4a9f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"obs_t shape: torch.Size([1, 4, 84, 84]) torch.uint8\n",
"Sampled action: 0\n",
"Log prob: -1.6102\n",
"Entropy: 1.6094\n",
"Value estimate: 0.3286\n"
]
}
],
"source": [
"from src.env_wrappers import make_env\n",
"import numpy as np\n",
"\n",
"env = make_env(seed=42)\n",
"obs, _ = env.reset(seed=42)\n",
"\n",
"# obs is a numpy uint8 array (4, 84, 84). Add a batch dim and move to device.\n",
"obs_t = torch.as_tensor(obs).unsqueeze(0).to(device)\n",
"print(\"obs_t shape:\", obs_t.shape, obs_t.dtype)\n",
"\n",
"action, log_prob, entropy, value = net.get_action_and_value(obs_t)\n",
"print(f\"Sampled action: {action.item()}\")\n",
"print(f\"Log prob: {log_prob.item():.4f}\")\n",
"print(f\"Entropy: {entropy.item():.4f}\")\n",
"print(f\"Value estimate: {value.item():.4f}\")\n",
"\n",
"env.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "PYTORCH",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 5
}