{ "cells": [ { "cell_type": "markdown", "id": "21eb8780", "metadata": { "origin_pos": 0 }, "source": [ "# 学习率调度器\n", ":label:`sec_scheduler`\n", "\n", "到目前为止,我们主要关注如何更新权重向量的优化算法,而不是它们的更新速率。\n", "然而,调整学习率通常与实际算法同样重要,有如下几方面需要考虑:\n", "\n", "* 首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果。我们之前看到问题的条件数很重要(有关详细信息,请参见 :numref:`sec_momentum`)。直观地说,这是最不敏感与最敏感方向的变化量的比率。\n", "* 其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。 :numref:`sec_minibatch_sgd`比较详细地讨论了这一点,在 :numref:`sec_sgd`中我们则分析了性能保证。简而言之,我们希望速率衰减,但要比$\\mathcal{O}(t^{-\\frac{1}{2}})$慢,这样能成为解决凸问题的不错选择。\n", "* 另一个同样重要的方面是初始化。这既涉及参数最初的设置方式(详情请参阅 :numref:`sec_numerical_stability`),又关系到它们最初的演变方式。这被戏称为*预热*(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的。\n", "* 最后,还有许多优化变体可以执行周期性学习率调整。这超出了本章的范围,我们建议读者阅读 :cite:`Izmailov.Podoprikhin.Garipov.ea.2018`来了解个中细节。例如,如何通过对整个路径参数求平均值来获得更好的解。\n", "\n", "鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。\n", "在本章中,我们将梳理不同的调度策略对准确性的影响,并展示如何通过*学习率调度器*(learning rate scheduler)来有效管理。\n", "\n", "## 一个简单的问题\n", "\n", "我们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。\n", "为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用`relu`而不是`sigmoid`,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。\n", "此外,我们混合网络以提高性能。\n", "由于大多数代码都是标准的,我们只介绍基础知识,而不做进一步的详细讨论。如果需要,请参阅 :numref:`chap_cnn`进行复习。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "fa35f5a3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:40.885445Z", "iopub.status.busy": "2023-08-18T07:22:40.884804Z", "iopub.status.idle": "2023-08-18T07:22:43.950999Z", "shell.execute_reply": "2023-08-18T07:22:43.950124Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import math\n", "import torch\n", "from torch import nn\n", "from torch.optim import lr_scheduler\n", "from d2l import torch as d2l\n", "\n", "\n", "def net_fn():\n", " model = nn.Sequential(\n", " nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2),\n", " nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2),\n", " nn.Flatten(),\n", " nn.Linear(16 * 5 * 5, 120), nn.ReLU(),\n", " nn.Linear(120, 84), nn.ReLU(),\n", " nn.Linear(84, 10))\n", "\n", " return model\n", "\n", "loss = nn.CrossEntropyLoss()\n", "device = d2l.try_gpu()\n", "\n", "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)\n", "\n", "# 代码几乎与d2l.train_ch6定义在卷积神经网络一章LeNet一节中的相同\n", "def train(net, train_iter, test_iter, num_epochs, loss, trainer, device,\n", " scheduler=None):\n", " net.to(device)\n", " animator = d2l.Animator(xlabel='epoch', xlim=[0, num_epochs],\n", " legend=['train loss', 'train acc', 'test acc'])\n", "\n", " for epoch in range(num_epochs):\n", " metric = d2l.Accumulator(3) # train_loss,train_acc,num_examples\n", " for i, (X, y) in enumerate(train_iter):\n", " net.train()\n", " trainer.zero_grad()\n", " X, y = X.to(device), y.to(device)\n", " y_hat = net(X)\n", " l = loss(y_hat, y)\n", " l.backward()\n", " trainer.step()\n", " with torch.no_grad():\n", " metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])\n", " train_loss = metric[0] / metric[2]\n", " train_acc = metric[1] / metric[2]\n", " if (i + 1) % 50 == 0:\n", " animator.add(epoch + i / len(train_iter),\n", " (train_loss, train_acc, None))\n", "\n", " test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)\n", " animator.add(epoch+1, (None, None, test_acc))\n", "\n", " if scheduler:\n", " if scheduler.__module__ == lr_scheduler.__name__:\n", " # UsingPyTorchIn-Builtscheduler\n", " scheduler.step()\n", " else:\n", " # Usingcustomdefinedscheduler\n", " for param_group in trainer.param_groups:\n", " param_group['lr'] = scheduler(epoch)\n", "\n", " print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, '\n", " f'test acc {test_acc:.3f}')" ] }, { "cell_type": "markdown", "id": "32c9a5d9", "metadata": { "origin_pos": 5 }, "source": [ "让我们来看看如果使用默认设置,调用此算法会发生什么。\n", "例如设学习率为$0.3$并训练$30$次迭代。\n", "留意在超过了某点、测试准确度方面的进展停滞时,训练准确度将如何继续提高。\n", "两条曲线之间的间隙表示过拟合。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "c830419f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:43.955399Z", "iopub.status.busy": "2023-08-18T07:22:43.954810Z", "iopub.status.idle": "2023-08-18T07:24:50.626624Z", "shell.execute_reply": "2023-08-18T07:24:50.625712Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.128, train acc 0.951, test acc 0.885\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:24:50.584572\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "lr, num_epochs = 0.3, 30\n", "net = net_fn()\n", "trainer = torch.optim.SGD(net.parameters(), lr=lr)\n", "train(net, train_iter, test_iter, num_epochs, loss, trainer, device)" ] }, { "cell_type": "markdown", "id": "52310b75", "metadata": { "origin_pos": 10 }, "source": [ "## 学习率调度器\n", "\n", "我们可以在每个迭代轮数(甚至在每个小批量)之后向下调整学习率。\n", "例如,以动态的方式来响应优化的进展情况。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "f849cce9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:24:50.630736Z", "iopub.status.busy": "2023-08-18T07:24:50.630110Z", "iopub.status.idle": "2023-08-18T07:24:50.636043Z", "shell.execute_reply": "2023-08-18T07:24:50.635027Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "learning rate is now 0.10\n" ] } ], "source": [ "lr = 0.1\n", "trainer.param_groups[0][\"lr\"] = lr\n", "print(f'learning rate is now {trainer.param_groups[0][\"lr\"]:.2f}')" ] }, { "cell_type": "markdown", "id": "4f8fbbe5", "metadata": { "origin_pos": 15 }, "source": [ "更通常而言,我们应该定义一个调度器。\n", "当调用更新次数时,它将返回学习率的适当值。\n", "让我们定义一个简单的方法,将学习率设置为$\\eta = \\eta_0 (t + 1)^{-\\frac{1}{2}}$。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "082e8fdc", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:24:50.640033Z", "iopub.status.busy": "2023-08-18T07:24:50.639295Z", "iopub.status.idle": "2023-08-18T07:24:50.644640Z", "shell.execute_reply": "2023-08-18T07:24:50.643726Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class SquareRootScheduler:\n", " def __init__(self, lr=0.1):\n", " self.lr = lr\n", "\n", " def __call__(self, num_update):\n", " return self.lr * pow(num_update + 1.0, -0.5)" ] }, { "cell_type": "markdown", "id": "555add43", "metadata": { "origin_pos": 17 }, "source": [ "让我们在一系列值上绘制它的行为。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "80c65d0e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:24:50.648321Z", "iopub.status.busy": "2023-08-18T07:24:50.647670Z", "iopub.status.idle": "2023-08-18T07:24:50.813569Z", "shell.execute_reply": "2023-08-18T07:24:50.812039Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:24:50.770558\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "scheduler = SquareRootScheduler(lr=0.1)\n", "d2l.plot(torch.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" ] }, { "cell_type": "markdown", "id": "8f51d518", "metadata": { "origin_pos": 19 }, "source": [ "现在让我们来看看这对在Fashion-MNIST数据集上的训练有何影响。\n", "我们只是提供调度器作为训练算法的额外参数。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "97efd99f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:24:50.818216Z", "iopub.status.busy": "2023-08-18T07:24:50.817433Z", "iopub.status.idle": "2023-08-18T07:26:48.861761Z", "shell.execute_reply": "2023-08-18T07:26:48.860789Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.270, train acc 0.901, test acc 0.876\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:26:48.818017\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "net = net_fn()\n", "trainer = torch.optim.SGD(net.parameters(), lr)\n", "train(net, train_iter, test_iter, num_epochs, loss, trainer, device,\n", " scheduler)" ] }, { "cell_type": "markdown", "id": "cf2f8f37", "metadata": { "origin_pos": 24 }, "source": [ "这比以前好一些:曲线比以前更加平滑,并且过拟合更小了。\n", "遗憾的是,关于为什么在理论上某些策略会导致较轻的过拟合,有一些观点认为,较小的步长将导致参数更接近零,因此更简单。\n", "但是,这并不能完全解释这种现象,因为我们并没有真正地提前停止,而只是轻柔地降低了学习率。\n", "\n", "## 策略\n", "\n", "虽然我们不可能涵盖所有类型的学习率调度器,但我们会尝试在下面简要概述常用的策略:多项式衰减和分段常数表。\n", "此外,余弦学习率调度在实践中的一些问题上运行效果很好。\n", "在某些问题上,最好在使用较高的学习率之前预热优化器。\n", "\n", "### 单因子调度器\n", "\n", "多项式衰减的一种替代方案是乘法衰减,即$\\eta_{t+1} \\leftarrow \\eta_t \\cdot \\alpha$其中$\\alpha \\in (0, 1)$。\n", "为了防止学习率衰减到一个合理的下界之下,\n", "更新方程经常修改为$\\eta_{t+1} \\leftarrow \\mathop{\\mathrm{max}}(\\eta_{\\mathrm{min}}, \\eta_t \\cdot \\alpha)$。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "1cb0ffae", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:26:48.867074Z", "iopub.status.busy": "2023-08-18T07:26:48.866476Z", "iopub.status.idle": "2023-08-18T07:26:49.021313Z", "shell.execute_reply": "2023-08-18T07:26:49.020038Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:26:48.990371\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "class FactorScheduler:\n", " def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):\n", " self.factor = factor\n", " self.stop_factor_lr = stop_factor_lr\n", " self.base_lr = base_lr\n", "\n", " def __call__(self, num_update):\n", " self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)\n", " return self.base_lr\n", "\n", "scheduler = FactorScheduler(factor=0.9, stop_factor_lr=1e-2, base_lr=2.0)\n", "d2l.plot(torch.arange(50), [scheduler(t) for t in range(50)])" ] }, { "cell_type": "markdown", "id": "191f4c5d", "metadata": { "origin_pos": 26 }, "source": [ "接下来,我们将使用内置的调度器,但在这里仅解释它们的功能。\n", "\n", "### 多因子调度器\n", "\n", "训练深度网络的常见策略之一是保持学习率为一组分段的常量,并且不时地按给定的参数对学习率做乘法衰减。\n", "具体地说,给定一组降低学习率的时间点,例如$s = \\{5, 10, 20\\}$,\n", "每当$t \\in s$时,降低$\\eta_{t+1} \\leftarrow \\eta_t \\cdot \\alpha$。\n", "假设每步中的值减半,我们可以按如下方式实现这一点。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "5c235d55", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:26:49.025730Z", "iopub.status.busy": "2023-08-18T07:26:49.024933Z", "iopub.status.idle": "2023-08-18T07:26:49.204184Z", "shell.execute_reply": "2023-08-18T07:26:49.203112Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:26:49.173634\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "net = net_fn()\n", "trainer = torch.optim.SGD(net.parameters(), lr=0.5)\n", "scheduler = lr_scheduler.MultiStepLR(trainer, milestones=[15, 30], gamma=0.5)\n", "\n", "def get_lr(trainer, scheduler):\n", " lr = scheduler.get_last_lr()[0]\n", " trainer.step()\n", " scheduler.step()\n", " return lr\n", "\n", "d2l.plot(torch.arange(num_epochs), [get_lr(trainer, scheduler)\n", " for t in range(num_epochs)])" ] }, { "cell_type": "markdown", "id": "e80a8605", "metadata": { "origin_pos": 31 }, "source": [ "这种分段恒定学习率调度背后的直觉是,让优化持续进行,直到权重向量的分布达到一个驻点。\n", "此时,我们才将学习率降低,以获得更高质量的代理来达到一个良好的局部最小值。\n", "下面的例子展示了如何使用这种方法产生更好的解决方案。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "f999a96a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:26:49.208130Z", "iopub.status.busy": "2023-08-18T07:26:49.207506Z", "iopub.status.idle": "2023-08-18T07:28:47.899463Z", "shell.execute_reply": "2023-08-18T07:28:47.898543Z" }, "origin_pos": 33, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.191, train acc 0.928, test acc 0.889\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:28:47.855546\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train(net, train_iter, test_iter, num_epochs, loss, trainer, device,\n", " scheduler)" ] }, { "cell_type": "markdown", "id": "655df579", "metadata": { "origin_pos": 36 }, "source": [ "### 余弦调度器\n", "\n", "余弦调度器是 :cite:`Loshchilov.Hutter.2016`提出的一种启发式算法。\n", "它所依据的观点是:我们可能不想在一开始就太大地降低学习率,而且可能希望最终能用非常小的学习率来“改进”解决方案。\n", "这产生了一个类似于余弦的调度,函数形式如下所示,学习率的值在$t \\in [0, T]$之间。\n", "\n", "$$\\eta_t = \\eta_T + \\frac{\\eta_0 - \\eta_T}{2} \\left(1 + \\cos(\\pi t/T)\\right)$$\n", "\n", "这里$\\eta_0$是初始学习率,$\\eta_T$是当$T$时的目标学习率。\n", "此外,对于$t > T$,我们只需将值固定到$\\eta_T$而不再增加它。\n", "在下面的示例中,我们设置了最大更新步数$T = 20$。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "ccd9120f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:28:47.903646Z", "iopub.status.busy": "2023-08-18T07:28:47.903033Z", "iopub.status.idle": "2023-08-18T07:28:48.107251Z", "shell.execute_reply": "2023-08-18T07:28:48.106198Z" }, "origin_pos": 38, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:28:48.066245\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "class CosineScheduler:\n", " def __init__(self, max_update, base_lr=0.01, final_lr=0,\n", " warmup_steps=0, warmup_begin_lr=0):\n", " self.base_lr_orig = base_lr\n", " self.max_update = max_update\n", " self.final_lr = final_lr\n", " self.warmup_steps = warmup_steps\n", " self.warmup_begin_lr = warmup_begin_lr\n", " self.max_steps = self.max_update - self.warmup_steps\n", "\n", " def get_warmup_lr(self, epoch):\n", " increase = (self.base_lr_orig - self.warmup_begin_lr) \\\n", " * float(epoch) / float(self.warmup_steps)\n", " return self.warmup_begin_lr + increase\n", "\n", " def __call__(self, epoch):\n", " if epoch < self.warmup_steps:\n", " return self.get_warmup_lr(epoch)\n", " if epoch <= self.max_update:\n", " self.base_lr = self.final_lr + (\n", " self.base_lr_orig - self.final_lr) * (1 + math.cos(\n", " math.pi * (epoch - self.warmup_steps) / self.max_steps)) / 2\n", " return self.base_lr\n", "\n", "scheduler = CosineScheduler(max_update=20, base_lr=0.3, final_lr=0.01)\n", "d2l.plot(torch.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" ] }, { "cell_type": "markdown", "id": "e732fa0f", "metadata": { "origin_pos": 39 }, "source": [ "在计算机视觉的背景下,这个调度方式可能产生改进的结果。\n", "但请注意,如下所示,这种改进并不一定成立。\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "6f53d237", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:28:48.113142Z", "iopub.status.busy": "2023-08-18T07:28:48.112226Z", "iopub.status.idle": "2023-08-18T07:30:43.364866Z", "shell.execute_reply": "2023-08-18T07:30:43.363967Z" }, "origin_pos": 41, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.207, train acc 0.923, test acc 0.892\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:30:43.323106\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "net = net_fn()\n", "trainer = torch.optim.SGD(net.parameters(), lr=0.3)\n", "train(net, train_iter, test_iter, num_epochs, loss, trainer, device,\n", " scheduler)" ] }, { "cell_type": "markdown", "id": "bc2e52ff", "metadata": { "origin_pos": 44 }, "source": [ "### 预热\n", "\n", "在某些情况下,初始化参数不足以得到良好的解。\n", "这对某些高级网络设计来说尤其棘手,可能导致不稳定的优化结果。\n", "对此,一方面,我们可以选择一个足够小的学习率,\n", "从而防止一开始发散,然而这样进展太缓慢。\n", "另一方面,较高的学习率最初就会导致发散。\n", "\n", "解决这种困境的一个相当简单的解决方法是使用预热期,在此期间学习率将增加至初始最大值,然后冷却直到优化过程结束。\n", "为了简单起见,通常使用线性递增。\n", "这引出了如下表所示的时间表。\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "89ebf1cb", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:30:43.368785Z", "iopub.status.busy": "2023-08-18T07:30:43.368179Z", "iopub.status.idle": "2023-08-18T07:30:43.508443Z", "shell.execute_reply": "2023-08-18T07:30:43.507219Z" }, "origin_pos": 46, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:30:43.481487\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "scheduler = CosineScheduler(20, warmup_steps=5, base_lr=0.3, final_lr=0.01)\n", "d2l.plot(torch.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" ] }, { "cell_type": "markdown", "id": "b104d0e4", "metadata": { "origin_pos": 47 }, "source": [ "注意,观察前5个迭代轮数的性能,网络最初收敛得更好。\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "f9da4ce3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:30:43.513191Z", "iopub.status.busy": "2023-08-18T07:30:43.512099Z", "iopub.status.idle": "2023-08-18T07:32:34.996401Z", "shell.execute_reply": "2023-08-18T07:32:34.995507Z" }, "origin_pos": 49, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.261, train acc 0.904, test acc 0.878\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:32:34.954473\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "net = net_fn()\n", "trainer = torch.optim.SGD(net.parameters(), lr=0.3)\n", "train(net, train_iter, test_iter, num_epochs, loss, trainer, device,\n", " scheduler)" ] }, { "cell_type": "markdown", "id": "dcf78d30", "metadata": { "origin_pos": 52 }, "source": [ "预热可以应用于任何调度器,而不仅仅是余弦。\n", "有关学习率调度的更多实验和更详细讨论,请参阅 :cite:`Gotmare.Keskar.Xiong.ea.2018`。\n", "其中,这篇论文的点睛之笔的发现:预热阶段限制了非常深的网络中参数的发散程度 。\n", "这在直觉上是有道理的:在网络中那些一开始花费最多时间取得进展的部分,随机初始化会产生巨大的发散。\n", "\n", "## 小结\n", "\n", "* 在训练期间逐步降低学习率可以提高准确性,并且减少模型的过拟合。\n", "* 在实验中,每当进展趋于稳定时就降低学习率,这是很有效的。从本质上说,这可以确保我们有效地收敛到一个适当的解,也只有这样才能通过降低学习率来减小参数的固有方差。\n", "* 余弦调度器在某些计算机视觉问题中很受欢迎。\n", "* 优化之前的预热期可以防止发散。\n", "* 优化在深度学习中有多种用途。对于同样的训练误差而言,选择不同的优化算法和学习率调度,除了最大限度地减少训练时间,可以导致测试集上不同的泛化和过拟合量。\n", "\n", "## 练习\n", "\n", "1. 试验给定固定学习率的优化行为。这种情况下可以获得的最佳模型是什么?\n", "1. 如果改变学习率下降的指数,收敛性会如何改变?在实验中方便起见,使用`PolyScheduler`。\n", "1. 将余弦调度器应用于大型计算机视觉问题,例如训练ImageNet数据集。与其他调度器相比,它如何影响性能?\n", "1. 预热应该持续多长时间?\n", "1. 可以试着把优化和采样联系起来吗?首先,在随机梯度朗之万动力学上使用 :cite:`Welling.Teh.2011`的结果。\n" ] }, { "cell_type": "markdown", "id": "d09d87dd", "metadata": { "origin_pos": 54, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/4334)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }