{ "cells": [ { "cell_type": "markdown", "id": "c4071b53", "metadata": { "origin_pos": 0 }, "source": [ "# 多层感知机\n", ":label:`sec_mlp`\n", "\n", "在 :numref:`chap_linear`中,\n", "我们介绍了softmax回归( :numref:`sec_softmax`),\n", "然后我们从零开始实现了softmax回归( :numref:`sec_softmax_scratch`),\n", "接着使用高级API实现了算法( :numref:`sec_softmax_concise`),\n", "并训练分类器从低分辨率图像中识别10类服装。\n", "在这个过程中,我们学习了如何处理数据,如何将输出转换为有效的概率分布,\n", "并应用适当的损失函数,根据模型参数最小化损失。\n", "我们已经在简单的线性模型背景下掌握了这些知识,\n", "现在我们可以开始对深度神经网络的探索,这也是本书主要涉及的一类模型。\n", "\n", "## 隐藏层\n", "\n", "我们在 :numref:`subsec_linear_model`中描述了仿射变换,\n", "它是一种带有偏置项的线性变换。\n", "首先,回想一下如 :numref:`fig_softmaxreg`中所示的softmax回归的模型架构。\n", "该模型通过单个仿射变换将我们的输入直接映射到输出,然后进行softmax操作。\n", "如果我们的标签通过仿射变换后确实与我们的输入数据相关,那么这种方法确实足够了。\n", "但是,仿射变换中的*线性*是一个很强的假设。\n", "\n", "### 线性模型可能会出错\n", "\n", "例如,线性意味着*单调*假设:\n", "任何特征的增大都会导致模型输出的增大(如果对应的权重为正),\n", "或者导致模型输出的减小(如果对应的权重为负)。\n", "有时这是有道理的。\n", "例如,如果我们试图预测一个人是否会偿还贷款。\n", "我们可以认为,在其他条件不变的情况下,\n", "收入较高的申请人比收入较低的申请人更有可能偿还贷款。\n", "但是,虽然收入与还款概率存在单调性,但它们不是线性相关的。\n", "收入从0增加到5万,可能比从100万增加到105万带来更大的还款可能性。\n", "处理这一问题的一种方法是对我们的数据进行预处理,\n", "使线性变得更合理,如使用收入的对数作为我们的特征。\n", "\n", "然而我们可以很容易找出违反单调性的例子。\n", "例如,我们想要根据体温预测死亡率。\n", "对体温高于37摄氏度的人来说,温度越高风险越大。\n", "然而,对体温低于37摄氏度的人来说,温度越高风险就越低。\n", "在这种情况下,我们也可以通过一些巧妙的预处理来解决问题。\n", "例如,我们可以使用与37摄氏度的距离作为特征。\n", "\n", "但是,如何对猫和狗的图像进行分类呢?\n", "增加位置$(13, 17)$处像素的强度是否总是增加(或降低)图像描绘狗的似然?\n", "对线性模型的依赖对应于一个隐含的假设,\n", "即区分猫和狗的唯一要求是评估单个像素的强度。\n", "在一个倒置图像后依然保留类别的世界里,这种方法注定会失败。\n", "\n", "与我们前面的例子相比,这里的线性很荒谬,\n", "而且我们难以通过简单的预处理来解决这个问题。\n", "这是因为任何像素的重要性都以复杂的方式取决于该像素的上下文(周围像素的值)。\n", "我们的数据可能会有一种表示,这种表示会考虑到我们在特征之间的相关交互作用。\n", "在此表示的基础上建立一个线性模型可能会是合适的,\n", "但我们不知道如何手动计算这么一种表示。\n", "对于深度神经网络,我们使用观测数据来联合学习隐藏层表示和应用于该表示的线性预测器。\n", "\n", "### 在网络中加入隐藏层\n", "\n", "我们可以通过在网络中加入一个或多个隐藏层来克服线性模型的限制,\n", "使其能处理更普遍的函数关系类型。\n", "要做到这一点,最简单的方法是将许多全连接层堆叠在一起。\n", "每一层都输出到上面的层,直到生成最后的输出。\n", "我们可以把前$L-1$层看作表示,把最后一层看作线性预测器。\n", "这种架构通常称为*多层感知机*(multilayer perceptron),通常缩写为*MLP*。\n", "下面,我们以图的方式描述了多层感知机( :numref:`fig_mlp`)。\n", "\n", "![一个单隐藏层的多层感知机,具有5个隐藏单元](../img/mlp.svg)\n", ":label:`fig_mlp`\n", "\n", "这个多层感知机有4个输入,3个输出,其隐藏层包含5个隐藏单元。\n", "输入层不涉及任何计算,因此使用此网络产生输出只需要实现隐藏层和输出层的计算。\n", "因此,这个多层感知机中的层数为2。\n", "注意,这两个层都是全连接的。\n", "每个输入都会影响隐藏层中的每个神经元,\n", "而隐藏层中的每个神经元又会影响输出层中的每个神经元。\n", "\n", "然而,正如 :numref:`subsec_parameterization-cost-fc-layers`所说,\n", "具有全连接层的多层感知机的参数开销可能会高得令人望而却步。\n", "即使在不改变输入或输出大小的情况下,\n", "可能在参数节约和模型有效性之间进行权衡 :cite:`Zhang.Tay.Zhang.ea.2021`。\n", "\n", "### 从线性到非线性\n", "\n", "同之前的章节一样,\n", "我们通过矩阵$\\mathbf{X} \\in \\mathbb{R}^{n \\times d}$\n", "来表示$n$个样本的小批量,\n", "其中每个样本具有$d$个输入特征。\n", "对于具有$h$个隐藏单元的单隐藏层多层感知机,\n", "用$\\mathbf{H} \\in \\mathbb{R}^{n \\times h}$表示隐藏层的输出,\n", "称为*隐藏表示*(hidden representations)。\n", "在数学或代码中,$\\mathbf{H}$也被称为*隐藏层变量*(hidden-layer variable)\n", "或*隐藏变量*(hidden variable)。\n", "因为隐藏层和输出层都是全连接的,\n", "所以我们有隐藏层权重$\\mathbf{W}^{(1)} \\in \\mathbb{R}^{d \\times h}$\n", "和隐藏层偏置$\\mathbf{b}^{(1)} \\in \\mathbb{R}^{1 \\times h}$\n", "以及输出层权重$\\mathbf{W}^{(2)} \\in \\mathbb{R}^{h \\times q}$\n", "和输出层偏置$\\mathbf{b}^{(2)} \\in \\mathbb{R}^{1 \\times q}$。\n", "形式上,我们按如下方式计算单隐藏层多层感知机的输出\n", "$\\mathbf{O} \\in \\mathbb{R}^{n \\times q}$:\n", "\n", "$$\n", "\\begin{aligned}\n", " \\mathbf{H} & = \\mathbf{X} \\mathbf{W}^{(1)} + \\mathbf{b}^{(1)}, \\\\\n", " \\mathbf{O} & = \\mathbf{H}\\mathbf{W}^{(2)} + \\mathbf{b}^{(2)}.\n", "\\end{aligned}\n", "$$\n", "\n", "注意在添加隐藏层之后,模型现在需要跟踪和更新额外的参数。\n", "可我们能从中得到什么好处呢?在上面定义的模型里,我们没有好处!\n", "原因很简单:上面的隐藏单元由输入的仿射函数给出,\n", "而输出(softmax操作前)只是隐藏单元的仿射函数。\n", "仿射函数的仿射函数本身就是仿射函数,\n", "但是我们之前的线性模型已经能够表示任何仿射函数。\n", "\n", "我们可以证明这一等价性,即对于任意权重值,\n", "我们只需合并隐藏层,便可产生具有参数\n", "$\\mathbf{W} = \\mathbf{W}^{(1)}\\mathbf{W}^{(2)}$\n", "和$\\mathbf{b} = \\mathbf{b}^{(1)} \\mathbf{W}^{(2)} + \\mathbf{b}^{(2)}$\n", "的等价单层模型:\n", "\n", "$$\n", "\\mathbf{O} = (\\mathbf{X} \\mathbf{W}^{(1)} + \\mathbf{b}^{(1)})\\mathbf{W}^{(2)} + \\mathbf{b}^{(2)} = \\mathbf{X} \\mathbf{W}^{(1)}\\mathbf{W}^{(2)} + \\mathbf{b}^{(1)} \\mathbf{W}^{(2)} + \\mathbf{b}^{(2)} = \\mathbf{X} \\mathbf{W} + \\mathbf{b}.\n", "$$\n", "\n", "为了发挥多层架构的潜力,\n", "我们还需要一个额外的关键要素:\n", "在仿射变换之后对每个隐藏单元应用非线性的*激活函数*(activation function)$\\sigma$。\n", "激活函数的输出(例如,$\\sigma(\\cdot)$)被称为*活性值*(activations)。\n", "一般来说,有了激活函数,就不可能再将我们的多层感知机退化成线性模型:\n", "\n", "$$\n", "\\begin{aligned}\n", " \\mathbf{H} & = \\sigma(\\mathbf{X} \\mathbf{W}^{(1)} + \\mathbf{b}^{(1)}), \\\\\n", " \\mathbf{O} & = \\mathbf{H}\\mathbf{W}^{(2)} + \\mathbf{b}^{(2)}.\\\\\n", "\\end{aligned}\n", "$$\n", "\n", "由于$\\mathbf{X}$中的每一行对应于小批量中的一个样本,\n", "出于记号习惯的考量,\n", "我们定义非线性函数$\\sigma$也以按行的方式作用于其输入,\n", "即一次计算一个样本。\n", "我们在 :numref:`subsec_softmax_vectorization`中\n", "以相同的方式使用了softmax符号来表示按行操作。\n", "但是本节应用于隐藏层的激活函数通常不仅按行操作,也按元素操作。\n", "这意味着在计算每一层的线性部分之后,我们可以计算每个活性值,\n", "而不需要查看其他隐藏单元所取的值。对于大多数激活函数都是这样。\n", "\n", "为了构建更通用的多层感知机,\n", "我们可以继续堆叠这样的隐藏层,\n", "例如$\\mathbf{H}^{(1)} = \\sigma_1(\\mathbf{X} \\mathbf{W}^{(1)} + \\mathbf{b}^{(1)})$和$\\mathbf{H}^{(2)} = \\sigma_2(\\mathbf{H}^{(1)} \\mathbf{W}^{(2)} + \\mathbf{b}^{(2)})$,\n", "一层叠一层,从而产生更有表达能力的模型。\n", "\n", "### 通用近似定理\n", "\n", "多层感知机可以通过隐藏神经元,捕捉到输入之间复杂的相互作用,\n", "这些神经元依赖于每个输入的值。\n", "我们可以很容易地设计隐藏节点来执行任意计算。\n", "例如,在一对输入上进行基本逻辑操作,多层感知机是通用近似器。\n", "即使是网络只有一个隐藏层,给定足够的神经元和正确的权重,\n", "我们可以对任意函数建模,尽管实际中学习该函数是很困难的。\n", "神经网络有点像C语言。\n", "C语言和任何其他现代编程语言一样,能够表达任何可计算的程序。\n", "但实际上,想出一个符合规范的程序才是最困难的部分。\n", "\n", "而且,虽然一个单隐层网络能学习任何函数,\n", "但并不意味着我们应该尝试使用单隐藏层网络来解决所有问题。\n", "事实上,通过使用更深(而不是更广)的网络,我们可以更容易地逼近许多函数。\n", "我们将在后面的章节中进行更细致的讨论。\n", "\n", "## 激活函数\n", ":label:`subsec_activation_functions`\n", "\n", "*激活函数*(activation function)通过计算加权和并加上偏置来确定神经元是否应该被激活,\n", "它们将输入信号转换为输出的可微运算。\n", "大多数激活函数都是非线性的。\n", "由于激活函数是深度学习的基础,下面(**简要介绍一些常见的激活函数**)。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "5fe77ef0", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:49.792151Z", "iopub.status.busy": "2023-08-18T06:56:49.791335Z", "iopub.status.idle": "2023-08-18T06:56:51.817055Z", "shell.execute_reply": "2023-08-18T06:56:51.816151Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "be8e8952", "metadata": { "origin_pos": 5 }, "source": [ "### ReLU函数\n", "\n", "最受欢迎的激活函数是*修正线性单元*(Rectified linear unit,*ReLU*),\n", "因为它实现简单,同时在各种预测任务中表现良好。\n", "[**ReLU提供了一种非常简单的非线性变换**]。\n", "给定元素$x$,ReLU函数被定义为该元素与$0$的最大值:\n", "\n", "(**$$\\operatorname{ReLU}(x) = \\max(x, 0).$$**)\n", "\n", "通俗地说,ReLU函数通过将相应的活性值设为0,仅保留正元素并丢弃所有负元素。\n", "为了直观感受一下,我们可以画出函数的曲线图。\n", "正如从图中所看到,激活函数是分段线性的。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "e5c9d313", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:51.821221Z", "iopub.status.busy": "2023-08-18T06:56:51.820525Z", "iopub.status.idle": "2023-08-18T06:56:52.021660Z", "shell.execute_reply": "2023-08-18T06:56:52.020787Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T06:56:51.980684\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)\n", "y = torch.relu(x)\n", "d2l.plot(x.detach(), y.detach(), 'x', 'relu(x)', figsize=(5, 2.5))" ] }, { "cell_type": "markdown", "id": "aa375bdc", "metadata": { "origin_pos": 10 }, "source": [ "当输入为负时,ReLU函数的导数为0,而当输入为正时,ReLU函数的导数为1。\n", "注意,当输入值精确等于0时,ReLU函数不可导。\n", "在此时,我们默认使用左侧的导数,即当输入为0时导数为0。\n", "我们可以忽略这种情况,因为输入可能永远都不会是0。\n", "这里引用一句古老的谚语,“如果微妙的边界条件很重要,我们很可能是在研究数学而非工程”,\n", "这个观点正好适用于这里。\n", "下面我们绘制ReLU函数的导数。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "98dfa40b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:52.025683Z", "iopub.status.busy": "2023-08-18T06:56:52.025087Z", "iopub.status.idle": "2023-08-18T06:56:52.225589Z", "shell.execute_reply": "2023-08-18T06:56:52.224709Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T06:56:52.181598\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y.backward(torch.ones_like(x), retain_graph=True)\n", "d2l.plot(x.detach(), x.grad, 'x', 'grad of relu', figsize=(5, 2.5))" ] }, { "cell_type": "markdown", "id": "eb182493", "metadata": { "origin_pos": 15 }, "source": [ "使用ReLU的原因是,它求导表现得特别好:要么让参数消失,要么让参数通过。\n", "这使得优化表现得更好,并且ReLU减轻了困扰以往神经网络的梯度消失问题(稍后将详细介绍)。\n", "\n", "注意,ReLU函数有许多变体,包括*参数化ReLU*(Parameterized ReLU,*pReLU*)\n", "函数 :cite:`He.Zhang.Ren.ea.2015`。\n", "该变体为ReLU添加了一个线性项,因此即使参数是负的,某些信息仍然可以通过:\n", "\n", "$$\\operatorname{pReLU}(x) = \\max(0, x) + \\alpha \\min(0, x).$$\n", "\n", "### sigmoid函数\n", "\n", "[**对于一个定义域在$\\mathbb{R}$中的输入,\n", "*sigmoid函数*将输入变换为区间(0, 1)上的输出**]。\n", "因此,sigmoid通常称为*挤压函数*(squashing function):\n", "它将范围(-inf, inf)中的任意输入压缩到区间(0, 1)中的某个值:\n", "\n", "(**$$\\operatorname{sigmoid}(x) = \\frac{1}{1 + \\exp(-x)}.$$**)\n", "\n", "在最早的神经网络中,科学家们感兴趣的是对“激发”或“不激发”的生物神经元进行建模。\n", "因此,这一领域的先驱可以一直追溯到人工神经元的发明者麦卡洛克和皮茨,他们专注于阈值单元。\n", "阈值单元在其输入低于某个阈值时取值0,当输入超过阈值时取值1。\n", "\n", "当人们逐渐关注到到基于梯度的学习时,\n", "sigmoid函数是一个自然的选择,因为它是一个平滑的、可微的阈值单元近似。\n", "当我们想要将输出视作二元分类问题的概率时,\n", "sigmoid仍然被广泛用作输出单元上的激活函数\n", "(sigmoid可以视为softmax的特例)。\n", "然而,sigmoid在隐藏层中已经较少使用,\n", "它在大部分时候被更简单、更容易训练的ReLU所取代。\n", "在后面关于循环神经网络的章节中,我们将描述利用sigmoid单元来控制时序信息流的架构。\n", "\n", "下面,我们绘制sigmoid函数。\n", "注意,当输入接近0时,sigmoid函数接近线性变换。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "2dff6175", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:52.229275Z", "iopub.status.busy": "2023-08-18T06:56:52.228651Z", "iopub.status.idle": "2023-08-18T06:56:52.431085Z", "shell.execute_reply": "2023-08-18T06:56:52.430246Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T06:56:52.387012\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y = torch.sigmoid(x)\n", "d2l.plot(x.detach(), y.detach(), 'x', 'sigmoid(x)', figsize=(5, 2.5))" ] }, { "cell_type": "markdown", "id": "8bd55f6c", "metadata": { "origin_pos": 20 }, "source": [ "sigmoid函数的导数为下面的公式:\n", "\n", "$$\\frac{d}{dx} \\operatorname{sigmoid}(x) = \\frac{\\exp(-x)}{(1 + \\exp(-x))^2} = \\operatorname{sigmoid}(x)\\left(1-\\operatorname{sigmoid}(x)\\right).$$\n", "\n", "sigmoid函数的导数图像如下所示。\n", "注意,当输入为0时,sigmoid函数的导数达到最大值0.25;\n", "而输入在任一方向上越远离0点时,导数越接近0。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "21eafca3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:52.434872Z", "iopub.status.busy": "2023-08-18T06:56:52.434265Z", "iopub.status.idle": "2023-08-18T06:56:52.637287Z", "shell.execute_reply": "2023-08-18T06:56:52.636457Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T06:56:52.592042\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 清除以前的梯度\n", "x.grad.data.zero_()\n", "y.backward(torch.ones_like(x),retain_graph=True)\n", "d2l.plot(x.detach(), x.grad, 'x', 'grad of sigmoid', figsize=(5, 2.5))" ] }, { "cell_type": "markdown", "id": "db161ec3", "metadata": { "origin_pos": 25 }, "source": [ "### tanh函数\n", "\n", "与sigmoid函数类似,\n", "[**tanh(双曲正切)函数也能将其输入压缩转换到区间(-1, 1)上**]。\n", "tanh函数的公式如下:\n", "\n", "(**$$\\operatorname{tanh}(x) = \\frac{1 - \\exp(-2x)}{1 + \\exp(-2x)}.$$**)\n", "\n", "下面我们绘制tanh函数。\n", "注意,当输入在0附近时,tanh函数接近线性变换。\n", "函数的形状类似于sigmoid函数,\n", "不同的是tanh函数关于坐标系原点中心对称。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "3d4603cd", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:52.641041Z", "iopub.status.busy": "2023-08-18T06:56:52.640414Z", "iopub.status.idle": "2023-08-18T06:56:52.836897Z", "shell.execute_reply": "2023-08-18T06:56:52.836054Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T06:56:52.794854\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y = torch.tanh(x)\n", "d2l.plot(x.detach(), y.detach(), 'x', 'tanh(x)', figsize=(5, 2.5))" ] }, { "cell_type": "markdown", "id": "30ad8975", "metadata": { "origin_pos": 30 }, "source": [ "tanh函数的导数是:\n", "\n", "$$\\frac{d}{dx} \\operatorname{tanh}(x) = 1 - \\operatorname{tanh}^2(x).$$\n", "\n", "tanh函数的导数图像如下所示。\n", "当输入接近0时,tanh函数的导数接近最大值1。\n", "与我们在sigmoid函数图像中看到的类似,\n", "输入在任一方向上越远离0点,导数越接近0。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "175057b9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:52.840722Z", "iopub.status.busy": "2023-08-18T06:56:52.840116Z", "iopub.status.idle": "2023-08-18T06:56:53.041511Z", "shell.execute_reply": "2023-08-18T06:56:53.040572Z" }, "origin_pos": 32, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T06:56:52.997123\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 清除以前的梯度\n", "x.grad.data.zero_()\n", "y.backward(torch.ones_like(x),retain_graph=True)\n", "d2l.plot(x.detach(), x.grad, 'x', 'grad of tanh', figsize=(5, 2.5))" ] }, { "cell_type": "markdown", "id": "a2beb0bf", "metadata": { "origin_pos": 35 }, "source": [ "总结一下,我们现在了解了如何结合非线性函数来构建具有更强表达能力的多层神经网络架构。\n", "顺便说一句,这些知识已经让你掌握了一个类似于1990年左右深度学习从业者的工具。\n", "在某些方面,你比在20世纪90年代工作的任何人都有优势,\n", "因为你可以利用功能强大的开源深度学习框架,只需几行代码就可以快速构建模型,\n", "而以前训练这些网络需要研究人员编写数千行的C或Fortran代码。\n", "\n", "## 小结\n", "\n", "* 多层感知机在输出层和输入层之间增加一个或多个全连接隐藏层,并通过激活函数转换隐藏层的输出。\n", "* 常用的激活函数包括ReLU函数、sigmoid函数和tanh函数。\n", "\n", "## 练习\n", "\n", "1. 计算pReLU激活函数的导数。\n", "1. 证明一个仅使用ReLU(或pReLU)的多层感知机构造了一个连续的分段线性函数。\n", "1. 证明$\\operatorname{tanh}(x) + 1 = 2 \\operatorname{sigmoid}(2x)$。\n", "1. 假设我们有一个非线性单元,将它一次应用于一个小批量的数据。这会导致什么样的问题?\n" ] }, { "cell_type": "markdown", "id": "0609e29e", "metadata": { "origin_pos": 37, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1796)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }