{ "cells": [ { "cell_type": "markdown", "id": "59b31683", "metadata": { "origin_pos": 0 }, "source": [ "# 多层感知机的从零开始实现\n", ":label:`sec_mlp_scratch`\n", "\n", "我们已经在 :numref:`sec_mlp`中描述了多层感知机(MLP),\n", "现在让我们尝试自己实现一个多层感知机。\n", "为了与之前softmax回归( :numref:`sec_softmax_scratch` )\n", "获得的结果进行比较,\n", "我们将继续使用Fashion-MNIST图像分类数据集\n", "( :numref:`sec_fashion_mnist`)。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "ffbb0fc1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:21.394152Z", "iopub.status.busy": "2023-08-18T06:59:21.393407Z", "iopub.status.idle": "2023-08-18T06:59:24.364157Z", "shell.execute_reply": "2023-08-18T06:59:24.362977Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "code", "execution_count": 2, "id": "0be61c4f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.369567Z", "iopub.status.busy": "2023-08-18T06:59:24.368990Z", "iopub.status.idle": "2023-08-18T06:59:24.501326Z", "shell.execute_reply": "2023-08-18T06:59:24.500151Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)" ] }, { "cell_type": "markdown", "id": "8236e2cd", "metadata": { "origin_pos": 6 }, "source": [ "## 初始化模型参数\n", "\n", "回想一下,Fashion-MNIST中的每个图像由\n", "$28 \\times 28 = 784$个灰度像素值组成。\n", "所有图像共分为10个类别。\n", "忽略像素之间的空间结构,\n", "我们可以将每个图像视为具有784个输入特征\n", "和10个类的简单分类数据集。\n", "首先,我们将[**实现一个具有单隐藏层的多层感知机,\n", "它包含256个隐藏单元**]。\n", "注意,我们可以将这两个变量都视为超参数。\n", "通常,我们选择2的若干次幂作为层的宽度。\n", "因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效。\n", "\n", "我们用几个张量来表示我们的参数。\n", "注意,对于每一层我们都要记录一个权重矩阵和一个偏置向量。\n", "跟以前一样,我们要为损失关于这些参数的梯度分配内存。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "7730f280", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.508163Z", "iopub.status.busy": "2023-08-18T06:59:24.506257Z", "iopub.status.idle": "2023-08-18T06:59:24.520861Z", "shell.execute_reply": "2023-08-18T06:59:24.519861Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "num_inputs, num_outputs, num_hiddens = 784, 10, 256\n", "\n", "W1 = nn.Parameter(torch.randn(\n", " num_inputs, num_hiddens, requires_grad=True) * 0.01)\n", "b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))\n", "W2 = nn.Parameter(torch.randn(\n", " num_hiddens, num_outputs, requires_grad=True) * 0.01)\n", "b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))\n", "\n", "params = [W1, b1, W2, b2]" ] }, { "cell_type": "markdown", "id": "2700dfe8", "metadata": { "origin_pos": 11 }, "source": [ "## 激活函数\n", "\n", "为了确保我们对模型的细节了如指掌,\n", "我们将[**实现ReLU激活函数**],\n", "而不是直接调用内置的`relu`函数。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "5f46a813", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.528151Z", "iopub.status.busy": "2023-08-18T06:59:24.526356Z", "iopub.status.idle": "2023-08-18T06:59:24.533695Z", "shell.execute_reply": "2023-08-18T06:59:24.532654Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def relu(X):\n", " a = torch.zeros_like(X)\n", " return torch.max(X, a)" ] }, { "cell_type": "markdown", "id": "741dbe39", "metadata": { "origin_pos": 16 }, "source": [ "## 模型\n", "\n", "因为我们忽略了空间结构,\n", "所以我们使用`reshape`将每个二维图像转换为一个长度为`num_inputs`的向量。\n", "只需几行代码就可以(**实现我们的模型**)。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "b3d9923a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.541482Z", "iopub.status.busy": "2023-08-18T06:59:24.539621Z", "iopub.status.idle": "2023-08-18T06:59:24.547435Z", "shell.execute_reply": "2023-08-18T06:59:24.546468Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def net(X):\n", " X = X.reshape((-1, num_inputs))\n", " H = relu(X@W1 + b1) # 这里“@”代表矩阵乘法\n", " return (H@W2 + b2)" ] }, { "cell_type": "markdown", "id": "bd600c14", "metadata": { "origin_pos": 21 }, "source": [ "## 损失函数\n", "\n", "由于我们已经从零实现过softmax函数( :numref:`sec_softmax_scratch`),\n", "因此在这里我们直接使用高级API中的内置函数来计算softmax和交叉熵损失。\n", "回想一下我们之前在 :numref:`subsec_softmax-implementation-revisited`中\n", "对这些复杂问题的讨论。\n", "我们鼓励感兴趣的读者查看损失函数的源代码,以加深对实现细节的了解。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "f55fe0ea", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.554675Z", "iopub.status.busy": "2023-08-18T06:59:24.552824Z", "iopub.status.idle": "2023-08-18T06:59:24.560084Z", "shell.execute_reply": "2023-08-18T06:59:24.559049Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "loss = nn.CrossEntropyLoss(reduction='none')" ] }, { "cell_type": "markdown", "id": "b3a03c3a", "metadata": { "origin_pos": 25 }, "source": [ "## 训练\n", "\n", "幸运的是,[**多层感知机的训练过程与softmax回归的训练过程完全相同**]。\n", "可以直接调用`d2l`包的`train_ch3`函数(参见 :numref:`sec_softmax_scratch` ),\n", "将迭代周期数设置为10,并将学习率设置为0.1.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "c83cc0c7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:59:24.567796Z", "iopub.status.busy": "2023-08-18T06:59:24.566005Z", "iopub.status.idle": "2023-08-18T07:00:19.750339Z", "shell.execute_reply": "2023-08-18T07:00:19.748990Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:19.710036\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs, lr = 10, 0.1\n", "updater = torch.optim.SGD(params, lr=lr)\n", "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)" ] }, { "cell_type": "markdown", "id": "4da98919", "metadata": { "origin_pos": 30 }, "source": [ "为了对学习到的模型进行评估,我们将[**在一些测试数据上应用这个模型**]。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "8230ba7c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:19.755336Z", "iopub.status.busy": "2023-08-18T07:00:19.754506Z", "iopub.status.idle": "2023-08-18T07:00:20.323813Z", "shell.execute_reply": "2023-08-18T07:00:20.322738Z" }, "origin_pos": 31, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:20.249980\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.predict_ch3(net, test_iter)" ] }, { "cell_type": "markdown", "id": "c97420c6", "metadata": { "origin_pos": 32 }, "source": [ "## 小结\n", "\n", "* 手动实现一个简单的多层感知机是很容易的。然而如果有大量的层,从零开始实现多层感知机会变得很麻烦(例如,要命名和记录模型的参数)。\n", "\n", "## 练习\n", "\n", "1. 在所有其他参数保持不变的情况下,更改超参数`num_hiddens`的值,并查看此超参数的变化对结果有何影响。确定此超参数的最佳值。\n", "1. 尝试添加更多的隐藏层,并查看它对结果有何影响。\n", "1. 改变学习速率会如何影响结果?保持模型架构和其他超参数(包括轮数)不变,学习率设置为多少会带来最好的结果?\n", "1. 通过对所有超参数(学习率、轮数、隐藏层数、每层的隐藏单元数)进行联合优化,可以得到的最佳结果是什么?\n", "1. 描述为什么涉及多个超参数更具挑战性。\n", "1. 如果想要构建多个超参数的搜索方法,请想出一个聪明的策略。\n" ] }, { "cell_type": "markdown", "id": "d0e00850", "metadata": { "origin_pos": 34, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1804)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }