{ "cells": [ { "cell_type": "markdown", "id": "6709af5b", "metadata": { "origin_pos": 0 }, "source": [ "# Bahdanau 注意力\n", ":label:`sec_seq2seq_attention`\n", "\n", " :numref:`sec_seq2seq`中探讨了机器翻译问题:\n", "通过设计一个基于两个循环神经网络的编码器-解码器架构,\n", "用于序列到序列学习。\n", "具体来说,循环神经网络编码器将长度可变的序列转换为固定形状的上下文变量,\n", "然后循环神经网络解码器根据生成的词元和上下文变量\n", "按词元生成输出(目标)序列词元。\n", "然而,即使并非所有输入(源)词元都对解码某个词元都有用,\n", "在每个解码步骤中仍使用编码*相同*的上下文变量。\n", "有什么方法能改变上下文变量呢?\n", "\n", "我们试着从 :cite:`Graves.2013`中找到灵感:\n", "在为给定文本序列生成手写的挑战中,\n", "Graves设计了一种可微注意力模型,\n", "将文本字符与更长的笔迹对齐,\n", "其中对齐方式仅向一个方向移动。\n", "受学习对齐想法的启发,Bahdanau等人提出了一个没有严格单向对齐限制的\n", "可微注意力模型 :cite:`Bahdanau.Cho.Bengio.2014`。\n", "在预测词元时,如果不是所有输入词元都相关,模型将仅对齐(或参与)输入序列中与当前预测相关的部分。这是通过将上下文变量视为注意力集中的输出来实现的。\n", "\n", "## 模型\n", "\n", "下面描述的Bahdanau注意力模型\n", "将遵循 :numref:`sec_seq2seq`中的相同符号表达。\n", "这个新的基于注意力的模型与 :numref:`sec_seq2seq`中的模型相同,\n", "只不过 :eqref:`eq_seq2seq_s_t`中的上下文变量$\\mathbf{c}$\n", "在任何解码时间步$t'$都会被$\\mathbf{c}_{t'}$替换。\n", "假设输入序列中有$T$个词元,\n", "解码时间步$t'$的上下文变量是注意力集中的输出:\n", "\n", "$$\\mathbf{c}_{t'} = \\sum_{t=1}^T \\alpha(\\mathbf{s}_{t' - 1}, \\mathbf{h}_t) \\mathbf{h}_t,$$\n", "\n", "其中,时间步$t' - 1$时的解码器隐状态$\\mathbf{s}_{t' - 1}$是查询,\n", "编码器隐状态$\\mathbf{h}_t$既是键,也是值,\n", "注意力权重$\\alpha$是使用 :eqref:`eq_attn-scoring-alpha`\n", "所定义的加性注意力打分函数计算的。\n", "\n", "与 :numref:`fig_seq2seq_details`中的循环神经网络编码器-解码器架构略有不同,\n", " :numref:`fig_s2s_attention_details`描述了Bahdanau注意力的架构。\n", "\n", "![一个带有Bahdanau注意力的循环神经网络编码器-解码器模型](../img/seq2seq-attention-details.svg)\n", ":label:`fig_s2s_attention_details`\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "578eec9e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:15:29.787913Z", "iopub.status.busy": "2023-08-18T07:15:29.787398Z", "iopub.status.idle": "2023-08-18T07:15:31.837042Z", "shell.execute_reply": "2023-08-18T07:15:31.836075Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "9b4707cd", "metadata": { "origin_pos": 5 }, "source": [ "## 定义注意力解码器\n", "\n", "下面看看如何定义Bahdanau注意力,实现循环神经网络编码器-解码器。\n", "其实,我们只需重新定义解码器即可。\n", "为了更方便地显示学习的注意力权重,\n", "以下`AttentionDecoder`类定义了[**带有注意力机制解码器的基本接口**]。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "ca599f83", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:15:31.875080Z", "iopub.status.busy": "2023-08-18T07:15:31.874396Z", "iopub.status.idle": "2023-08-18T07:15:31.879397Z", "shell.execute_reply": "2023-08-18T07:15:31.878617Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "class AttentionDecoder(d2l.Decoder):\n", " \"\"\"带有注意力机制解码器的基本接口\"\"\"\n", " def __init__(self, **kwargs):\n", " super(AttentionDecoder, self).__init__(**kwargs)\n", "\n", " @property\n", " def attention_weights(self):\n", " raise NotImplementedError" ] }, { "cell_type": "markdown", "id": "fa6f1f6d", "metadata": { "origin_pos": 7 }, "source": [ "接下来,让我们在接下来的`Seq2SeqAttentionDecoder`类中\n", "[**实现带有Bahdanau注意力的循环神经网络解码器**]。\n", "首先,初始化解码器的状态,需要下面的输入:\n", "\n", "1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;\n", "1. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;\n", "1. 编码器有效长度(排除在注意力池中填充词元)。\n", "\n", "在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。\n", "因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "1d21b004", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:15:31.883127Z", "iopub.status.busy": "2023-08-18T07:15:31.882496Z", "iopub.status.idle": "2023-08-18T07:15:31.892223Z", "shell.execute_reply": "2023-08-18T07:15:31.891475Z" }, "origin_pos": 9, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class Seq2SeqAttentionDecoder(AttentionDecoder):\n", " def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n", " dropout=0, **kwargs):\n", " super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)\n", " self.attention = d2l.AdditiveAttention(\n", " num_hiddens, num_hiddens, num_hiddens, dropout)\n", " self.embedding = nn.Embedding(vocab_size, embed_size)\n", " self.rnn = nn.GRU(\n", " embed_size + num_hiddens, num_hiddens, num_layers,\n", " dropout=dropout)\n", " self.dense = nn.Linear(num_hiddens, vocab_size)\n", "\n", " def init_state(self, enc_outputs, enc_valid_lens, *args):\n", " # outputs的形状为(batch_size,num_steps,num_hiddens).\n", " # hidden_state的形状为(num_layers,batch_size,num_hiddens)\n", " outputs, hidden_state = enc_outputs\n", " return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)\n", "\n", " def forward(self, X, state):\n", " # enc_outputs的形状为(batch_size,num_steps,num_hiddens).\n", " # hidden_state的形状为(num_layers,batch_size,\n", " # num_hiddens)\n", " enc_outputs, hidden_state, enc_valid_lens = state\n", " # 输出X的形状为(num_steps,batch_size,embed_size)\n", " X = self.embedding(X).permute(1, 0, 2)\n", " outputs, self._attention_weights = [], []\n", " for x in X:\n", " # query的形状为(batch_size,1,num_hiddens)\n", " query = torch.unsqueeze(hidden_state[-1], dim=1)\n", " # context的形状为(batch_size,1,num_hiddens)\n", " context = self.attention(\n", " query, enc_outputs, enc_outputs, enc_valid_lens)\n", " # 在特征维度上连结\n", " x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)\n", " # 将x变形为(1,batch_size,embed_size+num_hiddens)\n", " out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)\n", " outputs.append(out)\n", " self._attention_weights.append(self.attention.attention_weights)\n", " # 全连接层变换后,outputs的形状为\n", " # (num_steps,batch_size,vocab_size)\n", " outputs = self.dense(torch.cat(outputs, dim=0))\n", " return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,\n", " enc_valid_lens]\n", "\n", " @property\n", " def attention_weights(self):\n", " return self._attention_weights" ] }, { "cell_type": "markdown", "id": "a1d2dbea", "metadata": { "origin_pos": 12 }, "source": [ "接下来,使用包含7个时间步的4个序列输入的小批量[**测试Bahdanau注意力解码器**]。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "9cd7ffa9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:15:31.895756Z", "iopub.status.busy": "2023-08-18T07:15:31.895141Z", "iopub.status.idle": "2023-08-18T07:15:31.935399Z", "shell.execute_reply": "2023-08-18T07:15:31.934579Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", " num_layers=2)\n", "encoder.eval()\n", "decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n", " num_layers=2)\n", "decoder.eval()\n", "X = torch.zeros((4, 7), dtype=torch.long) # (batch_size,num_steps)\n", "state = decoder.init_state(encoder(X), None)\n", "output, state = decoder(X, state)\n", "output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape" ] }, { "cell_type": "markdown", "id": "5d9a6381", "metadata": { "origin_pos": 17 }, "source": [ "## [**训练**]\n", "\n", "与 :numref:`sec_seq2seq_training`类似,\n", "我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器,\n", "并对这个模型进行机器翻译训练。\n", "由于新增的注意力机制,训练要比没有注意力机制的\n", " :numref:`sec_seq2seq_training`慢得多。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "44dd619c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:15:31.939197Z", "iopub.status.busy": "2023-08-18T07:15:31.938585Z", "iopub.status.idle": "2023-08-18T07:17:05.733764Z", "shell.execute_reply": "2023-08-18T07:17:05.732879Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.021, 4948.7 tokens/sec on cuda:0\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:17:05.700844\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n", "batch_size, num_steps = 64, 10\n", "lr, num_epochs, device = 0.005, 250, d2l.try_gpu()\n", "\n", "train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)\n", "encoder = d2l.Seq2SeqEncoder(\n", " len(src_vocab), embed_size, num_hiddens, num_layers, dropout)\n", "decoder = Seq2SeqAttentionDecoder(\n", " len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n", "net = d2l.EncoderDecoder(encoder, decoder)\n", "d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)" ] }, { "cell_type": "markdown", "id": "9b038fc0", "metadata": { "origin_pos": 19 }, "source": [ "模型训练后,我们用它[**将几个英语句子翻译成法语**]并计算它们的BLEU分数。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "b449b8a1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:05.738743Z", "iopub.status.busy": "2023-08-18T07:17:05.738202Z", "iopub.status.idle": "2023-08-18T07:17:05.773736Z", "shell.execute_reply": "2023-08-18T07:17:05.772225Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go . => va !, bleu 1.000\n", "i lost . => j'ai perdu ., bleu 1.000\n", "he's calm . => il est paresseux ., bleu 0.658\n", "i'm home . => je suis chez moi ., bleu 1.000\n" ] } ], "source": [ "engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .']\n", "fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n", "for eng, fra in zip(engs, fras):\n", " translation, dec_attention_weight_seq = d2l.predict_seq2seq(\n", " net, eng, src_vocab, tgt_vocab, num_steps, device, True)\n", " print(f'{eng} => {translation}, ',\n", " f'bleu {d2l.bleu(translation, fra, k=2):.3f}')" ] }, { "cell_type": "code", "execution_count": 7, "id": "703a029f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:05.780446Z", "iopub.status.busy": "2023-08-18T07:17:05.779931Z", "iopub.status.idle": "2023-08-18T07:17:05.800143Z", "shell.execute_reply": "2023-08-18T07:17:05.798893Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((\n", " 1, 1, -1, num_steps))" ] }, { "cell_type": "markdown", "id": "8bb0b7ce", "metadata": { "origin_pos": 23 }, "source": [ "训练结束后,下面通过[**可视化注意力权重**]\n", "会发现,每个查询都会在键值对上分配不同的权重,这说明\n", "在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "8074a1a6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:05.806859Z", "iopub.status.busy": "2023-08-18T07:17:05.805665Z", "iopub.status.idle": "2023-08-18T07:17:06.012470Z", "shell.execute_reply": "2023-08-18T07:17:06.011495Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:17:05.968894\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 加上一个包含序列结束词元\n", "d2l.show_heatmaps(\n", " attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),\n", " xlabel='Key positions', ylabel='Query positions')" ] }, { "cell_type": "markdown", "id": "30eb6f85", "metadata": { "origin_pos": 27 }, "source": [ "## 小结\n", "\n", "* 在预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。\n", "* 在循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步的解码器隐状态视为查询,在所有时间步的编码器隐状态同时视为键和值。\n", "\n", "## 练习\n", "\n", "1. 在实验中用LSTM替换GRU。\n", "1. 修改实验以将加性注意力打分函数替换为缩放点积注意力,它如何影响训练效率?\n" ] }, { "cell_type": "markdown", "id": "4d7906bc", "metadata": { "origin_pos": 29, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/5754)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }