{
"cells": [
{
"cell_type": "markdown",
"id": "12605036",
"metadata": {
"origin_pos": 0
},
"source": [
"# 注意力汇聚:Nadaraya-Watson 核回归\n",
":label:`sec_nadaraya-watson`\n",
"\n",
"上节介绍了框架下的注意力机制的主要成分 :numref:`fig_qkv`:\n",
"查询(自主提示)和键(非自主提示)之间的交互形成了注意力汇聚;\n",
"注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。\n",
"本节将介绍注意力汇聚的更多细节,\n",
"以便从宏观上了解注意力机制在实践中的运作方式。\n",
"具体来说,1964年提出的Nadaraya-Watson核回归模型\n",
"是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "47b48700",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:14.699824Z",
"iopub.status.busy": "2023-08-18T07:07:14.699278Z",
"iopub.status.idle": "2023-08-18T07:07:16.694044Z",
"shell.execute_reply": "2023-08-18T07:07:16.693186Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "45448f28",
"metadata": {
"origin_pos": 5
},
"source": [
"## [**生成数据集**]\n",
"\n",
"简单起见,考虑下面这个回归问题:\n",
"给定的成对的“输入-输出”数据集\n",
"$\\{(x_1, y_1), \\ldots, (x_n, y_n)\\}$,\n",
"如何学习$f$来预测任意新输入$x$的输出$\\hat{y} = f(x)$?\n",
"\n",
"根据下面的非线性函数生成一个人工数据集,\n",
"其中加入的噪声项为$\\epsilon$:\n",
"\n",
"$$y_i = 2\\sin(x_i) + x_i^{0.8} + \\epsilon,$$\n",
"\n",
"其中$\\epsilon$服从均值为$0$和标准差为$0.5$的正态分布。\n",
"在这里生成了$50$个训练样本和$50$个测试样本。\n",
"为了更好地可视化之后的注意力模式,需要将训练样本进行排序。\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "77ea63dd",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:16.698617Z",
"iopub.status.busy": "2023-08-18T07:07:16.697928Z",
"iopub.status.idle": "2023-08-18T07:07:16.720614Z",
"shell.execute_reply": "2023-08-18T07:07:16.719799Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"n_train = 50 # 训练样本数\n",
"x_train, _ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2b36fd68",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:16.724389Z",
"iopub.status.busy": "2023-08-18T07:07:16.723850Z",
"iopub.status.idle": "2023-08-18T07:07:16.734529Z",
"shell.execute_reply": "2023-08-18T07:07:16.733732Z"
},
"origin_pos": 11,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"50"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def f(x):\n",
" return 2 * torch.sin(x) + x**0.8\n",
"\n",
"y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # 训练样本的输出\n",
"x_test = torch.arange(0, 5, 0.1) # 测试样本\n",
"y_truth = f(x_test) # 测试样本的真实输出\n",
"n_test = len(x_test) # 测试样本数\n",
"n_test"
]
},
{
"cell_type": "markdown",
"id": "a8cd762c",
"metadata": {
"origin_pos": 14
},
"source": [
"下面的函数将绘制所有的训练样本(样本由圆圈表示),\n",
"不带噪声项的真实数据生成函数$f$(标记为“Truth”),\n",
"以及学习得到的预测函数(标记为“Pred”)。\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "84166e26",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:16.738117Z",
"iopub.status.busy": "2023-08-18T07:07:16.737614Z",
"iopub.status.idle": "2023-08-18T07:07:16.742118Z",
"shell.execute_reply": "2023-08-18T07:07:16.741332Z"
},
"origin_pos": 15,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def plot_kernel_reg(y_hat):\n",
" d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],\n",
" xlim=[0, 5], ylim=[-1, 5])\n",
" d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);"
]
},
{
"cell_type": "markdown",
"id": "83acd3e5",
"metadata": {
"origin_pos": 16
},
"source": [
"## 平均汇聚\n",
"\n",
"先使用最简单的估计器来解决回归问题。\n",
"基于平均汇聚来计算所有训练样本输出值的平均值:\n",
"\n",
"$$f(x) = \\frac{1}{n}\\sum_{i=1}^n y_i,$$\n",
":eqlabel:`eq_avg-pooling`\n",
"\n",
"如下图所示,这个估计器确实不够聪明。\n",
"真实函数$f$(“Truth”)和预测函数(“Pred”)相差很大。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b5227412",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:16.745659Z",
"iopub.status.busy": "2023-08-18T07:07:16.745145Z",
"iopub.status.idle": "2023-08-18T07:07:16.921666Z",
"shell.execute_reply": "2023-08-18T07:07:16.920767Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"y_hat = torch.repeat_interleave(y_train.mean(), n_test)\n",
"plot_kernel_reg(y_hat)"
]
},
{
"cell_type": "markdown",
"id": "2d4155fa",
"metadata": {
"origin_pos": 21
},
"source": [
"## [**非参数注意力汇聚**]\n",
"\n",
"显然,平均汇聚忽略了输入$x_i$。\n",
"于是Nadaraya :cite:`Nadaraya.1964`和\n",
"Watson :cite:`Watson.1964`提出了一个更好的想法,\n",
"根据输入的位置对输出$y_i$进行加权:\n",
"\n",
"$$f(x) = \\sum_{i=1}^n \\frac{K(x - x_i)}{\\sum_{j=1}^n K(x - x_j)} y_i,$$\n",
":eqlabel:`eq_nadaraya-watson`\n",
"\n",
"其中$K$是*核*(kernel)。\n",
"公式 :eqref:`eq_nadaraya-watson`所描述的估计器被称为\n",
"*Nadaraya-Watson核回归*(Nadaraya-Watson kernel regression)。\n",
"这里不会深入讨论核函数的细节,\n",
"但受此启发,\n",
"我们可以从 :numref:`fig_qkv`中的注意力机制框架的角度\n",
"重写 :eqref:`eq_nadaraya-watson`,\n",
"成为一个更加通用的*注意力汇聚*(attention pooling)公式:\n",
"\n",
"$$f(x) = \\sum_{i=1}^n \\alpha(x, x_i) y_i,$$\n",
":eqlabel:`eq_attn-pooling`\n",
"\n",
"其中$x$是查询,$(x_i, y_i)$是键值对。\n",
"比较 :eqref:`eq_attn-pooling`和 :eqref:`eq_avg-pooling`,\n",
"注意力汇聚是$y_i$的加权平均。\n",
"将查询$x$和键$x_i$之间的关系建模为\n",
"*注意力权重*(attention weight)$\\alpha(x, x_i)$,\n",
"如 :eqref:`eq_attn-pooling`所示,\n",
"这个权重将被分配给每一个对应值$y_i$。\n",
"对于任何查询,模型在所有键值对注意力权重都是一个有效的概率分布:\n",
"它们是非负的,并且总和为1。\n",
"\n",
"为了更好地理解注意力汇聚,\n",
"下面考虑一个*高斯核*(Gaussian kernel),其定义为:\n",
"\n",
"$$K(u) = \\frac{1}{\\sqrt{2\\pi}} \\exp(-\\frac{u^2}{2}).$$\n",
"\n",
"将高斯核代入 :eqref:`eq_attn-pooling`和\n",
" :eqref:`eq_nadaraya-watson`可以得到:\n",
"\n",
"$$\\begin{aligned} f(x) &=\\sum_{i=1}^n \\alpha(x, x_i) y_i\\\\ &= \\sum_{i=1}^n \\frac{\\exp\\left(-\\frac{1}{2}(x - x_i)^2\\right)}{\\sum_{j=1}^n \\exp\\left(-\\frac{1}{2}(x - x_j)^2\\right)} y_i \\\\&= \\sum_{i=1}^n \\mathrm{softmax}\\left(-\\frac{1}{2}(x - x_i)^2\\right) y_i. \\end{aligned}$$\n",
":eqlabel:`eq_nadaraya-watson-gaussian`\n",
"\n",
"在 :eqref:`eq_nadaraya-watson-gaussian`中,\n",
"如果一个键$x_i$越是接近给定的查询$x$,\n",
"那么分配给这个键对应值$y_i$的注意力权重就会越大,\n",
"也就“获得了更多的注意力”。\n",
"\n",
"值得注意的是,Nadaraya-Watson核回归是一个非参数模型。\n",
"因此, :eqref:`eq_nadaraya-watson-gaussian`是\n",
"*非参数的注意力汇聚*(nonparametric attention pooling)模型。\n",
"接下来,我们将基于这个非参数的注意力汇聚模型来绘制预测结果。\n",
"从绘制的结果会发现新的模型预测线是平滑的,并且比平均汇聚的预测更接近真实。\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "892e92a9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:16.925532Z",
"iopub.status.busy": "2023-08-18T07:07:16.924886Z",
"iopub.status.idle": "2023-08-18T07:07:17.137585Z",
"shell.execute_reply": "2023-08-18T07:07:17.136490Z"
},
"origin_pos": 23,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# X_repeat的形状:(n_test,n_train),\n",
"# 每一行都包含着相同的测试输入(例如:同样的查询)\n",
"X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))\n",
"# x_train包含着键。attention_weights的形状:(n_test,n_train),\n",
"# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重\n",
"attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)\n",
"# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重\n",
"y_hat = torch.matmul(attention_weights, y_train)\n",
"plot_kernel_reg(y_hat)"
]
},
{
"cell_type": "markdown",
"id": "60a40ebc",
"metadata": {
"origin_pos": 26
},
"source": [
"现在来观察注意力的权重。\n",
"这里测试数据的输入相当于查询,而训练数据的输入相当于键。\n",
"因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近,\n",
"注意力汇聚的[**注意力权重**]就越高。\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5068c7e7",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:17.141869Z",
"iopub.status.busy": "2023-08-18T07:07:17.141398Z",
"iopub.status.idle": "2023-08-18T07:07:17.325925Z",
"shell.execute_reply": "2023-08-18T07:07:17.325070Z"
},
"origin_pos": 28,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),\n",
" xlabel='Sorted training inputs',\n",
" ylabel='Sorted testing inputs')"
]
},
{
"cell_type": "markdown",
"id": "73a72797",
"metadata": {
"origin_pos": 31
},
"source": [
"## [**带参数注意力汇聚**]\n",
"\n",
"非参数的Nadaraya-Watson核回归具有*一致性*(consistency)的优点:\n",
"如果有足够的数据,此模型会收敛到最优结果。\n",
"尽管如此,我们还是可以轻松地将可学习的参数集成到注意力汇聚中。\n",
"\n",
"例如,与 :eqref:`eq_nadaraya-watson-gaussian`略有不同,\n",
"在下面的查询$x$和键$x_i$之间的距离乘以可学习参数$w$:\n",
"\n",
"$$\\begin{aligned}f(x) &= \\sum_{i=1}^n \\alpha(x, x_i) y_i \\\\&= \\sum_{i=1}^n \\frac{\\exp\\left(-\\frac{1}{2}((x - x_i)w)^2\\right)}{\\sum_{j=1}^n \\exp\\left(-\\frac{1}{2}((x - x_j)w)^2\\right)} y_i \\\\&= \\sum_{i=1}^n \\mathrm{softmax}\\left(-\\frac{1}{2}((x - x_i)w)^2\\right) y_i.\\end{aligned}$$\n",
":eqlabel:`eq_nadaraya-watson-gaussian-para`\n",
"\n",
"本节的余下部分将通过训练这个模型\n",
" :eqref:`eq_nadaraya-watson-gaussian-para`来学习注意力汇聚的参数。\n",
"\n",
"### 批量矩阵乘法\n",
"\n",
":label:`subsec_batch_dot`\n",
"\n",
"为了更有效地计算小批量数据的注意力,\n",
"我们可以利用深度学习开发框架中提供的批量矩阵乘法。\n",
"\n",
"假设第一个小批量数据包含$n$个矩阵$\\mathbf{X}_1,\\ldots, \\mathbf{X}_n$,\n",
"形状为$a\\times b$,\n",
"第二个小批量包含$n$个矩阵$\\mathbf{Y}_1, \\ldots, \\mathbf{Y}_n$,\n",
"形状为$b\\times c$。\n",
"它们的批量矩阵乘法得到$n$个矩阵\n",
"$\\mathbf{X}_1\\mathbf{Y}_1, \\ldots, \\mathbf{X}_n\\mathbf{Y}_n$,\n",
"形状为$a\\times c$。\n",
"因此,[**假定两个张量的形状分别是$(n,a,b)$和$(n,b,c)$,\n",
"它们的批量矩阵乘法输出的形状为$(n,a,c)$**]。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c5a4dc9e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:17.329873Z",
"iopub.status.busy": "2023-08-18T07:07:17.329293Z",
"iopub.status.idle": "2023-08-18T07:07:17.336079Z",
"shell.execute_reply": "2023-08-18T07:07:17.335148Z"
},
"origin_pos": 33,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 1, 6])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.ones((2, 1, 4))\n",
"Y = torch.ones((2, 4, 6))\n",
"torch.bmm(X, Y).shape"
]
},
{
"cell_type": "markdown",
"id": "acac0804",
"metadata": {
"origin_pos": 36
},
"source": [
"在注意力机制的背景中,我们可以[**使用小批量矩阵乘法来计算小批量数据中的加权平均值**]。\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7161cf35",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:17.341257Z",
"iopub.status.busy": "2023-08-18T07:07:17.340341Z",
"iopub.status.idle": "2023-08-18T07:07:17.348773Z",
"shell.execute_reply": "2023-08-18T07:07:17.347847Z"
},
"origin_pos": 38,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 4.5000]],\n",
"\n",
" [[14.5000]]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"weights = torch.ones((2, 10)) * 0.1\n",
"values = torch.arange(20.0).reshape((2, 10))\n",
"torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))"
]
},
{
"cell_type": "markdown",
"id": "f31cb449",
"metadata": {
"origin_pos": 41
},
"source": [
"### 定义模型\n",
"\n",
"基于 :eqref:`eq_nadaraya-watson-gaussian-para`中的\n",
"[**带参数的注意力汇聚**],使用小批量矩阵乘法,\n",
"定义Nadaraya-Watson核回归的带参数版本为:\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e7aee504",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:17.353382Z",
"iopub.status.busy": "2023-08-18T07:07:17.353094Z",
"iopub.status.idle": "2023-08-18T07:07:17.359677Z",
"shell.execute_reply": "2023-08-18T07:07:17.358720Z"
},
"origin_pos": 43,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class NWKernelRegression(nn.Module):\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.w = nn.Parameter(torch.rand((1,), requires_grad=True))\n",
"\n",
" def forward(self, queries, keys, values):\n",
" # queries和attention_weights的形状为(查询个数,“键-值”对个数)\n",
" queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))\n",
" self.attention_weights = nn.functional.softmax(\n",
" -((queries - keys) * self.w)**2 / 2, dim=1)\n",
" # values的形状为(查询个数,“键-值”对个数)\n",
" return torch.bmm(self.attention_weights.unsqueeze(1),\n",
" values.unsqueeze(-1)).reshape(-1)"
]
},
{
"cell_type": "markdown",
"id": "192a922f",
"metadata": {
"origin_pos": 46
},
"source": [
"### 训练\n",
"\n",
"接下来,[**将训练数据集变换为键和值**]用于训练注意力模型。\n",
"在带参数的注意力汇聚模型中,\n",
"任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算,\n",
"从而得到其对应的预测输出。\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c738e178",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:17.363589Z",
"iopub.status.busy": "2023-08-18T07:07:17.362749Z",
"iopub.status.idle": "2023-08-18T07:07:17.369359Z",
"shell.execute_reply": "2023-08-18T07:07:17.368357Z"
},
"origin_pos": 48,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入\n",
"X_tile = x_train.repeat((n_train, 1))\n",
"# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出\n",
"Y_tile = y_train.repeat((n_train, 1))\n",
"# keys的形状:('n_train','n_train'-1)\n",
"keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))\n",
"# values的形状:('n_train','n_train'-1)\n",
"values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))"
]
},
{
"cell_type": "markdown",
"id": "245d8e27",
"metadata": {
"origin_pos": 51
},
"source": [
"[**训练带参数的注意力汇聚模型**]时,使用平方损失函数和随机梯度下降。\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "37b732ff",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:17.373285Z",
"iopub.status.busy": "2023-08-18T07:07:17.372386Z",
"iopub.status.idle": "2023-08-18T07:07:18.050395Z",
"shell.execute_reply": "2023-08-18T07:07:18.049441Z"
},
"origin_pos": 53,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"net = NWKernelRegression()\n",
"loss = nn.MSELoss(reduction='none')\n",
"trainer = torch.optim.SGD(net.parameters(), lr=0.5)\n",
"animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])\n",
"\n",
"for epoch in range(5):\n",
" trainer.zero_grad()\n",
" l = loss(net(x_train, keys, values), y_train)\n",
" l.sum().backward()\n",
" trainer.step()\n",
" print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')\n",
" animator.add(epoch + 1, float(l.sum()))"
]
},
{
"cell_type": "markdown",
"id": "bcd73f8a",
"metadata": {
"origin_pos": 56
},
"source": [
"如下所示,训练完带参数的注意力汇聚模型后可以发现:\n",
"在尝试拟合带噪声的训练数据时,\n",
"[**预测结果绘制**]的线不如之前非参数模型的平滑。\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "79e57e0d",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:18.054215Z",
"iopub.status.busy": "2023-08-18T07:07:18.053635Z",
"iopub.status.idle": "2023-08-18T07:07:18.224763Z",
"shell.execute_reply": "2023-08-18T07:07:18.223767Z"
},
"origin_pos": 58,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)\n",
"keys = x_train.repeat((n_test, 1))\n",
"# value的形状:(n_test,n_train)\n",
"values = y_train.repeat((n_test, 1))\n",
"y_hat = net(x_test, keys, values).unsqueeze(1).detach()\n",
"plot_kernel_reg(y_hat)"
]
},
{
"cell_type": "markdown",
"id": "962bf3fe",
"metadata": {
"origin_pos": 61
},
"source": [
"为什么新的模型更不平滑了呢?\n",
"下面看一下输出结果的绘制图:\n",
"与非参数的注意力汇聚模型相比,\n",
"带参数的模型加入可学习的参数后,\n",
"[**曲线在注意力权重较大的区域变得更不平滑**]。\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "15949408",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:18.228629Z",
"iopub.status.busy": "2023-08-18T07:07:18.227958Z",
"iopub.status.idle": "2023-08-18T07:07:18.391875Z",
"shell.execute_reply": "2023-08-18T07:07:18.391062Z"
},
"origin_pos": 63,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),\n",
" xlabel='Sorted training inputs',\n",
" ylabel='Sorted testing inputs')"
]
},
{
"cell_type": "markdown",
"id": "0e4f7d57",
"metadata": {
"origin_pos": 66
},
"source": [
"## 小结\n",
"\n",
"* Nadaraya-Watson核回归是具有注意力机制的机器学习范例。\n",
"* Nadaraya-Watson核回归的注意力汇聚是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于将值所对应的键和查询作为输入的函数。\n",
"* 注意力汇聚可以分为非参数型和带参数型。\n",
"\n",
"## 练习\n",
"\n",
"1. 增加训练数据的样本数量,能否得到更好的非参数的Nadaraya-Watson核回归模型?\n",
"1. 在带参数的注意力汇聚的实验中学习得到的参数$w$的价值是什么?为什么在可视化注意力权重时,它会使加权区域更加尖锐?\n",
"1. 如何将超参数添加到非参数的Nadaraya-Watson核回归中以实现更好地预测结果?\n",
"1. 为本节的核回归设计一个新的带参数的注意力汇聚模型。训练这个新模型并可视化其注意力权重。\n"
]
},
{
"cell_type": "markdown",
"id": "72ded07b",
"metadata": {
"origin_pos": 68,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/5760)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}