{ "cells": [ { "cell_type": "markdown", "id": "6df75d72-c1c1-40e0-a7f8-ef3da32e4592", "metadata": {}, "source": [ "# 02 — Sanity-checking the Actor-Critic network\n", "\n", "### Verify that:\n", "### - the network accepts uint8 (4, 84, 84) input\n", "### - it runs on GPU\n", "### - forward pass returns the expected shapes\n", "### - get_action_and_value works for both sampling and scoring" ] }, { "cell_type": "markdown", "id": "9c6c1d35-f17c-4fca-9cfb-b5b001b7a0c8", "metadata": {}, "source": [ "## Cell 1 : test env" ] }, { "cell_type": "code", "execution_count": 17, "id": "9e09b2e5-c076-4599-8e98-1cb09c0a7cf5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device: cuda\n", "ActorCritic(\n", " (cnn): Sequential(\n", " (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))\n", " (1): ReLU()\n", " (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))\n", " (3): ReLU()\n", " (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", " (5): ReLU()\n", " (6): Flatten(start_dim=1, end_dim=-1)\n", " (7): Linear(in_features=3136, out_features=512, bias=True)\n", " (8): ReLU()\n", " )\n", " (actor): Linear(in_features=512, out_features=5, bias=True)\n", " (critic): Linear(in_features=512, out_features=1, bias=True)\n", ")\n", "\n", "Total parameters: 1,687,206\n" ] } ], "source": [ "import sys\n", "from pathlib import Path\n", "project_root = Path.cwd().parent\n", "if str(project_root) not in sys.path:\n", " sys.path.insert(0, str(project_root))\n", "\n", "import torch\n", "from src.networks import ActorCritic\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(\"Device:\", device)\n", "\n", "net = ActorCritic(n_actions=5).to(device)\n", "print(net)\n", "\n", "# Count parameters\n", "total_params = sum(p.numel() for p in net.parameters())\n", "print(f\"\\nTotal parameters: {total_params:,}\")" ] }, { "cell_type": "markdown", "id": "8334668c-8f1c-4460-9e1b-9cc2c8c938a1", "metadata": {}, "source": [ "## Cell 2 :test forward" ] }, { "cell_type": "code", "execution_count": 18, "id": "affedbd5-d08b-441b-8cae-b46057be5c63", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input shape : torch.Size([8, 4, 84, 84]) torch.uint8\n", "Logits shape: torch.Size([8, 5]) torch.float32\n", "Value shape : torch.Size([8]) torch.float32\n", "Sample logits: [-0.0013231671182438731, 0.0014129895716905594, 0.0010137694189324975, 0.0005002821562811732, -0.0012777929659932852]\n", "Sample value : 0.9111840724945068\n" ] } ], "source": [ "# Fake batch of 8 observations, shape (8, 4, 84, 84) uint8\n", "fake_obs = torch.randint(0, 255, (8, 4, 84, 84), dtype=torch.uint8, device=device)\n", "\n", "logits, value = net(fake_obs)\n", "print(\"Input shape :\", fake_obs.shape, fake_obs.dtype)\n", "print(\"Logits shape:\", logits.shape, logits.dtype)\n", "print(\"Value shape :\", value.shape, value.dtype)\n", "print(\"Sample logits:\", logits[0].detach().cpu().tolist())\n", "print(\"Sample value :\", value[0].item())" ] }, { "cell_type": "markdown", "id": "6eb6ee0d-adc7-4ac5-953b-91018599dd7f", "metadata": {}, "source": [ "## Cell 3 :test get_action_and_value" ] }, { "cell_type": "code", "execution_count": 19, "id": "e5ad46a8-7f62-442a-96a2-2d2c3ef91d59", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mode 1 (sample):\n", " action : torch.Size([8]) torch.int64, sample = [4, 2, 3]\n", " log_prob: torch.Size([8]), sample = [-1.6107815504074097, -1.60985267162323, -1.6094183921813965]\n", " entropy : torch.Size([8]), sample = [1.6094372272491455, 1.609436273574829, 1.6094352006912231]\n", " value : torch.Size([8]), sample = [0.9111840724945068, 0.8728611469268799, 0.9081785678863525]\n", "\n", "Mode 2 (score given action):\n", " log_prob shape: torch.Size([8])\n", " entropy shape : torch.Size([8])\n", "\n", "Reference: ln(5) = 1.6094\n", "Mean entropy at init: 1.6094\n" ] } ], "source": [ "# Mode 1: sample action\n", "action, log_prob, entropy, value = net.get_action_and_value(fake_obs)\n", "print(\"Mode 1 (sample):\")\n", "print(f\" action : {action.shape} {action.dtype}, sample = {action[:3].tolist()}\")\n", "print(f\" log_prob: {log_prob.shape}, sample = {log_prob[:3].detach().cpu().tolist()}\")\n", "print(f\" entropy : {entropy.shape}, sample = {entropy[:3].detach().cpu().tolist()}\")\n", "print(f\" value : {value.shape}, sample = {value[:3].detach().cpu().tolist()}\")\n", "\n", "# Mode 2: score given action (this is what PPO update uses)\n", "provided = torch.tensor([0, 3, 2, 1, 4, 0, 3, 2], device=device)\n", "_, log_prob2, entropy2, value2 = net.get_action_and_value(fake_obs, provided)\n", "print(\"\\nMode 2 (score given action):\")\n", "print(f\" log_prob shape: {log_prob2.shape}\")\n", "print(f\" entropy shape : {entropy2.shape}\")\n", "\n", "# Sanity: entropy of a uniform 5-action distribution should be ln(5) ≈ 1.6094\n", "import math\n", "print(f\"\\nReference: ln(5) = {math.log(5):.4f}\")\n", "print(f\"Mean entropy at init: {entropy.mean().item():.4f}\")" ] }, { "cell_type": "markdown", "id": "d4c00e27-2a64-46f9-99d0-aff04fe2e714", "metadata": {}, "source": [ "## Cell 4 : Run it once with the obs of the real env" ] }, { "cell_type": "code", "execution_count": 24, "id": "c362b84b-1b74-4f55-b7e9-d842414c4a9f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "obs_t shape: torch.Size([1, 4, 84, 84]) torch.uint8\n", "Sampled action: 0\n", "Log prob: -1.6102\n", "Entropy: 1.6094\n", "Value estimate: 0.3286\n" ] } ], "source": [ "from src.env_wrappers import make_env\n", "import numpy as np\n", "\n", "env = make_env(seed=42)\n", "obs, _ = env.reset(seed=42)\n", "\n", "# obs is a numpy uint8 array (4, 84, 84). Add a batch dim and move to device.\n", "obs_t = torch.as_tensor(obs).unsqueeze(0).to(device)\n", "print(\"obs_t shape:\", obs_t.shape, obs_t.dtype)\n", "\n", "action, log_prob, entropy, value = net.get_action_and_value(obs_t)\n", "print(f\"Sampled action: {action.item()}\")\n", "print(f\"Log prob: {log_prob.item():.4f}\")\n", "print(f\"Entropy: {entropy.item():.4f}\")\n", "print(f\"Value estimate: {value.item():.4f}\")\n", "\n", "env.close()" ] } ], "metadata": { "kernelspec": { "display_name": "PYTORCH", "language": "python", "name": "pytorch" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 5 }