{ "cells": [ { "cell_type": "markdown", "id": "c6b5f483", "metadata": { "origin_pos": 0 }, "source": [ "# Adadelta\n", ":label:`sec_adadelta`\n", "\n", "\n", "Adadelta是AdaGrad的另一种变体( :numref:`sec_adagrad`),\n", "主要区别在于前者减少了学习率适应坐标的数量。\n", "此外,广义上Adadelta被称为没有学习率,因为它使用变化量本身作为未来变化的校准。\n", "Adadelta算法是在 :cite:`Zeiler.2012`中提出的。\n", "\n", "## Adadelta算法\n", "\n", "简而言之,Adadelta使用两个状态变量,$\\mathbf{s}_t$用于存储梯度二阶导数的泄露平均值,$\\Delta\\mathbf{x}_t$用于存储模型本身中参数变化二阶导数的泄露平均值。请注意,为了与其他出版物和实现的兼容性,我们使用作者的原始符号和命名(没有其它真正理由让大家使用不同的希腊变量来表示在动量法、AdaGrad、RMSProp和Adadelta中用于相同用途的参数)。\n", "\n", "以下是Adadelta的技术细节。鉴于参数du jour是$\\rho$,我们获得了与 :numref:`sec_rmsprop`类似的以下泄漏更新:\n", "\n", "$$\\begin{aligned}\n", " \\mathbf{s}_t & = \\rho \\mathbf{s}_{t-1} + (1 - \\rho) \\mathbf{g}_t^2.\n", "\\end{aligned}$$\n", "\n", "与 :numref:`sec_rmsprop`的区别在于,我们使用重新缩放的梯度$\\mathbf{g}_t'$执行更新,即\n", "\n", "$$\\begin{aligned}\n", " \\mathbf{x}_t & = \\mathbf{x}_{t-1} - \\mathbf{g}_t'. \\\\\n", "\\end{aligned}$$\n", "\n", "那么,调整后的梯度$\\mathbf{g}_t'$是什么?我们可以按如下方式计算它:\n", "\n", "$$\\begin{aligned}\n", " \\mathbf{g}_t' & = \\frac{\\sqrt{\\Delta\\mathbf{x}_{t-1} + \\epsilon}}{\\sqrt{{\\mathbf{s}_t + \\epsilon}}} \\odot \\mathbf{g}_t, \\\\\n", "\\end{aligned}$$\n", "\n", "其中$\\Delta \\mathbf{x}_{t-1}$是重新缩放梯度的平方$\\mathbf{g}_t'$的泄漏平均值。我们将$\\Delta \\mathbf{x}_{0}$初始化为$0$,然后在每个步骤中使用$\\mathbf{g}_t'$更新它,即\n", "\n", "$$\\begin{aligned}\n", " \\Delta \\mathbf{x}_t & = \\rho \\Delta\\mathbf{x}_{t-1} + (1 - \\rho) {\\mathbf{g}_t'}^2,\n", "\\end{aligned}$$\n", "\n", "和$\\epsilon$(例如$10^{-5}$这样的小值)是为了保持数字稳定性而加入的。\n", "\n", "## 代码实现\n", "\n", "Adadelta需要为每个变量维护两个状态变量,即$\\mathbf{s}_t$和$\\Delta\\mathbf{x}_t$。这将产生以下实现。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "b249f128", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:09.332337Z", "iopub.status.busy": "2023-08-18T07:05:09.331524Z", "iopub.status.idle": "2023-08-18T07:05:11.424308Z", "shell.execute_reply": "2023-08-18T07:05:11.423239Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "from d2l import torch as d2l\n", "\n", "\n", "def init_adadelta_states(feature_dim):\n", " s_w, s_b = torch.zeros((feature_dim, 1)), torch.zeros(1)\n", " delta_w, delta_b = torch.zeros((feature_dim, 1)), torch.zeros(1)\n", " return ((s_w, delta_w), (s_b, delta_b))\n", "\n", "def adadelta(params, states, hyperparams):\n", " rho, eps = hyperparams['rho'], 1e-5\n", " for p, (s, delta) in zip(params, states):\n", " with torch.no_grad():\n", " # In-placeupdatesvia[:]\n", " s[:] = rho * s + (1 - rho) * torch.square(p.grad)\n", " g = (torch.sqrt(delta + eps) / torch.sqrt(s + eps)) * p.grad\n", " p[:] -= g\n", " delta[:] = rho * delta + (1 - rho) * g * g\n", " p.grad.data.zero_()" ] }, { "cell_type": "markdown", "id": "4411aabf", "metadata": { "origin_pos": 5 }, "source": [ "对于每次参数更新,选择$\\rho = 0.9$相当于10个半衰期。由此我们得到:\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "4f8025df", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:11.429258Z", "iopub.status.busy": "2023-08-18T07:05:11.428414Z", "iopub.status.idle": "2023-08-18T07:05:14.081998Z", "shell.execute_reply": "2023-08-18T07:05:14.081152Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.243, 0.014 sec/epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:05:14.047429\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)\n", "d2l.train_ch11(adadelta, init_adadelta_states(feature_dim),\n", " {'rho': 0.9}, data_iter, feature_dim);" ] }, { "cell_type": "markdown", "id": "a8cd0fd5", "metadata": { "origin_pos": 7 }, "source": [ "为了简洁实现,我们只需使用高级API中的Adadelta算法。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "25f0fd32", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:14.085468Z", "iopub.status.busy": "2023-08-18T07:05:14.085189Z", "iopub.status.idle": "2023-08-18T07:05:19.137299Z", "shell.execute_reply": "2023-08-18T07:05:19.136478Z" }, "origin_pos": 9, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.243, 0.013 sec/epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:05:19.103343\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "trainer = torch.optim.Adadelta\n", "d2l.train_concise_ch11(trainer, {'rho': 0.9}, data_iter)" ] }, { "cell_type": "markdown", "id": "ff4cd9a5", "metadata": { "origin_pos": 12 }, "source": [ "## 小结\n", "\n", "* Adadelta没有学习率参数。相反,它使用参数本身的变化率来调整学习率。\n", "* Adadelta需要两个状态变量来存储梯度的二阶导数和参数的变化。\n", "* Adadelta使用泄漏的平均值来保持对适当统计数据的运行估计。\n", "\n", "## 练习\n", "\n", "1. 调整$\\rho$的值,会发生什么?\n", "1. 展示如何在不使用$\\mathbf{g}_t'$的情况下实现算法。为什么这是个好主意?\n", "1. Adadelta真的是学习率为0吗?能找到Adadelta无法解决的优化问题吗?\n", "1. 将Adadelta的收敛行为与AdaGrad和RMSProp进行比较。\n" ] }, { "cell_type": "markdown", "id": "074f9479", "metadata": { "origin_pos": 14, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/5772)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }