fb09e66d09
- 将原始单环境训练代码重构为模块化结构,添加向量化环境支持以提高数据采集效率 - 实现完整的PPO训练流水线,包括共享CNN的Actor-Critic网络、向量化经验回放缓冲和GAE优势估计 - 添加训练脚本(train_vec.py)、评估脚本(evaluate.py)和SB3基线对比脚本(train_sb3_baseline.py) - 提供详细的文档和开发日志,包含问题解决记录和实验分析 - 移除旧版项目文件,统一项目结构到CW1_id_name目录下
239 lines
7.4 KiB
Plaintext
239 lines
7.4 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "6df75d72-c1c1-40e0-a7f8-ef3da32e4592",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 02 — Sanity-checking the Actor-Critic network\n",
|
||
"\n",
|
||
"### Verify that:\n",
|
||
"### - the network accepts uint8 (4, 84, 84) input\n",
|
||
"### - it runs on GPU\n",
|
||
"### - forward pass returns the expected shapes\n",
|
||
"### - get_action_and_value works for both sampling and scoring"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9c6c1d35-f17c-4fca-9cfb-b5b001b7a0c8",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Cell 1 : test env"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "9e09b2e5-c076-4599-8e98-1cb09c0a7cf5",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Device: cuda\n",
|
||
"ActorCritic(\n",
|
||
" (cnn): Sequential(\n",
|
||
" (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))\n",
|
||
" (1): ReLU()\n",
|
||
" (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n",
|
||
" (3): ReLU()\n",
|
||
" (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n",
|
||
" (5): ReLU()\n",
|
||
" (6): Flatten(start_dim=1, end_dim=-1)\n",
|
||
" (7): Linear(in_features=3136, out_features=512, bias=True)\n",
|
||
" (8): ReLU()\n",
|
||
" )\n",
|
||
" (actor): Linear(in_features=512, out_features=5, bias=True)\n",
|
||
" (critic): Linear(in_features=512, out_features=1, bias=True)\n",
|
||
")\n",
|
||
"\n",
|
||
"Total parameters: 1,687,206\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import sys\n",
|
||
"from pathlib import Path\n",
|
||
"project_root = Path.cwd().parent\n",
|
||
"if str(project_root) not in sys.path:\n",
|
||
" sys.path.insert(0, str(project_root))\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"from src.networks import ActorCritic\n",
|
||
"\n",
|
||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||
"print(\"Device:\", device)\n",
|
||
"\n",
|
||
"net = ActorCritic(n_actions=5).to(device)\n",
|
||
"print(net)\n",
|
||
"\n",
|
||
"# Count parameters\n",
|
||
"total_params = sum(p.numel() for p in net.parameters())\n",
|
||
"print(f\"\\nTotal parameters: {total_params:,}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8334668c-8f1c-4460-9e1b-9cc2c8c938a1",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Cell 2 :test forward"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"id": "affedbd5-d08b-441b-8cae-b46057be5c63",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Input shape : torch.Size([8, 4, 84, 84]) torch.uint8\n",
|
||
"Logits shape: torch.Size([8, 5]) torch.float32\n",
|
||
"Value shape : torch.Size([8]) torch.float32\n",
|
||
"Sample logits: [-0.0013231671182438731, 0.0014129895716905594, 0.0010137694189324975, 0.0005002821562811732, -0.0012777929659932852]\n",
|
||
"Sample value : 0.9111840724945068\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Fake batch of 8 observations, shape (8, 4, 84, 84) uint8\n",
|
||
"fake_obs = torch.randint(0, 255, (8, 4, 84, 84), dtype=torch.uint8, device=device)\n",
|
||
"\n",
|
||
"logits, value = net(fake_obs)\n",
|
||
"print(\"Input shape :\", fake_obs.shape, fake_obs.dtype)\n",
|
||
"print(\"Logits shape:\", logits.shape, logits.dtype)\n",
|
||
"print(\"Value shape :\", value.shape, value.dtype)\n",
|
||
"print(\"Sample logits:\", logits[0].detach().cpu().tolist())\n",
|
||
"print(\"Sample value :\", value[0].item())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "6eb6ee0d-adc7-4ac5-953b-91018599dd7f",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Cell 3 :test get_action_and_value"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"id": "e5ad46a8-7f62-442a-96a2-2d2c3ef91d59",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Mode 1 (sample):\n",
|
||
" action : torch.Size([8]) torch.int64, sample = [4, 2, 3]\n",
|
||
" log_prob: torch.Size([8]), sample = [-1.6107815504074097, -1.60985267162323, -1.6094183921813965]\n",
|
||
" entropy : torch.Size([8]), sample = [1.6094372272491455, 1.609436273574829, 1.6094352006912231]\n",
|
||
" value : torch.Size([8]), sample = [0.9111840724945068, 0.8728611469268799, 0.9081785678863525]\n",
|
||
"\n",
|
||
"Mode 2 (score given action):\n",
|
||
" log_prob shape: torch.Size([8])\n",
|
||
" entropy shape : torch.Size([8])\n",
|
||
"\n",
|
||
"Reference: ln(5) = 1.6094\n",
|
||
"Mean entropy at init: 1.6094\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Mode 1: sample action\n",
|
||
"action, log_prob, entropy, value = net.get_action_and_value(fake_obs)\n",
|
||
"print(\"Mode 1 (sample):\")\n",
|
||
"print(f\" action : {action.shape} {action.dtype}, sample = {action[:3].tolist()}\")\n",
|
||
"print(f\" log_prob: {log_prob.shape}, sample = {log_prob[:3].detach().cpu().tolist()}\")\n",
|
||
"print(f\" entropy : {entropy.shape}, sample = {entropy[:3].detach().cpu().tolist()}\")\n",
|
||
"print(f\" value : {value.shape}, sample = {value[:3].detach().cpu().tolist()}\")\n",
|
||
"\n",
|
||
"# Mode 2: score given action (this is what PPO update uses)\n",
|
||
"provided = torch.tensor([0, 3, 2, 1, 4, 0, 3, 2], device=device)\n",
|
||
"_, log_prob2, entropy2, value2 = net.get_action_and_value(fake_obs, provided)\n",
|
||
"print(\"\\nMode 2 (score given action):\")\n",
|
||
"print(f\" log_prob shape: {log_prob2.shape}\")\n",
|
||
"print(f\" entropy shape : {entropy2.shape}\")\n",
|
||
"\n",
|
||
"# Sanity: entropy of a uniform 5-action distribution should be ln(5) ≈ 1.6094\n",
|
||
"import math\n",
|
||
"print(f\"\\nReference: ln(5) = {math.log(5):.4f}\")\n",
|
||
"print(f\"Mean entropy at init: {entropy.mean().item():.4f}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d4c00e27-2a64-46f9-99d0-aff04fe2e714",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Cell 4 : Run it once with the obs of the real env"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"id": "c362b84b-1b74-4f55-b7e9-d842414c4a9f",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"obs_t shape: torch.Size([1, 4, 84, 84]) torch.uint8\n",
|
||
"Sampled action: 0\n",
|
||
"Log prob: -1.6102\n",
|
||
"Entropy: 1.6094\n",
|
||
"Value estimate: 0.3286\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from src.env_wrappers import make_env\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"env = make_env(seed=42)\n",
|
||
"obs, _ = env.reset(seed=42)\n",
|
||
"\n",
|
||
"# obs is a numpy uint8 array (4, 84, 84). Add a batch dim and move to device.\n",
|
||
"obs_t = torch.as_tensor(obs).unsqueeze(0).to(device)\n",
|
||
"print(\"obs_t shape:\", obs_t.shape, obs_t.dtype)\n",
|
||
"\n",
|
||
"action, log_prob, entropy, value = net.get_action_and_value(obs_t)\n",
|
||
"print(f\"Sampled action: {action.item()}\")\n",
|
||
"print(f\"Log prob: {log_prob.item():.4f}\")\n",
|
||
"print(f\"Entropy: {entropy.item():.4f}\")\n",
|
||
"print(f\"Value estimate: {value.item():.4f}\")\n",
|
||
"\n",
|
||
"env.close()"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "PYTORCH",
|
||
"language": "python",
|
||
"name": "pytorch"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.21"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|