{ "cells": [ { "cell_type": "markdown", "id": "12605036", "metadata": { "origin_pos": 0 }, "source": [ "# 注意力汇聚:Nadaraya-Watson 核回归\n", ":label:`sec_nadaraya-watson`\n", "\n", "上节介绍了框架下的注意力机制的主要成分 :numref:`fig_qkv`:\n", "查询(自主提示)和键(非自主提示)之间的交互形成了注意力汇聚;\n", "注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。\n", "本节将介绍注意力汇聚的更多细节,\n", "以便从宏观上了解注意力机制在实践中的运作方式。\n", "具体来说,1964年提出的Nadaraya-Watson核回归模型\n", "是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "47b48700", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:14.699824Z", "iopub.status.busy": "2023-08-18T07:07:14.699278Z", "iopub.status.idle": "2023-08-18T07:07:16.694044Z", "shell.execute_reply": "2023-08-18T07:07:16.693186Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "45448f28", "metadata": { "origin_pos": 5 }, "source": [ "## [**生成数据集**]\n", "\n", "简单起见,考虑下面这个回归问题:\n", "给定的成对的“输入-输出”数据集\n", "$\\{(x_1, y_1), \\ldots, (x_n, y_n)\\}$,\n", "如何学习$f$来预测任意新输入$x$的输出$\\hat{y} = f(x)$?\n", "\n", "根据下面的非线性函数生成一个人工数据集,\n", "其中加入的噪声项为$\\epsilon$:\n", "\n", "$$y_i = 2\\sin(x_i) + x_i^{0.8} + \\epsilon,$$\n", "\n", "其中$\\epsilon$服从均值为$0$和标准差为$0.5$的正态分布。\n", "在这里生成了$50$个训练样本和$50$个测试样本。\n", "为了更好地可视化之后的注意力模式,需要将训练样本进行排序。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "77ea63dd", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:16.698617Z", "iopub.status.busy": "2023-08-18T07:07:16.697928Z", "iopub.status.idle": "2023-08-18T07:07:16.720614Z", "shell.execute_reply": "2023-08-18T07:07:16.719799Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "n_train = 50 # 训练样本数\n", "x_train, _ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本" ] }, { "cell_type": "code", "execution_count": 3, "id": "2b36fd68", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:16.724389Z", "iopub.status.busy": "2023-08-18T07:07:16.723850Z", "iopub.status.idle": "2023-08-18T07:07:16.734529Z", "shell.execute_reply": "2023-08-18T07:07:16.733732Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "50" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f(x):\n", " return 2 * torch.sin(x) + x**0.8\n", "\n", "y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # 训练样本的输出\n", "x_test = torch.arange(0, 5, 0.1) # 测试样本\n", "y_truth = f(x_test) # 测试样本的真实输出\n", "n_test = len(x_test) # 测试样本数\n", "n_test" ] }, { "cell_type": "markdown", "id": "a8cd762c", "metadata": { "origin_pos": 14 }, "source": [ "下面的函数将绘制所有的训练样本(样本由圆圈表示),\n", "不带噪声项的真实数据生成函数$f$(标记为“Truth”),\n", "以及学习得到的预测函数(标记为“Pred”)。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "84166e26", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:16.738117Z", "iopub.status.busy": "2023-08-18T07:07:16.737614Z", "iopub.status.idle": "2023-08-18T07:07:16.742118Z", "shell.execute_reply": "2023-08-18T07:07:16.741332Z" }, "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def plot_kernel_reg(y_hat):\n", " d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],\n", " xlim=[0, 5], ylim=[-1, 5])\n", " d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);" ] }, { "cell_type": "markdown", "id": "83acd3e5", "metadata": { "origin_pos": 16 }, "source": [ "## 平均汇聚\n", "\n", "先使用最简单的估计器来解决回归问题。\n", "基于平均汇聚来计算所有训练样本输出值的平均值:\n", "\n", "$$f(x) = \\frac{1}{n}\\sum_{i=1}^n y_i,$$\n", ":eqlabel:`eq_avg-pooling`\n", "\n", "如下图所示,这个估计器确实不够聪明。\n", "真实函数$f$(“Truth”)和预测函数(“Pred”)相差很大。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "b5227412", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:16.745659Z", "iopub.status.busy": "2023-08-18T07:07:16.745145Z", "iopub.status.idle": "2023-08-18T07:07:16.921666Z", "shell.execute_reply": "2023-08-18T07:07:16.920767Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:16.873741\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y_hat = torch.repeat_interleave(y_train.mean(), n_test)\n", "plot_kernel_reg(y_hat)" ] }, { "cell_type": "markdown", "id": "2d4155fa", "metadata": { "origin_pos": 21 }, "source": [ "## [**非参数注意力汇聚**]\n", "\n", "显然,平均汇聚忽略了输入$x_i$。\n", "于是Nadaraya :cite:`Nadaraya.1964`和\n", "Watson :cite:`Watson.1964`提出了一个更好的想法,\n", "根据输入的位置对输出$y_i$进行加权:\n", "\n", "$$f(x) = \\sum_{i=1}^n \\frac{K(x - x_i)}{\\sum_{j=1}^n K(x - x_j)} y_i,$$\n", ":eqlabel:`eq_nadaraya-watson`\n", "\n", "其中$K$是*核*(kernel)。\n", "公式 :eqref:`eq_nadaraya-watson`所描述的估计器被称为\n", "*Nadaraya-Watson核回归*(Nadaraya-Watson kernel regression)。\n", "这里不会深入讨论核函数的细节,\n", "但受此启发,\n", "我们可以从 :numref:`fig_qkv`中的注意力机制框架的角度\n", "重写 :eqref:`eq_nadaraya-watson`,\n", "成为一个更加通用的*注意力汇聚*(attention pooling)公式:\n", "\n", "$$f(x) = \\sum_{i=1}^n \\alpha(x, x_i) y_i,$$\n", ":eqlabel:`eq_attn-pooling`\n", "\n", "其中$x$是查询,$(x_i, y_i)$是键值对。\n", "比较 :eqref:`eq_attn-pooling`和 :eqref:`eq_avg-pooling`,\n", "注意力汇聚是$y_i$的加权平均。\n", "将查询$x$和键$x_i$之间的关系建模为\n", "*注意力权重*(attention weight)$\\alpha(x, x_i)$,\n", "如 :eqref:`eq_attn-pooling`所示,\n", "这个权重将被分配给每一个对应值$y_i$。\n", "对于任何查询,模型在所有键值对注意力权重都是一个有效的概率分布:\n", "它们是非负的,并且总和为1。\n", "\n", "为了更好地理解注意力汇聚,\n", "下面考虑一个*高斯核*(Gaussian kernel),其定义为:\n", "\n", "$$K(u) = \\frac{1}{\\sqrt{2\\pi}} \\exp(-\\frac{u^2}{2}).$$\n", "\n", "将高斯核代入 :eqref:`eq_attn-pooling`和\n", " :eqref:`eq_nadaraya-watson`可以得到:\n", "\n", "$$\\begin{aligned} f(x) &=\\sum_{i=1}^n \\alpha(x, x_i) y_i\\\\ &= \\sum_{i=1}^n \\frac{\\exp\\left(-\\frac{1}{2}(x - x_i)^2\\right)}{\\sum_{j=1}^n \\exp\\left(-\\frac{1}{2}(x - x_j)^2\\right)} y_i \\\\&= \\sum_{i=1}^n \\mathrm{softmax}\\left(-\\frac{1}{2}(x - x_i)^2\\right) y_i. \\end{aligned}$$\n", ":eqlabel:`eq_nadaraya-watson-gaussian`\n", "\n", "在 :eqref:`eq_nadaraya-watson-gaussian`中,\n", "如果一个键$x_i$越是接近给定的查询$x$,\n", "那么分配给这个键对应值$y_i$的注意力权重就会越大,\n", "也就“获得了更多的注意力”。\n", "\n", "值得注意的是,Nadaraya-Watson核回归是一个非参数模型。\n", "因此, :eqref:`eq_nadaraya-watson-gaussian`是\n", "*非参数的注意力汇聚*(nonparametric attention pooling)模型。\n", "接下来,我们将基于这个非参数的注意力汇聚模型来绘制预测结果。\n", "从绘制的结果会发现新的模型预测线是平滑的,并且比平均汇聚的预测更接近真实。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "892e92a9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:16.925532Z", "iopub.status.busy": "2023-08-18T07:07:16.924886Z", "iopub.status.idle": "2023-08-18T07:07:17.137585Z", "shell.execute_reply": "2023-08-18T07:07:17.136490Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:17.059877\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# X_repeat的形状:(n_test,n_train),\n", "# 每一行都包含着相同的测试输入(例如:同样的查询)\n", "X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))\n", "# x_train包含着键。attention_weights的形状:(n_test,n_train),\n", "# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重\n", "attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)\n", "# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重\n", "y_hat = torch.matmul(attention_weights, y_train)\n", "plot_kernel_reg(y_hat)" ] }, { "cell_type": "markdown", "id": "60a40ebc", "metadata": { "origin_pos": 26 }, "source": [ "现在来观察注意力的权重。\n", "这里测试数据的输入相当于查询,而训练数据的输入相当于键。\n", "因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近,\n", "注意力汇聚的[**注意力权重**]就越高。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "5068c7e7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:17.141869Z", "iopub.status.busy": "2023-08-18T07:07:17.141398Z", "iopub.status.idle": "2023-08-18T07:07:17.325925Z", "shell.execute_reply": "2023-08-18T07:07:17.325070Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:17.274591\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),\n", " xlabel='Sorted training inputs',\n", " ylabel='Sorted testing inputs')" ] }, { "cell_type": "markdown", "id": "73a72797", "metadata": { "origin_pos": 31 }, "source": [ "## [**带参数注意力汇聚**]\n", "\n", "非参数的Nadaraya-Watson核回归具有*一致性*(consistency)的优点:\n", "如果有足够的数据,此模型会收敛到最优结果。\n", "尽管如此,我们还是可以轻松地将可学习的参数集成到注意力汇聚中。\n", "\n", "例如,与 :eqref:`eq_nadaraya-watson-gaussian`略有不同,\n", "在下面的查询$x$和键$x_i$之间的距离乘以可学习参数$w$:\n", "\n", "$$\\begin{aligned}f(x) &= \\sum_{i=1}^n \\alpha(x, x_i) y_i \\\\&= \\sum_{i=1}^n \\frac{\\exp\\left(-\\frac{1}{2}((x - x_i)w)^2\\right)}{\\sum_{j=1}^n \\exp\\left(-\\frac{1}{2}((x - x_j)w)^2\\right)} y_i \\\\&= \\sum_{i=1}^n \\mathrm{softmax}\\left(-\\frac{1}{2}((x - x_i)w)^2\\right) y_i.\\end{aligned}$$\n", ":eqlabel:`eq_nadaraya-watson-gaussian-para`\n", "\n", "本节的余下部分将通过训练这个模型\n", " :eqref:`eq_nadaraya-watson-gaussian-para`来学习注意力汇聚的参数。\n", "\n", "### 批量矩阵乘法\n", "\n", ":label:`subsec_batch_dot`\n", "\n", "为了更有效地计算小批量数据的注意力,\n", "我们可以利用深度学习开发框架中提供的批量矩阵乘法。\n", "\n", "假设第一个小批量数据包含$n$个矩阵$\\mathbf{X}_1,\\ldots, \\mathbf{X}_n$,\n", "形状为$a\\times b$,\n", "第二个小批量包含$n$个矩阵$\\mathbf{Y}_1, \\ldots, \\mathbf{Y}_n$,\n", "形状为$b\\times c$。\n", "它们的批量矩阵乘法得到$n$个矩阵\n", "$\\mathbf{X}_1\\mathbf{Y}_1, \\ldots, \\mathbf{X}_n\\mathbf{Y}_n$,\n", "形状为$a\\times c$。\n", "因此,[**假定两个张量的形状分别是$(n,a,b)$和$(n,b,c)$,\n", "它们的批量矩阵乘法输出的形状为$(n,a,c)$**]。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "c5a4dc9e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:17.329873Z", "iopub.status.busy": "2023-08-18T07:07:17.329293Z", "iopub.status.idle": "2023-08-18T07:07:17.336079Z", "shell.execute_reply": "2023-08-18T07:07:17.335148Z" }, "origin_pos": 33, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 6])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.ones((2, 1, 4))\n", "Y = torch.ones((2, 4, 6))\n", "torch.bmm(X, Y).shape" ] }, { "cell_type": "markdown", "id": "acac0804", "metadata": { "origin_pos": 36 }, "source": [ "在注意力机制的背景中,我们可以[**使用小批量矩阵乘法来计算小批量数据中的加权平均值**]。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "7161cf35", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:17.341257Z", "iopub.status.busy": "2023-08-18T07:07:17.340341Z", "iopub.status.idle": "2023-08-18T07:07:17.348773Z", "shell.execute_reply": "2023-08-18T07:07:17.347847Z" }, "origin_pos": 38, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 4.5000]],\n", "\n", " [[14.5000]]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weights = torch.ones((2, 10)) * 0.1\n", "values = torch.arange(20.0).reshape((2, 10))\n", "torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))" ] }, { "cell_type": "markdown", "id": "f31cb449", "metadata": { "origin_pos": 41 }, "source": [ "### 定义模型\n", "\n", "基于 :eqref:`eq_nadaraya-watson-gaussian-para`中的\n", "[**带参数的注意力汇聚**],使用小批量矩阵乘法,\n", "定义Nadaraya-Watson核回归的带参数版本为:\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "e7aee504", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:17.353382Z", "iopub.status.busy": "2023-08-18T07:07:17.353094Z", "iopub.status.idle": "2023-08-18T07:07:17.359677Z", "shell.execute_reply": "2023-08-18T07:07:17.358720Z" }, "origin_pos": 43, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class NWKernelRegression(nn.Module):\n", " def __init__(self, **kwargs):\n", " super().__init__(**kwargs)\n", " self.w = nn.Parameter(torch.rand((1,), requires_grad=True))\n", "\n", " def forward(self, queries, keys, values):\n", " # queries和attention_weights的形状为(查询个数,“键-值”对个数)\n", " queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))\n", " self.attention_weights = nn.functional.softmax(\n", " -((queries - keys) * self.w)**2 / 2, dim=1)\n", " # values的形状为(查询个数,“键-值”对个数)\n", " return torch.bmm(self.attention_weights.unsqueeze(1),\n", " values.unsqueeze(-1)).reshape(-1)" ] }, { "cell_type": "markdown", "id": "192a922f", "metadata": { "origin_pos": 46 }, "source": [ "### 训练\n", "\n", "接下来,[**将训练数据集变换为键和值**]用于训练注意力模型。\n", "在带参数的注意力汇聚模型中,\n", "任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算,\n", "从而得到其对应的预测输出。\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "c738e178", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:17.363589Z", "iopub.status.busy": "2023-08-18T07:07:17.362749Z", "iopub.status.idle": "2023-08-18T07:07:17.369359Z", "shell.execute_reply": "2023-08-18T07:07:17.368357Z" }, "origin_pos": 48, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入\n", "X_tile = x_train.repeat((n_train, 1))\n", "# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出\n", "Y_tile = y_train.repeat((n_train, 1))\n", "# keys的形状:('n_train','n_train'-1)\n", "keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))\n", "# values的形状:('n_train','n_train'-1)\n", "values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))" ] }, { "cell_type": "markdown", "id": "245d8e27", "metadata": { "origin_pos": 51 }, "source": [ "[**训练带参数的注意力汇聚模型**]时,使用平方损失函数和随机梯度下降。\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "37b732ff", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:17.373285Z", "iopub.status.busy": "2023-08-18T07:07:17.372386Z", "iopub.status.idle": "2023-08-18T07:07:18.050395Z", "shell.execute_reply": "2023-08-18T07:07:18.049441Z" }, "origin_pos": 53, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:18.019132\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "net = NWKernelRegression()\n", "loss = nn.MSELoss(reduction='none')\n", "trainer = torch.optim.SGD(net.parameters(), lr=0.5)\n", "animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])\n", "\n", "for epoch in range(5):\n", " trainer.zero_grad()\n", " l = loss(net(x_train, keys, values), y_train)\n", " l.sum().backward()\n", " trainer.step()\n", " print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')\n", " animator.add(epoch + 1, float(l.sum()))" ] }, { "cell_type": "markdown", "id": "bcd73f8a", "metadata": { "origin_pos": 56 }, "source": [ "如下所示,训练完带参数的注意力汇聚模型后可以发现:\n", "在尝试拟合带噪声的训练数据时,\n", "[**预测结果绘制**]的线不如之前非参数模型的平滑。\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "79e57e0d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:18.054215Z", "iopub.status.busy": "2023-08-18T07:07:18.053635Z", "iopub.status.idle": "2023-08-18T07:07:18.224763Z", "shell.execute_reply": "2023-08-18T07:07:18.223767Z" }, "origin_pos": 58, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:18.177522\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)\n", "keys = x_train.repeat((n_test, 1))\n", "# value的形状:(n_test,n_train)\n", "values = y_train.repeat((n_test, 1))\n", "y_hat = net(x_test, keys, values).unsqueeze(1).detach()\n", "plot_kernel_reg(y_hat)" ] }, { "cell_type": "markdown", "id": "962bf3fe", "metadata": { "origin_pos": 61 }, "source": [ "为什么新的模型更不平滑了呢?\n", "下面看一下输出结果的绘制图:\n", "与非参数的注意力汇聚模型相比,\n", "带参数的模型加入可学习的参数后,\n", "[**曲线在注意力权重较大的区域变得更不平滑**]。\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "15949408", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:18.228629Z", "iopub.status.busy": "2023-08-18T07:07:18.227958Z", "iopub.status.idle": "2023-08-18T07:07:18.391875Z", "shell.execute_reply": "2023-08-18T07:07:18.391062Z" }, "origin_pos": 63, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:07:18.347398\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),\n", " xlabel='Sorted training inputs',\n", " ylabel='Sorted testing inputs')" ] }, { "cell_type": "markdown", "id": "0e4f7d57", "metadata": { "origin_pos": 66 }, "source": [ "## 小结\n", "\n", "* Nadaraya-Watson核回归是具有注意力机制的机器学习范例。\n", "* Nadaraya-Watson核回归的注意力汇聚是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于将值所对应的键和查询作为输入的函数。\n", "* 注意力汇聚可以分为非参数型和带参数型。\n", "\n", "## 练习\n", "\n", "1. 增加训练数据的样本数量,能否得到更好的非参数的Nadaraya-Watson核回归模型?\n", "1. 在带参数的注意力汇聚的实验中学习得到的参数$w$的价值是什么?为什么在可视化注意力权重时,它会使加权区域更加尖锐?\n", "1. 如何将超参数添加到非参数的Nadaraya-Watson核回归中以实现更好地预测结果?\n", "1. 为本节的核回归设计一个新的带参数的注意力汇聚模型。训练这个新模型并可视化其注意力权重。\n" ] }, { "cell_type": "markdown", "id": "72ded07b", "metadata": { "origin_pos": 68, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/5760)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }