{ "cells": [ { "cell_type": "markdown", "id": "d5217b24", "metadata": { "origin_pos": 0 }, "source": [ "# 多层感知机的简洁实现\n", ":label:`sec_mlp_concise`\n", "\n", "本节将介绍(**通过高级API更简洁地实现多层感知机**)。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "f4b9d183", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:20.711610Z", "iopub.status.busy": "2023-08-18T07:04:20.711337Z", "iopub.status.idle": "2023-08-18T07:04:22.715766Z", "shell.execute_reply": "2023-08-18T07:04:22.714884Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "d1b8af0c", "metadata": { "origin_pos": 5 }, "source": [ "## 模型\n", "\n", "与softmax回归的简洁实现( :numref:`sec_softmax_concise`)相比,\n", "唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。\n", "第一层是[**隐藏层**],它(**包含256个隐藏单元,并使用了ReLU激活函数**)。\n", "第二层是输出层。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "a11cfbe9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:22.719981Z", "iopub.status.busy": "2023-08-18T07:04:22.719298Z", "iopub.status.idle": "2023-08-18T07:04:22.748628Z", "shell.execute_reply": "2023-08-18T07:04:22.747813Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "net = nn.Sequential(nn.Flatten(),\n", " nn.Linear(784, 256),\n", " nn.ReLU(),\n", " nn.Linear(256, 10))\n", "\n", "def init_weights(m):\n", " if type(m) == nn.Linear:\n", " nn.init.normal_(m.weight, std=0.01)\n", "\n", "net.apply(init_weights);" ] }, { "cell_type": "markdown", "id": "f5aceed6", "metadata": { "origin_pos": 10 }, "source": [ "[**训练过程**]的实现与我们实现softmax回归时完全相同,\n", "这种模块化设计使我们能够将与模型架构有关的内容独立出来。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "b23e8ab9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:22.753701Z", "iopub.status.busy": "2023-08-18T07:04:22.753406Z", "iopub.status.idle": "2023-08-18T07:04:22.758051Z", "shell.execute_reply": "2023-08-18T07:04:22.757284Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size, lr, num_epochs = 256, 0.1, 10\n", "loss = nn.CrossEntropyLoss(reduction='none')\n", "trainer = torch.optim.SGD(net.parameters(), lr=lr)" ] }, { "cell_type": "code", "execution_count": 4, "id": "78ac9bf1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:22.761842Z", "iopub.status.busy": "2023-08-18T07:04:22.761295Z", "iopub.status.idle": "2023-08-18T07:05:05.308680Z", "shell.execute_reply": "2023-08-18T07:05:05.307786Z" }, "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:05:05.270258\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n", "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)" ] }, { "cell_type": "markdown", "id": "9b636c57", "metadata": { "origin_pos": 16 }, "source": [ "## 小结\n", "\n", "* 我们可以使用高级API更简洁地实现多层感知机。\n", "* 对于相同的分类问题,多层感知机的实现与softmax回归的实现相同,只是多层感知机的实现里增加了带有激活函数的隐藏层。\n", "\n", "## 练习\n", "\n", "1. 尝试添加不同数量的隐藏层(也可以修改学习率),怎么样设置效果最好?\n", "1. 尝试不同的激活函数,哪个效果最好?\n", "1. 尝试不同的方案来初始化权重,什么方法效果最好?\n" ] }, { "cell_type": "markdown", "id": "36201fb3", "metadata": { "origin_pos": 18, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1802)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }