Files
2025-12-16 09:23:53 +08:00

967 lines
37 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "9122670f",
"metadata": {
"origin_pos": 0
},
"source": [
"# 注意力提示\n",
":label:`sec_attention-cues`\n",
"\n",
"感谢读者对本书的关注,因为读者的注意力是一种稀缺的资源:\n",
"此刻读者正在阅读本书(而忽略了其他的书),\n",
"因此读者的注意力是用机会成本(与金钱类似)来支付的。\n",
"为了确保读者现在投入的注意力是值得的,\n",
"作者们尽全力(全部的注意力)创作一本好书。\n",
"\n",
"自经济学研究稀缺资源分配以来,人们正处在“注意力经济”时代,\n",
"即人类的注意力被视为可以交换的、有限的、有价值的且稀缺的商品。\n",
"许多商业模式也被开发出来去利用这一点:\n",
"在音乐或视频流媒体服务上,人们要么消耗注意力在广告上,要么付钱来隐藏广告;\n",
"为了在网络游戏世界的成长,人们要么消耗注意力在游戏战斗中,\n",
"从而帮助吸引新的玩家,要么付钱立即变得强大。\n",
"总之,注意力不是免费的。\n",
"\n",
"注意力是稀缺的,而环境中的干扰注意力的信息却并不少。\n",
"比如人类的视觉神经系统大约每秒收到$10^8$位的信息,\n",
"这远远超过了大脑能够完全处理的水平。\n",
"幸运的是,人类的祖先已经从经验(也称为数据)中认识到\n",
"“并非感官的所有输入都是一样的”。\n",
"在整个人类历史中,这种只将注意力引向感兴趣的一小部分信息的能力,\n",
"使人类的大脑能够更明智地分配资源来生存、成长和社交,\n",
"例如发现天敌、找寻食物和伴侣。\n",
"\n",
"## 生物学中的注意力提示\n",
"\n",
"注意力是如何应用于视觉世界中的呢?\n",
"这要从当今十分普及的*双组件*(two-component)的框架开始讲起:\n",
"这个框架的出现可以追溯到19世纪90年代的威廉·詹姆斯,\n",
"他被认为是“美国心理学之父” :cite:`James.2007`。\n",
"在这个框架中,受试者基于*非自主性提示*和*自主性提示*\n",
"有选择地引导注意力的焦点。\n",
"\n",
"非自主性提示是基于环境中物体的突出性和易见性。\n",
"想象一下,假如我们面前有五个物品:\n",
"一份报纸、一篇研究论文、一杯咖啡、一本笔记本和一本书,\n",
"就像 :numref:`fig_eye-coffee`。\n",
"所有纸制品都是黑白印刷的,但咖啡杯是红色的。\n",
"换句话说,这个咖啡杯在这种视觉环境中是突出和显眼的,\n",
"不由自主地引起人们的注意。\n",
"所以我们会把视力最敏锐的地方放到咖啡上,\n",
"如 :numref:`fig_eye-coffee`所示。\n",
"\n",
"![由于突出性的非自主性提示(红杯子),注意力不自主地指向了咖啡杯](../img/eye-coffee.svg)\n",
":width:`400px`\n",
":label:`fig_eye-coffee`\n",
"\n",
"喝咖啡后,我们会变得兴奋并想读书,\n",
"所以转过头,重新聚焦眼睛,然后看看书,\n",
"就像 :numref:`fig_eye-book`中描述那样。\n",
"与 :numref:`fig_eye-coffee`中由于突出性导致的选择不同,\n",
"此时选择书是受到了认知和意识的控制,\n",
"因此注意力在基于自主性提示去辅助选择时将更为谨慎。\n",
"受试者的主观意愿推动,选择的力量也就更强大。\n",
"\n",
"![依赖于任务的意志提示(想读一本书),注意力被自主引导到书上](../img/eye-book.svg)\n",
":width:`400px`\n",
":label:`fig_eye-book`\n",
"\n",
"## 查询、键和值\n",
"\n",
"自主性的与非自主性的注意力提示解释了人类的注意力的方式,\n",
"下面来看看如何通过这两种注意力提示,\n",
"用神经网络来设计注意力机制的框架,\n",
"\n",
"首先,考虑一个相对简单的状况,\n",
"即只使用非自主性提示。\n",
"要想将选择偏向于感官输入,\n",
"则可以简单地使用参数化的全连接层,\n",
"甚至是非参数化的最大汇聚层或平均汇聚层。\n",
"\n",
"因此,“是否包含自主性提示”将注意力机制与全连接层或汇聚层区别开来。\n",
"在注意力机制的背景下,自主性提示被称为*查询*(query)。\n",
"给定任何查询,注意力机制通过*注意力汇聚*(attention pooling\n",
"将选择引导至*感官输入*sensory inputs,例如中间特征表示)。\n",
"在注意力机制中,这些感官输入被称为*值*(value)。\n",
"更通俗的解释,每个值都与一个*键*(key)配对,\n",
"这可以想象为感官输入的非自主提示。\n",
"如 :numref:`fig_qkv`所示,可以通过设计注意力汇聚的方式,\n",
"便于给定的查询(自主性提示)与键(非自主性提示)进行匹配,\n",
"这将引导得出最匹配的值(感官输入)。\n",
"\n",
"![注意力机制通过注意力汇聚将*查询*(自主性提示)和*键*(非自主性提示)结合在一起,实现对*值*(感官输入)的选择倾向](../img/qkv.svg)\n",
":label:`fig_qkv`\n",
"\n",
"鉴于上面所提框架在 :numref:`fig_qkv`中的主导地位,\n",
"因此这个框架下的模型将成为本章的中心。\n",
"然而,注意力机制的设计有许多替代方案。\n",
"例如可以设计一个不可微的注意力模型,\n",
"该模型可以使用强化学习方法 :cite:`Mnih.Heess.Graves.ea.2014`进行训练。\n",
"\n",
"## 注意力的可视化\n",
"\n",
"平均汇聚层可以被视为输入的加权平均值,\n",
"其中各输入的权重是一样的。\n",
"实际上,注意力汇聚得到的是加权平均的总和值,\n",
"其中权重是在给定的查询和不同的键之间计算得出的。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "58a7898f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:03:03.521946Z",
"iopub.status.busy": "2023-08-18T07:03:03.521507Z",
"iopub.status.idle": "2023-08-18T07:03:05.621623Z",
"shell.execute_reply": "2023-08-18T07:03:05.620583Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "6e9e92fc",
"metadata": {
"origin_pos": 5
},
"source": [
"为了可视化注意力权重,需要定义一个`show_heatmaps`函数。\n",
"其输入`matrices`的形状是\n",
"(要显示的行数,要显示的列数,查询的数目,键的数目)。\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3c30d535",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:03:05.627152Z",
"iopub.status.busy": "2023-08-18T07:03:05.626530Z",
"iopub.status.idle": "2023-08-18T07:03:05.634951Z",
"shell.execute_reply": "2023-08-18T07:03:05.633763Z"
},
"origin_pos": 6,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"#@save\n",
"def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),\n",
" cmap='Reds'):\n",
" \"\"\"显示矩阵热图\"\"\"\n",
" d2l.use_svg_display()\n",
" num_rows, num_cols = matrices.shape[0], matrices.shape[1]\n",
" fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,\n",
" sharex=True, sharey=True, squeeze=False)\n",
" for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):\n",
" for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):\n",
" pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)\n",
" if i == num_rows - 1:\n",
" ax.set_xlabel(xlabel)\n",
" if j == 0:\n",
" ax.set_ylabel(ylabel)\n",
" if titles:\n",
" ax.set_title(titles[j])\n",
" fig.colorbar(pcm, ax=axes, shrink=0.6);"
]
},
{
"cell_type": "markdown",
"id": "f48978d9",
"metadata": {
"origin_pos": 7
},
"source": [
"下面使用一个简单的例子进行演示。\n",
"在本例子中,仅当查询和键相同时,注意力权重为1,否则为0。\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bbabe8f3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:03:05.640096Z",
"iopub.status.busy": "2023-08-18T07:03:05.639355Z",
"iopub.status.idle": "2023-08-18T07:03:05.886353Z",
"shell.execute_reply": "2023-08-18T07:03:05.885235Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"193.35825pt\" height=\"156.35625pt\" viewBox=\"0 0 193.35825 156.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
" <metadata>\n",
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2023-08-18T07:03:05.823629</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
" <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
" </cc:Agent>\n",
" </dc:creator>\n",
" </cc:Work>\n",
" </rdf:RDF>\n",
" </metadata>\n",
" <defs>\n",
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M -0 156.35625 \n",
"L 193.35825 156.35625 \n",
"L 193.35825 0 \n",
"L -0 0 \n",
"L -0 156.35625 \n",
"z\n",
"\" style=\"fill: none\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 34.240625 118.8 \n",
"L 145.840625 118.8 \n",
"L 145.840625 7.2 \n",
"L 34.240625 7.2 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g clip-path=\"url(#p3a36c213e6)\">\n",
" <image xlink:href=\"data:image/png;base64,\n",
"iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAYAAADG4PRLAAABc0lEQVR4nO3dsW3CUBhGURPRJSNkPxgpA2YAGmqnpfRDMnnXnFO7sHT1N58s+bTeb+vCVK6f35uf/djxPXgBAeMEjBMwTsA4AeMEjBMwTsA4AePO//0C72JkHvu5/25+1gXGCRgnYJyAcQLGCRgnYJyAcQLGCRhnSnvSyDS2LGPz2AgXGCdgnIBxAsYJGCdgnIBxAsYJGCdgnIBxttAHe336tycXGCdgnIBxAsYJGCdgnIBxAsYJGCdg3OGntOI8NsIFxgkYJ2CcgHECxgkYJ2CcgHECxgkYl5zSjj6PjXCBcQLGCRgnYJyAcQLGCRgnYJyAcQLGTTOlmcee4wLjBIwTME7AOAHjBIwTME7AOAHjBIzbbUqb5b8KR+cC4wSMEzBOwDgB4wSMEzBOwDgB4wSMEzBuaAv16d98XGCcgHECxgkYJ2CcgHECxgkYJ2CcgHGny/K1bn3YPDYfFxgnYJyAcQLGCRgnYJyAcQLGCRgnYNwfkEghRAiKZdAAAAAASUVORK5CYII=\" id=\"image8a96b358df\" transform=\"scale(1 -1)translate(0 -112)\" x=\"34.240625\" y=\"-6.8\" width=\"112\" height=\"112\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path id=\"mdcf2e5a8c4\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#mdcf2e5a8c4\" x=\"39.820625\" y=\"118.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(36.639375 133.398438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
"Q 1547 4250 1301 3770 \n",
"Q 1056 3291 1056 2328 \n",
"Q 1056 1369 1301 889 \n",
"Q 1547 409 2034 409 \n",
"Q 2525 409 2770 889 \n",
"Q 3016 1369 3016 2328 \n",
"Q 3016 3291 2770 3770 \n",
"Q 2525 4250 2034 4250 \n",
"z\n",
"M 2034 4750 \n",
"Q 2819 4750 3233 4129 \n",
"Q 3647 3509 3647 2328 \n",
"Q 3647 1150 3233 529 \n",
"Q 2819 -91 2034 -91 \n",
"Q 1250 -91 836 529 \n",
"Q 422 1150 422 2328 \n",
"Q 422 3509 836 4129 \n",
"Q 1250 4750 2034 4750 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use xlink:href=\"#mdcf2e5a8c4\" x=\"95.620625\" y=\"118.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- 5 -->\n",
" <g transform=\"translate(92.439375 133.398438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
"L 3169 4666 \n",
"L 3169 4134 \n",
"L 1269 4134 \n",
"L 1269 2991 \n",
"Q 1406 3038 1543 3061 \n",
"Q 1681 3084 1819 3084 \n",
"Q 2600 3084 3056 2656 \n",
"Q 3513 2228 3513 1497 \n",
"Q 3513 744 3044 326 \n",
"Q 2575 -91 1722 -91 \n",
"Q 1428 -91 1123 -41 \n",
"Q 819 9 494 109 \n",
"L 494 744 \n",
"Q 775 591 1075 516 \n",
"Q 1375 441 1709 441 \n",
"Q 2250 441 2565 725 \n",
"Q 2881 1009 2881 1497 \n",
"Q 2881 1984 2565 2268 \n",
"Q 2250 2553 1709 2553 \n",
"Q 1456 2553 1204 2497 \n",
"Q 953 2441 691 2322 \n",
"L 691 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- Keys -->\n",
" <g transform=\"translate(78.371094 147.076563)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-4b\" d=\"M 628 4666 \n",
"L 1259 4666 \n",
"L 1259 2694 \n",
"L 3353 4666 \n",
"L 4166 4666 \n",
"L 1850 2491 \n",
"L 4331 0 \n",
"L 3500 0 \n",
"L 1259 2247 \n",
"L 1259 0 \n",
"L 628 0 \n",
"L 628 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
"L 3597 1613 \n",
"L 953 1613 \n",
"Q 991 1019 1311 708 \n",
"Q 1631 397 2203 397 \n",
"Q 2534 397 2845 478 \n",
"Q 3156 559 3463 722 \n",
"L 3463 178 \n",
"Q 3153 47 2828 -22 \n",
"Q 2503 -91 2169 -91 \n",
"Q 1331 -91 842 396 \n",
"Q 353 884 353 1716 \n",
"Q 353 2575 817 3079 \n",
"Q 1281 3584 2069 3584 \n",
"Q 2775 3584 3186 3129 \n",
"Q 3597 2675 3597 1894 \n",
"z\n",
"M 3022 2063 \n",
"Q 3016 2534 2758 2815 \n",
"Q 2500 3097 2075 3097 \n",
"Q 1594 3097 1305 2825 \n",
"Q 1016 2553 972 2059 \n",
"L 3022 2063 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-79\" d=\"M 2059 -325 \n",
"Q 1816 -950 1584 -1140 \n",
"Q 1353 -1331 966 -1331 \n",
"L 506 -1331 \n",
"L 506 -850 \n",
"L 844 -850 \n",
"Q 1081 -850 1212 -737 \n",
"Q 1344 -625 1503 -206 \n",
"L 1606 56 \n",
"L 191 3500 \n",
"L 800 3500 \n",
"L 1894 763 \n",
"L 2988 3500 \n",
"L 3597 3500 \n",
"L 2059 -325 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
"L 2834 2853 \n",
"Q 2591 2978 2328 3040 \n",
"Q 2066 3103 1784 3103 \n",
"Q 1356 3103 1142 2972 \n",
"Q 928 2841 928 2578 \n",
"Q 928 2378 1081 2264 \n",
"Q 1234 2150 1697 2047 \n",
"L 1894 2003 \n",
"Q 2506 1872 2764 1633 \n",
"Q 3022 1394 3022 966 \n",
"Q 3022 478 2636 193 \n",
"Q 2250 -91 1575 -91 \n",
"Q 1294 -91 989 -36 \n",
"Q 684 19 347 128 \n",
"L 347 722 \n",
"Q 666 556 975 473 \n",
"Q 1284 391 1588 391 \n",
"Q 1994 391 2212 530 \n",
"Q 2431 669 2431 922 \n",
"Q 2431 1156 2273 1281 \n",
"Q 2116 1406 1581 1522 \n",
"L 1381 1569 \n",
"Q 847 1681 609 1914 \n",
"Q 372 2147 372 2553 \n",
"Q 372 3047 722 3315 \n",
"Q 1072 3584 1716 3584 \n",
"Q 2034 3584 2315 3537 \n",
"Q 2597 3491 2834 3397 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-4b\"/>\n",
" <use xlink:href=\"#DejaVuSans-65\" x=\"60.576172\"/>\n",
" <use xlink:href=\"#DejaVuSans-79\" x=\"122.099609\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"181.279297\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_3\">\n",
" <defs>\n",
" <path id=\"m473666877a\" d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m473666877a\" x=\"34.240625\" y=\"12.78\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(20.878125 16.579219)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use xlink:href=\"#m473666877a\" x=\"34.240625\" y=\"35.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(20.878125 38.899219)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
"L 3431 531 \n",
"L 3431 0 \n",
"L 469 0 \n",
"L 469 531 \n",
"Q 828 903 1448 1529 \n",
"Q 2069 2156 2228 2338 \n",
"Q 2531 2678 2651 2914 \n",
"Q 2772 3150 2772 3378 \n",
"Q 2772 3750 2511 3984 \n",
"Q 2250 4219 1831 4219 \n",
"Q 1534 4219 1204 4116 \n",
"Q 875 4013 500 3803 \n",
"L 500 4441 \n",
"Q 881 4594 1212 4672 \n",
"Q 1544 4750 1819 4750 \n",
"Q 2544 4750 2975 4387 \n",
"Q 3406 4025 3406 3419 \n",
"Q 3406 3131 3298 2873 \n",
"Q 3191 2616 2906 2266 \n",
"Q 2828 2175 2409 1742 \n",
"Q 1991 1309 1228 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_5\">\n",
" <g>\n",
" <use xlink:href=\"#m473666877a\" x=\"34.240625\" y=\"57.42\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(20.878125 61.219219)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
"L 825 1625 \n",
"L 2419 1625 \n",
"L 2419 4116 \n",
"z\n",
"M 2253 4666 \n",
"L 3047 4666 \n",
"L 3047 1625 \n",
"L 3713 1625 \n",
"L 3713 1100 \n",
"L 3047 1100 \n",
"L 3047 0 \n",
"L 2419 0 \n",
"L 2419 1100 \n",
"L 313 1100 \n",
"L 313 1709 \n",
"L 2253 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_6\">\n",
" <g>\n",
" <use xlink:href=\"#m473666877a\" x=\"34.240625\" y=\"79.74\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 6 -->\n",
" <g transform=\"translate(20.878125 83.539219)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \n",
"Q 1688 2584 1439 2293 \n",
"Q 1191 2003 1191 1497 \n",
"Q 1191 994 1439 701 \n",
"Q 1688 409 2113 409 \n",
"Q 2538 409 2786 701 \n",
"Q 3034 994 3034 1497 \n",
"Q 3034 2003 2786 2293 \n",
"Q 2538 2584 2113 2584 \n",
"z\n",
"M 3366 4563 \n",
"L 3366 3988 \n",
"Q 3128 4100 2886 4159 \n",
"Q 2644 4219 2406 4219 \n",
"Q 1781 4219 1451 3797 \n",
"Q 1122 3375 1075 2522 \n",
"Q 1259 2794 1537 2939 \n",
"Q 1816 3084 2150 3084 \n",
"Q 2853 3084 3261 2657 \n",
"Q 3669 2231 3669 1497 \n",
"Q 3669 778 3244 343 \n",
"Q 2819 -91 2113 -91 \n",
"Q 1303 -91 875 529 \n",
"Q 447 1150 447 2328 \n",
"Q 447 3434 972 4092 \n",
"Q 1497 4750 2381 4750 \n",
"Q 2619 4750 2861 4703 \n",
"Q 3103 4656 3366 4563 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-36\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use xlink:href=\"#m473666877a\" x=\"34.240625\" y=\"102.06\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 8 -->\n",
" <g transform=\"translate(20.878125 105.859219)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \n",
"Q 1584 2216 1326 1975 \n",
"Q 1069 1734 1069 1313 \n",
"Q 1069 891 1326 650 \n",
"Q 1584 409 2034 409 \n",
"Q 2484 409 2743 651 \n",
"Q 3003 894 3003 1313 \n",
"Q 3003 1734 2745 1975 \n",
"Q 2488 2216 2034 2216 \n",
"z\n",
"M 1403 2484 \n",
"Q 997 2584 770 2862 \n",
"Q 544 3141 544 3541 \n",
"Q 544 4100 942 4425 \n",
"Q 1341 4750 2034 4750 \n",
"Q 2731 4750 3128 4425 \n",
"Q 3525 4100 3525 3541 \n",
"Q 3525 3141 3298 2862 \n",
"Q 3072 2584 2669 2484 \n",
"Q 3125 2378 3379 2068 \n",
"Q 3634 1759 3634 1313 \n",
"Q 3634 634 3220 271 \n",
"Q 2806 -91 2034 -91 \n",
"Q 1263 -91 848 271 \n",
"Q 434 634 434 1313 \n",
"Q 434 1759 690 2068 \n",
"Q 947 2378 1403 2484 \n",
"z\n",
"M 1172 3481 \n",
"Q 1172 3119 1398 2916 \n",
"Q 1625 2713 2034 2713 \n",
"Q 2441 2713 2670 2916 \n",
"Q 2900 3119 2900 3481 \n",
"Q 2900 3844 2670 4047 \n",
"Q 2441 4250 2034 4250 \n",
"Q 1625 4250 1398 4047 \n",
"Q 1172 3844 1172 3481 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-38\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- Queries -->\n",
" <g transform=\"translate(14.798437 82.307031)rotate(-90)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-51\" d=\"M 2522 4238 \n",
"Q 1834 4238 1429 3725 \n",
"Q 1025 3213 1025 2328 \n",
"Q 1025 1447 1429 934 \n",
"Q 1834 422 2522 422 \n",
"Q 3209 422 3611 934 \n",
"Q 4013 1447 4013 2328 \n",
"Q 4013 3213 3611 3725 \n",
"Q 3209 4238 2522 4238 \n",
"z\n",
"M 3406 84 \n",
"L 4238 -825 \n",
"L 3475 -825 \n",
"L 2784 -78 \n",
"Q 2681 -84 2626 -87 \n",
"Q 2572 -91 2522 -91 \n",
"Q 1538 -91 948 567 \n",
"Q 359 1225 359 2328 \n",
"Q 359 3434 948 4092 \n",
"Q 1538 4750 2522 4750 \n",
"Q 3503 4750 4090 4092 \n",
"Q 4678 3434 4678 2328 \n",
"Q 4678 1516 4351 937 \n",
"Q 4025 359 3406 84 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-75\" d=\"M 544 1381 \n",
"L 544 3500 \n",
"L 1119 3500 \n",
"L 1119 1403 \n",
"Q 1119 906 1312 657 \n",
"Q 1506 409 1894 409 \n",
"Q 2359 409 2629 706 \n",
"Q 2900 1003 2900 1516 \n",
"L 2900 3500 \n",
"L 3475 3500 \n",
"L 3475 0 \n",
"L 2900 0 \n",
"L 2900 538 \n",
"Q 2691 219 2414 64 \n",
"Q 2138 -91 1772 -91 \n",
"Q 1169 -91 856 284 \n",
"Q 544 659 544 1381 \n",
"z\n",
"M 1991 3584 \n",
"L 1991 3584 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
"Q 2534 3019 2420 3045 \n",
"Q 2306 3072 2169 3072 \n",
"Q 1681 3072 1420 2755 \n",
"Q 1159 2438 1159 1844 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2956 \n",
"Q 1341 3275 1631 3429 \n",
"Q 1922 3584 2338 3584 \n",
"Q 2397 3584 2469 3576 \n",
"Q 2541 3569 2628 3553 \n",
"L 2631 2963 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
"L 1178 3500 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 3500 \n",
"z\n",
"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 4134 \n",
"L 603 4134 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-51\"/>\n",
" <use xlink:href=\"#DejaVuSans-75\" x=\"78.710938\"/>\n",
" <use xlink:href=\"#DejaVuSans-65\" x=\"142.089844\"/>\n",
" <use xlink:href=\"#DejaVuSans-72\" x=\"203.613281\"/>\n",
" <use xlink:href=\"#DejaVuSans-69\" x=\"244.726562\"/>\n",
" <use xlink:href=\"#DejaVuSans-65\" x=\"272.509766\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"334.033203\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 34.240625 118.8 \n",
"L 34.240625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 145.840625 118.8 \n",
"L 145.840625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 34.240625 118.8 \n",
"L 145.840625 118.8 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 34.240625 7.2 \n",
"L 145.840625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"axes_2\">\n",
" <g id=\"patch_7\">\n",
" <path d=\"M 152.815625 103.77 \n",
"L 156.892625 103.77 \n",
"L 156.892625 22.23 \n",
"L 152.815625 22.23 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"patch_8\">\n",
" <path clip-path=\"url(#p81ac4143b7)\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.01; stroke-linejoin: miter\"/>\n",
" </g>\n",
" <image xlink:href=\"data:image/png;base64,\n",
"iVBORw0KGgoAAAANSUhEUgAAAAQAAABRCAYAAAD1sgc6AAAAnklEQVR4nJ2Suw7CQAwEjZT//1QqCnR+0XKzJxmSLqPx7krJo1/Ptq/nsgzbQRVBDqBH4wCkhTtkWDhqR8OSIHjii6H/Z2jtaLD2EDpm3DCkNvldftgB0GJIqM/TBfAfE+AAJcaSpZJRg1Fs0QyCjN5BRA0gkycEmTyRFhox7simsb/blSbGDWCDcdgB4MxwGu8CWAJ4shgqJ9IiBms/jwXJt9gA8G4AAAAASUVORK5CYII=\" id=\"image52587ed27e\" transform=\"scale(1 -1)translate(0 -81)\" x=\"153\" y=\"-22\" width=\"4\" height=\"81\"/>\n",
" <g id=\"matplotlib.axis_3\">\n",
" <g id=\"ytick_6\">\n",
" <g id=\"line2d_8\">\n",
" <defs>\n",
" <path id=\"m4194a474b4\" d=\"M 0 0 \n",
"L 3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m4194a474b4\" x=\"156.892625\" y=\"103.77\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 0.00 -->\n",
" <g transform=\"translate(163.892625 107.569219)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
"L 1344 794 \n",
"L 1344 0 \n",
"L 684 0 \n",
"L 684 794 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"159.033203\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_7\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use xlink:href=\"#m4194a474b4\" x=\"156.892625\" y=\"83.385\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 0.25 -->\n",
" <g transform=\"translate(163.892625 87.184219)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
" <use xlink:href=\"#DejaVuSans-32\" x=\"95.410156\"/>\n",
" <use xlink:href=\"#DejaVuSans-35\" x=\"159.033203\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_8\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use xlink:href=\"#m4194a474b4\" x=\"156.892625\" y=\"63\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- 0.50 -->\n",
" <g transform=\"translate(163.892625 66.799219)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
" <use xlink:href=\"#DejaVuSans-35\" x=\"95.410156\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"159.033203\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_9\">\n",
" <g id=\"line2d_11\">\n",
" <g>\n",
" <use xlink:href=\"#m4194a474b4\" x=\"156.892625\" y=\"42.615\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_13\">\n",
" <!-- 0.75 -->\n",
" <g transform=\"translate(163.892625 46.414219)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-37\" d=\"M 525 4666 \n",
"L 3525 4666 \n",
"L 3525 4397 \n",
"L 1831 0 \n",
"L 1172 0 \n",
"L 2766 4134 \n",
"L 525 4134 \n",
"L 525 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
" <use xlink:href=\"#DejaVuSans-37\" x=\"95.410156\"/>\n",
" <use xlink:href=\"#DejaVuSans-35\" x=\"159.033203\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_10\">\n",
" <g id=\"line2d_12\">\n",
" <g>\n",
" <use xlink:href=\"#m4194a474b4\" x=\"156.892625\" y=\"22.23\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_14\">\n",
" <!-- 1.00 -->\n",
" <g transform=\"translate(163.892625 26.029219)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
"L 1825 531 \n",
"L 1825 4091 \n",
"L 703 3866 \n",
"L 703 4441 \n",
"L 1819 4666 \n",
"L 2450 4666 \n",
"L 2450 531 \n",
"L 3481 531 \n",
"L 3481 0 \n",
"L 794 0 \n",
"L 794 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"159.033203\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"LineCollection_1\"/>\n",
" <g id=\"patch_9\">\n",
" <path d=\"M 152.815625 103.77 \n",
"L 154.854125 103.77 \n",
"L 156.892625 103.77 \n",
"L 156.892625 22.23 \n",
"L 154.854125 22.23 \n",
"L 152.815625 22.23 \n",
"L 152.815625 103.77 \n",
"z\n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p3a36c213e6\">\n",
" <rect x=\"34.240625\" y=\"7.2\" width=\"111.6\" height=\"111.6\"/>\n",
" </clipPath>\n",
" <clipPath id=\"p81ac4143b7\">\n",
" <rect x=\"152.815625\" y=\"22.23\" width=\"4.077\" height=\"81.54\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<Figure size 180x180 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"attention_weights = torch.eye(10).reshape((1, 1, 10, 10))\n",
"show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')"
]
},
{
"cell_type": "markdown",
"id": "0f6c23cb",
"metadata": {
"origin_pos": 9
},
"source": [
"后面的章节内容将经常调用`show_heatmaps`函数来显示注意力权重。\n",
"\n",
"## 小结\n",
"\n",
"* 人类的注意力是有限的、有价值和稀缺的资源。\n",
"* 受试者使用非自主性和自主性提示有选择性地引导注意力。前者基于突出性,后者则依赖于意识。\n",
"* 注意力机制与全连接层或者汇聚层的区别源于增加的自主提示。\n",
"* 由于包含了自主性提示,注意力机制与全连接的层或汇聚层不同。\n",
"* 注意力机制通过注意力汇聚使选择偏向于值(感官输入),其中包含查询(自主性提示)和键(非自主性提示)。键和值是成对的。\n",
"* 可视化查询和键之间的注意力权重是可行的。\n",
"\n",
"## 练习\n",
"\n",
"1. 在机器翻译中通过解码序列词元时,其自主性提示可能是什么?非自主性提示和感官输入又是什么?\n",
"1. 随机生成一个$10 \\times 10$矩阵并使用`softmax`运算来确保每行都是有效的概率分布,然后可视化输出注意力权重。\n"
]
},
{
"cell_type": "markdown",
"id": "675bab48",
"metadata": {
"origin_pos": 11,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/5764)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}