349 lines
12 KiB
Plaintext
349 lines
12 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "36341c1f",
|
||
"metadata": {
|
||
"origin_pos": 0
|
||
},
|
||
"source": [
|
||
"# 多头注意力\n",
|
||
":label:`sec_multihead-attention`\n",
|
||
"\n",
|
||
"在实践中,当给定相同的查询、键和值的集合时,\n",
|
||
"我们希望模型可以基于相同的注意力机制学习到不同的行为,\n",
|
||
"然后将不同的行为作为知识组合起来,\n",
|
||
"捕获序列内各种范围的依赖关系\n",
|
||
"(例如,短距离依赖和长距离依赖关系)。\n",
|
||
"因此,允许注意力机制组合使用查询、键和值的不同\n",
|
||
"*子空间表示*(representation subspaces)可能是有益的。\n",
|
||
"\n",
|
||
"为此,与其只使用单独一个注意力汇聚,\n",
|
||
"我们可以用独立学习得到的$h$组不同的\n",
|
||
"*线性投影*(linear projections)来变换查询、键和值。\n",
|
||
"然后,这$h$组变换后的查询、键和值将并行地送到注意力汇聚中。\n",
|
||
"最后,将这$h$个注意力汇聚的输出拼接在一起,\n",
|
||
"并且通过另一个可以学习的线性投影进行变换,\n",
|
||
"以产生最终输出。\n",
|
||
"这种设计被称为*多头注意力*(multihead attention)\n",
|
||
" :cite:`Vaswani.Shazeer.Parmar.ea.2017`。\n",
|
||
"对于$h$个注意力汇聚输出,每一个注意力汇聚都被称作一个*头*(head)。\n",
|
||
" :numref:`fig_multi-head-attention`\n",
|
||
"展示了使用全连接层来实现可学习的线性变换的多头注意力。\n",
|
||
"\n",
|
||
"\n",
|
||
":label:`fig_multi-head-attention`\n",
|
||
"\n",
|
||
"## 模型\n",
|
||
"\n",
|
||
"在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。\n",
|
||
"给定查询$\\mathbf{q} \\in \\mathbb{R}^{d_q}$、\n",
|
||
"键$\\mathbf{k} \\in \\mathbb{R}^{d_k}$和\n",
|
||
"值$\\mathbf{v} \\in \\mathbb{R}^{d_v}$,\n",
|
||
"每个注意力头$\\mathbf{h}_i$($i = 1, \\ldots, h$)的计算方法为:\n",
|
||
"\n",
|
||
"$$\\mathbf{h}_i = f(\\mathbf W_i^{(q)}\\mathbf q, \\mathbf W_i^{(k)}\\mathbf k,\\mathbf W_i^{(v)}\\mathbf v) \\in \\mathbb R^{p_v},$$\n",
|
||
"\n",
|
||
"其中,可学习的参数包括\n",
|
||
"$\\mathbf W_i^{(q)}\\in\\mathbb R^{p_q\\times d_q}$、\n",
|
||
"$\\mathbf W_i^{(k)}\\in\\mathbb R^{p_k\\times d_k}$和\n",
|
||
"$\\mathbf W_i^{(v)}\\in\\mathbb R^{p_v\\times d_v}$,\n",
|
||
"以及代表注意力汇聚的函数$f$。\n",
|
||
"$f$可以是 :numref:`sec_attention-scoring-functions`中的\n",
|
||
"加性注意力和缩放点积注意力。\n",
|
||
"多头注意力的输出需要经过另一个线性转换,\n",
|
||
"它对应着$h$个头连结后的结果,因此其可学习参数是\n",
|
||
"$\\mathbf W_o\\in\\mathbb R^{p_o\\times h p_v}$:\n",
|
||
"\n",
|
||
"$$\\mathbf W_o \\begin{bmatrix}\\mathbf h_1\\\\\\vdots\\\\\\mathbf h_h\\end{bmatrix} \\in \\mathbb{R}^{p_o}.$$\n",
|
||
"\n",
|
||
"基于这种设计,每个头都可能会关注输入的不同部分,\n",
|
||
"可以表示比简单加权平均值更复杂的函数。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "dc55ba33",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:01:32.189972Z",
|
||
"iopub.status.busy": "2023-08-18T07:01:32.189240Z",
|
||
"iopub.status.idle": "2023-08-18T07:01:34.516491Z",
|
||
"shell.execute_reply": "2023-08-18T07:01:34.515475Z"
|
||
},
|
||
"origin_pos": 2,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import math\n",
|
||
"import torch\n",
|
||
"from torch import nn\n",
|
||
"from d2l import torch as d2l"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b51ca181",
|
||
"metadata": {
|
||
"origin_pos": 5
|
||
},
|
||
"source": [
|
||
"## 实现\n",
|
||
"\n",
|
||
"在实现过程中通常[**选择缩放点积注意力作为每一个注意力头**]。\n",
|
||
"为了避免计算代价和参数代价的大幅增长,\n",
|
||
"我们设定$p_q = p_k = p_v = p_o / h$。\n",
|
||
"值得注意的是,如果将查询、键和值的线性变换的输出数量设置为\n",
|
||
"$p_q h = p_k h = p_v h = p_o$,\n",
|
||
"则可以并行计算$h$个头。\n",
|
||
"在下面的实现中,$p_o$是通过参数`num_hiddens`指定的。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "1bb10990",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:01:34.521491Z",
|
||
"iopub.status.busy": "2023-08-18T07:01:34.521131Z",
|
||
"iopub.status.idle": "2023-08-18T07:01:34.530492Z",
|
||
"shell.execute_reply": "2023-08-18T07:01:34.529556Z"
|
||
},
|
||
"origin_pos": 7,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@save\n",
|
||
"class MultiHeadAttention(nn.Module):\n",
|
||
" \"\"\"多头注意力\"\"\"\n",
|
||
" def __init__(self, key_size, query_size, value_size, num_hiddens,\n",
|
||
" num_heads, dropout, bias=False, **kwargs):\n",
|
||
" super(MultiHeadAttention, self).__init__(**kwargs)\n",
|
||
" self.num_heads = num_heads\n",
|
||
" self.attention = d2l.DotProductAttention(dropout)\n",
|
||
" self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)\n",
|
||
" self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)\n",
|
||
" self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)\n",
|
||
" self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)\n",
|
||
"\n",
|
||
" def forward(self, queries, keys, values, valid_lens):\n",
|
||
" # queries,keys,values的形状:\n",
|
||
" # (batch_size,查询或者“键-值”对的个数,num_hiddens)\n",
|
||
" # valid_lens 的形状:\n",
|
||
" # (batch_size,)或(batch_size,查询的个数)\n",
|
||
" # 经过变换后,输出的queries,keys,values 的形状:\n",
|
||
" # (batch_size*num_heads,查询或者“键-值”对的个数,\n",
|
||
" # num_hiddens/num_heads)\n",
|
||
" queries = transpose_qkv(self.W_q(queries), self.num_heads)\n",
|
||
" keys = transpose_qkv(self.W_k(keys), self.num_heads)\n",
|
||
" values = transpose_qkv(self.W_v(values), self.num_heads)\n",
|
||
"\n",
|
||
" if valid_lens is not None:\n",
|
||
" # 在轴0,将第一项(标量或者矢量)复制num_heads次,\n",
|
||
" # 然后如此复制第二项,然后诸如此类。\n",
|
||
" valid_lens = torch.repeat_interleave(\n",
|
||
" valid_lens, repeats=self.num_heads, dim=0)\n",
|
||
"\n",
|
||
" # output的形状:(batch_size*num_heads,查询的个数,\n",
|
||
" # num_hiddens/num_heads)\n",
|
||
" output = self.attention(queries, keys, values, valid_lens)\n",
|
||
"\n",
|
||
" # output_concat的形状:(batch_size,查询的个数,num_hiddens)\n",
|
||
" output_concat = transpose_output(output, self.num_heads)\n",
|
||
" return self.W_o(output_concat)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9ab1c33b",
|
||
"metadata": {
|
||
"origin_pos": 10
|
||
},
|
||
"source": [
|
||
"为了能够[**使多个头并行计算**],\n",
|
||
"上面的`MultiHeadAttention`类将使用下面定义的两个转置函数。\n",
|
||
"具体来说,`transpose_output`函数反转了`transpose_qkv`函数的操作。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "b2af5ed8",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:01:34.534820Z",
|
||
"iopub.status.busy": "2023-08-18T07:01:34.534308Z",
|
||
"iopub.status.idle": "2023-08-18T07:01:34.540852Z",
|
||
"shell.execute_reply": "2023-08-18T07:01:34.539927Z"
|
||
},
|
||
"origin_pos": 12,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@save\n",
|
||
"def transpose_qkv(X, num_heads):\n",
|
||
" \"\"\"为了多注意力头的并行计算而变换形状\"\"\"\n",
|
||
" # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)\n",
|
||
" # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,\n",
|
||
" # num_hiddens/num_heads)\n",
|
||
" X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)\n",
|
||
"\n",
|
||
" # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,\n",
|
||
" # num_hiddens/num_heads)\n",
|
||
" X = X.permute(0, 2, 1, 3)\n",
|
||
"\n",
|
||
" # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,\n",
|
||
" # num_hiddens/num_heads)\n",
|
||
" return X.reshape(-1, X.shape[2], X.shape[3])\n",
|
||
"\n",
|
||
"\n",
|
||
"#@save\n",
|
||
"def transpose_output(X, num_heads):\n",
|
||
" \"\"\"逆转transpose_qkv函数的操作\"\"\"\n",
|
||
" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])\n",
|
||
" X = X.permute(0, 2, 1, 3)\n",
|
||
" return X.reshape(X.shape[0], X.shape[1], -1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0e31b376",
|
||
"metadata": {
|
||
"origin_pos": 15
|
||
},
|
||
"source": [
|
||
"下面使用键和值相同的小例子来[**测试**]我们编写的`MultiHeadAttention`类。\n",
|
||
"多头注意力输出的形状是(`batch_size`,`num_queries`,`num_hiddens`)。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "d06baadf",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:01:34.545405Z",
|
||
"iopub.status.busy": "2023-08-18T07:01:34.544605Z",
|
||
"iopub.status.idle": "2023-08-18T07:01:34.571251Z",
|
||
"shell.execute_reply": "2023-08-18T07:01:34.570476Z"
|
||
},
|
||
"origin_pos": 17,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"MultiHeadAttention(\n",
|
||
" (attention): DotProductAttention(\n",
|
||
" (dropout): Dropout(p=0.5, inplace=False)\n",
|
||
" )\n",
|
||
" (W_q): Linear(in_features=100, out_features=100, bias=False)\n",
|
||
" (W_k): Linear(in_features=100, out_features=100, bias=False)\n",
|
||
" (W_v): Linear(in_features=100, out_features=100, bias=False)\n",
|
||
" (W_o): Linear(in_features=100, out_features=100, bias=False)\n",
|
||
")"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"num_hiddens, num_heads = 100, 5\n",
|
||
"attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n",
|
||
" num_hiddens, num_heads, 0.5)\n",
|
||
"attention.eval()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "8da65afc",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:01:34.574642Z",
|
||
"iopub.status.busy": "2023-08-18T07:01:34.574021Z",
|
||
"iopub.status.idle": "2023-08-18T07:01:34.588848Z",
|
||
"shell.execute_reply": "2023-08-18T07:01:34.587945Z"
|
||
},
|
||
"origin_pos": 20,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"torch.Size([2, 4, 100])"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"batch_size, num_queries = 2, 4\n",
|
||
"num_kvpairs, valid_lens = 6, torch.tensor([3, 2])\n",
|
||
"X = torch.ones((batch_size, num_queries, num_hiddens))\n",
|
||
"Y = torch.ones((batch_size, num_kvpairs, num_hiddens))\n",
|
||
"attention(X, Y, Y, valid_lens).shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c228d916",
|
||
"metadata": {
|
||
"origin_pos": 22
|
||
},
|
||
"source": [
|
||
"## 小结\n",
|
||
"\n",
|
||
"* 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。\n",
|
||
"* 基于适当的张量操作,可以实现多头注意力的并行计算。\n",
|
||
"\n",
|
||
"## 练习\n",
|
||
"\n",
|
||
"1. 分别可视化这个实验中的多个头的注意力权重。\n",
|
||
"1. 假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bfae5c77",
|
||
"metadata": {
|
||
"origin_pos": 24,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"source": [
|
||
"[Discussions](https://discuss.d2l.ai/t/5758)\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"language_info": {
|
||
"name": "python"
|
||
},
|
||
"required_libs": []
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
} |