{ "cells": [ { "cell_type": "markdown", "id": "12918bab", "metadata": { "origin_pos": 0 }, "source": [ "# 优化和深度学习\n", "\n", "本节将讨论优化与深度学习之间的关系以及在深度学习中使用优化的挑战。对于深度学习问题,我们通常会先定义*损失函数*。一旦我们有了损失函数,我们就可以使用优化算法来尝试最小化损失。在优化中,损失函数通常被称为优化问题的*目标函数*。按照传统惯例,大多数优化算法都关注的是*最小化*。如果我们需要最大化目标,那么有一个简单的解决方案:在目标函数前加负号即可。\n", "\n", "## 优化的目标\n", "\n", "尽管优化提供了一种最大限度地减少深度学习损失函数的方法,但本质上,优化和深度学习的目标是根本不同的。前者主要关注的是最小化目标,后者则关注在给定有限数据量的情况下寻找合适的模型。在 :numref:`sec_model_selection`中,我们详细讨论了这两个目标之间的区别。例如,训练误差和泛化误差通常不同:由于优化算法的目标函数通常是基于训练数据集的损失函数,因此优化的目标是减少训练误差。但是,深度学习(或更广义地说,统计推断)的目标是减少泛化误差。为了实现后者,除了使用优化算法来减少训练误差之外,我们还需要注意过拟合。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "3a6091f3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:06.516685Z", "iopub.status.busy": "2023-08-18T07:07:06.515981Z", "iopub.status.idle": "2023-08-18T07:07:08.496750Z", "shell.execute_reply": "2023-08-18T07:07:08.495844Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import torch\n", "from mpl_toolkits import mplot3d\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "9f760af7", "metadata": { "origin_pos": 5 }, "source": [ "为了说明上述不同的目标,引入两个概念*风险*和*经验风险*。如 :numref:`subsec_empirical-risk-and-risk`所述,经验风险是训练数据集的平均损失,而风险则是整个数据群的预期损失。下面我们定义了两个函数:风险函数`f`和经验风险函数`g`。假设我们只有有限的训练数据。因此,这里的`g`不如`f`平滑。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "d5f7a949", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:08.500835Z", "iopub.status.busy": "2023-08-18T07:07:08.500449Z", "iopub.status.idle": "2023-08-18T07:07:08.505447Z", "shell.execute_reply": "2023-08-18T07:07:08.504568Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def f(x):\n", " return x * torch.cos(np.pi * x)\n", "\n", "def g(x):\n", " return f(x) + 0.2 * torch.cos(5 * np.pi * x)" ] }, { "cell_type": "markdown", "id": "46724f04", "metadata": { "origin_pos": 7 }, "source": [ "下图说明,训练数据集的最低经验风险可能与最低风险(泛化误差)不同。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "66a76e85", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:08.509166Z", "iopub.status.busy": "2023-08-18T07:07:08.508582Z", "iopub.status.idle": "2023-08-18T07:07:08.781243Z", "shell.execute_reply": "2023-08-18T07:07:08.779982Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:08.726154\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def annotate(text, xy, xytext): #@save\n", " d2l.plt.gca().annotate(text, xy=xy, xytext=xytext,\n", " arrowprops=dict(arrowstyle='->'))\n", "\n", "x = torch.arange(0.5, 1.5, 0.01)\n", "d2l.set_figsize((4.5, 2.5))\n", "d2l.plot(x, [f(x), g(x)], 'x', 'risk')\n", "annotate('min of\\nempirical risk', (1.0, -1.2), (0.5, -1.1))\n", "annotate('min of risk', (1.1, -1.05), (0.95, -0.5))" ] }, { "cell_type": "markdown", "id": "1fbcc94f", "metadata": { "origin_pos": 10 }, "source": [ "## 深度学习中的优化挑战\n", "\n", "本章将关注优化算法在最小化目标函数方面的性能,而不是模型的泛化误差。在 :numref:`sec_linear_regression`中,我们区分了优化问题中的解析解和数值解。在深度学习中,大多数目标函数都很复杂,没有解析解。相反,我们必须使用数值优化算法。本章中的优化算法都属于此类别。\n", "\n", "深度学习优化存在许多挑战。其中最令人烦恼的是局部最小值、鞍点和梯度消失。\n", "\n", "### 局部最小值\n", "\n", "对于任何目标函数$f(x)$,如果在$x$处对应的$f(x)$值小于在$x$附近任意其他点的$f(x)$值,那么$f(x)$可能是局部最小值。如果$f(x)$在$x$处的值是整个域中目标函数的最小值,那么$f(x)$是全局最小值。\n", "\n", "例如,给定函数\n", "\n", "$$f(x) = x \\cdot \\text{cos}(\\pi x) \\text{ for } -1.0 \\leq x \\leq 2.0,$$\n", "\n", "我们可以近似该函数的局部最小值和全局最小值。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "0adadab1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:08.784743Z", "iopub.status.busy": "2023-08-18T07:07:08.784448Z", "iopub.status.idle": "2023-08-18T07:07:09.024530Z", "shell.execute_reply": "2023-08-18T07:07:09.023594Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:08.979118\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x = torch.arange(-1.0, 2.0, 0.01)\n", "d2l.plot(x, [f(x), ], 'x', 'f(x)')\n", "annotate('local minimum', (-0.3, -0.25), (-0.77, -1.0))\n", "annotate('global minimum', (1.1, -0.95), (0.6, 0.8))" ] }, { "cell_type": "markdown", "id": "407b7125", "metadata": { "origin_pos": 13 }, "source": [ "深度学习模型的目标函数通常有许多局部最优解。当优化问题的数值解接近局部最优值时,随着目标函数解的梯度接近或变为零,通过最终迭代获得的数值解可能仅使目标函数*局部*最优,而不是*全局*最优。只有一定程度的噪声可能会使参数跳出局部最小值。事实上,这是小批量随机梯度下降的有利特性之一。在这种情况下,小批量上梯度的自然变化能够将参数从局部极小值中跳出。\n", "\n", "### 鞍点\n", "\n", "除了局部最小值之外,鞍点是梯度消失的另一个原因。*鞍点*(saddle point)是指函数的所有梯度都消失但既不是全局最小值也不是局部最小值的任何位置。考虑这个函数$f(x) = x^3$。它的一阶和二阶导数在$x=0$时消失。这时优化可能会停止,尽管它不是最小值。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "26f5d6d0", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:09.028686Z", "iopub.status.busy": "2023-08-18T07:07:09.027907Z", "iopub.status.idle": "2023-08-18T07:07:09.220430Z", "shell.execute_reply": "2023-08-18T07:07:09.219542Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:09.183352\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x = torch.arange(-2.0, 2.0, 0.01)\n", "d2l.plot(x, [x**3], 'x', 'f(x)')\n", "annotate('saddle point', (0, -0.2), (-0.52, -5.0))" ] }, { "cell_type": "markdown", "id": "1539e180", "metadata": { "origin_pos": 16 }, "source": [ "如下例所示,较高维度的鞍点甚至更加隐蔽。考虑这个函数$f(x, y) = x^2 - y^2$。它的鞍点为$(0, 0)$。这是关于$y$的最大值,也是关于$x$的最小值。此外,它看起来像个马鞍,这就是鞍点的名字由来。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "52623051", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:09.225145Z", "iopub.status.busy": "2023-08-18T07:07:09.224532Z", "iopub.status.idle": "2023-08-18T07:07:09.358354Z", "shell.execute_reply": "2023-08-18T07:07:09.357475Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:09.320886\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x, y = torch.meshgrid(\n", " torch.linspace(-1.0, 1.0, 101), torch.linspace(-1.0, 1.0, 101))\n", "z = x**2 - y**2\n", "\n", "ax = d2l.plt.figure().add_subplot(111, projection='3d')\n", "ax.plot_wireframe(x, y, z, **{'rstride': 10, 'cstride': 10})\n", "ax.plot([0], [0], [0], 'rx')\n", "ticks = [-1, 0, 1]\n", "d2l.plt.xticks(ticks)\n", "d2l.plt.yticks(ticks)\n", "ax.set_zticks(ticks)\n", "d2l.plt.xlabel('x')\n", "d2l.plt.ylabel('y');" ] }, { "cell_type": "markdown", "id": "5878b89a", "metadata": { "origin_pos": 19 }, "source": [ "我们假设函数的输入是$k$维向量,其输出是标量,因此其Hessian矩阵(也称黑塞矩阵)将有$k$个特征值(参考[特征分解的在线附录](https://d2l.ai/chapter_appendix-mathematics-for-deep-learning/eigendecomposition.html))。函数的解可能是局部最小值、局部最大值或函数梯度为零位置处的鞍点:\n", "\n", "* 当函数在零梯度位置处的Hessian矩阵的特征值全部为正值时,我们有该函数的局部最小值;\n", "* 当函数在零梯度位置处的Hessian矩阵的特征值全部为负值时,我们有该函数的局部最大值;\n", "* 当函数在零梯度位置处的Hessian矩阵的特征值为负值和正值时,我们有该函数的一个鞍点。\n", "\n", "对于高维度问题,至少*部分*特征值为负的可能性相当高。这使得鞍点比局部最小值更有可能出现。我们将在下一节介绍凸性时讨论这种情况的一些例外。简而言之,凸函数是Hessian函数的特征值永远不为负值的函数。不幸的是,大多数深度学习问题并不属于这一类。尽管如此,它还是研究优化算法的一个很好的工具。\n", "\n", "### 梯度消失\n", "\n", "可能遇到的最隐蔽问题是梯度消失。回想一下我们在 :numref:`subsec_activation_functions`中常用的激活函数及其衍生函数。例如,假设我们想最小化函数$f(x) = \\tanh(x)$,然后我们恰好从$x = 4$开始。正如我们所看到的那样,$f$的梯度接近零。更具体地说,$f'(x) = 1 - \\tanh^2(x)$,因此是$f'(4) = 0.0013$。因此,在我们取得进展之前,优化将会停滞很长一段时间。事实证明,这是在引入ReLU激活函数之前训练深度学习模型相当棘手的原因之一。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "5aa88d1a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:09.362938Z", "iopub.status.busy": "2023-08-18T07:07:09.362153Z", "iopub.status.idle": "2023-08-18T07:07:09.566124Z", "shell.execute_reply": "2023-08-18T07:07:09.565284Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:09.526643\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x = torch.arange(-2.0, 5.0, 0.01)\n", "d2l.plot(x, [torch.tanh(x)], 'x', 'f(x)')\n", "annotate('vanishing gradient', (4, 1), (2, 0.0))" ] }, { "cell_type": "markdown", "id": "7cc9c3a2", "metadata": { "origin_pos": 22 }, "source": [ "正如我们所看到的那样,深度学习的优化充满挑战。幸运的是,有一系列强大的算法表现良好,即使对于初学者也很容易使用。此外,没有必要找到最优解。局部最优解或其近似解仍然非常有用。\n", "\n", "## 小结\n", "\n", "* 最小化训练误差并*不能*保证我们找到最佳的参数集来最小化泛化误差。\n", "* 优化问题可能有许多局部最小值。\n", "* 一个问题可能有很多的鞍点,因为问题通常不是凸的。\n", "* 梯度消失可能会导致优化停滞,重参数化通常会有所帮助。对参数进行良好的初始化也可能是有益的。\n", "\n", "## 练习\n", "\n", "1. 考虑一个简单的MLP,它有一个隐藏层,比如,隐藏层中维度为$d$和一个输出。证明对于任何局部最小值,至少有$d!$个等效方案。\n", "1. 假设我们有一个对称随机矩阵$\\mathbf{M}$,其中条目$M_{ij} = M_{ji}$各自从某种概率分布$p_{ij}$中抽取。此外,假设$p_{ij}(x) = p_{ij}(-x)$,即分布是对称的(详情请参见 :cite:`Wigner.1958`)。\n", " 1. 证明特征值的分布也是对称的。也就是说,对于任何特征向量$\\mathbf{v}$,关联的特征值$\\lambda$满足$P(\\lambda > 0) = P(\\lambda < 0)$的概率为$P(\\lambda > 0) = P(\\lambda < 0)$。\n", " 1. 为什么以上*没有*暗示$P(\\lambda > 0) = 0.5$?\n", "1. 你能想到深度学习优化还涉及哪些其他挑战?\n", "1. 假设你想在(真实的)鞍上平衡一个(真实的)球。\n", " 1. 为什么这很难?\n", " 1. 能利用这种效应来优化算法吗?\n" ] }, { "cell_type": "markdown", "id": "ea89bc0e", "metadata": { "origin_pos": 24, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/3841)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }