{ "cells": [ { "cell_type": "markdown", "id": "5df6c10f", "metadata": { "origin_pos": 0 }, "source": [ "# 序列模型\n", ":label:`sec_sequence`\n", "\n", "想象一下有人正在看网飞(Netflix,一个国外的视频网站)上的电影。\n", "一名忠实的用户会对每一部电影都给出评价,\n", "毕竟一部好电影需要更多的支持和认可。\n", "然而事实证明,事情并不那么简单。\n", "随着时间的推移,人们对电影的看法会发生很大的变化。\n", "事实上,心理学家甚至对这些现象起了名字:\n", "\n", "* *锚定*(anchoring)效应:基于其他人的意见做出评价。\n", " 例如,奥斯卡颁奖后,受到关注的电影的评分会上升,尽管它还是原来那部电影。\n", " 这种影响将持续几个月,直到人们忘记了这部电影曾经获得的奖项。\n", " 结果表明( :cite:`Wu.Ahmed.Beutel.ea.2017`),这种效应会使评分提高半个百分点以上。\n", "* *享乐适应*(hedonic adaption):人们迅速接受并且适应一种更好或者更坏的情况\n", " 作为新的常态。\n", " 例如,在看了很多好电影之后,人们会强烈期望下部电影会更好。\n", " 因此,在许多精彩的电影被看过之后,即使是一部普通的也可能被认为是糟糕的。\n", "* *季节性*(seasonality):少有观众喜欢在八月看圣诞老人的电影。\n", "* 有时,电影会由于导演或演员在制作中的不当行为变得不受欢迎。\n", "* 有些电影因为其极度糟糕只能成为小众电影。*Plan9from Outer Space*和*Troll2*就因为这个原因而臭名昭著的。\n", "\n", "简而言之,电影评分决不是固定不变的。\n", "因此,使用时间动力学可以得到更准确的电影推荐 :cite:`Koren.2009`。\n", "当然,序列数据不仅仅是关于电影评分的。\n", "下面给出了更多的场景。\n", "\n", "* 在使用程序时,许多用户都有很强的特定习惯。\n", " 例如,在学生放学后社交媒体应用更受欢迎。在市场开放时股市交易软件更常用。\n", "* 预测明天的股价要比过去的股价更困难,尽管两者都只是估计一个数字。\n", " 毕竟,先见之明比事后诸葛亮难得多。\n", " 在统计学中,前者(对超出已知观测范围进行预测)称为*外推法*(extrapolation),\n", " 而后者(在现有观测值之间进行估计)称为*内插法*(interpolation)。\n", "* 在本质上,音乐、语音、文本和视频都是连续的。\n", " 如果它们的序列被我们重排,那么就会失去原有的意义。\n", " 比如,一个文本标题“狗咬人”远没有“人咬狗”那么令人惊讶,尽管组成两句话的字完全相同。\n", "* 地震具有很强的相关性,即大地震发生后,很可能会有几次小余震,\n", " 这些余震的强度比非大地震后的余震要大得多。\n", " 事实上,地震是时空相关的,即余震通常发生在很短的时间跨度和很近的距离内。\n", "* 人类之间的互动也是连续的,这可以从微博上的争吵和辩论中看出。\n", "\n", "## 统计工具\n", "\n", "处理序列数据需要统计工具和新的深度神经网络架构。\n", "为了简单起见,我们以 :numref:`fig_ftse100`所示的股票价格(富时100指数)为例。\n", "\n", "![近30年的富时100指数](../img/ftse100.png)\n", ":width:`400px`\n", ":label:`fig_ftse100`\n", "\n", "其中,用$x_t$表示价格,即在*时间步*(time step)\n", "$t \\in \\mathbb{Z}^+$时,观察到的价格$x_t$。\n", "请注意,$t$对于本文中的序列通常是离散的,并在整数或其子集上变化。\n", "假设一个交易员想在$t$日的股市中表现良好,于是通过以下途径预测$x_t$:\n", "\n", "$$x_t \\sim P(x_t \\mid x_{t-1}, \\ldots, x_1).$$\n", "\n", "### 自回归模型\n", "\n", "为了实现这个预测,交易员可以使用回归模型,\n", "例如在 :numref:`sec_linear_concise`中训练的模型。\n", "仅有一个主要问题:输入数据的数量,\n", "输入$x_{t-1}, \\ldots, x_1$本身因$t$而异。\n", "也就是说,输入数据的数量这个数字将会随着我们遇到的数据量的增加而增加,\n", "因此需要一个近似方法来使这个计算变得容易处理。\n", "本章后面的大部分内容将围绕着如何有效估计\n", "$P(x_t \\mid x_{t-1}, \\ldots, x_1)$展开。\n", "简单地说,它归结为以下两种策略。\n", "\n", "第一种策略,假设在现实情况下相当长的序列\n", "$x_{t-1}, \\ldots, x_1$可能是不必要的,\n", "因此我们只需要满足某个长度为$\\tau$的时间跨度,\n", "即使用观测序列$x_{t-1}, \\ldots, x_{t-\\tau}$。\n", "当下获得的最直接的好处就是参数的数量总是不变的,\n", "至少在$t > \\tau$时如此,这就使我们能够训练一个上面提及的深度网络。\n", "这种模型被称为*自回归模型*(autoregressive models),\n", "因为它们是对自己执行回归。\n", "\n", "第二种策略,如 :numref:`fig_sequence-model`所示,\n", "是保留一些对过去观测的总结$h_t$,\n", "并且同时更新预测$\\hat{x}_t$和总结$h_t$。\n", "这就产生了基于$\\hat{x}_t = P(x_t \\mid h_{t})$估计$x_t$,\n", "以及公式$h_t = g(h_{t-1}, x_{t-1})$更新的模型。\n", "由于$h_t$从未被观测到,这类模型也被称为\n", "*隐变量自回归模型*(latent autoregressive models)。\n", "\n", "![隐变量自回归模型](../img/sequence-model.svg)\n", ":label:`fig_sequence-model`\n", "\n", "这两种情况都有一个显而易见的问题:如何生成训练数据?\n", "一个经典方法是使用历史观测来预测下一个未来观测。\n", "显然,我们并不指望时间会停滞不前。\n", "然而,一个常见的假设是虽然特定值$x_t$可能会改变,\n", "但是序列本身的动力学不会改变。\n", "这样的假设是合理的,因为新的动力学一定受新的数据影响,\n", "而我们不可能用目前所掌握的数据来预测新的动力学。\n", "统计学家称不变的动力学为*静止的*(stationary)。\n", "因此,整个序列的估计值都将通过以下的方式获得:\n", "\n", "$$P(x_1, \\ldots, x_T) = \\prod_{t=1}^T P(x_t \\mid x_{t-1}, \\ldots, x_1).$$\n", "\n", "注意,如果我们处理的是离散的对象(如单词),\n", "而不是连续的数字,则上述的考虑仍然有效。\n", "唯一的差别是,对于离散的对象,\n", "我们需要使用分类器而不是回归模型来估计$P(x_t \\mid x_{t-1}, \\ldots, x_1)$。\n", "\n", "### 马尔可夫模型\n", "\n", "回想一下,在自回归模型的近似法中,\n", "我们使用$x_{t-1}, \\ldots, x_{t-\\tau}$\n", "而不是$x_{t-1}, \\ldots, x_1$来估计$x_t$。\n", "只要这种是近似精确的,我们就说序列满足*马尔可夫条件*(Markov condition)。\n", "特别是,如果$\\tau = 1$,得到一个\n", "*一阶马尔可夫模型*(first-order Markov model),\n", "$P(x)$由下式给出:\n", "\n", "$$P(x_1, \\ldots, x_T) = \\prod_{t=1}^T P(x_t \\mid x_{t-1}) \\text{ 当 } P(x_1 \\mid x_0) = P(x_1).$$\n", "\n", "当假设$x_t$仅是离散值时,这样的模型特别棒,\n", "因为在这种情况下,使用动态规划可以沿着马尔可夫链精确地计算结果。\n", "例如,我们可以高效地计算$P(x_{t+1} \\mid x_{t-1})$:\n", "\n", "$$\n", "\\begin{aligned}\n", "P(x_{t+1} \\mid x_{t-1})\n", "&= \\frac{\\sum_{x_t} P(x_{t+1}, x_t, x_{t-1})}{P(x_{t-1})}\\\\\n", "&= \\frac{\\sum_{x_t} P(x_{t+1} \\mid x_t, x_{t-1}) P(x_t, x_{t-1})}{P(x_{t-1})}\\\\\n", "&= \\sum_{x_t} P(x_{t+1} \\mid x_t) P(x_t \\mid x_{t-1})\n", "\\end{aligned}\n", "$$\n", "\n", "利用这一事实,我们只需要考虑过去观察中的一个非常短的历史:\n", "$P(x_{t+1} \\mid x_t, x_{t-1}) = P(x_{t+1} \\mid x_t)$。\n", "隐马尔可夫模型中的动态规划超出了本节的范围\n", "(我们将在 :numref:`sec_bi_rnn`再次遇到),\n", "而动态规划这些计算工具已经在控制算法和强化学习算法广泛使用。\n", "\n", "### 因果关系\n", "\n", "原则上,将$P(x_1, \\ldots, x_T)$倒序展开也没什么问题。\n", "毕竟,基于条件概率公式,我们总是可以写出:\n", "\n", "$$P(x_1, \\ldots, x_T) = \\prod_{t=T}^1 P(x_t \\mid x_{t+1}, \\ldots, x_T).$$\n", "\n", "事实上,如果基于一个马尔可夫模型,\n", "我们还可以得到一个反向的条件概率分布。\n", "然而,在许多情况下,数据存在一个自然的方向,即在时间上是前进的。\n", "很明显,未来的事件不能影响过去。\n", "因此,如果我们改变$x_t$,可能会影响未来发生的事情$x_{t+1}$,但不能反过来。\n", "也就是说,如果我们改变$x_t$,基于过去事件得到的分布不会改变。\n", "因此,解释$P(x_{t+1} \\mid x_t)$应该比解释$P(x_t \\mid x_{t+1})$更容易。\n", "例如,在某些情况下,对于某些可加性噪声$\\epsilon$,\n", "显然我们可以找到$x_{t+1} = f(x_t) + \\epsilon$,\n", "而反之则不行 :cite:`Hoyer.Janzing.Mooij.ea.2009`。\n", "而这个向前推进的方向恰好也是我们通常感兴趣的方向。\n", "彼得斯等人 :cite:`Peters.Janzing.Scholkopf.2017`\n", "对该主题的更多内容做了详尽的解释,而我们的上述讨论只是其中的冰山一角。\n", "\n", "## 训练\n", "\n", "在了解了上述统计工具后,让我们在实践中尝试一下!\n", "首先,我们生成一些数据:(**使用正弦函数和一些可加性噪声来生成序列数据,\n", "时间步为$1, 2, \\ldots, 1000$。**)\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "cec552f1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:29.737097Z", "iopub.status.busy": "2023-08-18T07:00:29.736329Z", "iopub.status.idle": "2023-08-18T07:00:32.767900Z", "shell.execute_reply": "2023-08-18T07:00:32.766800Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "code", "execution_count": 2, "id": "8703181c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:32.773582Z", "iopub.status.busy": "2023-08-18T07:00:32.772564Z", "iopub.status.idle": "2023-08-18T07:00:33.054920Z", "shell.execute_reply": "2023-08-18T07:00:33.053899Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:33.001240\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "T = 1000 # 总共产生1000个点\n", "time = torch.arange(1, T + 1, dtype=torch.float32)\n", "x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))\n", "d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))" ] }, { "cell_type": "markdown", "id": "91a4997c", "metadata": { "origin_pos": 7 }, "source": [ "接下来,我们将这个序列转换为模型的*特征-标签*(feature-label)对。\n", "基于嵌入维度$\\tau$,我们[**将数据映射为数据对$y_t = x_t$\n", "和$\\mathbf{x}_t = [x_{t-\\tau}, \\ldots, x_{t-1}]$。**]\n", "这比我们提供的数据样本少了$\\tau$个,\n", "因为我们没有足够的历史记录来描述前$\\tau$个数据样本。\n", "一个简单的解决办法是:如果拥有足够长的序列就丢弃这几项;\n", "另一个方法是用零填充序列。\n", "在这里,我们仅使用前600个“特征-标签”对进行训练。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "15071cc2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:33.062134Z", "iopub.status.busy": "2023-08-18T07:00:33.061290Z", "iopub.status.idle": "2023-08-18T07:00:33.068807Z", "shell.execute_reply": "2023-08-18T07:00:33.067785Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "tau = 4\n", "features = torch.zeros((T - tau, tau))\n", "for i in range(tau):\n", " features[:, i] = x[i: T - tau + i]\n", "labels = x[tau:].reshape((-1, 1))" ] }, { "cell_type": "code", "execution_count": 4, "id": "de5cf9b0", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:33.073641Z", "iopub.status.busy": "2023-08-18T07:00:33.072850Z", "iopub.status.idle": "2023-08-18T07:00:33.078757Z", "shell.execute_reply": "2023-08-18T07:00:33.077578Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size, n_train = 16, 600\n", "# 只有前n_train个样本用于训练\n", "train_iter = d2l.load_array((features[:n_train], labels[:n_train]),\n", " batch_size, is_train=True)" ] }, { "cell_type": "markdown", "id": "3f830b1f", "metadata": { "origin_pos": 11 }, "source": [ "在这里,我们[**使用一个相当简单的架构训练模型:\n", "一个拥有两个全连接层的多层感知机**],ReLU激活函数和平方损失。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "bf897572", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:33.083521Z", "iopub.status.busy": "2023-08-18T07:00:33.082814Z", "iopub.status.idle": "2023-08-18T07:00:33.089936Z", "shell.execute_reply": "2023-08-18T07:00:33.088945Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "# 初始化网络权重的函数\n", "def init_weights(m):\n", " if type(m) == nn.Linear:\n", " nn.init.xavier_uniform_(m.weight)\n", "\n", "# 一个简单的多层感知机\n", "def get_net():\n", " net = nn.Sequential(nn.Linear(4, 10),\n", " nn.ReLU(),\n", " nn.Linear(10, 1))\n", " net.apply(init_weights)\n", " return net\n", "\n", "# 平方损失。注意:MSELoss计算平方误差时不带系数1/2\n", "loss = nn.MSELoss(reduction='none')" ] }, { "cell_type": "markdown", "id": "62592eee", "metadata": { "origin_pos": 16 }, "source": [ "现在,准备[**训练模型**]了。实现下面的训练代码的方式与前面几节(如 :numref:`sec_linear_concise`)中的循环训练基本相同。因此,我们不会深入探讨太多细节。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "8842e1a5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:33.095006Z", "iopub.status.busy": "2023-08-18T07:00:33.094051Z", "iopub.status.idle": "2023-08-18T07:00:33.409280Z", "shell.execute_reply": "2023-08-18T07:00:33.408040Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 1, loss: 0.076846\n", "epoch 2, loss: 0.056340\n", "epoch 3, loss: 0.053779\n", "epoch 4, loss: 0.056320\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch 5, loss: 0.051650\n" ] } ], "source": [ "def train(net, train_iter, loss, epochs, lr):\n", " trainer = torch.optim.Adam(net.parameters(), lr)\n", " for epoch in range(epochs):\n", " for X, y in train_iter:\n", " trainer.zero_grad()\n", " l = loss(net(X), y)\n", " l.sum().backward()\n", " trainer.step()\n", " print(f'epoch {epoch + 1}, '\n", " f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')\n", "\n", "net = get_net()\n", "train(net, train_iter, loss, 5, 0.01)" ] }, { "cell_type": "markdown", "id": "e60966e7", "metadata": { "origin_pos": 21 }, "source": [ "## 预测\n", "\n", "由于训练损失很小,因此我们期望模型能有很好的工作效果。\n", "让我们看看这在实践中意味着什么。\n", "首先是检查[**模型预测下一个时间步**]的能力,\n", "也就是*单步预测*(one-step-ahead prediction)。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "c8fc264a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:33.413628Z", "iopub.status.busy": "2023-08-18T07:00:33.413334Z", "iopub.status.idle": "2023-08-18T07:00:33.703430Z", "shell.execute_reply": "2023-08-18T07:00:33.702292Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:33.652151\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "onestep_preds = net(features)\n", "d2l.plot([time, time[tau:]],\n", " [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',\n", " 'x', legend=['data', '1-step preds'], xlim=[1, 1000],\n", " figsize=(6, 3))" ] }, { "cell_type": "markdown", "id": "8170de92", "metadata": { "origin_pos": 23 }, "source": [ "正如我们所料,单步预测效果不错。\n", "即使这些预测的时间步超过了$600+4$(`n_train + tau`),\n", "其结果看起来仍然是可信的。\n", "然而有一个小问题:如果数据观察序列的时间步只到$604$,\n", "我们需要一步一步地向前迈进:\n", "$$\n", "\\hat{x}_{605} = f(x_{601}, x_{602}, x_{603}, x_{604}), \\\\\n", "\\hat{x}_{606} = f(x_{602}, x_{603}, x_{604}, \\hat{x}_{605}), \\\\\n", "\\hat{x}_{607} = f(x_{603}, x_{604}, \\hat{x}_{605}, \\hat{x}_{606}),\\\\\n", "\\hat{x}_{608} = f(x_{604}, \\hat{x}_{605}, \\hat{x}_{606}, \\hat{x}_{607}),\\\\\n", "\\hat{x}_{609} = f(\\hat{x}_{605}, \\hat{x}_{606}, \\hat{x}_{607}, \\hat{x}_{608}),\\\\\n", "\\ldots\n", "$$\n", "\n", "通常,对于直到$x_t$的观测序列,其在时间步$t+k$处的预测输出$\\hat{x}_{t+k}$\n", "称为$k$*步预测*($k$-step-ahead-prediction)。\n", "由于我们的观察已经到了$x_{604}$,它的$k$步预测是$\\hat{x}_{604+k}$。\n", "换句话说,我们必须使用我们自己的预测(而不是原始数据)来[**进行多步预测**]。\n", "让我们看看效果如何。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "89ba8713", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:33.710400Z", "iopub.status.busy": "2023-08-18T07:00:33.709599Z", "iopub.status.idle": "2023-08-18T07:00:33.765864Z", "shell.execute_reply": "2023-08-18T07:00:33.764650Z" }, "origin_pos": 24, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "multistep_preds = torch.zeros(T)\n", "multistep_preds[: n_train + tau] = x[: n_train + tau]\n", "for i in range(n_train + tau, T):\n", " multistep_preds[i] = net(\n", " multistep_preds[i - tau:i].reshape((1, -1)))" ] }, { "cell_type": "code", "execution_count": 9, "id": "f29da1fa", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:33.770459Z", "iopub.status.busy": "2023-08-18T07:00:33.769838Z", "iopub.status.idle": "2023-08-18T07:00:34.140732Z", "shell.execute_reply": "2023-08-18T07:00:34.139616Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:34.059922\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.plot([time, time[tau:], time[n_train + tau:]],\n", " [x.detach().numpy(), onestep_preds.detach().numpy(),\n", " multistep_preds[n_train + tau:].detach().numpy()], 'time',\n", " 'x', legend=['data', '1-step preds', 'multistep preds'],\n", " xlim=[1, 1000], figsize=(6, 3))" ] }, { "cell_type": "markdown", "id": "5e25e464", "metadata": { "origin_pos": 28 }, "source": [ "如上面的例子所示,绿线的预测显然并不理想。\n", "经过几个预测步骤之后,预测的结果很快就会衰减到一个常数。\n", "为什么这个算法效果这么差呢?事实是由于错误的累积:\n", "假设在步骤$1$之后,我们积累了一些错误$\\epsilon_1 = \\bar\\epsilon$。\n", "于是,步骤$2$的输入被扰动了$\\epsilon_1$,\n", "结果积累的误差是依照次序的$\\epsilon_2 = \\bar\\epsilon + c \\epsilon_1$,\n", "其中$c$为某个常数,后面的预测误差依此类推。\n", "因此误差可能会相当快地偏离真实的观测结果。\n", "例如,未来$24$小时的天气预报往往相当准确,\n", "但超过这一点,精度就会迅速下降。\n", "我们将在本章及后续章节中讨论如何改进这一点。\n", "\n", "基于$k = 1, 4, 16, 64$,通过对整个序列预测的计算,\n", "让我们[**更仔细地看一下$k$步预测**]的困难。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "86fd62af", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:34.147991Z", "iopub.status.busy": "2023-08-18T07:00:34.147118Z", "iopub.status.idle": "2023-08-18T07:00:34.152314Z", "shell.execute_reply": "2023-08-18T07:00:34.151283Z" }, "origin_pos": 29, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "max_steps = 64" ] }, { "cell_type": "code", "execution_count": 11, "id": "b5468e42", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:34.156964Z", "iopub.status.busy": "2023-08-18T07:00:34.156220Z", "iopub.status.idle": "2023-08-18T07:00:34.188969Z", "shell.execute_reply": "2023-08-18T07:00:34.187836Z" }, "origin_pos": 30, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "features = torch.zeros((T - tau - max_steps + 1, tau + max_steps))\n", "# 列i(i=tau)是来自(i-tau+1)步的预测,其时间步从(i)到(i+T-tau-max_steps+1)\n", "for i in range(tau, tau + max_steps):\n", " features[:, i] = net(features[:, i - tau:i]).reshape(-1)" ] }, { "cell_type": "code", "execution_count": 12, "id": "e41dbccd", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:34.193832Z", "iopub.status.busy": "2023-08-18T07:00:34.193416Z", "iopub.status.idle": "2023-08-18T07:00:34.567071Z", "shell.execute_reply": "2023-08-18T07:00:34.566267Z" }, "origin_pos": 33, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:34.491900\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "steps = (1, 4, 16, 64)\n", "d2l.plot([time[tau + i - 1: T - max_steps + i] for i in steps],\n", " [features[:, (tau + i - 1)].detach().numpy() for i in steps], 'time', 'x',\n", " legend=[f'{i}-step preds' for i in steps], xlim=[5, 1000],\n", " figsize=(6, 3))" ] }, { "cell_type": "markdown", "id": "13e7b084", "metadata": { "origin_pos": 34 }, "source": [ "以上例子清楚地说明了当我们试图预测更远的未来时,预测的质量是如何变化的。\n", "虽然“$4$步预测”看起来仍然不错,但超过这个跨度的任何预测几乎都是无用的。\n", "\n", "## 小结\n", "\n", "* 内插法(在现有观测值之间进行估计)和外推法(对超出已知观测范围进行预测)在实践的难度上差别很大。因此,对于所拥有的序列数据,在训练时始终要尊重其时间顺序,即最好不要基于未来的数据进行训练。\n", "* 序列模型的估计需要专门的统计工具,两种较流行的选择是自回归模型和隐变量自回归模型。\n", "* 对于时间是向前推进的因果模型,正向估计通常比反向估计更容易。\n", "* 对于直到时间步$t$的观测序列,其在时间步$t+k$的预测输出是“$k$步预测”。随着我们对预测时间$k$值的增加,会造成误差的快速累积和预测质量的极速下降。\n", "\n", "## 练习\n", "\n", "1. 改进本节实验中的模型。\n", " 1. 是否包含了过去$4$个以上的观测结果?真实值需要是多少个?\n", " 1. 如果没有噪音,需要多少个过去的观测结果?提示:把$\\sin$和$\\cos$写成微分方程。\n", " 1. 可以在保持特征总数不变的情况下合并旧的观察结果吗?这能提高正确度吗?为什么?\n", " 1. 改变神经网络架构并评估其性能。\n", "1. 一位投资者想要找到一种好的证券来购买。他查看过去的回报,以决定哪一种可能是表现良好的。这一策略可能会出什么问题呢?\n", "1. 时间是向前推进的因果模型在多大程度上适用于文本呢?\n", "1. 举例说明什么时候可能需要隐变量自回归模型来捕捉数据的动力学模型。\n" ] }, { "cell_type": "markdown", "id": "0938a128", "metadata": { "origin_pos": 36, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/2091)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }