{ "cells": [ { "cell_type": "markdown", "id": "310ac5cd", "metadata": { "origin_pos": 0 }, "source": [ "# 长短期记忆网络(LSTM)\n", ":label:`sec_lstm`\n", "\n", "长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。\n", "解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM)\n", " :cite:`Hochreiter.Schmidhuber.1997`。\n", "它有许多与门控循环单元( :numref:`sec_gru`)一样的属性。\n", "有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些,\n", "却比门控循环单元早诞生了近20年。\n", "\n", "## 门控记忆元\n", "\n", "可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。\n", "长短期记忆网络引入了*记忆元*(memory cell),或简称为*单元*(cell)。\n", "有些文献认为记忆元是隐状态的一种特殊类型,\n", "它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。\n", "为了控制记忆元,我们需要许多门。\n", "其中一个门用来从单元中输出条目,我们将其称为*输出门*(output gate)。\n", "另外一个门用来决定何时将数据读入单元,我们将其称为*输入门*(input gate)。\n", "我们还需要一种机制来重置单元的内容,由*遗忘门*(forget gate)来管理,\n", "这种设计的动机与门控循环单元相同,\n", "能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。\n", "让我们看看这在实践中是如何运作的。\n", "\n", "### 输入门、忘记门和输出门\n", "\n", "就如在门控循环单元中一样,\n", "当前时间步的输入和前一个时间步的隐状态\n", "作为数据送入长短期记忆网络的门中,\n", "如 :numref:`lstm_0`所示。\n", "它们由三个具有sigmoid激活函数的全连接层处理,\n", "以计算输入门、遗忘门和输出门的值。\n", "因此,这三个门的值都在$(0, 1)$的范围内。\n", "\n", "![长短期记忆模型中的输入门、遗忘门和输出门](../img/lstm-0.svg)\n", ":label:`lstm_0`\n", "\n", "我们来细化一下长短期记忆网络的数学表达。\n", "假设有$h$个隐藏单元,批量大小为$n$,输入数为$d$。\n", "因此,输入为$\\mathbf{X}_t \\in \\mathbb{R}^{n \\times d}$,\n", "前一时间步的隐状态为$\\mathbf{H}_{t-1} \\in \\mathbb{R}^{n \\times h}$。\n", "相应地,时间步$t$的门被定义如下:\n", "输入门是$\\mathbf{I}_t \\in \\mathbb{R}^{n \\times h}$,\n", "遗忘门是$\\mathbf{F}_t \\in \\mathbb{R}^{n \\times h}$,\n", "输出门是$\\mathbf{O}_t \\in \\mathbb{R}^{n \\times h}$。\n", "它们的计算方法如下:\n", "\n", "$$\n", "\\begin{aligned}\n", "\\mathbf{I}_t &= \\sigma(\\mathbf{X}_t \\mathbf{W}_{xi} + \\mathbf{H}_{t-1} \\mathbf{W}_{hi} + \\mathbf{b}_i),\\\\\n", "\\mathbf{F}_t &= \\sigma(\\mathbf{X}_t \\mathbf{W}_{xf} + \\mathbf{H}_{t-1} \\mathbf{W}_{hf} + \\mathbf{b}_f),\\\\\n", "\\mathbf{O}_t &= \\sigma(\\mathbf{X}_t \\mathbf{W}_{xo} + \\mathbf{H}_{t-1} \\mathbf{W}_{ho} + \\mathbf{b}_o),\n", "\\end{aligned}\n", "$$\n", "\n", "其中$\\mathbf{W}_{xi}, \\mathbf{W}_{xf}, \\mathbf{W}_{xo} \\in \\mathbb{R}^{d \\times h}$\n", "和$\\mathbf{W}_{hi}, \\mathbf{W}_{hf}, \\mathbf{W}_{ho} \\in \\mathbb{R}^{h \\times h}$是权重参数,\n", "$\\mathbf{b}_i, \\mathbf{b}_f, \\mathbf{b}_o \\in \\mathbb{R}^{1 \\times h}$是偏置参数。\n", "\n", "### 候选记忆元\n", "\n", "由于还没有指定各种门的操作,所以先介绍*候选记忆元*(candidate memory cell)\n", "$\\tilde{\\mathbf{C}}_t \\in \\mathbb{R}^{n \\times h}$。\n", "它的计算与上面描述的三个门的计算类似,\n", "但是使用$\\tanh$函数作为激活函数,函数的值范围为$(-1, 1)$。\n", "下面导出在时间步$t$处的方程:\n", "\n", "$$\\tilde{\\mathbf{C}}_t = \\text{tanh}(\\mathbf{X}_t \\mathbf{W}_{xc} + \\mathbf{H}_{t-1} \\mathbf{W}_{hc} + \\mathbf{b}_c),$$\n", "\n", "其中$\\mathbf{W}_{xc} \\in \\mathbb{R}^{d \\times h}$和\n", "$\\mathbf{W}_{hc} \\in \\mathbb{R}^{h \\times h}$是权重参数,\n", "$\\mathbf{b}_c \\in \\mathbb{R}^{1 \\times h}$是偏置参数。\n", "\n", "候选记忆元的如 :numref:`lstm_1`所示。\n", "\n", "![长短期记忆模型中的候选记忆元](../img/lstm-1.svg)\n", ":label:`lstm_1`\n", "\n", "### 记忆元\n", "\n", "在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。\n", "类似地,在长短期记忆网络中,也有两个门用于这样的目的:\n", "输入门$\\mathbf{I}_t$控制采用多少来自$\\tilde{\\mathbf{C}}_t$的新数据,\n", "而遗忘门$\\mathbf{F}_t$控制保留多少过去的\n", "记忆元$\\mathbf{C}_{t-1} \\in \\mathbb{R}^{n \\times h}$的内容。\n", "使用按元素乘法,得出:\n", "\n", "$$\\mathbf{C}_t = \\mathbf{F}_t \\odot \\mathbf{C}_{t-1} + \\mathbf{I}_t \\odot \\tilde{\\mathbf{C}}_t.$$\n", "\n", "如果遗忘门始终为$1$且输入门始终为$0$,\n", "则过去的记忆元$\\mathbf{C}_{t-1}$\n", "将随时间被保存并传递到当前时间步。\n", "引入这种设计是为了缓解梯度消失问题,\n", "并更好地捕获序列中的长距离依赖关系。\n", "\n", "这样我们就得到了计算记忆元的流程图,如 :numref:`lstm_2`。\n", "\n", "![在长短期记忆网络模型中计算记忆元](../img/lstm-2.svg)\n", "\n", ":label:`lstm_2`\n", "\n", "### 隐状态\n", "\n", "最后,我们需要定义如何计算隐状态\n", "$\\mathbf{H}_t \\in \\mathbb{R}^{n \\times h}$,\n", "这就是输出门发挥作用的地方。\n", "在长短期记忆网络中,它仅仅是记忆元的$\\tanh$的门控版本。\n", "这就确保了$\\mathbf{H}_t$的值始终在区间$(-1, 1)$内:\n", "\n", "$$\\mathbf{H}_t = \\mathbf{O}_t \\odot \\tanh(\\mathbf{C}_t).$$\n", "\n", "只要输出门接近$1$,我们就能够有效地将所有记忆信息传递给预测部分,\n", "而对于输出门接近$0$,我们只保留记忆元内的所有信息,而不需要更新隐状态。\n", "\n", " :numref:`lstm_3`提供了数据流的图形化演示。\n", "\n", "![在长短期记忆模型中计算隐状态](../img/lstm-3.svg)\n", ":label:`lstm_3`\n", "\n", "## 从零开始实现\n", "\n", "现在,我们从零开始实现长短期记忆网络。\n", "与 :numref:`sec_rnn_scratch`中的实验相同,\n", "我们首先加载时光机器数据集。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "96a7231c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:24:18.324326Z", "iopub.status.busy": "2023-08-18T07:24:18.323673Z", "iopub.status.idle": "2023-08-18T07:24:21.607373Z", "shell.execute_reply": "2023-08-18T07:24:21.606483Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "batch_size, num_steps = 32, 35\n", "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)" ] }, { "cell_type": "markdown", "id": "3ef97237", "metadata": { "origin_pos": 5 }, "source": [ "### [**初始化模型参数**]\n", "\n", "接下来,我们需要定义和初始化模型参数。\n", "如前所述,超参数`num_hiddens`定义隐藏单元的数量。\n", "我们按照标准差$0.01$的高斯分布初始化权重,并将偏置项设为$0$。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "2d90745d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:24:21.612618Z", "iopub.status.busy": "2023-08-18T07:24:21.611785Z", "iopub.status.idle": "2023-08-18T07:24:21.619127Z", "shell.execute_reply": "2023-08-18T07:24:21.618267Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def get_lstm_params(vocab_size, num_hiddens, device):\n", " num_inputs = num_outputs = vocab_size\n", "\n", " def normal(shape):\n", " return torch.randn(size=shape, device=device)*0.01\n", "\n", " def three():\n", " return (normal((num_inputs, num_hiddens)),\n", " normal((num_hiddens, num_hiddens)),\n", " torch.zeros(num_hiddens, device=device))\n", "\n", " W_xi, W_hi, b_i = three() # 输入门参数\n", " W_xf, W_hf, b_f = three() # 遗忘门参数\n", " W_xo, W_ho, b_o = three() # 输出门参数\n", " W_xc, W_hc, b_c = three() # 候选记忆元参数\n", " # 输出层参数\n", " W_hq = normal((num_hiddens, num_outputs))\n", " b_q = torch.zeros(num_outputs, device=device)\n", " # 附加梯度\n", " params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,\n", " b_c, W_hq, b_q]\n", " for param in params:\n", " param.requires_grad_(True)\n", " return params" ] }, { "cell_type": "markdown", "id": "57aecd87", "metadata": { "origin_pos": 10 }, "source": [ "### 定义模型\n", "\n", "在[**初始化函数**]中,\n", "长短期记忆网络的隐状态需要返回一个*额外*的记忆元,\n", "单元的值为0,形状为(批量大小,隐藏单元数)。\n", "因此,我们得到以下的状态初始化。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "beeae3b9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:24:21.623788Z", "iopub.status.busy": "2023-08-18T07:24:21.623183Z", "iopub.status.idle": "2023-08-18T07:24:21.628423Z", "shell.execute_reply": "2023-08-18T07:24:21.627603Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def init_lstm_state(batch_size, num_hiddens, device):\n", " return (torch.zeros((batch_size, num_hiddens), device=device),\n", " torch.zeros((batch_size, num_hiddens), device=device))" ] }, { "cell_type": "markdown", "id": "9c45563d", "metadata": { "origin_pos": 15 }, "source": [ "[**实际模型**]的定义与我们前面讨论的一样:\n", "提供三个门和一个额外的记忆元。\n", "请注意,只有隐状态才会传递到输出层,\n", "而记忆元$\\mathbf{C}_t$不直接参与输出计算。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "50de4e9c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:24:21.632468Z", "iopub.status.busy": "2023-08-18T07:24:21.631955Z", "iopub.status.idle": "2023-08-18T07:24:21.638292Z", "shell.execute_reply": "2023-08-18T07:24:21.637498Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def lstm(inputs, state, params):\n", " [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,\n", " W_hq, b_q] = params\n", " (H, C) = state\n", " outputs = []\n", " for X in inputs:\n", " I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)\n", " F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)\n", " O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)\n", " C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)\n", " C = F * C + I * C_tilda\n", " H = O * torch.tanh(C)\n", " Y = (H @ W_hq) + b_q\n", " outputs.append(Y)\n", " return torch.cat(outputs, dim=0), (H, C)" ] }, { "cell_type": "markdown", "id": "c36361aa", "metadata": { "origin_pos": 20 }, "source": [ "### [**训练**]和预测\n", "\n", "让我们通过实例化 :numref:`sec_rnn_scratch`中\n", "引入的`RNNModelScratch`类来训练一个长短期记忆网络,\n", "就如我们在 :numref:`sec_gru`中所做的一样。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "e2025d7a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:24:21.642387Z", "iopub.status.busy": "2023-08-18T07:24:21.641860Z", "iopub.status.idle": "2023-08-18T07:28:54.778673Z", "shell.execute_reply": "2023-08-18T07:28:54.777826Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "perplexity 1.3, 17736.0 tokens/sec on cuda:0\n", "time traveller for so it will leong go it we melenot ir cove i s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "traveller care be can so i ngrecpely as along the time dime\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:28:54.745221\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()\n", "num_epochs, lr = 500, 1\n", "model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,\n", " init_lstm_state, lstm)\n", "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)" ] }, { "cell_type": "markdown", "id": "245cff43", "metadata": { "origin_pos": 24 }, "source": [ "## [**简洁实现**]\n", "\n", "使用高级API,我们可以直接实例化`LSTM`模型。\n", "高级API封装了前文介绍的所有配置细节。\n", "这段代码的运行速度要快得多,\n", "因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "ca7c37b2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:28:54.782381Z", "iopub.status.busy": "2023-08-18T07:28:54.781811Z", "iopub.status.idle": "2023-08-18T07:29:19.488133Z", "shell.execute_reply": "2023-08-18T07:29:19.487180Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "perplexity 1.1, 234815.0 tokens/sec on cuda:0\n", "time traveller for so it will be convenient to speak of himwas e\n", "travelleryou can show black is white by argument said filby\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:29:19.454029\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_inputs = vocab_size\n", "lstm_layer = nn.LSTM(num_inputs, num_hiddens)\n", "model = d2l.RNNModel(lstm_layer, len(vocab))\n", "model = model.to(device)\n", "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)" ] }, { "cell_type": "markdown", "id": "bc39a6bb", "metadata": { "origin_pos": 29 }, "source": [ "长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。\n", "多年来已经提出了其许多变体,例如,多层、残差连接、不同类型的正则化。\n", "然而,由于序列的长距离依赖性,训练长短期记忆网络\n", "和其他序列模型(例如门控循环单元)的成本是相当高的。\n", "在后面的内容中,我们将讲述更高级的替代模型,如Transformer。\n", "\n", "## 小结\n", "\n", "* 长短期记忆网络有三种类型的门:输入门、遗忘门和输出门。\n", "* 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。\n", "* 长短期记忆网络可以缓解梯度消失和梯度爆炸。\n", "\n", "\n", "## 练习\n", "\n", "1. 调整和分析超参数对运行时间、困惑度和输出顺序的影响。\n", "1. 如何更改模型以生成适当的单词,而不是字符序列?\n", "1. 在给定隐藏层维度的情况下,比较门控循环单元、长短期记忆网络和常规循环神经网络的计算成本。要特别注意训练和推断成本。\n", "1. 既然候选记忆元通过使用$\\tanh$函数来确保值范围在$(-1,1)$之间,那么为什么隐状态需要再次使用$\\tanh$函数来确保输出值范围在$(-1,1)$之间呢?\n", "1. 实现一个能够基于时间序列进行预测而不是基于字符序列进行预测的长短期记忆网络模型。\n" ] }, { "cell_type": "markdown", "id": "416f6654", "metadata": { "origin_pos": 31, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/2768)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }