This commit is contained in:
2025-12-16 09:23:53 +08:00
parent 19138d3cc1
commit 9e7efd0626
409 changed files with 272713 additions and 241 deletions
@@ -0,0 +1,204 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9ed6d9cb",
"metadata": {
"origin_pos": 0
},
"source": [
"# 前向传播、反向传播和计算图\n",
":label:`sec_backprop`\n",
"\n",
"我们已经学习了如何用小批量随机梯度下降训练模型。\n",
"然而当实现该算法时,我们只考虑了通过*前向传播*forward propagation)所涉及的计算。\n",
"在计算梯度时,我们只调用了深度学习框架提供的反向传播函数,而不知其所以然。\n",
"\n",
"梯度的自动计算(自动微分)大大简化了深度学习算法的实现。\n",
"在自动微分之前,即使是对复杂模型的微小调整也需要手工重新计算复杂的导数,\n",
"学术论文也不得不分配大量页面来推导更新规则。\n",
"本节将通过一些基本的数学和计算图,\n",
"深入探讨*反向传播*的细节。\n",
"首先,我们将重点放在带权重衰减($L_2$正则化)的单隐藏层多层感知机上。\n",
"\n",
"## 前向传播\n",
"\n",
"*前向传播*forward propagation或forward pass\n",
"指的是:按顺序(从输入层到输出层)计算和存储神经网络中每层的结果。\n",
"\n",
"我们将一步步研究单隐藏层神经网络的机制,\n",
"为了简单起见,我们假设输入样本是 $\\mathbf{x}\\in \\mathbb{R}^d$\n",
"并且我们的隐藏层不包括偏置项。\n",
"这里的中间变量是:\n",
"\n",
"$$\\mathbf{z}= \\mathbf{W}^{(1)} \\mathbf{x},$$\n",
"\n",
"其中$\\mathbf{W}^{(1)} \\in \\mathbb{R}^{h \\times d}$\n",
"是隐藏层的权重参数。\n",
"将中间变量$\\mathbf{z}\\in \\mathbb{R}^h$通过激活函数$\\phi$后,\n",
"我们得到长度为$h$的隐藏激活向量:\n",
"\n",
"$$\\mathbf{h}= \\phi (\\mathbf{z}).$$\n",
"\n",
"隐藏变量$\\mathbf{h}$也是一个中间变量。\n",
"假设输出层的参数只有权重$\\mathbf{W}^{(2)} \\in \\mathbb{R}^{q \\times h}$\n",
"我们可以得到输出层变量,它是一个长度为$q$的向量:\n",
"\n",
"$$\\mathbf{o}= \\mathbf{W}^{(2)} \\mathbf{h}.$$\n",
"\n",
"假设损失函数为$l$,样本标签为$y$,我们可以计算单个数据样本的损失项,\n",
"\n",
"$$L = l(\\mathbf{o}, y).$$\n",
"\n",
"根据$L_2$正则化的定义,给定超参数$\\lambda$,正则化项为\n",
"\n",
"$$s = \\frac{\\lambda}{2} \\left(\\|\\mathbf{W}^{(1)}\\|_F^2 + \\|\\mathbf{W}^{(2)}\\|_F^2\\right),$$\n",
":eqlabel:`eq_forward-s`\n",
"\n",
"其中矩阵的Frobenius范数是将矩阵展平为向量后应用的$L_2$范数。\n",
"最后,模型在给定数据样本上的正则化损失为:\n",
"\n",
"$$J = L + s.$$\n",
"\n",
"在下面的讨论中,我们将$J$称为*目标函数*objective function)。\n",
"\n",
"## 前向传播计算图\n",
"\n",
"绘制*计算图*有助于我们可视化计算中操作符和变量的依赖关系。\n",
" :numref:`fig_forward` 是与上述简单网络相对应的计算图,\n",
" 其中正方形表示变量,圆圈表示操作符。\n",
" 左下角表示输入,右上角表示输出。\n",
" 注意显示数据流的箭头方向主要是向右和向上的。\n",
"\n",
"![前向传播的计算图](../img/forward.svg)\n",
":label:`fig_forward`\n",
"\n",
"## 反向传播\n",
"\n",
"*反向传播*backward propagation或backpropagation)指的是计算神经网络参数梯度的方法。\n",
"简言之,该方法根据微积分中的*链式规则*,按相反的顺序从输出层到输入层遍历网络。\n",
"该算法存储了计算某些参数梯度时所需的任何中间变量(偏导数)。\n",
"假设我们有函数$\\mathsf{Y}=f(\\mathsf{X})$和$\\mathsf{Z}=g(\\mathsf{Y})$\n",
"其中输入和输出$\\mathsf{X}, \\mathsf{Y}, \\mathsf{Z}$是任意形状的张量。\n",
"利用链式法则,我们可以计算$\\mathsf{Z}$关于$\\mathsf{X}$的导数\n",
"\n",
"$$\\frac{\\partial \\mathsf{Z}}{\\partial \\mathsf{X}} = \\text{prod}\\left(\\frac{\\partial \\mathsf{Z}}{\\partial \\mathsf{Y}}, \\frac{\\partial \\mathsf{Y}}{\\partial \\mathsf{X}}\\right).$$\n",
"\n",
"在这里,我们使用$\\text{prod}$运算符在执行必要的操作(如换位和交换输入位置)后将其参数相乘。\n",
"对于向量,这很简单,它只是矩阵-矩阵乘法。\n",
"对于高维张量,我们使用适当的对应项。\n",
"运算符$\\text{prod}$指代了所有的这些符号。\n",
"\n",
"回想一下,在计算图 :numref:`fig_forward`中的单隐藏层简单网络的参数是\n",
"$\\mathbf{W}^{(1)}$和$\\mathbf{W}^{(2)}$。\n",
"反向传播的目的是计算梯度$\\partial J/\\partial \\mathbf{W}^{(1)}$和\n",
"$\\partial J/\\partial \\mathbf{W}^{(2)}$。\n",
"为此,我们应用链式法则,依次计算每个中间变量和参数的梯度。\n",
"计算的顺序与前向传播中执行的顺序相反,因为我们需要从计算图的结果开始,并朝着参数的方向努力。第一步是计算目标函数$J=L+s$相对于损失项$L$和正则项$s$的梯度。\n",
"\n",
"$$\\frac{\\partial J}{\\partial L} = 1 \\; \\text{and} \\; \\frac{\\partial J}{\\partial s} = 1.$$\n",
"\n",
"接下来,我们根据链式法则计算目标函数关于输出层变量$\\mathbf{o}$的梯度:\n",
"\n",
"$$\n",
"\\frac{\\partial J}{\\partial \\mathbf{o}}\n",
"= \\text{prod}\\left(\\frac{\\partial J}{\\partial L}, \\frac{\\partial L}{\\partial \\mathbf{o}}\\right)\n",
"= \\frac{\\partial L}{\\partial \\mathbf{o}}\n",
"\\in \\mathbb{R}^q.\n",
"$$\n",
"\n",
"接下来,我们计算正则化项相对于两个参数的梯度:\n",
"\n",
"$$\\frac{\\partial s}{\\partial \\mathbf{W}^{(1)}} = \\lambda \\mathbf{W}^{(1)}\n",
"\\; \\text{and} \\;\n",
"\\frac{\\partial s}{\\partial \\mathbf{W}^{(2)}} = \\lambda \\mathbf{W}^{(2)}.$$\n",
"\n",
"现在我们可以计算最接近输出层的模型参数的梯度\n",
"$\\partial J/\\partial \\mathbf{W}^{(2)} \\in \\mathbb{R}^{q \\times h}$。\n",
"使用链式法则得出:\n",
"\n",
"$$\\frac{\\partial J}{\\partial \\mathbf{W}^{(2)}}= \\text{prod}\\left(\\frac{\\partial J}{\\partial \\mathbf{o}}, \\frac{\\partial \\mathbf{o}}{\\partial \\mathbf{W}^{(2)}}\\right) + \\text{prod}\\left(\\frac{\\partial J}{\\partial s}, \\frac{\\partial s}{\\partial \\mathbf{W}^{(2)}}\\right)= \\frac{\\partial J}{\\partial \\mathbf{o}} \\mathbf{h}^\\top + \\lambda \\mathbf{W}^{(2)}.$$\n",
":eqlabel:`eq_backprop-J-h`\n",
"\n",
"为了获得关于$\\mathbf{W}^{(1)}$的梯度,我们需要继续沿着输出层到隐藏层反向传播。\n",
"关于隐藏层输出的梯度$\\partial J/\\partial \\mathbf{h} \\in \\mathbb{R}^h$由下式给出:\n",
"\n",
"$$\n",
"\\frac{\\partial J}{\\partial \\mathbf{h}}\n",
"= \\text{prod}\\left(\\frac{\\partial J}{\\partial \\mathbf{o}}, \\frac{\\partial \\mathbf{o}}{\\partial \\mathbf{h}}\\right)\n",
"= {\\mathbf{W}^{(2)}}^\\top \\frac{\\partial J}{\\partial \\mathbf{o}}.\n",
"$$\n",
"\n",
"由于激活函数$\\phi$是按元素计算的,\n",
"计算中间变量$\\mathbf{z}$的梯度$\\partial J/\\partial \\mathbf{z} \\in \\mathbb{R}^h$\n",
"需要使用按元素乘法运算符,我们用$\\odot$表示:\n",
"\n",
"$$\n",
"\\frac{\\partial J}{\\partial \\mathbf{z}}\n",
"= \\text{prod}\\left(\\frac{\\partial J}{\\partial \\mathbf{h}}, \\frac{\\partial \\mathbf{h}}{\\partial \\mathbf{z}}\\right)\n",
"= \\frac{\\partial J}{\\partial \\mathbf{h}} \\odot \\phi'\\left(\\mathbf{z}\\right).\n",
"$$\n",
"\n",
"最后,我们可以得到最接近输入层的模型参数的梯度\n",
"$\\partial J/\\partial \\mathbf{W}^{(1)} \\in \\mathbb{R}^{h \\times d}$。\n",
"根据链式法则,我们得到:\n",
"\n",
"$$\n",
"\\frac{\\partial J}{\\partial \\mathbf{W}^{(1)}}\n",
"= \\text{prod}\\left(\\frac{\\partial J}{\\partial \\mathbf{z}}, \\frac{\\partial \\mathbf{z}}{\\partial \\mathbf{W}^{(1)}}\\right) + \\text{prod}\\left(\\frac{\\partial J}{\\partial s}, \\frac{\\partial s}{\\partial \\mathbf{W}^{(1)}}\\right)\n",
"= \\frac{\\partial J}{\\partial \\mathbf{z}} \\mathbf{x}^\\top + \\lambda \\mathbf{W}^{(1)}.\n",
"$$\n",
"\n",
"## 训练神经网络\n",
"\n",
"在训练神经网络时,前向传播和反向传播相互依赖。\n",
"对于前向传播,我们沿着依赖的方向遍历计算图并计算其路径上的所有变量。\n",
"然后将这些用于反向传播,其中计算顺序与计算图的相反。\n",
"\n",
"以上述简单网络为例:一方面,在前向传播期间计算正则项\n",
" :eqref:`eq_forward-s`取决于模型参数$\\mathbf{W}^{(1)}$和\n",
"$\\mathbf{W}^{(2)}$的当前值。\n",
"它们是由优化算法根据最近迭代的反向传播给出的。\n",
"另一方面,反向传播期间参数 :eqref:`eq_backprop-J-h`的梯度计算,\n",
"取决于由前向传播给出的隐藏变量$\\mathbf{h}$的当前值。\n",
"\n",
"因此,在训练神经网络时,在初始化模型参数后,\n",
"我们交替使用前向传播和反向传播,利用反向传播给出的梯度来更新模型参数。\n",
"注意,反向传播重复利用前向传播中存储的中间值,以避免重复计算。\n",
"带来的影响之一是我们需要保留中间值,直到反向传播完成。\n",
"这也是训练比单纯的预测需要更多的内存(显存)的原因之一。\n",
"此外,这些中间值的大小与网络层的数量和批量的大小大致成正比。\n",
"因此,使用更大的批量来训练更深层次的网络更容易导致*内存不足*(out of memory)错误。\n",
"\n",
"## 小结\n",
"\n",
"* 前向传播在神经网络定义的计算图中按顺序计算和存储中间变量,它的顺序是从输入层到输出层。\n",
"* 反向传播按相反的顺序(从输出层到输入层)计算和存储神经网络的中间变量和参数的梯度。\n",
"* 在训练深度学习模型时,前向传播和反向传播是相互依赖的。\n",
"* 训练比预测需要更多的内存。\n",
"\n",
"## 练习\n",
"\n",
"1. 假设一些标量函数$\\mathbf{X}$的输入$\\mathbf{X}$是$n \\times m$矩阵。$f$相对于$\\mathbf{X}$的梯度维数是多少?\n",
"1. 向本节中描述的模型的隐藏层添加偏置项(不需要在正则化项中包含偏置项)。\n",
" 1. 画出相应的计算图。\n",
" 1. 推导正向和反向传播方程。\n",
"1. 计算本节所描述的模型,用于训练和预测的内存占用。\n",
"1. 假设想计算二阶导数。计算图发生了什么?预计计算需要多长时间?\n",
"1. 假设计算图对当前拥有的GPU来说太大了。\n",
" 1. 请试着把它划分到多个GPU上。\n",
" 1. 与小批量训练相比,有哪些优点和缺点?\n",
"\n",
"[Discussions](https://discuss.d2l.ai/t/5769)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,498 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1658580a",
"metadata": {
"origin_pos": 0
},
"source": [
"# 环境和分布偏移\n",
"\n",
"前面我们学习了许多机器学习的实际应用,将模型拟合各种数据集。\n",
"然而,我们从来没有想过数据最初从哪里来?以及我们计划最终如何处理模型的输出?\n",
"通常情况下,开发人员会拥有一些数据且急于开发模型,而不关注这些基本问题。\n",
"\n",
"许多失败的机器学习部署(即实际应用)都可以追究到这种方式。\n",
"有时,根据测试集的精度衡量,模型表现得非常出色。\n",
"但是当数据分布突然改变时,模型在部署中会出现灾难性的失败。\n",
"更隐蔽的是,有时模型的部署本身就是扰乱数据分布的催化剂。\n",
"举一个有点荒谬却可能真实存在的例子。\n",
"假设我们训练了一个贷款申请人违约风险模型,用来预测谁将偿还贷款或违约。\n",
"这个模型发现申请人的鞋子与违约风险相关(穿牛津鞋申请人会偿还,穿运动鞋申请人会违约)。\n",
"此后,这个模型可能倾向于向所有穿着牛津鞋的申请人发放贷款,并拒绝所有穿着运动鞋的申请人。\n",
"\n",
"这种情况可能会带来灾难性的后果。\n",
"首先,一旦模型开始根据鞋类做出决定,顾客就会理解并改变他们的行为。\n",
"不久,所有的申请者都会穿牛津鞋,而信用度却没有相应的提高。\n",
"总而言之,机器学习的许多应用中都存在类似的问题:\n",
"通过将基于模型的决策引入环境,我们可能会破坏模型。\n",
"\n",
"虽然我们不可能在一节中讨论全部的问题,但我们希望揭示一些常见的问题,\n",
"并激发批判性思考,以便及早发现这些情况,减轻灾难性的损害。\n",
"有些解决方案很简单(要求“正确”的数据),有些在技术上很困难(实施强化学习系统),\n",
"还有一些解决方案要求我们完全跳出统计预测,解决一些棘手的、与算法伦理应用有关的哲学问题。\n",
"\n",
"## 分布偏移的类型\n",
"\n",
"首先,我们考虑数据分布可能发生变化的各种方式,以及为挽救模型性能可能采取的措施。\n",
"在一个经典的情景中,假设训练数据是从某个分布$p_S(\\mathbf{x},y)$中采样的,\n",
"但是测试数据将包含从不同分布$p_T(\\mathbf{x},y)$中抽取的未标记样本。\n",
"一个清醒的现实是:如果没有任何关于$p_S$和$p_T$之间相互关系的假设,\n",
"学习到一个分类器是不可能的。\n",
"\n",
"考虑一个二元分类问题:区分狗和猫。\n",
"如果分布可以以任意方式偏移,那么我们的情景允许病态的情况,\n",
"即输入的分布保持不变:$p_S(\\mathbf{x}) = p_T(\\mathbf{x})$\n",
"但标签全部翻转:$p_S(y | \\mathbf{x}) = 1 - p_T(y | \\mathbf{x})$。\n",
"换言之,如果将来所有的“猫”现在都是狗,而我们以前所说的“狗”现在是猫。\n",
"而此时输入$p(\\mathbf{x})$的分布没有任何改变,\n",
"那么我们就不可能将这种情景与分布完全没有变化的情景区分开。\n",
"\n",
"幸运的是,在对未来我们的数据可能发生变化的一些限制性假设下,\n",
"有些算法可以检测这种偏移,甚至可以动态调整,提高原始分类器的精度。\n",
"\n",
"### 协变量偏移\n",
"\n",
"在不同分布偏移中,协变量偏移可能是最为广泛研究的。\n",
"这里我们假设:虽然输入的分布可能随时间而改变,\n",
"但标签函数(即条件分布$P(y \\mid \\mathbf{x})$)没有改变。\n",
"统计学家称之为*协变量偏移*covariate shift),\n",
"因为这个问题是由于协变量(特征)分布的变化而产生的。\n",
"虽然有时我们可以在不引用因果关系的情况下对分布偏移进行推断,\n",
"但在我们认为$\\mathbf{x}$导致$y$的情况下,协变量偏移是一种自然假设。\n",
"\n",
"考虑一下区分猫和狗的问题:训练数据包括 :numref:`fig_cat-dog-train`中的图像。\n",
"\n",
"![区分猫和狗的训练数据](../img/cat-dog-train.svg)\n",
":label:`fig_cat-dog-train`\n",
"\n",
"在测试时,我们被要求对 :numref:`fig_cat-dog-test`中的图像进行分类。\n",
"\n",
"![区分猫和狗的测试数据](../img/cat-dog-test.svg)\n",
":label:`fig_cat-dog-test`\n",
"\n",
"训练集由真实照片组成,而测试集只包含卡通图片。\n",
"假设在一个与测试集的特征有着本质不同的数据集上进行训练,\n",
"如果没有方法来适应新的领域,可能会有麻烦。\n",
"\n",
"### 标签偏移\n",
"\n",
"*标签偏移*label shift)描述了与协变量偏移相反的问题。\n",
"这里我们假设标签边缘概率$P(y)$可以改变,\n",
"但是类别条件分布$P(\\mathbf{x} \\mid y)$在不同的领域之间保持不变。\n",
"当我们认为$y$导致$\\mathbf{x}$时,标签偏移是一个合理的假设。\n",
"例如,预测患者的疾病,我们可能根据症状来判断,\n",
"即使疾病的相对流行率随着时间的推移而变化。\n",
"标签偏移在这里是恰当的假设,因为疾病会引起症状。\n",
"在另一些情况下,标签偏移和协变量偏移假设可以同时成立。\n",
"例如,当标签是确定的,即使$y$导致$\\mathbf{x}$,协变量偏移假设也会得到满足。\n",
"有趣的是,在这些情况下,使用基于标签偏移假设的方法通常是有利的。\n",
"这是因为这些方法倾向于包含看起来像标签(通常是低维)的对象,\n",
"而不是像输入(通常是高维的)对象。\n",
"\n",
"### 概念偏移\n",
"\n",
"我们也可能会遇到*概念偏移*concept shift):\n",
"当标签的定义发生变化时,就会出现这种问题。\n",
"这听起来很奇怪——一只猫就是一只猫,不是吗?\n",
"然而,其他类别会随着不同时间的用法而发生变化。\n",
"精神疾病的诊断标准、所谓的时髦、以及工作头衔等等,都是概念偏移的日常映射。\n",
"事实证明,假如我们环游美国,根据所在的地理位置改变我们的数据来源,\n",
"我们会发现关于“软饮”名称的分布发生了相当大的概念偏移,\n",
"如 :numref:`fig_popvssoda` 所示。\n",
"\n",
"![美国软饮名称的概念偏移](../img/popvssoda.png)\n",
":width:`400px`\n",
":label:`fig_popvssoda`\n",
"\n",
"如果我们要建立一个机器翻译系统,\n",
"$P(y \\mid \\mathbf{x})$的分布可能会因我们的位置不同而得到不同的翻译。\n",
"这个问题可能很难被发现。\n",
"所以,我们最好可以利用在时间或空间上逐渐发生偏移的知识。\n",
"\n",
"## 分布偏移示例\n",
"\n",
"在深入研究形式体系和算法之前,我们可以讨论一些协变量偏移或概念偏移可能并不明显的具体情况。\n",
"\n",
"### 医学诊断\n",
"\n",
"假设我们想设计一个检测癌症的算法,从健康人和病人那里收集数据,然后训练算法。\n",
"它工作得很好,有很高的精度,然后我们得出了已经准备好在医疗诊断上取得成功的结论。\n",
"请先别着急。\n",
"\n",
"收集训练数据的分布和在实际中遇到的数据分布可能有很大的不同。\n",
"这件事在一个不幸的初创公司身上发生过,我们中的一些作者几年前和他们合作过。\n",
"他们正在研究一种血液检测方法,主要针对一种影响老年男性的疾病,\n",
"并希望利用他们从病人身上采集的血液样本进行研究。\n",
"然而,从健康男性身上获取血样比从系统中已有的病人身上获取要困难得多。\n",
"作为补偿,这家初创公司向一所大学校园内的学生征集献血,作为开发测试的健康对照样本。\n",
"然后这家初创公司问我们是否可以帮助他们建立一个用于检测疾病的分类器。\n",
"\n",
"正如我们向他们解释的那样,用近乎完美的精度来区分健康和患病人群确实很容易。\n",
"然而,这可能是因为受试者在年龄、激素水平、体力活动、\n",
"饮食、饮酒以及其他许多与疾病无关的因素上存在差异。\n",
"这对检测疾病的分类器可能并不适用。\n",
"这些抽样可能会遇到极端的协变量偏移。\n",
"此外,这种情况不太可能通过常规方法加以纠正。\n",
"简言之,他们浪费了一大笔钱。\n",
"\n",
"### 自动驾驶汽车\n",
"\n",
"对于一家想利用机器学习来开发自动驾驶汽车的公司,一个关键部件是“路沿检测器”。\n",
"由于真实的注释数据获取成本很高,他们想出了一个“聪明”的想法:\n",
"将游戏渲染引擎中的合成数据用作额外的训练数据。\n",
"这对从渲染引擎中抽取的“测试数据”非常有效,但应用在一辆真正的汽车里真是一场灾难。\n",
"正如事实证明的那样,路沿被渲染成一种非常简单的纹理。\n",
"更重要的是,所有的路沿都被渲染成了相同的纹理,路沿检测器很快就学习到了这个“特征”。\n",
"\n",
"当美军第一次试图在森林中探测坦克时,也发生了类似的事情。\n",
"他们在没有坦克的情况下拍摄了森林的航拍照片,然后把坦克开进森林,拍摄了另一组照片。\n",
"使用这两组数据训练的分类器似乎工作得很好。\n",
"不幸的是,分类器仅仅学会了如何区分有阴影的树和没有阴影的树:\n",
"第一组照片是在清晨拍摄的,而第二组是在中午拍摄的。\n",
"\n",
"### 非平稳分布\n",
"\n",
"当分布变化缓慢并且模型没有得到充分更新时,就会出现更微妙的情况:\n",
"*非平稳分布*nonstationary distribution)。\n",
"以下是一些典型例子:\n",
"\n",
"* 训练一个计算广告模型,但却没有经常更新(例如,一个2009年训练的模型不知道一个叫iPad的不知名新设备刚刚上市);\n",
"* 建立一个垃圾邮件过滤器,它能很好地检测到所有垃圾邮件。但是,垃圾邮件发送者们变得聪明起来,制造出新的信息,看起来不像我们以前见过的任何垃圾邮件;\n",
"* 建立一个产品推荐系统,它在整个冬天都有效,但圣诞节过后很久还会继续推荐圣诞帽。\n",
"\n",
"### 更多轶事\n",
"\n",
"* 建立一个人脸检测器,它在所有基准测试中都能很好地工作,但是它在测试数据上失败了:有问题的例子是人脸充满了整个图像的特写镜头(训练集中没有这样的数据)。\n",
"* 为美国市场建立了一个网络搜索引擎,并希望将其部署到英国。\n",
"* 通过在一个大的数据集来训练图像分类器,其中每一个大类的数量在数据集近乎是平均的,比如1000个类别,每个类别由1000个图像表示。但是将该系统部署到真实世界中,照片的实际标签分布显然是不均匀的。\n",
"\n",
"## 分布偏移纠正\n",
"\n",
"正如我们所讨论的,在许多情况下训练和测试分布$P(\\mathbf{x}, y)$是不同的。\n",
"在一些情况下,我们很幸运,不管协变量、标签或概念如何发生偏移,模型都能正常工作。\n",
"在另一些情况下,我们可以通过运用策略来应对这种偏移,从而做得更好。\n",
"本节的其余部分将着重于应对这种偏移的技术细节。\n",
"\n",
"### 经验风险与实际风险\n",
":label:`subsec_empirical-risk-and-risk`\n",
"\n",
"首先我们反思一下在模型训练期间到底发生了什么?\n",
"训练数据$\\{(\\mathbf{x}_1, y_1), \\ldots, (\\mathbf{x}_n, y_n)\\}$\n",
"的特征和相关的标签经过迭代,在每一个小批量之后更新模型$f$的参数。\n",
"为了简单起见,我们不考虑正则化,因此极大地降低了训练损失:\n",
"\n",
"$$\\mathop{\\mathrm{minimize}}_f \\frac{1}{n} \\sum_{i=1}^n l(f(\\mathbf{x}_i), y_i),$$\n",
":eqlabel:`eq_empirical-risk-min`\n",
"\n",
"其中$l$是损失函数,用来度量:\n",
"给定标签$y_i$,预测$f(\\mathbf{x}_i)$的“糟糕程度”。\n",
"统计学家称 :eqref:`eq_empirical-risk-min`中的这一项为经验风险。\n",
"*经验风险*empirical risk)是为了近似 *真实风险*true risk),\n",
"整个训练数据上的平均损失,即从其真实分布$p(\\mathbf{x},y)$中\n",
"抽取的所有数据的总体损失的期望值:\n",
"\n",
"$$E_{p(\\mathbf{x}, y)} [l(f(\\mathbf{x}), y)] = \\int\\int l(f(\\mathbf{x}), y) p(\\mathbf{x}, y) \\;d\\mathbf{x}dy.$$\n",
":eqlabel:`eq_true-risk`\n",
"\n",
"然而在实践中,我们通常无法获得总体数据。\n",
"因此,*经验风险最小化*即在 :eqref:`eq_empirical-risk-min`中最小化经验风险,\n",
"是一种实用的机器学习策略,希望能近似最小化真实风险。\n",
"\n",
"### 协变量偏移纠正\n",
":label:`subsec_covariate-shift-correction`\n",
"\n",
"假设对于带标签的数据$(\\mathbf{x}_i, y_i)$\n",
"我们要评估$P(y \\mid \\mathbf{x})$。\n",
"然而观测值$\\mathbf{x}_i$是从某些*源分布*$q(\\mathbf{x})$中得出的,\n",
"而不是从*目标分布*$p(\\mathbf{x})$中得出的。\n",
"幸运的是,依赖性假设意味着条件分布保持不变,即:\n",
"$p(y \\mid \\mathbf{x}) = q(y \\mid \\mathbf{x})$。\n",
"如果源分布$q(\\mathbf{x})$是“错误的”,\n",
"我们可以通过在真实风险的计算中,使用以下简单的恒等式来进行纠正:\n",
"\n",
"$$\n",
"\\begin{aligned}\n",
"\\int\\int l(f(\\mathbf{x}), y) p(y \\mid \\mathbf{x})p(\\mathbf{x}) \\;d\\mathbf{x}dy =\n",
"\\int\\int l(f(\\mathbf{x}), y) q(y \\mid \\mathbf{x})q(\\mathbf{x})\\frac{p(\\mathbf{x})}{q(\\mathbf{x})} \\;d\\mathbf{x}dy.\n",
"\\end{aligned}\n",
"$$\n",
"\n",
"换句话说,我们需要根据数据来自正确分布与来自错误分布的概率之比,\n",
"来重新衡量每个数据样本的权重:\n",
"\n",
"$$\\beta_i \\stackrel{\\mathrm{def}}{=} \\frac{p(\\mathbf{x}_i)}{q(\\mathbf{x}_i)}.$$\n",
"\n",
"将权重$\\beta_i$代入到每个数据样本$(\\mathbf{x}_i, y_i)$中,\n",
"我们可以使用”加权经验风险最小化“来训练模型:\n",
"\n",
"$$\\mathop{\\mathrm{minimize}}_f \\frac{1}{n} \\sum_{i=1}^n \\beta_i l(f(\\mathbf{x}_i), y_i).$$\n",
":eqlabel:`eq_weighted-empirical-risk-min`\n",
"\n",
"由于不知道这个比率,我们需要估计它。\n",
"有许多方法都可以用,包括一些花哨的算子理论方法,\n",
"试图直接使用最小范数或最大熵原理重新校准期望算子。\n",
"对于任意一种这样的方法,我们都需要从两个分布中抽取样本:\n",
"“真实”的分布$p$,通过访问测试数据获取;\n",
"训练集$q$,通过人工合成的很容易获得。\n",
"请注意,我们只需要特征$\\mathbf{x} \\sim p(\\mathbf{x})$\n",
"不需要访问标签$y \\sim p(y)$。\n",
"\n",
"在这种情况下,有一种非常有效的方法可以得到几乎与原始方法一样好的结果:\n",
"*对数几率回归*logistic regression)。\n",
"这是用于二元分类的softmax回归(见 :numref:`sec_softmax`)的一个特例。\n",
"综上所述,我们学习了一个分类器来区分从$p(\\mathbf{x})$抽取的数据\n",
"和从$q(\\mathbf{x})$抽取的数据。\n",
"如果无法区分这两个分布,则意味着相关的样本可能来自这两个分布中的任何一个。\n",
"另一方面,任何可以很好区分的样本都应该相应地显著增加或减少权重。\n",
"\n",
"为了简单起见,假设我们分别从$p(\\mathbf{x})$和$q(\\mathbf{x})$\n",
"两个分布中抽取相同数量的样本。\n",
"现在用$z$标签表示:从$p$抽取的数据为$1$,从$q$抽取的数据为$-1$。\n",
"然后,混合数据集中的概率由下式给出\n",
"\n",
"$$P(z=1 \\mid \\mathbf{x}) = \\frac{p(\\mathbf{x})}{p(\\mathbf{x})+q(\\mathbf{x})} \\text{ and hence } \\frac{P(z=1 \\mid \\mathbf{x})}{P(z=-1 \\mid \\mathbf{x})} = \\frac{p(\\mathbf{x})}{q(\\mathbf{x})}.$$\n",
"\n",
"因此,如果我们使用对数几率回归方法,其中\n",
"$P(z=1 \\mid \\mathbf{x})=\\frac{1}{1+\\exp(-h(\\mathbf{x}))}$\n",
"($h$是一个参数化函数),则很自然有:\n",
"\n",
"$$\n",
"\\beta_i = \\frac{1/(1 + \\exp(-h(\\mathbf{x}_i)))}{\\exp(-h(\\mathbf{x}_i))/(1 + \\exp(-h(\\mathbf{x}_i)))} = \\exp(h(\\mathbf{x}_i)).\n",
"$$\n",
"\n",
"因此,我们需要解决两个问题:\n",
"第一个问题是关于区分来自两个分布的数据;\n",
"第二个问题是关于 :eqref:`eq_weighted-empirical-risk-min`\n",
"中的加权经验风险的最小化问题。\n",
"在这个问题中,我们将对其中的项加权$\\beta_i$。\n",
"\n",
"现在,我们来看一下完整的协变量偏移纠正算法。\n",
"假设我们有一个训练集$\\{(\\mathbf{x}_1, y_1), \\ldots, (\\mathbf{x}_n, y_n)\\}$\n",
"和一个未标记的测试集$\\{\\mathbf{u}_1, \\ldots, \\mathbf{u}_m\\}$。\n",
"对于协变量偏移,我们假设$1 \\leq i \\leq n$的$\\mathbf{x}_i$来自某个源分布,\n",
"$\\mathbf{u}_i$来自目标分布。\n",
"以下是纠正协变量偏移的典型算法:\n",
"\n",
"1. 生成一个二元分类训练集:$\\{(\\mathbf{x}_1, -1), \\ldots, (\\mathbf{x}_n, -1), (\\mathbf{u}_1, 1), \\ldots, (\\mathbf{u}_m, 1)\\}$。\n",
"1. 用对数几率回归训练二元分类器得到函数$h$。\n",
"1. 使用$\\beta_i = \\exp(h(\\mathbf{x}_i))$或更好的$\\beta_i = \\min(\\exp(h(\\mathbf{x}_i)), c)$$c$为常量)对训练数据进行加权。\n",
"1. 使用权重$\\beta_i$进行 :eqref:`eq_weighted-empirical-risk-min` 中$\\{(\\mathbf{x}_1, y_1), \\ldots, (\\mathbf{x}_n, y_n)\\}$的训练。\n",
"\n",
"请注意,上述算法依赖于一个重要的假设:\n",
"需要目标分布(例如,测试分布)中的每个数据样本在训练时出现的概率非零。\n",
"如果我们找到$p(\\mathbf{x}) > 0$但$q(\\mathbf{x}) = 0$的点,\n",
"那么相应的重要性权重会是无穷大。\n",
"\n",
"### 标签偏移纠正\n",
"\n",
"假设我们处理的是$k$个类别的分类任务。\n",
"使用 :numref:`subsec_covariate-shift-correction`中相同符号,\n",
"$q$和$p$中分别是源分布(例如训练时的分布)和目标分布(例如测试时的分布)。\n",
"假设标签的分布随时间变化:$q(y) \\neq p(y)$\n",
"但类别条件分布保持不变:$q(\\mathbf{x} \\mid y)=p(\\mathbf{x} \\mid y)$。\n",
"如果源分布$q(y)$是“错误的”,\n",
"我们可以根据 :eqref:`eq_true-risk`中定义的真实风险中的恒等式进行更正:\n",
"\n",
"$$\n",
"\\begin{aligned}\n",
"\\int\\int l(f(\\mathbf{x}), y) p(\\mathbf{x} \\mid y)p(y) \\;d\\mathbf{x}dy =\n",
"\\int\\int l(f(\\mathbf{x}), y) q(\\mathbf{x} \\mid y)q(y)\\frac{p(y)}{q(y)} \\;d\\mathbf{x}dy.\n",
"\\end{aligned}\n",
"$$\n",
"\n",
"这里,重要性权重将对应于标签似然比率\n",
"\n",
"$$\\beta_i \\stackrel{\\mathrm{def}}{=} \\frac{p(y_i)}{q(y_i)}.$$\n",
"\n",
"标签偏移的一个好处是,如果我们在源分布上有一个相当好的模型,\n",
"那么我们可以得到对这些权重的一致估计,而不需要处理周边的其他维度。\n",
"在深度学习中,输入往往是高维对象(如图像),而标签通常是低维(如类别)。\n",
"\n",
"为了估计目标标签分布,我们首先采用性能相当好的现成的分类器(通常基于训练数据进行训练),\n",
"并使用验证集(也来自训练分布)计算其混淆矩阵。\n",
"混淆矩阵$\\mathbf{C}$是一个$k \\times k$矩阵,\n",
"其中每列对应于标签类别,每行对应于模型的预测类别。\n",
"每个单元格的值$c_{ij}$是验证集中,真实标签为$j$\n",
"而我们的模型预测为$i$的样本数量所占的比例。\n",
"\n",
"现在,我们不能直接计算目标数据上的混淆矩阵,\n",
"因为我们无法看到真实环境下的样本的标签,\n",
"除非我们再搭建一个复杂的实时标注流程。\n",
"然而,我们所能做的是将所有模型在测试时的预测取平均数,\n",
"得到平均模型输出$\\mu(\\hat{\\mathbf{y}}) \\in \\mathbb{R}^k$\n",
"其中第$i$个元素$\\mu(\\hat{y}_i)$是我们模型预测测试集中$i$的总预测分数。\n",
"\n",
"结果表明,如果我们的分类器一开始就相当准确,\n",
"并且目标数据只包含我们以前见过的类别,\n",
"以及如果标签偏移假设成立(这里最强的假设),\n",
"我们就可以通过求解一个简单的线性系统来估计测试集的标签分布\n",
"\n",
"$$\\mathbf{C} p(\\mathbf{y}) = \\mu(\\hat{\\mathbf{y}}),$$\n",
"\n",
"因为作为一个估计,$\\sum_{j=1}^k c_{ij} p(y_j) = \\mu(\\hat{y}_i)$\n",
"对所有$1 \\leq i \\leq k$成立,\n",
"其中$p(y_j)$是$k$维标签分布向量$p(\\mathbf{y})$的第$j^\\mathrm{th}$元素。\n",
"如果我们的分类器一开始就足够精确,那么混淆矩阵$\\mathbf{C}$将是可逆的,\n",
"进而我们可以得到一个解$p(\\mathbf{y}) = \\mathbf{C}^{-1} \\mu(\\hat{\\mathbf{y}})$。\n",
"\n",
"因为我们观测源数据上的标签,所以很容易估计分布$q(y)$。\n",
"那么对于标签为$y_i$的任何训练样本$i$\n",
"我们可以使用我们估计的$p(y_i)/q(y_i)$比率来计算权重$\\beta_i$\n",
"并将其代入 :eqref:`eq_weighted-empirical-risk-min`中的加权经验风险最小化中。\n",
"\n",
"### 概念偏移纠正\n",
"\n",
"概念偏移很难用原则性的方式解决。\n",
"例如,在一个问题突然从“区分猫和狗”偏移为“区分白色和黑色动物”的情况下,\n",
"除了从零开始收集新标签和训练,别无妙方。\n",
"幸运的是,在实践中这种极端的偏移是罕见的。\n",
"相反,通常情况下,概念的变化总是缓慢的。\n",
"比如下面是一些例子:\n",
"\n",
"* 在计算广告中,新产品推出后,旧产品变得不那么受欢迎了。这意味着广告的分布和受欢迎程度是逐渐变化的,任何点击率预测器都需要随之逐渐变化;\n",
"* 由于环境的磨损,交通摄像头的镜头会逐渐退化,影响摄像头的图像质量;\n",
"* 新闻内容逐渐变化(即新新闻的出现)。\n",
"\n",
"在这种情况下,我们可以使用与训练网络相同的方法,使其适应数据的变化。\n",
"换言之,我们使用新数据更新现有的网络权重,而不是从头开始训练。\n",
"\n",
"## 学习问题的分类法\n",
"\n",
"有了如何处理分布变化的知识,我们现在可以考虑机器学习问题形式化的其他方面。\n",
"\n",
"### 批量学习\n",
"\n",
"在*批量学习*batch learning)中,我们可以访问一组训练特征和标签\n",
"$\\{(\\mathbf{x}_1, y_1), \\ldots, (\\mathbf{x}_n, y_n)\\}$\n",
"我们使用这些特性和标签训练$f(\\mathbf{x})$。\n",
"然后,我们部署此模型来对来自同一分布的新数据$(\\mathbf{x}, y)$进行评分。\n",
"例如,我们可以根据猫和狗的大量图片训练猫检测器。\n",
"一旦我们训练了它,我们就把它作为智能猫门计算视觉系统的一部分,来控制只允许猫进入。\n",
"然后这个系统会被安装在客户家中,基本再也不会更新。\n",
"\n",
"### 在线学习\n",
"\n",
"除了“批量”地学习,我们还可以单个“在线”学习数据$(\\mathbf{x}_i, y_i)$。\n",
"更具体地说,我们首先观测到$\\mathbf{x}_i$\n",
"然后我们得出一个估计值$f(\\mathbf{x}_i)$\n",
"只有当我们做到这一点后,我们才观测到$y_i$。\n",
"然后根据我们的决定,我们会得到奖励或损失。\n",
"许多实际问题都属于这一类。\n",
"例如,我们需要预测明天的股票价格,\n",
"这样我们就可以根据这个预测进行交易。\n",
"在一天结束时,我们会评估我们的预测是否盈利。\n",
"换句话说,在*在线学习*online learning)中,我们有以下的循环。\n",
"在这个循环中,给定新的观测结果,我们会不断地改进我们的模型。\n",
"\n",
"$$\n",
"\\mathrm{model} ~ f_t \\longrightarrow\n",
"\\mathrm{data} ~ \\mathbf{x}_t \\longrightarrow\n",
"\\mathrm{estimate} ~ f_t(\\mathbf{x}_t) \\longrightarrow\n",
"\\mathrm{observation} ~ y_t \\longrightarrow\n",
"\\mathrm{loss} ~ l(y_t, f_t(\\mathbf{x}_t)) \\longrightarrow\n",
"\\mathrm{model} ~ f_{t+1}\n",
"$$\n",
"\n",
"### 老虎机\n",
"\n",
"*老虎机*(bandits)是上述问题的一个特例。\n",
"虽然在大多数学习问题中,我们有一个连续参数化的函数$f$(例如,一个深度网络)。\n",
"但在一个*老虎机*问题中,我们只有有限数量的手臂可以拉动。\n",
"也就是说,我们可以采取的行动是有限的。\n",
"对于这个更简单的问题,可以获得更强的最优性理论保证,这并不令人惊讶。\n",
"我们之所以列出它,主要是因为这个问题经常被视为一个单独的学习问题的情景。\n",
"\n",
"### 控制\n",
"\n",
"在很多情况下,环境会记住我们所做的事。\n",
"不一定是以一种对抗的方式,但它会记住,而且它的反应将取决于之前发生的事情。\n",
"例如,咖啡锅炉控制器将根据之前是否加热锅炉来观测到不同的温度。\n",
"在这种情况下,PID(比例—积分—微分)控制器算法是一个流行的选择。\n",
"同样,一个用户在新闻网站上的行为将取决于之前向她展示的内容(例如,大多数新闻她只阅读一次)。\n",
"许多这样的算法形成了一个环境模型,在这个模型中,他们的行为使得他们的决策看起来不那么随机。\n",
"近年来,控制理论(如PID的变体)也被用于自动调整超参数,\n",
"以获得更好的解构和重建质量,提高生成文本的多样性和生成图像的重建质量\n",
" :cite:`Shao.Yao.Sun.ea.2020`。\n",
"\n",
"### 强化学习\n",
"\n",
"*强化学习*reinforcement learning)强调如何基于环境而行动,以取得最大化的预期利益。\n",
"国际象棋、围棋、西洋双陆棋或星际争霸都是强化学习的应用实例。\n",
"再比如,为自动驾驶汽车制造一个控制器,或者以其他方式对自动驾驶汽车的驾驶方式做出反应\n",
"(例如,试图避开某物体,试图造成事故,或者试图与其合作)。\n",
"\n",
"### 考虑到环境\n",
"\n",
"上述不同情况之间的一个关键区别是:\n",
"在静止环境中可能一直有效的相同策略,\n",
"在环境能够改变的情况下可能不会始终有效。\n",
"例如,一个交易者发现的套利机会很可能在他开始利用它时就消失了。\n",
"环境变化的速度和方式在很大程度上决定了我们可以采用的算法类型。\n",
"例如,如果我们知道事情只会缓慢地变化,\n",
"就可以迫使任何估计也只能缓慢地发生改变。\n",
"如果我们知道环境可能会瞬间发生变化,但这种变化非常罕见,\n",
"我们就可以在使用算法时考虑到这一点。\n",
"当一个数据科学家试图解决的问题会随着时间的推移而发生变化时,\n",
"这些类型的知识至关重要。\n",
"\n",
"## 机器学习中的公平、责任和透明度\n",
"\n",
"最后,重要的是,当我们部署机器学习系统时,\n",
"不仅仅是在优化一个预测模型,\n",
"而通常是在提供一个会被用来(部分或完全)进行自动化决策的工具。\n",
"这些技术系统可能会通过其进行的决定而影响到每个人的生活。\n",
"\n",
"从考虑预测到决策的飞跃不仅提出了新的技术问题,\n",
"而且还提出了一系列必须仔细考虑的伦理问题。\n",
"如果我们正在部署一个医疗诊断系统,我们需要知道它可能适用于哪些人群,哪些人群可能无效。\n",
"忽视对一个亚群体的幸福的可预见风险可能会导致我们执行劣质的护理水平。\n",
"此外,一旦我们规划整个决策系统,我们必须退后一步,重新考虑如何评估我们的技术。\n",
"在这个视野变化所导致的结果中,我们会发现精度很少成为合适的衡量标准。\n",
"例如,当我们将预测转化为行动时,我们通常会考虑到各种方式犯错的潜在成本敏感性。\n",
"举个例子:将图像错误地分到某一类别可能被视为种族歧视,而错误地分到另一个类别是无害的,\n",
"那么我们可能需要相应地调整我们的阈值,在设计决策方式时考虑到这些社会价值。\n",
"我们还需要注意预测系统如何导致反馈循环。\n",
"例如,考虑预测性警务系统,它将巡逻人员分配到预测犯罪率较高的地区。\n",
"很容易看出一种令人担忧的模式是如何出现的:\n",
"\n",
" 1. 犯罪率高的社区会得到更多的巡逻;\n",
" 2. 因此,在这些社区中会发现更多的犯罪行为,输入可用于未来迭代的训练数据;\n",
" 3. 面对更多的积极因素,该模型预测这些社区还会有更多的犯罪;\n",
" 4. 下一次迭代中,更新后的模型会更加倾向于针对同一个地区,这会导致更多的犯罪行为被发现等等。\n",
"\n",
"通常,在建模纠正过程中,模型的预测与训练数据耦合的各种机制都没有得到解释,\n",
"研究人员称之为“失控反馈循环”的现象。\n",
"此外,我们首先要注意我们是否解决了正确的问题。\n",
"比如,预测算法现在在信息传播中起着巨大的中介作用,\n",
"个人看到的新闻应该由他们喜欢的Facebook页面决定吗?\n",
"这些只是在机器学习职业生涯中可能遇到的令人感到“压力山大”的道德困境中的一小部分。\n",
"\n",
"## 小结\n",
"\n",
"* 在许多情况下,训练集和测试集并不来自同一个分布。这就是所谓的分布偏移。\n",
"* 真实风险是从真实分布中抽取的所有数据的总体损失的预期。然而,这个数据总体通常是无法获得的。经验风险是训练数据的平均损失,用于近似真实风险。在实践中,我们进行经验风险最小化。\n",
"* 在相应的假设条件下,可以在测试时检测并纠正协变量偏移和标签偏移。在测试时,不考虑这种偏移可能会成为问题。\n",
"* 在某些情况下,环境可能会记住自动操作并以令人惊讶的方式做出响应。在构建模型时,我们必须考虑到这种可能性,并继续监控实时系统,并对我们的模型和环境以意想不到的方式纠缠在一起的可能性持开放态度。\n",
"\n",
"## 练习\n",
"\n",
"1. 当我们改变搜索引擎的行为时会发生什么?用户可能会做什么?广告商呢?\n",
"2. 实现一个协变量偏移检测器。提示:构建一个分类器。\n",
"3. 实现协变量偏移纠正。\n",
"4. 除了分布偏移,还有什么会影响经验风险接近真实风险的程度?\n",
"\n",
"[Discussions](https://discuss.d2l.ai/t/1822)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}
@@ -0,0 +1,48 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "f0f1791b",
"metadata": {
"origin_pos": 0
},
"source": [
"# 多层感知机\n",
":label:`chap_perceptrons`\n",
"\n",
"在本章中,我们将第一次介绍真正的*深度*网络。\n",
"最简单的深度网络称为*多层感知机*。多层感知机由多层神经元组成,\n",
"每一层与它的上一层相连,从中接收输入;\n",
"同时每一层也与它的下一层相连,影响当前层的神经元。\n",
"当我们训练容量较大的模型时,我们面临着*过拟合*的风险。\n",
"因此,本章将从基本的概念介绍开始讲起,包括*过拟合*、*欠拟合*和模型选择。\n",
"为了解决这些问题,本章将介绍*权重衰减*和*暂退法*等正则化技术。\n",
"我们还将讨论数值稳定性和参数初始化相关的问题,\n",
"这些问题是成功训练深度网络的关键。\n",
"在本章的最后,我们将把所介绍的内容应用到一个真实的案例:房价预测。\n",
"关于模型计算性能、可伸缩性和效率相关的问题,我们将放在后面的章节中讨论。\n",
"\n",
":begin_tab:toc\n",
" - [mlp](mlp.ipynb)\n",
" - [mlp-scratch](mlp-scratch.ipynb)\n",
" - [mlp-concise](mlp-concise.ipynb)\n",
" - [underfit-overfit](underfit-overfit.ipynb)\n",
" - [weight-decay](weight-decay.ipynb)\n",
" - [dropout](dropout.ipynb)\n",
" - [backprop](backprop.ipynb)\n",
" - [numerical-stability-and-init](numerical-stability-and-init.ipynb)\n",
" - [environment](environment.ipynb)\n",
" - [kaggle-house-price](kaggle-house-price.ipynb)\n",
":end_tab:\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,976 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d5217b24",
"metadata": {
"origin_pos": 0
},
"source": [
"# 多层感知机的简洁实现\n",
":label:`sec_mlp_concise`\n",
"\n",
"本节将介绍(**通过高级API更简洁地实现多层感知机**)。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f4b9d183",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:04:20.711610Z",
"iopub.status.busy": "2023-08-18T07:04:20.711337Z",
"iopub.status.idle": "2023-08-18T07:04:22.715766Z",
"shell.execute_reply": "2023-08-18T07:04:22.714884Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "d1b8af0c",
"metadata": {
"origin_pos": 5
},
"source": [
"## 模型\n",
"\n",
"与softmax回归的简洁实现( :numref:`sec_softmax_concise`)相比,\n",
"唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。\n",
"第一层是[**隐藏层**],它(**包含256个隐藏单元,并使用了ReLU激活函数**)。\n",
"第二层是输出层。\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a11cfbe9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:04:22.719981Z",
"iopub.status.busy": "2023-08-18T07:04:22.719298Z",
"iopub.status.idle": "2023-08-18T07:04:22.748628Z",
"shell.execute_reply": "2023-08-18T07:04:22.747813Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"net = nn.Sequential(nn.Flatten(),\n",
" nn.Linear(784, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, 10))\n",
"\n",
"def init_weights(m):\n",
" if type(m) == nn.Linear:\n",
" nn.init.normal_(m.weight, std=0.01)\n",
"\n",
"net.apply(init_weights);"
]
},
{
"cell_type": "markdown",
"id": "f5aceed6",
"metadata": {
"origin_pos": 10
},
"source": [
"[**训练过程**]的实现与我们实现softmax回归时完全相同,\n",
"这种模块化设计使我们能够将与模型架构有关的内容独立出来。\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b23e8ab9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:04:22.753701Z",
"iopub.status.busy": "2023-08-18T07:04:22.753406Z",
"iopub.status.idle": "2023-08-18T07:04:22.758051Z",
"shell.execute_reply": "2023-08-18T07:04:22.757284Z"
},
"origin_pos": 12,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"batch_size, lr, num_epochs = 256, 0.1, 10\n",
"loss = nn.CrossEntropyLoss(reduction='none')\n",
"trainer = torch.optim.SGD(net.parameters(), lr=lr)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "78ac9bf1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:04:22.761842Z",
"iopub.status.busy": "2023-08-18T07:04:22.761295Z",
"iopub.status.idle": "2023-08-18T07:05:05.308680Z",
"shell.execute_reply": "2023-08-18T07:05:05.307786Z"
},
"origin_pos": 15,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"238.965625pt\" height=\"180.65625pt\" viewBox=\"0 0 238.965625 180.65625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
" <metadata>\n",
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2023-08-18T07:05:05.270258</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
" <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
" </cc:Agent>\n",
" </dc:creator>\n",
" </cc:Work>\n",
" </rdf:RDF>\n",
" </metadata>\n",
" <defs>\n",
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 180.65625 \n",
"L 238.965625 180.65625 \n",
"L 238.965625 0 \n",
"L 0 0 \n",
"L 0 180.65625 \n",
"z\n",
"\" style=\"fill: none\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 30.103125 143.1 \n",
"L 225.403125 143.1 \n",
"L 225.403125 7.2 \n",
"L 30.103125 7.2 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <path d=\"M 51.803125 143.1 \n",
"L 51.803125 7.2 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_2\">\n",
" <defs>\n",
" <path id=\"m69cc5df15a\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m69cc5df15a\" x=\"51.803125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(48.621875 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
"L 3431 531 \n",
"L 3431 0 \n",
"L 469 0 \n",
"L 469 531 \n",
"Q 828 903 1448 1529 \n",
"Q 2069 2156 2228 2338 \n",
"Q 2531 2678 2651 2914 \n",
"Q 2772 3150 2772 3378 \n",
"Q 2772 3750 2511 3984 \n",
"Q 2250 4219 1831 4219 \n",
"Q 1534 4219 1204 4116 \n",
"Q 875 4013 500 3803 \n",
"L 500 4441 \n",
"Q 881 4594 1212 4672 \n",
"Q 1544 4750 1819 4750 \n",
"Q 2544 4750 2975 4387 \n",
"Q 3406 4025 3406 3419 \n",
"Q 3406 3131 3298 2873 \n",
"Q 3191 2616 2906 2266 \n",
"Q 2828 2175 2409 1742 \n",
"Q 1991 1309 1228 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_3\">\n",
" <path d=\"M 95.203125 143.1 \n",
"L 95.203125 7.2 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use xlink:href=\"#m69cc5df15a\" x=\"95.203125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(92.021875 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
"L 825 1625 \n",
"L 2419 1625 \n",
"L 2419 4116 \n",
"z\n",
"M 2253 4666 \n",
"L 3047 4666 \n",
"L 3047 1625 \n",
"L 3713 1625 \n",
"L 3713 1100 \n",
"L 3047 1100 \n",
"L 3047 0 \n",
"L 2419 0 \n",
"L 2419 1100 \n",
"L 313 1100 \n",
"L 313 1709 \n",
"L 2253 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_5\">\n",
" <path d=\"M 138.603125 143.1 \n",
"L 138.603125 7.2 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_6\">\n",
" <g>\n",
" <use xlink:href=\"#m69cc5df15a\" x=\"138.603125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 6 -->\n",
" <g transform=\"translate(135.421875 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \n",
"Q 1688 2584 1439 2293 \n",
"Q 1191 2003 1191 1497 \n",
"Q 1191 994 1439 701 \n",
"Q 1688 409 2113 409 \n",
"Q 2538 409 2786 701 \n",
"Q 3034 994 3034 1497 \n",
"Q 3034 2003 2786 2293 \n",
"Q 2538 2584 2113 2584 \n",
"z\n",
"M 3366 4563 \n",
"L 3366 3988 \n",
"Q 3128 4100 2886 4159 \n",
"Q 2644 4219 2406 4219 \n",
"Q 1781 4219 1451 3797 \n",
"Q 1122 3375 1075 2522 \n",
"Q 1259 2794 1537 2939 \n",
"Q 1816 3084 2150 3084 \n",
"Q 2853 3084 3261 2657 \n",
"Q 3669 2231 3669 1497 \n",
"Q 3669 778 3244 343 \n",
"Q 2819 -91 2113 -91 \n",
"Q 1303 -91 875 529 \n",
"Q 447 1150 447 2328 \n",
"Q 447 3434 972 4092 \n",
"Q 1497 4750 2381 4750 \n",
"Q 2619 4750 2861 4703 \n",
"Q 3103 4656 3366 4563 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-36\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_7\">\n",
" <path d=\"M 182.003125 143.1 \n",
"L 182.003125 7.2 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use xlink:href=\"#m69cc5df15a\" x=\"182.003125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 8 -->\n",
" <g transform=\"translate(178.821875 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \n",
"Q 1584 2216 1326 1975 \n",
"Q 1069 1734 1069 1313 \n",
"Q 1069 891 1326 650 \n",
"Q 1584 409 2034 409 \n",
"Q 2484 409 2743 651 \n",
"Q 3003 894 3003 1313 \n",
"Q 3003 1734 2745 1975 \n",
"Q 2488 2216 2034 2216 \n",
"z\n",
"M 1403 2484 \n",
"Q 997 2584 770 2862 \n",
"Q 544 3141 544 3541 \n",
"Q 544 4100 942 4425 \n",
"Q 1341 4750 2034 4750 \n",
"Q 2731 4750 3128 4425 \n",
"Q 3525 4100 3525 3541 \n",
"Q 3525 3141 3298 2862 \n",
"Q 3072 2584 2669 2484 \n",
"Q 3125 2378 3379 2068 \n",
"Q 3634 1759 3634 1313 \n",
"Q 3634 634 3220 271 \n",
"Q 2806 -91 2034 -91 \n",
"Q 1263 -91 848 271 \n",
"Q 434 634 434 1313 \n",
"Q 434 1759 690 2068 \n",
"Q 947 2378 1403 2484 \n",
"z\n",
"M 1172 3481 \n",
"Q 1172 3119 1398 2916 \n",
"Q 1625 2713 2034 2713 \n",
"Q 2441 2713 2670 2916 \n",
"Q 2900 3119 2900 3481 \n",
"Q 2900 3844 2670 4047 \n",
"Q 2441 4250 2034 4250 \n",
"Q 1625 4250 1398 4047 \n",
"Q 1172 3844 1172 3481 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-38\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_9\">\n",
" <path d=\"M 225.403125 143.1 \n",
"L 225.403125 7.2 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use xlink:href=\"#m69cc5df15a\" x=\"225.403125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 10 -->\n",
" <g transform=\"translate(219.040625 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
"L 1825 531 \n",
"L 1825 4091 \n",
"L 703 3866 \n",
"L 703 4441 \n",
"L 1819 4666 \n",
"L 2450 4666 \n",
"L 2450 531 \n",
"L 3481 531 \n",
"L 3481 0 \n",
"L 794 0 \n",
"L 794 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
"Q 1547 4250 1301 3770 \n",
"Q 1056 3291 1056 2328 \n",
"Q 1056 1369 1301 889 \n",
"Q 1547 409 2034 409 \n",
"Q 2525 409 2770 889 \n",
"Q 3016 1369 3016 2328 \n",
"Q 3016 3291 2770 3770 \n",
"Q 2525 4250 2034 4250 \n",
"z\n",
"M 2034 4750 \n",
"Q 2819 4750 3233 4129 \n",
"Q 3647 3509 3647 2328 \n",
"Q 3647 1150 3233 529 \n",
"Q 2819 -91 2034 -91 \n",
"Q 1250 -91 836 529 \n",
"Q 422 1150 422 2328 \n",
"Q 422 3509 836 4129 \n",
"Q 1250 4750 2034 4750 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- epoch -->\n",
" <g transform=\"translate(112.525 171.376563)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
"L 3597 1613 \n",
"L 953 1613 \n",
"Q 991 1019 1311 708 \n",
"Q 1631 397 2203 397 \n",
"Q 2534 397 2845 478 \n",
"Q 3156 559 3463 722 \n",
"L 3463 178 \n",
"Q 3153 47 2828 -22 \n",
"Q 2503 -91 2169 -91 \n",
"Q 1331 -91 842 396 \n",
"Q 353 884 353 1716 \n",
"Q 353 2575 817 3079 \n",
"Q 1281 3584 2069 3584 \n",
"Q 2775 3584 3186 3129 \n",
"Q 3597 2675 3597 1894 \n",
"z\n",
"M 3022 2063 \n",
"Q 3016 2534 2758 2815 \n",
"Q 2500 3097 2075 3097 \n",
"Q 1594 3097 1305 2825 \n",
"Q 1016 2553 972 2059 \n",
"L 3022 2063 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
"L 1159 -1331 \n",
"L 581 -1331 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2969 \n",
"Q 1341 3281 1617 3432 \n",
"Q 1894 3584 2278 3584 \n",
"Q 2916 3584 3314 3078 \n",
"Q 3713 2572 3713 1747 \n",
"Q 3713 922 3314 415 \n",
"Q 2916 -91 2278 -91 \n",
"Q 1894 -91 1617 61 \n",
"Q 1341 213 1159 525 \n",
"z\n",
"M 3116 1747 \n",
"Q 3116 2381 2855 2742 \n",
"Q 2594 3103 2138 3103 \n",
"Q 1681 3103 1420 2742 \n",
"Q 1159 2381 1159 1747 \n",
"Q 1159 1113 1420 752 \n",
"Q 1681 391 2138 391 \n",
"Q 2594 391 2855 752 \n",
"Q 3116 1113 3116 1747 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
"Q 1497 3097 1228 2736 \n",
"Q 959 2375 959 1747 \n",
"Q 959 1119 1226 758 \n",
"Q 1494 397 1959 397 \n",
"Q 2419 397 2687 759 \n",
"Q 2956 1122 2956 1747 \n",
"Q 2956 2369 2687 2733 \n",
"Q 2419 3097 1959 3097 \n",
"z\n",
"M 1959 3584 \n",
"Q 2709 3584 3137 3096 \n",
"Q 3566 2609 3566 1747 \n",
"Q 3566 888 3137 398 \n",
"Q 2709 -91 1959 -91 \n",
"Q 1206 -91 779 398 \n",
"Q 353 888 353 1747 \n",
"Q 353 2609 779 3096 \n",
"Q 1206 3584 1959 3584 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-63\" d=\"M 3122 3366 \n",
"L 3122 2828 \n",
"Q 2878 2963 2633 3030 \n",
"Q 2388 3097 2138 3097 \n",
"Q 1578 3097 1268 2742 \n",
"Q 959 2388 959 1747 \n",
"Q 959 1106 1268 751 \n",
"Q 1578 397 2138 397 \n",
"Q 2388 397 2633 464 \n",
"Q 2878 531 3122 666 \n",
"L 3122 134 \n",
"Q 2881 22 2623 -34 \n",
"Q 2366 -91 2075 -91 \n",
"Q 1284 -91 818 406 \n",
"Q 353 903 353 1747 \n",
"Q 353 2603 823 3093 \n",
"Q 1294 3584 2113 3584 \n",
"Q 2378 3584 2631 3529 \n",
"Q 2884 3475 3122 3366 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-68\" d=\"M 3513 2113 \n",
"L 3513 0 \n",
"L 2938 0 \n",
"L 2938 2094 \n",
"Q 2938 2591 2744 2837 \n",
"Q 2550 3084 2163 3084 \n",
"Q 1697 3084 1428 2787 \n",
"Q 1159 2491 1159 1978 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 4863 \n",
"L 1159 4863 \n",
"L 1159 2956 \n",
"Q 1366 3272 1645 3428 \n",
"Q 1925 3584 2291 3584 \n",
"Q 2894 3584 3203 3211 \n",
"Q 3513 2838 3513 2113 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-65\"/>\n",
" <use xlink:href=\"#DejaVuSans-70\" x=\"61.523438\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"125\"/>\n",
" <use xlink:href=\"#DejaVuSans-63\" x=\"186.181641\"/>\n",
" <use xlink:href=\"#DejaVuSans-68\" x=\"241.162109\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_11\">\n",
" <path d=\"M 30.103125 120.45 \n",
"L 225.403125 120.45 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_12\">\n",
" <defs>\n",
" <path id=\"m0ca26dcbeb\" d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m0ca26dcbeb\" x=\"30.103125\" y=\"120.45\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0.4 -->\n",
" <g transform=\"translate(7.2 124.249219)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
"L 1344 794 \n",
"L 1344 0 \n",
"L 684 0 \n",
"L 684 794 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
" <use xlink:href=\"#DejaVuSans-34\" x=\"95.410156\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_13\">\n",
" <path d=\"M 30.103125 75.15 \n",
"L 225.403125 75.15 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_14\">\n",
" <g>\n",
" <use xlink:href=\"#m0ca26dcbeb\" x=\"30.103125\" y=\"75.15\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 0.6 -->\n",
" <g transform=\"translate(7.2 78.949219)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
" <use xlink:href=\"#DejaVuSans-36\" x=\"95.410156\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_15\">\n",
" <path d=\"M 30.103125 29.85 \n",
"L 225.403125 29.85 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_16\">\n",
" <g>\n",
" <use xlink:href=\"#m0ca26dcbeb\" x=\"30.103125\" y=\"29.85\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 0.8 -->\n",
" <g transform=\"translate(7.2 33.649219)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
" <use xlink:href=\"#DejaVuSans-38\" x=\"95.410156\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_17\">\n",
" <path d=\"M 35.272884 -1 \n",
"L 51.803125 75.61021 \n",
"L 73.503125 93.672344 \n",
"L 95.203125 102.778348 \n",
"L 116.903125 107.632437 \n",
"L 138.603125 112.487156 \n",
"L 160.303125 116.4354 \n",
"L 182.003125 119.040329 \n",
"L 203.703125 121.424263 \n",
"L 225.403125 124.527028 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_18\">\n",
" <path d=\"M 30.103125 65.6219 \n",
"L 51.803125 32.179175 \n",
"L 73.503125 25.7881 \n",
"L 95.203125 22.432125 \n",
"L 116.903125 21.005175 \n",
"L 138.603125 18.959125 \n",
"L 160.303125 18.0418 \n",
"L 182.003125 17.124475 \n",
"L 203.703125 16.0939 \n",
"L 225.403125 15.08975 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n",
" </g>\n",
" <g id=\"line2d_19\">\n",
" <path d=\"M 30.103125 41.6733 \n",
"L 51.803125 32.77185 \n",
"L 73.503125 25.11615 \n",
"L 95.203125 23.84775 \n",
"L 116.903125 27.3585 \n",
"L 138.603125 22.5567 \n",
"L 160.303125 23.84775 \n",
"L 182.003125 19.49895 \n",
"L 203.703125 22.7832 \n",
"L 225.403125 21.1977 \n",
"\" clip-path=\"url(#p38f7277f50)\" style=\"fill: none; stroke-dasharray: 9.6,2.4,1.5,2.4; stroke-dashoffset: 0; stroke: #008000; stroke-width: 1.5\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 30.103125 143.1 \n",
"L 30.103125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 225.403125 143.1 \n",
"L 225.403125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 30.103125 143.1 \n",
"L 225.403125 143.1 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 30.103125 7.2 \n",
"L 225.403125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"legend_1\">\n",
" <g id=\"patch_7\">\n",
" <path d=\"M 140.634375 98.667187 \n",
"L 218.403125 98.667187 \n",
"Q 220.403125 98.667187 220.403125 96.667187 \n",
"L 220.403125 53.632812 \n",
"Q 220.403125 51.632812 218.403125 51.632812 \n",
"L 140.634375 51.632812 \n",
"Q 138.634375 51.632812 138.634375 53.632812 \n",
"L 138.634375 96.667187 \n",
"Q 138.634375 98.667187 140.634375 98.667187 \n",
"z\n",
"\" style=\"fill: #ffffff; opacity: 0.8; stroke: #cccccc; stroke-linejoin: miter\"/>\n",
" </g>\n",
" <g id=\"line2d_20\">\n",
" <path d=\"M 142.634375 59.73125 \n",
"L 152.634375 59.73125 \n",
"L 162.634375 59.73125 \n",
"\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- train loss -->\n",
" <g transform=\"translate(170.634375 63.23125)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-74\" d=\"M 1172 4494 \n",
"L 1172 3500 \n",
"L 2356 3500 \n",
"L 2356 3053 \n",
"L 1172 3053 \n",
"L 1172 1153 \n",
"Q 1172 725 1289 603 \n",
"Q 1406 481 1766 481 \n",
"L 2356 481 \n",
"L 2356 0 \n",
"L 1766 0 \n",
"Q 1100 0 847 248 \n",
"Q 594 497 594 1153 \n",
"L 594 3053 \n",
"L 172 3053 \n",
"L 172 3500 \n",
"L 594 3500 \n",
"L 594 4494 \n",
"L 1172 4494 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
"Q 2534 3019 2420 3045 \n",
"Q 2306 3072 2169 3072 \n",
"Q 1681 3072 1420 2755 \n",
"Q 1159 2438 1159 1844 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2956 \n",
"Q 1341 3275 1631 3429 \n",
"Q 1922 3584 2338 3584 \n",
"Q 2397 3584 2469 3576 \n",
"Q 2541 3569 2628 3553 \n",
"L 2631 2963 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-61\" d=\"M 2194 1759 \n",
"Q 1497 1759 1228 1600 \n",
"Q 959 1441 959 1056 \n",
"Q 959 750 1161 570 \n",
"Q 1363 391 1709 391 \n",
"Q 2188 391 2477 730 \n",
"Q 2766 1069 2766 1631 \n",
"L 2766 1759 \n",
"L 2194 1759 \n",
"z\n",
"M 3341 1997 \n",
"L 3341 0 \n",
"L 2766 0 \n",
"L 2766 531 \n",
"Q 2569 213 2275 61 \n",
"Q 1981 -91 1556 -91 \n",
"Q 1019 -91 701 211 \n",
"Q 384 513 384 1019 \n",
"Q 384 1609 779 1909 \n",
"Q 1175 2209 1959 2209 \n",
"L 2766 2209 \n",
"L 2766 2266 \n",
"Q 2766 2663 2505 2880 \n",
"Q 2244 3097 1772 3097 \n",
"Q 1472 3097 1187 3025 \n",
"Q 903 2953 641 2809 \n",
"L 641 3341 \n",
"Q 956 3463 1253 3523 \n",
"Q 1550 3584 1831 3584 \n",
"Q 2591 3584 2966 3190 \n",
"Q 3341 2797 3341 1997 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
"L 1178 3500 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 3500 \n",
"z\n",
"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 4134 \n",
"L 603 4134 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6e\" d=\"M 3513 2113 \n",
"L 3513 0 \n",
"L 2938 0 \n",
"L 2938 2094 \n",
"Q 2938 2591 2744 2837 \n",
"Q 2550 3084 2163 3084 \n",
"Q 1697 3084 1428 2787 \n",
"Q 1159 2491 1159 1978 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2956 \n",
"Q 1366 3272 1645 3428 \n",
"Q 1925 3584 2291 3584 \n",
"Q 2894 3584 3203 3211 \n",
"Q 3513 2838 3513 2113 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-20\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
"L 2834 2853 \n",
"Q 2591 2978 2328 3040 \n",
"Q 2066 3103 1784 3103 \n",
"Q 1356 3103 1142 2972 \n",
"Q 928 2841 928 2578 \n",
"Q 928 2378 1081 2264 \n",
"Q 1234 2150 1697 2047 \n",
"L 1894 2003 \n",
"Q 2506 1872 2764 1633 \n",
"Q 3022 1394 3022 966 \n",
"Q 3022 478 2636 193 \n",
"Q 2250 -91 1575 -91 \n",
"Q 1294 -91 989 -36 \n",
"Q 684 19 347 128 \n",
"L 347 722 \n",
"Q 666 556 975 473 \n",
"Q 1284 391 1588 391 \n",
"Q 1994 391 2212 530 \n",
"Q 2431 669 2431 922 \n",
"Q 2431 1156 2273 1281 \n",
"Q 2116 1406 1581 1522 \n",
"L 1381 1569 \n",
"Q 847 1681 609 1914 \n",
"Q 372 2147 372 2553 \n",
"Q 372 3047 722 3315 \n",
"Q 1072 3584 1716 3584 \n",
"Q 2034 3584 2315 3537 \n",
"Q 2597 3491 2834 3397 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-74\"/>\n",
" <use xlink:href=\"#DejaVuSans-72\" x=\"39.208984\"/>\n",
" <use xlink:href=\"#DejaVuSans-61\" x=\"80.322266\"/>\n",
" <use xlink:href=\"#DejaVuSans-69\" x=\"141.601562\"/>\n",
" <use xlink:href=\"#DejaVuSans-6e\" x=\"169.384766\"/>\n",
" <use xlink:href=\"#DejaVuSans-20\" x=\"232.763672\"/>\n",
" <use xlink:href=\"#DejaVuSans-6c\" x=\"264.550781\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"292.333984\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"353.515625\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"405.615234\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_21\">\n",
" <path d=\"M 142.634375 74.409375 \n",
"L 152.634375 74.409375 \n",
"L 162.634375 74.409375 \n",
"\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- train acc -->\n",
" <g transform=\"translate(170.634375 77.909375)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-74\"/>\n",
" <use xlink:href=\"#DejaVuSans-72\" x=\"39.208984\"/>\n",
" <use xlink:href=\"#DejaVuSans-61\" x=\"80.322266\"/>\n",
" <use xlink:href=\"#DejaVuSans-69\" x=\"141.601562\"/>\n",
" <use xlink:href=\"#DejaVuSans-6e\" x=\"169.384766\"/>\n",
" <use xlink:href=\"#DejaVuSans-20\" x=\"232.763672\"/>\n",
" <use xlink:href=\"#DejaVuSans-61\" x=\"264.550781\"/>\n",
" <use xlink:href=\"#DejaVuSans-63\" x=\"325.830078\"/>\n",
" <use xlink:href=\"#DejaVuSans-63\" x=\"380.810547\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_22\">\n",
" <path d=\"M 142.634375 89.0875 \n",
"L 152.634375 89.0875 \n",
"L 162.634375 89.0875 \n",
"\" style=\"fill: none; stroke-dasharray: 9.6,2.4,1.5,2.4; stroke-dashoffset: 0; stroke: #008000; stroke-width: 1.5\"/>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- test acc -->\n",
" <g transform=\"translate(170.634375 92.5875)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-74\"/>\n",
" <use xlink:href=\"#DejaVuSans-65\" x=\"39.208984\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"100.732422\"/>\n",
" <use xlink:href=\"#DejaVuSans-74\" x=\"152.832031\"/>\n",
" <use xlink:href=\"#DejaVuSans-20\" x=\"192.041016\"/>\n",
" <use xlink:href=\"#DejaVuSans-61\" x=\"223.828125\"/>\n",
" <use xlink:href=\"#DejaVuSans-63\" x=\"285.107422\"/>\n",
" <use xlink:href=\"#DejaVuSans-63\" x=\"340.087891\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p38f7277f50\">\n",
" <rect x=\"30.103125\" y=\"7.2\" width=\"195.3\" height=\"135.9\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<Figure size 252x180 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
"d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)"
]
},
{
"cell_type": "markdown",
"id": "9b636c57",
"metadata": {
"origin_pos": 16
},
"source": [
"## 小结\n",
"\n",
"* 我们可以使用高级API更简洁地实现多层感知机。\n",
"* 对于相同的分类问题,多层感知机的实现与softmax回归的实现相同,只是多层感知机的实现里增加了带有激活函数的隐藏层。\n",
"\n",
"## 练习\n",
"\n",
"1. 尝试添加不同数量的隐藏层(也可以修改学习率),怎么样设置效果最好?\n",
"1. 尝试不同的激活函数,哪个效果最好?\n",
"1. 尝试不同的方案来初始化权重,什么方法效果最好?\n"
]
},
{
"cell_type": "markdown",
"id": "36201fb3",
"metadata": {
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/1802)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff