{ "cells": [ { "cell_type": "markdown", "id": "502c118b", "metadata": { "origin_pos": 0 }, "source": [ "# 小批量随机梯度下降\n", ":label:`sec_minibatch_sgd`\n", "\n", "到目前为止,我们在基于梯度的学习方法中遇到了两个极端情况:\n", " :numref:`sec_gd`中使用完整数据集来计算梯度并更新参数,\n", " :numref:`sec_sgd`中一次处理一个训练样本来取得进展。\n", "二者各有利弊:每当数据非常相似时,梯度下降并不是非常“数据高效”。\n", "而由于CPU和GPU无法充分利用向量化,随机梯度下降并不特别“计算高效”。\n", "这暗示了两者之间可能有折中方案,这便涉及到*小批量随机梯度下降*(minibatch gradient descent)。\n", "\n", "## 向量化和缓存\n", "\n", "使用小批量的决策的核心是计算效率。\n", "当考虑与多个GPU和多台服务器并行处理时,这一点最容易被理解。在这种情况下,我们需要向每个GPU发送至少一张图像。\n", "有了每台服务器8个GPU和16台服务器,我们就能得到大小为128的小批量。\n", "\n", "当涉及到单个GPU甚至CPU时,事情会更微妙一些:\n", "这些设备有多种类型的内存、通常情况下多种类型的计算单元以及在它们之间不同的带宽限制。\n", "例如,一个CPU有少量寄存器(register),L1和L2缓存,以及L3缓存(在不同的处理器内核之间共享)。\n", "随着缓存的大小的增加,它们的延迟也在增加,同时带宽在减少。\n", "可以说,处理器能够执行的操作远比主内存接口所能提供的多得多。\n", "\n", "首先,具有16个内核和AVX-512向量化的2GHz CPU每秒可处理高达$2 \\cdot 10^9 \\cdot 16 \\cdot 32 = 10^{12}$个字节。\n", "同时,GPU的性能很容易超过该数字100倍。\n", "而另一方面,中端服务器处理器的带宽可能不超过100Gb/s,即不到处理器满负荷所需的十分之一。\n", "更糟糕的是,并非所有的内存入口都是相等的:内存接口通常为64位或更宽(例如,在最多384位的GPU上)。\n", "因此读取单个字节会导致由于更宽的存取而产生的代价。\n", "\n", "其次,第一次存取的额外开销很大,而按序存取(sequential access)或突发读取(burst read)相对开销较小。\n", "有关更深入的讨论,请参阅此[维基百科文章](https://en.wikipedia.org/wiki/Cache_hierarchy)。\n", "\n", "减轻这些限制的方法是使用足够快的CPU缓存层次结构来为处理器提供数据。\n", "这是深度学习中批量处理背后的推动力。\n", "举一个简单的例子:矩阵-矩阵乘法。\n", "比如$\\mathbf{A} = \\mathbf{B}\\mathbf{C}$,我们有很多方法来计算$\\mathbf{A}$。例如,我们可以尝试以下方法:\n", "\n", "1. 我们可以计算$\\mathbf{A}_{ij} = \\mathbf{B}_{i,:} \\mathbf{C}_{:,j}^\\top$,也就是说,我们可以通过点积进行逐元素计算。\n", "1. 我们可以计算$\\mathbf{A}_{:,j} = \\mathbf{B} \\mathbf{C}_{:,j}^\\top$,也就是说,我们可以一次计算一列。同样,我们可以一次计算$\\mathbf{A}$一行$\\mathbf{A}_{i,:}$。\n", "1. 我们可以简单地计算$\\mathbf{A} = \\mathbf{B} \\mathbf{C}$。\n", "1. 我们可以将$\\mathbf{B}$和$\\mathbf{C}$分成较小的区块矩阵,然后一次计算$\\mathbf{A}$的一个区块。\n", "\n", "如果我们使用第一个选择,每次我们计算一个元素$\\mathbf{A}_{ij}$时,都需要将一行和一列向量复制到CPU中。\n", "更糟糕的是,由于矩阵元素是按顺序对齐的,因此当从内存中读取它们时,我们需要访问两个向量中许多不相交的位置。\n", "第二种选择相对更有利:我们能够在遍历$\\mathbf{B}$的同时,将列向量$\\mathbf{C}_{:,j}$保留在CPU缓存中。\n", "它将内存带宽需求减半,相应地提高了访问速度。\n", "第三种选择表面上是最可取的,然而大多数矩阵可能不能完全放入缓存中。\n", "第四种选择提供了一个实践上很有用的方案:我们可以将矩阵的区块移到缓存中然后在本地将它们相乘。\n", "让我们来看看这些操作在实践中的效率如何。\n", "\n", "除了计算效率之外,Python和深度学习框架本身带来的额外开销也是相当大的。\n", "回想一下,每次我们执行代码时,Python解释器都会向深度学习框架发送一个命令,要求将其插入到计算图中并在调度过程中处理它。\n", "这样的额外开销可能是非常不利的。\n", "总而言之,我们最好用向量化(和矩阵)。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "74d0b8bd", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:50.345234Z", "iopub.status.busy": "2023-08-18T07:00:50.344241Z", "iopub.status.idle": "2023-08-18T07:00:52.828095Z", "shell.execute_reply": "2023-08-18T07:00:52.826894Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "timer = d2l.Timer()\n", "A = torch.zeros(256, 256)\n", "B = torch.randn(256, 256)\n", "C = torch.randn(256, 256)" ] }, { "cell_type": "markdown", "id": "d11f4087", "metadata": { "origin_pos": 5 }, "source": [ "按元素分配只需遍历分别为$\\mathbf{B}$和$\\mathbf{C}$的所有行和列,即可将该值分配给$\\mathbf{A}$。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "cf7f7a46", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:52.834373Z", "iopub.status.busy": "2023-08-18T07:00:52.833156Z", "iopub.status.idle": "2023-08-18T07:00:54.504095Z", "shell.execute_reply": "2023-08-18T07:00:54.502961Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "1.6609790325164795" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 逐元素计算A=BC\n", "timer.start()\n", "for i in range(256):\n", " for j in range(256):\n", " A[i, j] = torch.dot(B[i, :], C[:, j])\n", "timer.stop()" ] }, { "cell_type": "markdown", "id": "a28f6b4d", "metadata": { "origin_pos": 10 }, "source": [ "更快的策略是执行按列分配。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "312257cf", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:54.509396Z", "iopub.status.busy": "2023-08-18T07:00:54.508607Z", "iopub.status.idle": "2023-08-18T07:00:54.539525Z", "shell.execute_reply": "2023-08-18T07:00:54.537980Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "0.022588253021240234" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 逐列计算A=BC\n", "timer.start()\n", "for j in range(256):\n", " A[:, j] = torch.mv(B, C[:, j])\n", "timer.stop()" ] }, { "cell_type": "markdown", "id": "afb99637", "metadata": { "origin_pos": 15 }, "source": [ "最有效的方法是在一个区块中执行整个操作。让我们看看它们各自的操作速度是多少。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "5735b43b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:54.544547Z", "iopub.status.busy": "2023-08-18T07:00:54.543948Z", "iopub.status.idle": "2023-08-18T07:00:54.561974Z", "shell.execute_reply": "2023-08-18T07:00:54.560964Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "performance in Gigaflops: element 1.204, column 88.542, full 195.625\n" ] } ], "source": [ "# 一次性计算A=BC\n", "timer.start()\n", "A = torch.mm(B, C)\n", "timer.stop()\n", "\n", "# 乘法和加法作为单独的操作(在实践中融合)\n", "gigaflops = [2/i for i in timer.times]\n", "print(f'performance in Gigaflops: element {gigaflops[0]:.3f}, '\n", " f'column {gigaflops[1]:.3f}, full {gigaflops[2]:.3f}')" ] }, { "cell_type": "markdown", "id": "d94851f8", "metadata": { "origin_pos": 20 }, "source": [ "## 小批量\n", "\n", ":label:`sec_minibatches`\n", "\n", "之前我们会理所当然地读取数据的*小批量*,而不是观测单个数据来更新参数,现在简要解释一下原因。\n", "处理单个观测值需要我们执行许多单一矩阵-矢量(甚至矢量-矢量)乘法,这耗费相当大,而且对应深度学习框架也要巨大的开销。\n", "这既适用于计算梯度以更新参数时,也适用于用神经网络预测。\n", "也就是说,每当我们执行$\\mathbf{w} \\leftarrow \\mathbf{w} - \\eta_t \\mathbf{g}_t$时,消耗巨大。其中\n", "\n", "$$\\mathbf{g}_t = \\partial_{\\mathbf{w}} f(\\mathbf{x}_{t}, \\mathbf{w}).$$\n", "\n", "我们可以通过将其应用于一个小批量观测值来提高此操作的*计算*效率。\n", "也就是说,我们将梯度$\\mathbf{g}_t$替换为一个小批量而不是单个观测值\n", "\n", "$$\\mathbf{g}_t = \\partial_{\\mathbf{w}} \\frac{1}{|\\mathcal{B}_t|} \\sum_{i \\in \\mathcal{B}_t} f(\\mathbf{x}_{i}, \\mathbf{w}).$$\n", "\n", "让我们看看这对$\\mathbf{g}_t$的统计属性有什么影响:由于$\\mathbf{x}_t$和小批量$\\mathcal{B}_t$的所有元素都是从训练集中随机抽出的,因此梯度的期望保持不变。\n", "另一方面,方差显著降低。\n", "由于小批量梯度由正在被平均计算的$b := |\\mathcal{B}_t|$个独立梯度组成,其标准差降低了$b^{-\\frac{1}{2}}$。\n", "这本身就是一件好事,因为这意味着更新与完整的梯度更接近了。\n", "\n", "直观来说,这表明选择大型的小批量$\\mathcal{B}_t$将是普遍可行的。\n", "然而,经过一段时间后,与计算代价的线性增长相比,标准差的额外减少是微乎其微的。\n", "在实践中我们选择一个足够大的小批量,它可以提供良好的计算效率同时仍适合GPU的内存。\n", "下面,我们来看看这些高效的代码。\n", "在里面我们执行相同的矩阵-矩阵乘法,但是这次我们将其一次性分为64列的“小批量”。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "1abe0207", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:54.568121Z", "iopub.status.busy": "2023-08-18T07:00:54.567048Z", "iopub.status.idle": "2023-08-18T07:00:54.575140Z", "shell.execute_reply": "2023-08-18T07:00:54.573950Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "performance in Gigaflops: block 2056.535\n" ] } ], "source": [ "timer.start()\n", "for j in range(0, 256, 64):\n", " A[:, j:j+64] = torch.mm(B, C[:, j:j+64])\n", "timer.stop()\n", "print(f'performance in Gigaflops: block {2 / timer.times[3]:.3f}')" ] }, { "cell_type": "markdown", "id": "350a97c4", "metadata": { "origin_pos": 25 }, "source": [ "显而易见,小批量上的计算基本上与完整矩阵一样有效。\n", "需要注意的是,在 :numref:`sec_batch_norm`中,我们使用了一种在很大程度上取决于小批量中的方差的正则化。\n", "随着后者增加,方差会减少,随之而来的是批量规范化带来的噪声注入的好处。\n", "关于实例,请参阅 :cite:`Ioffe.2017`,了解有关如何重新缩放并计算适当项目。\n", "\n", "## 读取数据集\n", "\n", "让我们来看看如何从数据中有效地生成小批量。\n", "下面我们使用NASA开发的测试机翼的数据集[不同飞行器产生的噪声](https://archive.ics.uci.edu/ml/datasets/Airfoil+Self-Noise)来比较这些优化算法。\n", "为方便起见,我们只使用前$1,500$样本。\n", "数据已作预处理:我们移除了均值并将方差重新缩放到每个坐标为$1$。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "2965fbc9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:54.580353Z", "iopub.status.busy": "2023-08-18T07:00:54.579977Z", "iopub.status.idle": "2023-08-18T07:00:54.587997Z", "shell.execute_reply": "2023-08-18T07:00:54.586803Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "d2l.DATA_HUB['airfoil'] = (d2l.DATA_URL + 'airfoil_self_noise.dat',\n", " '76e5be1548fd8222e5074cf0faae75edff8cf93f')\n", "\n", "#@save\n", "def get_data_ch11(batch_size=10, n=1500):\n", " data = np.genfromtxt(d2l.download('airfoil'),\n", " dtype=np.float32, delimiter='\\t')\n", " data = torch.from_numpy((data - data.mean(axis=0)) / data.std(axis=0))\n", " data_iter = d2l.load_array((data[:n, :-1], data[:n, -1]),\n", " batch_size, is_train=True)\n", " return data_iter, data.shape[1]-1" ] }, { "cell_type": "markdown", "id": "cfec2ad4", "metadata": { "origin_pos": 30 }, "source": [ "## 从零开始实现\n", "\n", " :numref:`sec_linear_scratch`一节中已经实现过小批量随机梯度下降算法。\n", "我们在这里将它的输入参数变得更加通用,主要是为了方便本章后面介绍的其他优化算法也可以使用同样的输入。\n", "具体来说,我们添加了一个状态输入`states`并将超参数放在字典`hyperparams`中。\n", "此外,我们将在训练函数里对各个小批量样本的损失求平均,因此优化算法中的梯度不需要除以批量大小。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "f1624e05", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:54.595789Z", "iopub.status.busy": "2023-08-18T07:00:54.595436Z", "iopub.status.idle": "2023-08-18T07:00:54.604397Z", "shell.execute_reply": "2023-08-18T07:00:54.603217Z" }, "origin_pos": 32, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def sgd(params, states, hyperparams):\n", " for p in params:\n", " p.data.sub_(hyperparams['lr'] * p.grad)\n", " p.grad.data.zero_()" ] }, { "cell_type": "markdown", "id": "2b4c7291", "metadata": { "origin_pos": 35 }, "source": [ "下面实现一个通用的训练函数,以方便本章后面介绍的其他优化算法使用。\n", "它初始化了一个线性回归模型,然后可以使用小批量随机梯度下降以及后续小节介绍的其他算法来训练模型。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "72ee9f70", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:54.610003Z", "iopub.status.busy": "2023-08-18T07:00:54.609641Z", "iopub.status.idle": "2023-08-18T07:00:54.631566Z", "shell.execute_reply": "2023-08-18T07:00:54.630601Z" }, "origin_pos": 37, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def train_ch11(trainer_fn, states, hyperparams, data_iter,\n", " feature_dim, num_epochs=2):\n", " # 初始化模型\n", " w = torch.normal(mean=0.0, std=0.01, size=(feature_dim, 1),\n", " requires_grad=True)\n", " b = torch.zeros((1), requires_grad=True)\n", " net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss\n", " # 训练模型\n", " animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n", " xlim=[0, num_epochs], ylim=[0.22, 0.35])\n", " n, timer = 0, d2l.Timer()\n", " for _ in range(num_epochs):\n", " for X, y in data_iter:\n", " l = loss(net(X), y).mean()\n", " l.backward()\n", " trainer_fn([w, b], states, hyperparams)\n", " n += X.shape[0]\n", " if n % 200 == 0:\n", " timer.stop()\n", " animator.add(n/X.shape[0]/len(data_iter),\n", " (d2l.evaluate_loss(net, data_iter, loss),))\n", " timer.start()\n", " print(f'loss: {animator.Y[0][-1]:.3f}, {timer.avg():.3f} sec/epoch')\n", " return timer.cumsum(), animator.Y[0]" ] }, { "cell_type": "markdown", "id": "6082ed33", "metadata": { "origin_pos": 40 }, "source": [ "让我们来看看批量梯度下降的优化是如何进行的。\n", "这可以通过将小批量设置为1500(即样本总数)来实现。\n", "因此,模型参数每个迭代轮数只迭代一次。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "9d3c9c35", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:54.637080Z", "iopub.status.busy": "2023-08-18T07:00:54.636377Z", "iopub.status.idle": "2023-08-18T07:00:56.139533Z", "shell.execute_reply": "2023-08-18T07:00:56.138355Z" }, "origin_pos": 41, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.243, 0.055 sec/epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:56.087079\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def train_sgd(lr, batch_size, num_epochs=2):\n", " data_iter, feature_dim = get_data_ch11(batch_size)\n", " return train_ch11(\n", " sgd, None, {'lr': lr}, data_iter, feature_dim, num_epochs)\n", "\n", "gd_res = train_sgd(1, 1500, 10)" ] }, { "cell_type": "markdown", "id": "305d75ad", "metadata": { "origin_pos": 42 }, "source": [ "当批量大小为1时,优化使用的是随机梯度下降。\n", "为了简化实现,我们选择了很小的学习率。\n", "在随机梯度下降的实验中,每当一个样本被处理,模型参数都会更新。\n", "在这个例子中,这相当于每个迭代轮数有1500次更新。\n", "可以看到,目标函数值的下降在1个迭代轮数后就变得较为平缓。\n", "尽管两个例子在一个迭代轮数内都处理了1500个样本,但实验中随机梯度下降的一个迭代轮数耗时更多。\n", "这是因为随机梯度下降更频繁地更新了参数,而且一次处理单个观测值效率较低。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "978aaaa7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:56.144769Z", "iopub.status.busy": "2023-08-18T07:00:56.144058Z", "iopub.status.idle": "2023-08-18T07:01:05.324085Z", "shell.execute_reply": "2023-08-18T07:01:05.323113Z" }, "origin_pos": 43, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.246, 0.102 sec/epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:01:05.288019\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sgd_res = train_sgd(0.005, 1)" ] }, { "cell_type": "markdown", "id": "41566c84", "metadata": { "origin_pos": 44 }, "source": [ "最后,当批量大小等于100时,我们使用小批量随机梯度下降进行优化。\n", "每个迭代轮数所需的时间比随机梯度下降和批量梯度下降所需的时间短。\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "40dc467c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:05.328798Z", "iopub.status.busy": "2023-08-18T07:01:05.328160Z", "iopub.status.idle": "2023-08-18T07:01:07.629782Z", "shell.execute_reply": "2023-08-18T07:01:07.628750Z" }, "origin_pos": 45, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.243, 0.003 sec/epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:01:07.594528\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "mini1_res = train_sgd(.4, 100)" ] }, { "cell_type": "markdown", "id": "1c4fe0ad", "metadata": { "origin_pos": 46 }, "source": [ "将批量大小减少到10,每个迭代轮数的时间都会增加,因为每批工作负载的执行效率变得更低。\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "063d6493", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:07.634814Z", "iopub.status.busy": "2023-08-18T07:01:07.633816Z", "iopub.status.idle": "2023-08-18T07:01:11.029070Z", "shell.execute_reply": "2023-08-18T07:01:11.027884Z" }, "origin_pos": 47, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.243, 0.013 sec/epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:01:10.978618\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "mini2_res = train_sgd(.05, 10)" ] }, { "cell_type": "markdown", "id": "f496be01", "metadata": { "origin_pos": 48 }, "source": [ "现在我们可以比较前四个实验的时间与损失。\n", "可以看出,尽管在处理的样本数方面,随机梯度下降的收敛速度快于梯度下降,但与梯度下降相比,它需要更多的时间来达到同样的损失,因为逐个样本来计算梯度并不那么有效。\n", "小批量随机梯度下降能够平衡收敛速度和计算效率。\n", "大小为10的小批量比随机梯度下降更有效;\n", "大小为100的小批量在运行时间上甚至优于梯度下降。\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "2c707764", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:11.034447Z", "iopub.status.busy": "2023-08-18T07:01:11.033671Z", "iopub.status.idle": "2023-08-18T07:01:11.811055Z", "shell.execute_reply": "2023-08-18T07:01:11.809850Z" }, "origin_pos": 49, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:01:11.695943\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.set_figsize([6, 3])\n", "d2l.plot(*list(map(list, zip(gd_res, sgd_res, mini1_res, mini2_res))),\n", " 'time (sec)', 'loss', xlim=[1e-2, 10],\n", " legend=['gd', 'sgd', 'batch size=100', 'batch size=10'])\n", "d2l.plt.gca().set_xscale('log')" ] }, { "cell_type": "markdown", "id": "566c339b", "metadata": { "origin_pos": 50 }, "source": [ "## 简洁实现\n", "\n", "下面用深度学习框架自带算法实现一个通用的训练函数,我们将在本章中其它小节使用它。\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "4fde0834", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:11.815684Z", "iopub.status.busy": "2023-08-18T07:01:11.814821Z", "iopub.status.idle": "2023-08-18T07:01:11.824095Z", "shell.execute_reply": "2023-08-18T07:01:11.822924Z" }, "origin_pos": 52, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def train_concise_ch11(trainer_fn, hyperparams, data_iter, num_epochs=4):\n", " # 初始化模型\n", " net = nn.Sequential(nn.Linear(5, 1))\n", " def init_weights(m):\n", " if type(m) == nn.Linear:\n", " torch.nn.init.normal_(m.weight, std=0.01)\n", " net.apply(init_weights)\n", "\n", " optimizer = trainer_fn(net.parameters(), **hyperparams)\n", " loss = nn.MSELoss(reduction='none')\n", " animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n", " xlim=[0, num_epochs], ylim=[0.22, 0.35])\n", " n, timer = 0, d2l.Timer()\n", " for _ in range(num_epochs):\n", " for X, y in data_iter:\n", " optimizer.zero_grad()\n", " out = net(X)\n", " y = y.reshape(out.shape)\n", " l = loss(out, y)\n", " l.mean().backward()\n", " optimizer.step()\n", " n += X.shape[0]\n", " if n % 200 == 0:\n", " timer.stop()\n", " # MSELoss计算平方误差时不带系数1/2\n", " animator.add(n/X.shape[0]/len(data_iter),\n", " (d2l.evaluate_loss(net, data_iter, loss) / 2,))\n", " timer.start()\n", " print(f'loss: {animator.Y[0][-1]:.3f}, {timer.avg():.3f} sec/epoch')" ] }, { "cell_type": "markdown", "id": "973d2140", "metadata": { "origin_pos": 55 }, "source": [ "下面使用这个训练函数,复现之前的实验。\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "c443d323", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:11.828129Z", "iopub.status.busy": "2023-08-18T07:01:11.827413Z", "iopub.status.idle": "2023-08-18T07:01:18.659993Z", "shell.execute_reply": "2023-08-18T07:01:18.658861Z" }, "origin_pos": 57, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.243, 0.015 sec/epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:01:18.610093\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data_iter, _ = get_data_ch11(10)\n", "trainer = torch.optim.SGD\n", "train_concise_ch11(trainer, {'lr': 0.01}, data_iter)" ] }, { "cell_type": "markdown", "id": "fd1152ec", "metadata": { "origin_pos": 60 }, "source": [ "## 小结\n", "\n", "* 由于减少了深度学习框架的额外开销,使用更好的内存定位以及CPU和GPU上的缓存,向量化使代码更加高效。\n", "* 随机梯度下降的“统计效率”与大批量一次处理数据的“计算效率”之间存在权衡。小批量随机梯度下降提供了两全其美的答案:计算和统计效率。\n", "* 在小批量随机梯度下降中,我们处理通过训练数据的随机排列获得的批量数据(即每个观测值只处理一次,但按随机顺序)。\n", "* 在训练期间降低学习率有助于训练。\n", "* 一般来说,小批量随机梯度下降比随机梯度下降和梯度下降的速度快,收敛风险较小。\n", "\n", "## 练习\n", "\n", "1. 修改批量大小和学习率,并观察目标函数值的下降率以及每个迭代轮数消耗的时间。\n", "1. 将小批量随机梯度下降与实际从训练集中*取样替换*的变体进行比较。会看出什么?\n", "1. 一个邪恶的精灵在没通知你的情况下复制了你的数据集(即每个观测发生两次,数据集增加到原始大小的两倍,但没有人告诉你)。随机梯度下降、小批量随机梯度下降和梯度下降的表现将如何变化?\n" ] }, { "cell_type": "markdown", "id": "0e60c65b", "metadata": { "origin_pos": 62, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/4325)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }