{ "cells": [ { "cell_type": "markdown", "id": "6efa4466", "metadata": { "origin_pos": 0 }, "source": [ "# RMSProp算法\n", ":label:`sec_rmsprop`\n", "\n", " :numref:`sec_adagrad`中的关键问题之一,是学习率按预定时间表$\\mathcal{O}(t^{-\\frac{1}{2}})$显著降低。\n", "虽然这通常适用于凸问题,但对于深度学习中遇到的非凸问题,可能并不理想。\n", "但是,作为一个预处理器,Adagrad算法按坐标顺序的适应性是非常可取的。\n", "\n", " :cite:`Tieleman.Hinton.2012`建议以RMSProp算法作为将速率调度与坐标自适应学习率分离的简单修复方法。\n", "问题在于,Adagrad算法将梯度$\\mathbf{g}_t$的平方累加成状态矢量$\\mathbf{s}_t = \\mathbf{s}_{t-1} + \\mathbf{g}_t^2$。\n", "因此,由于缺乏规范化,没有约束力,$\\mathbf{s}_t$持续增长,几乎上是在算法收敛时呈线性递增。\n", "\n", "解决此问题的一种方法是使用$\\mathbf{s}_t / t$。\n", "对$\\mathbf{g}_t$的合理分布来说,它将收敛。\n", "遗憾的是,限制行为生效可能需要很长时间,因为该流程记住了值的完整轨迹。\n", "另一种方法是按动量法中的方式使用泄漏平均值,即$\\mathbf{s}_t \\leftarrow \\gamma \\mathbf{s}_{t-1} + (1-\\gamma) \\mathbf{g}_t^2$,其中参数$\\gamma > 0$。\n", "保持所有其它部分不变就产生了RMSProp算法。\n", "\n", "## 算法\n", "\n", "让我们详细写出这些方程式。\n", "\n", "$$\\begin{aligned}\n", " \\mathbf{s}_t & \\leftarrow \\gamma \\mathbf{s}_{t-1} + (1 - \\gamma) \\mathbf{g}_t^2, \\\\\n", " \\mathbf{x}_t & \\leftarrow \\mathbf{x}_{t-1} - \\frac{\\eta}{\\sqrt{\\mathbf{s}_t + \\epsilon}} \\odot \\mathbf{g}_t.\n", "\\end{aligned}$$\n", "\n", "常数$\\epsilon > 0$通常设置为$10^{-6}$,以确保我们不会因除以零或步长过大而受到影响。\n", "鉴于这种扩展,我们现在可以自由控制学习率$\\eta$,而不考虑基于每个坐标应用的缩放。\n", "就泄漏平均值而言,我们可以采用与之前在动量法中适用的相同推理。\n", "扩展$\\mathbf{s}_t$定义可获得\n", "\n", "$$\n", "\\begin{aligned}\n", "\\mathbf{s}_t & = (1 - \\gamma) \\mathbf{g}_t^2 + \\gamma \\mathbf{s}_{t-1} \\\\\n", "& = (1 - \\gamma) \\left(\\mathbf{g}_t^2 + \\gamma \\mathbf{g}_{t-1}^2 + \\gamma^2 \\mathbf{g}_{t-2} + \\ldots, \\right).\n", "\\end{aligned}\n", "$$\n", "\n", "同之前在 :numref:`sec_momentum`小节一样,我们使用$1 + \\gamma + \\gamma^2 + \\ldots, = \\frac{1}{1-\\gamma}$。\n", "因此,权重总和标准化为$1$且观测值的半衰期为$\\gamma^{-1}$。\n", "让我们图像化各种数值的$\\gamma$在过去40个时间步长的权重。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "30751083", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:32.547846Z", "iopub.status.busy": "2023-08-18T07:05:32.547295Z", "iopub.status.idle": "2023-08-18T07:05:34.484898Z", "shell.execute_reply": "2023-08-18T07:05:34.483633Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import math\n", "import torch\n", "from d2l import torch as d2l" ] }, { "cell_type": "code", "execution_count": 2, "id": "254f2129", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:34.489092Z", "iopub.status.busy": "2023-08-18T07:05:34.488387Z", "iopub.status.idle": "2023-08-18T07:05:34.615400Z", "shell.execute_reply": "2023-08-18T07:05:34.614588Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:05:34.587794\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.set_figsize()\n", "gammas = [0.95, 0.9, 0.8, 0.7]\n", "for gamma in gammas:\n", " x = torch.arange(40).detach().numpy()\n", " d2l.plt.plot(x, (1-gamma) * gamma ** x, label=f'gamma = {gamma:.2f}')\n", "d2l.plt.xlabel('time');" ] }, { "cell_type": "markdown", "id": "19893348", "metadata": { "origin_pos": 6 }, "source": [ "## 从零开始实现\n", "\n", "和之前一样,我们使用二次函数$f(\\mathbf{x})=0.1x_1^2+2x_2^2$来观察RMSProp算法的轨迹。\n", "回想在 :numref:`sec_adagrad`一节中,当我们使用学习率为0.4的Adagrad算法时,变量在算法的后期阶段移动非常缓慢,因为学习率衰减太快。\n", "RMSProp算法中不会发生这种情况,因为$\\eta$是单独控制的。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "c3f8b14e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:34.622343Z", "iopub.status.busy": "2023-08-18T07:05:34.621764Z", "iopub.status.idle": "2023-08-18T07:05:34.731199Z", "shell.execute_reply": "2023-08-18T07:05:34.730392Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 20, x1: -0.010599, x2: 0.000000\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:05:34.702544\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def rmsprop_2d(x1, x2, s1, s2):\n", " g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6\n", " s1 = gamma * s1 + (1 - gamma) * g1 ** 2\n", " s2 = gamma * s2 + (1 - gamma) * g2 ** 2\n", " x1 -= eta / math.sqrt(s1 + eps) * g1\n", " x2 -= eta / math.sqrt(s2 + eps) * g2\n", " return x1, x2, s1, s2\n", "\n", "def f_2d(x1, x2):\n", " return 0.1 * x1 ** 2 + 2 * x2 ** 2\n", "\n", "eta, gamma = 0.4, 0.9\n", "d2l.show_trace_2d(f_2d, d2l.train_2d(rmsprop_2d))" ] }, { "cell_type": "markdown", "id": "40cd5b7f", "metadata": { "origin_pos": 8 }, "source": [ "接下来,我们在深度网络中实现RMSProp算法。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "26c70a5c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:34.737342Z", "iopub.status.busy": "2023-08-18T07:05:34.736751Z", "iopub.status.idle": "2023-08-18T07:05:34.740975Z", "shell.execute_reply": "2023-08-18T07:05:34.740221Z" }, "origin_pos": 9, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def init_rmsprop_states(feature_dim):\n", " s_w = torch.zeros((feature_dim, 1))\n", " s_b = torch.zeros(1)\n", " return (s_w, s_b)" ] }, { "cell_type": "code", "execution_count": 5, "id": "3fc7c2f1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:34.745617Z", "iopub.status.busy": "2023-08-18T07:05:34.744961Z", "iopub.status.idle": "2023-08-18T07:05:34.749992Z", "shell.execute_reply": "2023-08-18T07:05:34.749197Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def rmsprop(params, states, hyperparams):\n", " gamma, eps = hyperparams['gamma'], 1e-6\n", " for p, s in zip(params, states):\n", " with torch.no_grad():\n", " s[:] = gamma * s + (1 - gamma) * torch.square(p.grad)\n", " p[:] -= hyperparams['lr'] * p.grad / torch.sqrt(s + eps)\n", " p.grad.data.zero_()" ] }, { "cell_type": "markdown", "id": "2d099676", "metadata": { "origin_pos": 16 }, "source": [ "我们将初始学习率设置为0.01,加权项$\\gamma$设置为0.9。\n", "也就是说,$\\mathbf{s}$累加了过去的$1/(1-\\gamma) = 10$次平方梯度观测值的平均值。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "692fe2a3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:34.754673Z", "iopub.status.busy": "2023-08-18T07:05:34.754066Z", "iopub.status.idle": "2023-08-18T07:05:37.535518Z", "shell.execute_reply": "2023-08-18T07:05:37.534481Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.247, 0.014 sec/epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:05:37.487309\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)\n", "d2l.train_ch11(rmsprop, init_rmsprop_states(feature_dim),\n", " {'lr': 0.01, 'gamma': 0.9}, data_iter, feature_dim);" ] }, { "cell_type": "markdown", "id": "30b0a9cd", "metadata": { "origin_pos": 18 }, "source": [ "## 简洁实现\n", "\n", "我们可直接使用深度学习框架中提供的RMSProp算法来训练模型。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "32e89d5e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:37.540843Z", "iopub.status.busy": "2023-08-18T07:05:37.540147Z", "iopub.status.idle": "2023-08-18T07:05:44.802579Z", "shell.execute_reply": "2023-08-18T07:05:44.801508Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 0.244, 0.017 sec/epoch\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:05:44.767183\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "trainer = torch.optim.RMSprop\n", "d2l.train_concise_ch11(trainer, {'lr': 0.01, 'alpha': 0.9},\n", " data_iter)" ] }, { "cell_type": "markdown", "id": "c00792e2", "metadata": { "origin_pos": 23 }, "source": [ "## 小结\n", "\n", "* RMSProp算法与Adagrad算法非常相似,因为两者都使用梯度的平方来缩放系数。\n", "* RMSProp算法与动量法都使用泄漏平均值。但是,RMSProp算法使用该技术来调整按系数顺序的预处理器。\n", "* 在实验中,学习率需要由实验者调度。\n", "* 系数$\\gamma$决定了在调整每坐标比例时历史记录的时长。\n", "\n", "## 练习\n", "\n", "1. 如果我们设置$\\gamma = 1$,实验会发生什么?为什么?\n", "1. 旋转优化问题以最小化$f(\\mathbf{x}) = 0.1 (x_1 + x_2)^2 + 2 (x_1 - x_2)^2$。收敛会发生什么?\n", "1. 试试在真正的机器学习问题上应用RMSProp算法会发生什么,例如在Fashion-MNIST上的训练。试验不同的取值来调整学习率。\n", "1. 随着优化的进展,需要调整$\\gamma$吗?RMSProp算法对此有多敏感?\n" ] }, { "cell_type": "markdown", "id": "04ff42ae", "metadata": { "origin_pos": 25, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/4322)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }