1761 lines
65 KiB
Plaintext
1761 lines
65 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "6709af5b",
|
||
"metadata": {
|
||
"origin_pos": 0
|
||
},
|
||
"source": [
|
||
"# Bahdanau 注意力\n",
|
||
":label:`sec_seq2seq_attention`\n",
|
||
"\n",
|
||
" :numref:`sec_seq2seq`中探讨了机器翻译问题:\n",
|
||
"通过设计一个基于两个循环神经网络的编码器-解码器架构,\n",
|
||
"用于序列到序列学习。\n",
|
||
"具体来说,循环神经网络编码器将长度可变的序列转换为固定形状的上下文变量,\n",
|
||
"然后循环神经网络解码器根据生成的词元和上下文变量\n",
|
||
"按词元生成输出(目标)序列词元。\n",
|
||
"然而,即使并非所有输入(源)词元都对解码某个词元都有用,\n",
|
||
"在每个解码步骤中仍使用编码*相同*的上下文变量。\n",
|
||
"有什么方法能改变上下文变量呢?\n",
|
||
"\n",
|
||
"我们试着从 :cite:`Graves.2013`中找到灵感:\n",
|
||
"在为给定文本序列生成手写的挑战中,\n",
|
||
"Graves设计了一种可微注意力模型,\n",
|
||
"将文本字符与更长的笔迹对齐,\n",
|
||
"其中对齐方式仅向一个方向移动。\n",
|
||
"受学习对齐想法的启发,Bahdanau等人提出了一个没有严格单向对齐限制的\n",
|
||
"可微注意力模型 :cite:`Bahdanau.Cho.Bengio.2014`。\n",
|
||
"在预测词元时,如果不是所有输入词元都相关,模型将仅对齐(或参与)输入序列中与当前预测相关的部分。这是通过将上下文变量视为注意力集中的输出来实现的。\n",
|
||
"\n",
|
||
"## 模型\n",
|
||
"\n",
|
||
"下面描述的Bahdanau注意力模型\n",
|
||
"将遵循 :numref:`sec_seq2seq`中的相同符号表达。\n",
|
||
"这个新的基于注意力的模型与 :numref:`sec_seq2seq`中的模型相同,\n",
|
||
"只不过 :eqref:`eq_seq2seq_s_t`中的上下文变量$\\mathbf{c}$\n",
|
||
"在任何解码时间步$t'$都会被$\\mathbf{c}_{t'}$替换。\n",
|
||
"假设输入序列中有$T$个词元,\n",
|
||
"解码时间步$t'$的上下文变量是注意力集中的输出:\n",
|
||
"\n",
|
||
"$$\\mathbf{c}_{t'} = \\sum_{t=1}^T \\alpha(\\mathbf{s}_{t' - 1}, \\mathbf{h}_t) \\mathbf{h}_t,$$\n",
|
||
"\n",
|
||
"其中,时间步$t' - 1$时的解码器隐状态$\\mathbf{s}_{t' - 1}$是查询,\n",
|
||
"编码器隐状态$\\mathbf{h}_t$既是键,也是值,\n",
|
||
"注意力权重$\\alpha$是使用 :eqref:`eq_attn-scoring-alpha`\n",
|
||
"所定义的加性注意力打分函数计算的。\n",
|
||
"\n",
|
||
"与 :numref:`fig_seq2seq_details`中的循环神经网络编码器-解码器架构略有不同,\n",
|
||
" :numref:`fig_s2s_attention_details`描述了Bahdanau注意力的架构。\n",
|
||
"\n",
|
||
"\n",
|
||
":label:`fig_s2s_attention_details`\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "578eec9e",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:29.787913Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:29.787398Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:31.837042Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:31.836075Z"
|
||
},
|
||
"origin_pos": 2,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"from torch import nn\n",
|
||
"from d2l import torch as d2l"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9b4707cd",
|
||
"metadata": {
|
||
"origin_pos": 5
|
||
},
|
||
"source": [
|
||
"## 定义注意力解码器\n",
|
||
"\n",
|
||
"下面看看如何定义Bahdanau注意力,实现循环神经网络编码器-解码器。\n",
|
||
"其实,我们只需重新定义解码器即可。\n",
|
||
"为了更方便地显示学习的注意力权重,\n",
|
||
"以下`AttentionDecoder`类定义了[**带有注意力机制解码器的基本接口**]。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "ca599f83",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:31.875080Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:31.874396Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:31.879397Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:31.878617Z"
|
||
},
|
||
"origin_pos": 6,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@save\n",
|
||
"class AttentionDecoder(d2l.Decoder):\n",
|
||
" \"\"\"带有注意力机制解码器的基本接口\"\"\"\n",
|
||
" def __init__(self, **kwargs):\n",
|
||
" super(AttentionDecoder, self).__init__(**kwargs)\n",
|
||
"\n",
|
||
" @property\n",
|
||
" def attention_weights(self):\n",
|
||
" raise NotImplementedError"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "fa6f1f6d",
|
||
"metadata": {
|
||
"origin_pos": 7
|
||
},
|
||
"source": [
|
||
"接下来,让我们在接下来的`Seq2SeqAttentionDecoder`类中\n",
|
||
"[**实现带有Bahdanau注意力的循环神经网络解码器**]。\n",
|
||
"首先,初始化解码器的状态,需要下面的输入:\n",
|
||
"\n",
|
||
"1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;\n",
|
||
"1. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;\n",
|
||
"1. 编码器有效长度(排除在注意力池中填充词元)。\n",
|
||
"\n",
|
||
"在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。\n",
|
||
"因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "1d21b004",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:31.883127Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:31.882496Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:31.892223Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:31.891475Z"
|
||
},
|
||
"origin_pos": 9,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class Seq2SeqAttentionDecoder(AttentionDecoder):\n",
|
||
" def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n",
|
||
" dropout=0, **kwargs):\n",
|
||
" super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)\n",
|
||
" self.attention = d2l.AdditiveAttention(\n",
|
||
" num_hiddens, num_hiddens, num_hiddens, dropout)\n",
|
||
" self.embedding = nn.Embedding(vocab_size, embed_size)\n",
|
||
" self.rnn = nn.GRU(\n",
|
||
" embed_size + num_hiddens, num_hiddens, num_layers,\n",
|
||
" dropout=dropout)\n",
|
||
" self.dense = nn.Linear(num_hiddens, vocab_size)\n",
|
||
"\n",
|
||
" def init_state(self, enc_outputs, enc_valid_lens, *args):\n",
|
||
" # outputs的形状为(batch_size,num_steps,num_hiddens).\n",
|
||
" # hidden_state的形状为(num_layers,batch_size,num_hiddens)\n",
|
||
" outputs, hidden_state = enc_outputs\n",
|
||
" return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)\n",
|
||
"\n",
|
||
" def forward(self, X, state):\n",
|
||
" # enc_outputs的形状为(batch_size,num_steps,num_hiddens).\n",
|
||
" # hidden_state的形状为(num_layers,batch_size,\n",
|
||
" # num_hiddens)\n",
|
||
" enc_outputs, hidden_state, enc_valid_lens = state\n",
|
||
" # 输出X的形状为(num_steps,batch_size,embed_size)\n",
|
||
" X = self.embedding(X).permute(1, 0, 2)\n",
|
||
" outputs, self._attention_weights = [], []\n",
|
||
" for x in X:\n",
|
||
" # query的形状为(batch_size,1,num_hiddens)\n",
|
||
" query = torch.unsqueeze(hidden_state[-1], dim=1)\n",
|
||
" # context的形状为(batch_size,1,num_hiddens)\n",
|
||
" context = self.attention(\n",
|
||
" query, enc_outputs, enc_outputs, enc_valid_lens)\n",
|
||
" # 在特征维度上连结\n",
|
||
" x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)\n",
|
||
" # 将x变形为(1,batch_size,embed_size+num_hiddens)\n",
|
||
" out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)\n",
|
||
" outputs.append(out)\n",
|
||
" self._attention_weights.append(self.attention.attention_weights)\n",
|
||
" # 全连接层变换后,outputs的形状为\n",
|
||
" # (num_steps,batch_size,vocab_size)\n",
|
||
" outputs = self.dense(torch.cat(outputs, dim=0))\n",
|
||
" return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,\n",
|
||
" enc_valid_lens]\n",
|
||
"\n",
|
||
" @property\n",
|
||
" def attention_weights(self):\n",
|
||
" return self._attention_weights"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a1d2dbea",
|
||
"metadata": {
|
||
"origin_pos": 12
|
||
},
|
||
"source": [
|
||
"接下来,使用包含7个时间步的4个序列输入的小批量[**测试Bahdanau注意力解码器**]。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "9cd7ffa9",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:31.895756Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:31.895141Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:31.935399Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:31.934579Z"
|
||
},
|
||
"origin_pos": 14,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n",
|
||
" num_layers=2)\n",
|
||
"encoder.eval()\n",
|
||
"decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n",
|
||
" num_layers=2)\n",
|
||
"decoder.eval()\n",
|
||
"X = torch.zeros((4, 7), dtype=torch.long) # (batch_size,num_steps)\n",
|
||
"state = decoder.init_state(encoder(X), None)\n",
|
||
"output, state = decoder(X, state)\n",
|
||
"output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "5d9a6381",
|
||
"metadata": {
|
||
"origin_pos": 17
|
||
},
|
||
"source": [
|
||
"## [**训练**]\n",
|
||
"\n",
|
||
"与 :numref:`sec_seq2seq_training`类似,\n",
|
||
"我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器,\n",
|
||
"并对这个模型进行机器翻译训练。\n",
|
||
"由于新增的注意力机制,训练要比没有注意力机制的\n",
|
||
" :numref:`sec_seq2seq_training`慢得多。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "44dd619c",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:31.939197Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:31.938585Z",
|
||
"iopub.status.idle": "2023-08-18T07:17:05.733764Z",
|
||
"shell.execute_reply": "2023-08-18T07:17:05.732879Z"
|
||
},
|
||
"origin_pos": 18,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"loss 0.021, 4948.7 tokens/sec on cuda:0\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/svg+xml": [
|
||
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
|
||
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"262.1875pt\" height=\"180.65625pt\" viewBox=\"0 0 262.1875 180.65625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
|
||
" <metadata>\n",
|
||
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
|
||
" <cc:Work>\n",
|
||
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
|
||
" <dc:date>2023-08-18T07:17:05.700844</dc:date>\n",
|
||
" <dc:format>image/svg+xml</dc:format>\n",
|
||
" <dc:creator>\n",
|
||
" <cc:Agent>\n",
|
||
" <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
|
||
" </cc:Agent>\n",
|
||
" </dc:creator>\n",
|
||
" </cc:Work>\n",
|
||
" </rdf:RDF>\n",
|
||
" </metadata>\n",
|
||
" <defs>\n",
|
||
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
|
||
" </defs>\n",
|
||
" <g id=\"figure_1\">\n",
|
||
" <g id=\"patch_1\">\n",
|
||
" <path d=\"M 0 180.65625 \n",
|
||
"L 262.1875 180.65625 \n",
|
||
"L 262.1875 0 \n",
|
||
"L 0 0 \n",
|
||
"L 0 180.65625 \n",
|
||
"z\n",
|
||
"\" style=\"fill: none\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"axes_1\">\n",
|
||
" <g id=\"patch_2\">\n",
|
||
" <path d=\"M 50.14375 143.1 \n",
|
||
"L 245.44375 143.1 \n",
|
||
"L 245.44375 7.2 \n",
|
||
"L 50.14375 7.2 \n",
|
||
"z\n",
|
||
"\" style=\"fill: #ffffff\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"matplotlib.axis_1\">\n",
|
||
" <g id=\"xtick_1\">\n",
|
||
" <g id=\"line2d_1\">\n",
|
||
" <path d=\"M 82.69375 143.1 \n",
|
||
"L 82.69375 7.2 \n",
|
||
"\" clip-path=\"url(#p2a062d3940)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_2\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"m1447da81e9\" d=\"M 0 0 \n",
|
||
"L 0 3.5 \n",
|
||
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </defs>\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m1447da81e9\" x=\"82.69375\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_1\">\n",
|
||
" <!-- 50 -->\n",
|
||
" <g transform=\"translate(76.33125 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
|
||
"L 3169 4666 \n",
|
||
"L 3169 4134 \n",
|
||
"L 1269 4134 \n",
|
||
"L 1269 2991 \n",
|
||
"Q 1406 3038 1543 3061 \n",
|
||
"Q 1681 3084 1819 3084 \n",
|
||
"Q 2600 3084 3056 2656 \n",
|
||
"Q 3513 2228 3513 1497 \n",
|
||
"Q 3513 744 3044 326 \n",
|
||
"Q 2575 -91 1722 -91 \n",
|
||
"Q 1428 -91 1123 -41 \n",
|
||
"Q 819 9 494 109 \n",
|
||
"L 494 744 \n",
|
||
"Q 775 591 1075 516 \n",
|
||
"Q 1375 441 1709 441 \n",
|
||
"Q 2250 441 2565 725 \n",
|
||
"Q 2881 1009 2881 1497 \n",
|
||
"Q 2881 1984 2565 2268 \n",
|
||
"Q 2250 2553 1709 2553 \n",
|
||
"Q 1456 2553 1204 2497 \n",
|
||
"Q 953 2441 691 2322 \n",
|
||
"L 691 4666 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
|
||
"Q 1547 4250 1301 3770 \n",
|
||
"Q 1056 3291 1056 2328 \n",
|
||
"Q 1056 1369 1301 889 \n",
|
||
"Q 1547 409 2034 409 \n",
|
||
"Q 2525 409 2770 889 \n",
|
||
"Q 3016 1369 3016 2328 \n",
|
||
"Q 3016 3291 2770 3770 \n",
|
||
"Q 2525 4250 2034 4250 \n",
|
||
"z\n",
|
||
"M 2034 4750 \n",
|
||
"Q 2819 4750 3233 4129 \n",
|
||
"Q 3647 3509 3647 2328 \n",
|
||
"Q 3647 1150 3233 529 \n",
|
||
"Q 2819 -91 2034 -91 \n",
|
||
"Q 1250 -91 836 529 \n",
|
||
"Q 422 1150 422 2328 \n",
|
||
"Q 422 3509 836 4129 \n",
|
||
"Q 1250 4750 2034 4750 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_2\">\n",
|
||
" <g id=\"line2d_3\">\n",
|
||
" <path d=\"M 123.38125 143.1 \n",
|
||
"L 123.38125 7.2 \n",
|
||
"\" clip-path=\"url(#p2a062d3940)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_4\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m1447da81e9\" x=\"123.38125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_2\">\n",
|
||
" <!-- 100 -->\n",
|
||
" <g transform=\"translate(113.8375 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
|
||
"L 1825 531 \n",
|
||
"L 1825 4091 \n",
|
||
"L 703 3866 \n",
|
||
"L 703 4441 \n",
|
||
"L 1819 4666 \n",
|
||
"L 2450 4666 \n",
|
||
"L 2450 531 \n",
|
||
"L 3481 531 \n",
|
||
"L 3481 0 \n",
|
||
"L 794 0 \n",
|
||
"L 794 531 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_3\">\n",
|
||
" <g id=\"line2d_5\">\n",
|
||
" <path d=\"M 164.06875 143.1 \n",
|
||
"L 164.06875 7.2 \n",
|
||
"\" clip-path=\"url(#p2a062d3940)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_6\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m1447da81e9\" x=\"164.06875\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_3\">\n",
|
||
" <!-- 150 -->\n",
|
||
" <g transform=\"translate(154.525 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_4\">\n",
|
||
" <g id=\"line2d_7\">\n",
|
||
" <path d=\"M 204.75625 143.1 \n",
|
||
"L 204.75625 7.2 \n",
|
||
"\" clip-path=\"url(#p2a062d3940)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_8\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m1447da81e9\" x=\"204.75625\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_4\">\n",
|
||
" <!-- 200 -->\n",
|
||
" <g transform=\"translate(195.2125 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
|
||
"L 3431 531 \n",
|
||
"L 3431 0 \n",
|
||
"L 469 0 \n",
|
||
"L 469 531 \n",
|
||
"Q 828 903 1448 1529 \n",
|
||
"Q 2069 2156 2228 2338 \n",
|
||
"Q 2531 2678 2651 2914 \n",
|
||
"Q 2772 3150 2772 3378 \n",
|
||
"Q 2772 3750 2511 3984 \n",
|
||
"Q 2250 4219 1831 4219 \n",
|
||
"Q 1534 4219 1204 4116 \n",
|
||
"Q 875 4013 500 3803 \n",
|
||
"L 500 4441 \n",
|
||
"Q 881 4594 1212 4672 \n",
|
||
"Q 1544 4750 1819 4750 \n",
|
||
"Q 2544 4750 2975 4387 \n",
|
||
"Q 3406 4025 3406 3419 \n",
|
||
"Q 3406 3131 3298 2873 \n",
|
||
"Q 3191 2616 2906 2266 \n",
|
||
"Q 2828 2175 2409 1742 \n",
|
||
"Q 1991 1309 1228 531 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_5\">\n",
|
||
" <g id=\"line2d_9\">\n",
|
||
" <path d=\"M 245.44375 143.1 \n",
|
||
"L 245.44375 7.2 \n",
|
||
"\" clip-path=\"url(#p2a062d3940)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_10\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m1447da81e9\" x=\"245.44375\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_5\">\n",
|
||
" <!-- 250 -->\n",
|
||
" <g transform=\"translate(235.9 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_6\">\n",
|
||
" <!-- epoch -->\n",
|
||
" <g transform=\"translate(132.565625 171.376563)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
|
||
"L 3597 1613 \n",
|
||
"L 953 1613 \n",
|
||
"Q 991 1019 1311 708 \n",
|
||
"Q 1631 397 2203 397 \n",
|
||
"Q 2534 397 2845 478 \n",
|
||
"Q 3156 559 3463 722 \n",
|
||
"L 3463 178 \n",
|
||
"Q 3153 47 2828 -22 \n",
|
||
"Q 2503 -91 2169 -91 \n",
|
||
"Q 1331 -91 842 396 \n",
|
||
"Q 353 884 353 1716 \n",
|
||
"Q 353 2575 817 3079 \n",
|
||
"Q 1281 3584 2069 3584 \n",
|
||
"Q 2775 3584 3186 3129 \n",
|
||
"Q 3597 2675 3597 1894 \n",
|
||
"z\n",
|
||
"M 3022 2063 \n",
|
||
"Q 3016 2534 2758 2815 \n",
|
||
"Q 2500 3097 2075 3097 \n",
|
||
"Q 1594 3097 1305 2825 \n",
|
||
"Q 1016 2553 972 2059 \n",
|
||
"L 3022 2063 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
|
||
"L 1159 -1331 \n",
|
||
"L 581 -1331 \n",
|
||
"L 581 3500 \n",
|
||
"L 1159 3500 \n",
|
||
"L 1159 2969 \n",
|
||
"Q 1341 3281 1617 3432 \n",
|
||
"Q 1894 3584 2278 3584 \n",
|
||
"Q 2916 3584 3314 3078 \n",
|
||
"Q 3713 2572 3713 1747 \n",
|
||
"Q 3713 922 3314 415 \n",
|
||
"Q 2916 -91 2278 -91 \n",
|
||
"Q 1894 -91 1617 61 \n",
|
||
"Q 1341 213 1159 525 \n",
|
||
"z\n",
|
||
"M 3116 1747 \n",
|
||
"Q 3116 2381 2855 2742 \n",
|
||
"Q 2594 3103 2138 3103 \n",
|
||
"Q 1681 3103 1420 2742 \n",
|
||
"Q 1159 2381 1159 1747 \n",
|
||
"Q 1159 1113 1420 752 \n",
|
||
"Q 1681 391 2138 391 \n",
|
||
"Q 2594 391 2855 752 \n",
|
||
"Q 3116 1113 3116 1747 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
|
||
"Q 1497 3097 1228 2736 \n",
|
||
"Q 959 2375 959 1747 \n",
|
||
"Q 959 1119 1226 758 \n",
|
||
"Q 1494 397 1959 397 \n",
|
||
"Q 2419 397 2687 759 \n",
|
||
"Q 2956 1122 2956 1747 \n",
|
||
"Q 2956 2369 2687 2733 \n",
|
||
"Q 2419 3097 1959 3097 \n",
|
||
"z\n",
|
||
"M 1959 3584 \n",
|
||
"Q 2709 3584 3137 3096 \n",
|
||
"Q 3566 2609 3566 1747 \n",
|
||
"Q 3566 888 3137 398 \n",
|
||
"Q 2709 -91 1959 -91 \n",
|
||
"Q 1206 -91 779 398 \n",
|
||
"Q 353 888 353 1747 \n",
|
||
"Q 353 2609 779 3096 \n",
|
||
"Q 1206 3584 1959 3584 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-63\" d=\"M 3122 3366 \n",
|
||
"L 3122 2828 \n",
|
||
"Q 2878 2963 2633 3030 \n",
|
||
"Q 2388 3097 2138 3097 \n",
|
||
"Q 1578 3097 1268 2742 \n",
|
||
"Q 959 2388 959 1747 \n",
|
||
"Q 959 1106 1268 751 \n",
|
||
"Q 1578 397 2138 397 \n",
|
||
"Q 2388 397 2633 464 \n",
|
||
"Q 2878 531 3122 666 \n",
|
||
"L 3122 134 \n",
|
||
"Q 2881 22 2623 -34 \n",
|
||
"Q 2366 -91 2075 -91 \n",
|
||
"Q 1284 -91 818 406 \n",
|
||
"Q 353 903 353 1747 \n",
|
||
"Q 353 2603 823 3093 \n",
|
||
"Q 1294 3584 2113 3584 \n",
|
||
"Q 2378 3584 2631 3529 \n",
|
||
"Q 2884 3475 3122 3366 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-68\" d=\"M 3513 2113 \n",
|
||
"L 3513 0 \n",
|
||
"L 2938 0 \n",
|
||
"L 2938 2094 \n",
|
||
"Q 2938 2591 2744 2837 \n",
|
||
"Q 2550 3084 2163 3084 \n",
|
||
"Q 1697 3084 1428 2787 \n",
|
||
"Q 1159 2491 1159 1978 \n",
|
||
"L 1159 0 \n",
|
||
"L 581 0 \n",
|
||
"L 581 4863 \n",
|
||
"L 1159 4863 \n",
|
||
"L 1159 2956 \n",
|
||
"Q 1366 3272 1645 3428 \n",
|
||
"Q 1925 3584 2291 3584 \n",
|
||
"Q 2894 3584 3203 3211 \n",
|
||
"Q 3513 2838 3513 2113 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-65\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-70\" x=\"61.523438\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6f\" x=\"125\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-63\" x=\"186.181641\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-68\" x=\"241.162109\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"matplotlib.axis_2\">\n",
|
||
" <g id=\"ytick_1\">\n",
|
||
" <g id=\"line2d_11\">\n",
|
||
" <path d=\"M 50.14375 116.885299 \n",
|
||
"L 245.44375 116.885299 \n",
|
||
"\" clip-path=\"url(#p2a062d3940)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_12\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"mb4dd508465\" d=\"M 0 0 \n",
|
||
"L -3.5 0 \n",
|
||
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </defs>\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#mb4dd508465\" x=\"50.14375\" y=\"116.885299\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_7\">\n",
|
||
" <!-- 0.05 -->\n",
|
||
" <g transform=\"translate(20.878125 120.684518)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
|
||
"L 1344 794 \n",
|
||
"L 1344 0 \n",
|
||
"L 684 0 \n",
|
||
"L 684 794 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\" x=\"159.033203\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_2\">\n",
|
||
" <g id=\"line2d_13\">\n",
|
||
" <path d=\"M 50.14375 82.94084 \n",
|
||
"L 245.44375 82.94084 \n",
|
||
"\" clip-path=\"url(#p2a062d3940)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_14\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#mb4dd508465\" x=\"50.14375\" y=\"82.94084\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_8\">\n",
|
||
" <!-- 0.10 -->\n",
|
||
" <g transform=\"translate(20.878125 86.740059)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\" x=\"95.410156\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"159.033203\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_3\">\n",
|
||
" <g id=\"line2d_15\">\n",
|
||
" <path d=\"M 50.14375 48.996381 \n",
|
||
"L 245.44375 48.996381 \n",
|
||
"\" clip-path=\"url(#p2a062d3940)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_16\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#mb4dd508465\" x=\"50.14375\" y=\"48.996381\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_9\">\n",
|
||
" <!-- 0.15 -->\n",
|
||
" <g transform=\"translate(20.878125 52.7956)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\" x=\"95.410156\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\" x=\"159.033203\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_4\">\n",
|
||
" <g id=\"line2d_17\">\n",
|
||
" <path d=\"M 50.14375 15.051922 \n",
|
||
"L 245.44375 15.051922 \n",
|
||
"\" clip-path=\"url(#p2a062d3940)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_18\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#mb4dd508465\" x=\"50.14375\" y=\"15.051922\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_10\">\n",
|
||
" <!-- 0.20 -->\n",
|
||
" <g transform=\"translate(20.878125 18.851141)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\" x=\"95.410156\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"159.033203\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_11\">\n",
|
||
" <!-- loss -->\n",
|
||
" <g transform=\"translate(14.798437 84.807812)rotate(-90)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
|
||
"L 1178 4863 \n",
|
||
"L 1178 0 \n",
|
||
"L 603 0 \n",
|
||
"L 603 4863 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
|
||
"L 2834 2853 \n",
|
||
"Q 2591 2978 2328 3040 \n",
|
||
"Q 2066 3103 1784 3103 \n",
|
||
"Q 1356 3103 1142 2972 \n",
|
||
"Q 928 2841 928 2578 \n",
|
||
"Q 928 2378 1081 2264 \n",
|
||
"Q 1234 2150 1697 2047 \n",
|
||
"L 1894 2003 \n",
|
||
"Q 2506 1872 2764 1633 \n",
|
||
"Q 3022 1394 3022 966 \n",
|
||
"Q 3022 478 2636 193 \n",
|
||
"Q 2250 -91 1575 -91 \n",
|
||
"Q 1294 -91 989 -36 \n",
|
||
"Q 684 19 347 128 \n",
|
||
"L 347 722 \n",
|
||
"Q 666 556 975 473 \n",
|
||
"Q 1284 391 1588 391 \n",
|
||
"Q 1994 391 2212 530 \n",
|
||
"Q 2431 669 2431 922 \n",
|
||
"Q 2431 1156 2273 1281 \n",
|
||
"Q 2116 1406 1581 1522 \n",
|
||
"L 1381 1569 \n",
|
||
"Q 847 1681 609 1914 \n",
|
||
"Q 372 2147 372 2553 \n",
|
||
"Q 372 3047 722 3315 \n",
|
||
"Q 1072 3584 1716 3584 \n",
|
||
"Q 2034 3584 2315 3537 \n",
|
||
"Q 2597 3491 2834 3397 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"88.964844\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"141.064453\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_19\">\n",
|
||
" <path d=\"M 50.14375 13.377273 \n",
|
||
"L 58.28125 59.149776 \n",
|
||
"L 66.41875 82.645683 \n",
|
||
"L 74.55625 98.327398 \n",
|
||
"L 82.69375 108.725129 \n",
|
||
"L 90.83125 115.918917 \n",
|
||
"L 98.96875 120.467924 \n",
|
||
"L 107.10625 123.806795 \n",
|
||
"L 115.24375 127.208336 \n",
|
||
"L 123.38125 128.293605 \n",
|
||
"L 131.51875 130.661535 \n",
|
||
"L 139.65625 131.764555 \n",
|
||
"L 147.79375 132.848189 \n",
|
||
"L 155.93125 133.955715 \n",
|
||
"L 164.06875 134.020737 \n",
|
||
"L 172.20625 134.799663 \n",
|
||
"L 180.34375 135.084322 \n",
|
||
"L 188.48125 135.743767 \n",
|
||
"L 196.61875 136.085946 \n",
|
||
"L 204.75625 136.172514 \n",
|
||
"L 212.89375 136.484072 \n",
|
||
"L 221.03125 136.460007 \n",
|
||
"L 229.16875 136.456094 \n",
|
||
"L 237.30625 136.922727 \n",
|
||
"L 245.44375 136.690915 \n",
|
||
"\" clip-path=\"url(#p2a062d3940)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_3\">\n",
|
||
" <path d=\"M 50.14375 143.1 \n",
|
||
"L 50.14375 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_4\">\n",
|
||
" <path d=\"M 245.44375 143.1 \n",
|
||
"L 245.44375 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_5\">\n",
|
||
" <path d=\"M 50.14375 143.1 \n",
|
||
"L 245.44375 143.1 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_6\">\n",
|
||
" <path d=\"M 50.14375 7.2 \n",
|
||
"L 245.44375 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <defs>\n",
|
||
" <clipPath id=\"p2a062d3940\">\n",
|
||
" <rect x=\"50.14375\" y=\"7.2\" width=\"195.3\" height=\"135.9\"/>\n",
|
||
" </clipPath>\n",
|
||
" </defs>\n",
|
||
"</svg>\n"
|
||
],
|
||
"text/plain": [
|
||
"<Figure size 252x180 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n",
|
||
"batch_size, num_steps = 64, 10\n",
|
||
"lr, num_epochs, device = 0.005, 250, d2l.try_gpu()\n",
|
||
"\n",
|
||
"train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)\n",
|
||
"encoder = d2l.Seq2SeqEncoder(\n",
|
||
" len(src_vocab), embed_size, num_hiddens, num_layers, dropout)\n",
|
||
"decoder = Seq2SeqAttentionDecoder(\n",
|
||
" len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n",
|
||
"net = d2l.EncoderDecoder(encoder, decoder)\n",
|
||
"d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9b038fc0",
|
||
"metadata": {
|
||
"origin_pos": 19
|
||
},
|
||
"source": [
|
||
"模型训练后,我们用它[**将几个英语句子翻译成法语**]并计算它们的BLEU分数。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "b449b8a1",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:17:05.738743Z",
|
||
"iopub.status.busy": "2023-08-18T07:17:05.738202Z",
|
||
"iopub.status.idle": "2023-08-18T07:17:05.773736Z",
|
||
"shell.execute_reply": "2023-08-18T07:17:05.772225Z"
|
||
},
|
||
"origin_pos": 20,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"go . => va !, bleu 1.000\n",
|
||
"i lost . => j'ai perdu ., bleu 1.000\n",
|
||
"he's calm . => il est paresseux ., bleu 0.658\n",
|
||
"i'm home . => je suis chez moi ., bleu 1.000\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .']\n",
|
||
"fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n",
|
||
"for eng, fra in zip(engs, fras):\n",
|
||
" translation, dec_attention_weight_seq = d2l.predict_seq2seq(\n",
|
||
" net, eng, src_vocab, tgt_vocab, num_steps, device, True)\n",
|
||
" print(f'{eng} => {translation}, ',\n",
|
||
" f'bleu {d2l.bleu(translation, fra, k=2):.3f}')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "703a029f",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:17:05.780446Z",
|
||
"iopub.status.busy": "2023-08-18T07:17:05.779931Z",
|
||
"iopub.status.idle": "2023-08-18T07:17:05.800143Z",
|
||
"shell.execute_reply": "2023-08-18T07:17:05.798893Z"
|
||
},
|
||
"origin_pos": 22,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((\n",
|
||
" 1, 1, -1, num_steps))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8bb0b7ce",
|
||
"metadata": {
|
||
"origin_pos": 23
|
||
},
|
||
"source": [
|
||
"训练结束后,下面通过[**可视化注意力权重**]\n",
|
||
"会发现,每个查询都会在键值对上分配不同的权重,这说明\n",
|
||
"在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "8074a1a6",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:17:05.806859Z",
|
||
"iopub.status.busy": "2023-08-18T07:17:05.805665Z",
|
||
"iopub.status.idle": "2023-08-18T07:17:06.012470Z",
|
||
"shell.execute_reply": "2023-08-18T07:17:06.011495Z"
|
||
},
|
||
"origin_pos": 25,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/svg+xml": [
|
||
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
|
||
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"165.99575pt\" height=\"180.65625pt\" viewBox=\"0 0 165.99575 180.65625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
|
||
" <metadata>\n",
|
||
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
|
||
" <cc:Work>\n",
|
||
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
|
||
" <dc:date>2023-08-18T07:17:05.968894</dc:date>\n",
|
||
" <dc:format>image/svg+xml</dc:format>\n",
|
||
" <dc:creator>\n",
|
||
" <cc:Agent>\n",
|
||
" <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
|
||
" </cc:Agent>\n",
|
||
" </dc:creator>\n",
|
||
" </cc:Work>\n",
|
||
" </rdf:RDF>\n",
|
||
" </metadata>\n",
|
||
" <defs>\n",
|
||
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
|
||
" </defs>\n",
|
||
" <g id=\"figure_1\">\n",
|
||
" <g id=\"patch_1\">\n",
|
||
" <path d=\"M 0 180.65625 \n",
|
||
"L 165.99575 180.65625 \n",
|
||
"L 165.99575 0 \n",
|
||
"L 0 0 \n",
|
||
"L 0 180.65625 \n",
|
||
"z\n",
|
||
"\" style=\"fill: none\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"axes_1\">\n",
|
||
" <g id=\"patch_2\">\n",
|
||
" <path d=\"M 34.240625 143.1 \n",
|
||
"L 124.840625 143.1 \n",
|
||
"L 124.840625 7.2 \n",
|
||
"L 34.240625 7.2 \n",
|
||
"z\n",
|
||
"\" style=\"fill: #ffffff\"/>\n",
|
||
" </g>\n",
|
||
" <g clip-path=\"url(#p9f0cc4a14d)\">\n",
|
||
" <image xlink:href=\"data:image/png;base64,\n",
|
||
"iVBORw0KGgoAAAANSUhEUgAAAFsAAACICAYAAACInJcYAAAB80lEQVR4nO3cMS6EYRhF4fsZIVHQT4IQhdkB60CsQaaUoB2FFUwyvURFrxG9hIZGMhoVhRqJybAA9Xuq8yzgFidv9/352/T16TdFpjeXJbsfo5rdJHl7/yzbnilb1j/GBhkbZGyQsUHGBhkbZGyQsUHGBhkbZGyQsUHGBhkbZGyQsUHGBhkbZGyQsUHGBrWfw52yTxmerx5KdnvH+yW7SZLuctm0lw0yNsjYIGODjA0yNsjYIGODjA0yNsjYIGODjA0yNsjYIGODjA0yNsjYIGODjA1qk/vrstf1PN6VzI7Pzkt2k6S7slS27WWDjA0yNsjYIGODjA0yNsjYIGODjA0yNsjYIGODjA0yNsjYIGODjA0yNsjYIGODjA2a7fS2y8anC4slu/NzFyW7SXJ0Oy7b9rJBxgYZG2RskLFBxgYZG2RskLFBxgYZG2RskLFBxgYZG2RskLFBxgYZG2RskLFBxga1yeik7K8M/YNhye5wsFeymySd/mnZtpcNMjbI2CBjg4wNMjbI2CBjg4wNMjbI2CBjg4wNMjbI2CBjg4wNMjbI2CBjg4wNai+bG2Wv66u7WyW7bW29ZDdJ8v1VNu1lg4wNMjbI2CBjg4wNMjbI2CBjg4wNMjbI2CBjg4wNMjbI2CBjg4wNMjbI2CBjg4wN+gMvTCnGfZ/JHQAAAABJRU5ErkJggg==\" id=\"imagedadf3d4502\" transform=\"scale(1 -1)translate(0 -136)\" x=\"34.240625\" y=\"-7.1\" width=\"91\" height=\"136\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"matplotlib.axis_1\">\n",
|
||
" <g id=\"xtick_1\">\n",
|
||
" <g id=\"line2d_1\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"m0d7b43271a\" d=\"M 0 0 \n",
|
||
"L 0 3.5 \n",
|
||
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </defs>\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m0d7b43271a\" x=\"45.565625\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_1\">\n",
|
||
" <!-- 0 -->\n",
|
||
" <g transform=\"translate(42.384375 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
|
||
"Q 1547 4250 1301 3770 \n",
|
||
"Q 1056 3291 1056 2328 \n",
|
||
"Q 1056 1369 1301 889 \n",
|
||
"Q 1547 409 2034 409 \n",
|
||
"Q 2525 409 2770 889 \n",
|
||
"Q 3016 1369 3016 2328 \n",
|
||
"Q 3016 3291 2770 3770 \n",
|
||
"Q 2525 4250 2034 4250 \n",
|
||
"z\n",
|
||
"M 2034 4750 \n",
|
||
"Q 2819 4750 3233 4129 \n",
|
||
"Q 3647 3509 3647 2328 \n",
|
||
"Q 3647 1150 3233 529 \n",
|
||
"Q 2819 -91 2034 -91 \n",
|
||
"Q 1250 -91 836 529 \n",
|
||
"Q 422 1150 422 2328 \n",
|
||
"Q 422 3509 836 4129 \n",
|
||
"Q 1250 4750 2034 4750 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_2\">\n",
|
||
" <g id=\"line2d_2\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m0d7b43271a\" x=\"90.865625\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_2\">\n",
|
||
" <!-- 2 -->\n",
|
||
" <g transform=\"translate(87.684375 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
|
||
"L 3431 531 \n",
|
||
"L 3431 0 \n",
|
||
"L 469 0 \n",
|
||
"L 469 531 \n",
|
||
"Q 828 903 1448 1529 \n",
|
||
"Q 2069 2156 2228 2338 \n",
|
||
"Q 2531 2678 2651 2914 \n",
|
||
"Q 2772 3150 2772 3378 \n",
|
||
"Q 2772 3750 2511 3984 \n",
|
||
"Q 2250 4219 1831 4219 \n",
|
||
"Q 1534 4219 1204 4116 \n",
|
||
"Q 875 4013 500 3803 \n",
|
||
"L 500 4441 \n",
|
||
"Q 881 4594 1212 4672 \n",
|
||
"Q 1544 4750 1819 4750 \n",
|
||
"Q 2544 4750 2975 4387 \n",
|
||
"Q 3406 4025 3406 3419 \n",
|
||
"Q 3406 3131 3298 2873 \n",
|
||
"Q 3191 2616 2906 2266 \n",
|
||
"Q 2828 2175 2409 1742 \n",
|
||
"Q 1991 1309 1228 531 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_3\">\n",
|
||
" <!-- Key positions -->\n",
|
||
" <g transform=\"translate(46.477344 171.376563)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-4b\" d=\"M 628 4666 \n",
|
||
"L 1259 4666 \n",
|
||
"L 1259 2694 \n",
|
||
"L 3353 4666 \n",
|
||
"L 4166 4666 \n",
|
||
"L 1850 2491 \n",
|
||
"L 4331 0 \n",
|
||
"L 3500 0 \n",
|
||
"L 1259 2247 \n",
|
||
"L 1259 0 \n",
|
||
"L 628 0 \n",
|
||
"L 628 4666 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
|
||
"L 3597 1613 \n",
|
||
"L 953 1613 \n",
|
||
"Q 991 1019 1311 708 \n",
|
||
"Q 1631 397 2203 397 \n",
|
||
"Q 2534 397 2845 478 \n",
|
||
"Q 3156 559 3463 722 \n",
|
||
"L 3463 178 \n",
|
||
"Q 3153 47 2828 -22 \n",
|
||
"Q 2503 -91 2169 -91 \n",
|
||
"Q 1331 -91 842 396 \n",
|
||
"Q 353 884 353 1716 \n",
|
||
"Q 353 2575 817 3079 \n",
|
||
"Q 1281 3584 2069 3584 \n",
|
||
"Q 2775 3584 3186 3129 \n",
|
||
"Q 3597 2675 3597 1894 \n",
|
||
"z\n",
|
||
"M 3022 2063 \n",
|
||
"Q 3016 2534 2758 2815 \n",
|
||
"Q 2500 3097 2075 3097 \n",
|
||
"Q 1594 3097 1305 2825 \n",
|
||
"Q 1016 2553 972 2059 \n",
|
||
"L 3022 2063 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-79\" d=\"M 2059 -325 \n",
|
||
"Q 1816 -950 1584 -1140 \n",
|
||
"Q 1353 -1331 966 -1331 \n",
|
||
"L 506 -1331 \n",
|
||
"L 506 -850 \n",
|
||
"L 844 -850 \n",
|
||
"Q 1081 -850 1212 -737 \n",
|
||
"Q 1344 -625 1503 -206 \n",
|
||
"L 1606 56 \n",
|
||
"L 191 3500 \n",
|
||
"L 800 3500 \n",
|
||
"L 1894 763 \n",
|
||
"L 2988 3500 \n",
|
||
"L 3597 3500 \n",
|
||
"L 2059 -325 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-20\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
|
||
"L 1159 -1331 \n",
|
||
"L 581 -1331 \n",
|
||
"L 581 3500 \n",
|
||
"L 1159 3500 \n",
|
||
"L 1159 2969 \n",
|
||
"Q 1341 3281 1617 3432 \n",
|
||
"Q 1894 3584 2278 3584 \n",
|
||
"Q 2916 3584 3314 3078 \n",
|
||
"Q 3713 2572 3713 1747 \n",
|
||
"Q 3713 922 3314 415 \n",
|
||
"Q 2916 -91 2278 -91 \n",
|
||
"Q 1894 -91 1617 61 \n",
|
||
"Q 1341 213 1159 525 \n",
|
||
"z\n",
|
||
"M 3116 1747 \n",
|
||
"Q 3116 2381 2855 2742 \n",
|
||
"Q 2594 3103 2138 3103 \n",
|
||
"Q 1681 3103 1420 2742 \n",
|
||
"Q 1159 2381 1159 1747 \n",
|
||
"Q 1159 1113 1420 752 \n",
|
||
"Q 1681 391 2138 391 \n",
|
||
"Q 2594 391 2855 752 \n",
|
||
"Q 3116 1113 3116 1747 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
|
||
"Q 1497 3097 1228 2736 \n",
|
||
"Q 959 2375 959 1747 \n",
|
||
"Q 959 1119 1226 758 \n",
|
||
"Q 1494 397 1959 397 \n",
|
||
"Q 2419 397 2687 759 \n",
|
||
"Q 2956 1122 2956 1747 \n",
|
||
"Q 2956 2369 2687 2733 \n",
|
||
"Q 2419 3097 1959 3097 \n",
|
||
"z\n",
|
||
"M 1959 3584 \n",
|
||
"Q 2709 3584 3137 3096 \n",
|
||
"Q 3566 2609 3566 1747 \n",
|
||
"Q 3566 888 3137 398 \n",
|
||
"Q 2709 -91 1959 -91 \n",
|
||
"Q 1206 -91 779 398 \n",
|
||
"Q 353 888 353 1747 \n",
|
||
"Q 353 2609 779 3096 \n",
|
||
"Q 1206 3584 1959 3584 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
|
||
"L 2834 2853 \n",
|
||
"Q 2591 2978 2328 3040 \n",
|
||
"Q 2066 3103 1784 3103 \n",
|
||
"Q 1356 3103 1142 2972 \n",
|
||
"Q 928 2841 928 2578 \n",
|
||
"Q 928 2378 1081 2264 \n",
|
||
"Q 1234 2150 1697 2047 \n",
|
||
"L 1894 2003 \n",
|
||
"Q 2506 1872 2764 1633 \n",
|
||
"Q 3022 1394 3022 966 \n",
|
||
"Q 3022 478 2636 193 \n",
|
||
"Q 2250 -91 1575 -91 \n",
|
||
"Q 1294 -91 989 -36 \n",
|
||
"Q 684 19 347 128 \n",
|
||
"L 347 722 \n",
|
||
"Q 666 556 975 473 \n",
|
||
"Q 1284 391 1588 391 \n",
|
||
"Q 1994 391 2212 530 \n",
|
||
"Q 2431 669 2431 922 \n",
|
||
"Q 2431 1156 2273 1281 \n",
|
||
"Q 2116 1406 1581 1522 \n",
|
||
"L 1381 1569 \n",
|
||
"Q 847 1681 609 1914 \n",
|
||
"Q 372 2147 372 2553 \n",
|
||
"Q 372 3047 722 3315 \n",
|
||
"Q 1072 3584 1716 3584 \n",
|
||
"Q 2034 3584 2315 3537 \n",
|
||
"Q 2597 3491 2834 3397 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
|
||
"L 1178 3500 \n",
|
||
"L 1178 0 \n",
|
||
"L 603 0 \n",
|
||
"L 603 3500 \n",
|
||
"z\n",
|
||
"M 603 4863 \n",
|
||
"L 1178 4863 \n",
|
||
"L 1178 4134 \n",
|
||
"L 603 4134 \n",
|
||
"L 603 4863 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-74\" d=\"M 1172 4494 \n",
|
||
"L 1172 3500 \n",
|
||
"L 2356 3500 \n",
|
||
"L 2356 3053 \n",
|
||
"L 1172 3053 \n",
|
||
"L 1172 1153 \n",
|
||
"Q 1172 725 1289 603 \n",
|
||
"Q 1406 481 1766 481 \n",
|
||
"L 2356 481 \n",
|
||
"L 2356 0 \n",
|
||
"L 1766 0 \n",
|
||
"Q 1100 0 847 248 \n",
|
||
"Q 594 497 594 1153 \n",
|
||
"L 594 3053 \n",
|
||
"L 172 3053 \n",
|
||
"L 172 3500 \n",
|
||
"L 594 3500 \n",
|
||
"L 594 4494 \n",
|
||
"L 1172 4494 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-6e\" d=\"M 3513 2113 \n",
|
||
"L 3513 0 \n",
|
||
"L 2938 0 \n",
|
||
"L 2938 2094 \n",
|
||
"Q 2938 2591 2744 2837 \n",
|
||
"Q 2550 3084 2163 3084 \n",
|
||
"Q 1697 3084 1428 2787 \n",
|
||
"Q 1159 2491 1159 1978 \n",
|
||
"L 1159 0 \n",
|
||
"L 581 0 \n",
|
||
"L 581 3500 \n",
|
||
"L 1159 3500 \n",
|
||
"L 1159 2956 \n",
|
||
"Q 1366 3272 1645 3428 \n",
|
||
"Q 1925 3584 2291 3584 \n",
|
||
"Q 2894 3584 3203 3211 \n",
|
||
"Q 3513 2838 3513 2113 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-4b\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-65\" x=\"60.576172\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-79\" x=\"122.099609\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-20\" x=\"181.279297\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-70\" x=\"213.066406\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6f\" x=\"276.542969\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"337.724609\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-69\" x=\"389.824219\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-74\" x=\"417.607422\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-69\" x=\"456.816406\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6f\" x=\"484.599609\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6e\" x=\"545.78125\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"609.160156\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"matplotlib.axis_2\">\n",
|
||
" <g id=\"ytick_1\">\n",
|
||
" <g id=\"line2d_3\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"m70e8bacb97\" d=\"M 0 0 \n",
|
||
"L -3.5 0 \n",
|
||
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </defs>\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m70e8bacb97\" x=\"34.240625\" y=\"18.525\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_4\">\n",
|
||
" <!-- 0 -->\n",
|
||
" <g transform=\"translate(20.878125 22.324219)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_2\">\n",
|
||
" <g id=\"line2d_4\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m70e8bacb97\" x=\"34.240625\" y=\"41.175\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_5\">\n",
|
||
" <!-- 1 -->\n",
|
||
" <g transform=\"translate(20.878125 44.974219)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
|
||
"L 1825 531 \n",
|
||
"L 1825 4091 \n",
|
||
"L 703 3866 \n",
|
||
"L 703 4441 \n",
|
||
"L 1819 4666 \n",
|
||
"L 2450 4666 \n",
|
||
"L 2450 531 \n",
|
||
"L 3481 531 \n",
|
||
"L 3481 0 \n",
|
||
"L 794 0 \n",
|
||
"L 794 531 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_3\">\n",
|
||
" <g id=\"line2d_5\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m70e8bacb97\" x=\"34.240625\" y=\"63.825\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_6\">\n",
|
||
" <!-- 2 -->\n",
|
||
" <g transform=\"translate(20.878125 67.624219)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_4\">\n",
|
||
" <g id=\"line2d_6\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m70e8bacb97\" x=\"34.240625\" y=\"86.475\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_7\">\n",
|
||
" <!-- 3 -->\n",
|
||
" <g transform=\"translate(20.878125 90.274219)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
|
||
"Q 3050 2419 3304 2112 \n",
|
||
"Q 3559 1806 3559 1356 \n",
|
||
"Q 3559 666 3084 287 \n",
|
||
"Q 2609 -91 1734 -91 \n",
|
||
"Q 1441 -91 1130 -33 \n",
|
||
"Q 819 25 488 141 \n",
|
||
"L 488 750 \n",
|
||
"Q 750 597 1062 519 \n",
|
||
"Q 1375 441 1716 441 \n",
|
||
"Q 2309 441 2620 675 \n",
|
||
"Q 2931 909 2931 1356 \n",
|
||
"Q 2931 1769 2642 2001 \n",
|
||
"Q 2353 2234 1838 2234 \n",
|
||
"L 1294 2234 \n",
|
||
"L 1294 2753 \n",
|
||
"L 1863 2753 \n",
|
||
"Q 2328 2753 2575 2939 \n",
|
||
"Q 2822 3125 2822 3475 \n",
|
||
"Q 2822 3834 2567 4026 \n",
|
||
"Q 2313 4219 1838 4219 \n",
|
||
"Q 1578 4219 1281 4162 \n",
|
||
"Q 984 4106 628 3988 \n",
|
||
"L 628 4550 \n",
|
||
"Q 988 4650 1302 4700 \n",
|
||
"Q 1616 4750 1894 4750 \n",
|
||
"Q 2613 4750 3031 4423 \n",
|
||
"Q 3450 4097 3450 3541 \n",
|
||
"Q 3450 3153 3228 2886 \n",
|
||
"Q 3006 2619 2597 2516 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-33\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_5\">\n",
|
||
" <g id=\"line2d_7\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m70e8bacb97\" x=\"34.240625\" y=\"109.125\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_8\">\n",
|
||
" <!-- 4 -->\n",
|
||
" <g transform=\"translate(20.878125 112.924219)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
|
||
"L 825 1625 \n",
|
||
"L 2419 1625 \n",
|
||
"L 2419 4116 \n",
|
||
"z\n",
|
||
"M 2253 4666 \n",
|
||
"L 3047 4666 \n",
|
||
"L 3047 1625 \n",
|
||
"L 3713 1625 \n",
|
||
"L 3713 1100 \n",
|
||
"L 3047 1100 \n",
|
||
"L 3047 0 \n",
|
||
"L 2419 0 \n",
|
||
"L 2419 1100 \n",
|
||
"L 313 1100 \n",
|
||
"L 313 1709 \n",
|
||
"L 2253 4666 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_6\">\n",
|
||
" <g id=\"line2d_8\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m70e8bacb97\" x=\"34.240625\" y=\"131.775\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_9\">\n",
|
||
" <!-- 5 -->\n",
|
||
" <g transform=\"translate(20.878125 135.574219)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
|
||
"L 3169 4666 \n",
|
||
"L 3169 4134 \n",
|
||
"L 1269 4134 \n",
|
||
"L 1269 2991 \n",
|
||
"Q 1406 3038 1543 3061 \n",
|
||
"Q 1681 3084 1819 3084 \n",
|
||
"Q 2600 3084 3056 2656 \n",
|
||
"Q 3513 2228 3513 1497 \n",
|
||
"Q 3513 744 3044 326 \n",
|
||
"Q 2575 -91 1722 -91 \n",
|
||
"Q 1428 -91 1123 -41 \n",
|
||
"Q 819 9 494 109 \n",
|
||
"L 494 744 \n",
|
||
"Q 775 591 1075 516 \n",
|
||
"Q 1375 441 1709 441 \n",
|
||
"Q 2250 441 2565 725 \n",
|
||
"Q 2881 1009 2881 1497 \n",
|
||
"Q 2881 1984 2565 2268 \n",
|
||
"Q 2250 2553 1709 2553 \n",
|
||
"Q 1456 2553 1204 2497 \n",
|
||
"Q 953 2441 691 2322 \n",
|
||
"L 691 4666 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_10\">\n",
|
||
" <!-- Query positions -->\n",
|
||
" <g transform=\"translate(14.798438 114.344531)rotate(-90)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-51\" d=\"M 2522 4238 \n",
|
||
"Q 1834 4238 1429 3725 \n",
|
||
"Q 1025 3213 1025 2328 \n",
|
||
"Q 1025 1447 1429 934 \n",
|
||
"Q 1834 422 2522 422 \n",
|
||
"Q 3209 422 3611 934 \n",
|
||
"Q 4013 1447 4013 2328 \n",
|
||
"Q 4013 3213 3611 3725 \n",
|
||
"Q 3209 4238 2522 4238 \n",
|
||
"z\n",
|
||
"M 3406 84 \n",
|
||
"L 4238 -825 \n",
|
||
"L 3475 -825 \n",
|
||
"L 2784 -78 \n",
|
||
"Q 2681 -84 2626 -87 \n",
|
||
"Q 2572 -91 2522 -91 \n",
|
||
"Q 1538 -91 948 567 \n",
|
||
"Q 359 1225 359 2328 \n",
|
||
"Q 359 3434 948 4092 \n",
|
||
"Q 1538 4750 2522 4750 \n",
|
||
"Q 3503 4750 4090 4092 \n",
|
||
"Q 4678 3434 4678 2328 \n",
|
||
"Q 4678 1516 4351 937 \n",
|
||
"Q 4025 359 3406 84 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-75\" d=\"M 544 1381 \n",
|
||
"L 544 3500 \n",
|
||
"L 1119 3500 \n",
|
||
"L 1119 1403 \n",
|
||
"Q 1119 906 1312 657 \n",
|
||
"Q 1506 409 1894 409 \n",
|
||
"Q 2359 409 2629 706 \n",
|
||
"Q 2900 1003 2900 1516 \n",
|
||
"L 2900 3500 \n",
|
||
"L 3475 3500 \n",
|
||
"L 3475 0 \n",
|
||
"L 2900 0 \n",
|
||
"L 2900 538 \n",
|
||
"Q 2691 219 2414 64 \n",
|
||
"Q 2138 -91 1772 -91 \n",
|
||
"Q 1169 -91 856 284 \n",
|
||
"Q 544 659 544 1381 \n",
|
||
"z\n",
|
||
"M 1991 3584 \n",
|
||
"L 1991 3584 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
|
||
"Q 2534 3019 2420 3045 \n",
|
||
"Q 2306 3072 2169 3072 \n",
|
||
"Q 1681 3072 1420 2755 \n",
|
||
"Q 1159 2438 1159 1844 \n",
|
||
"L 1159 0 \n",
|
||
"L 581 0 \n",
|
||
"L 581 3500 \n",
|
||
"L 1159 3500 \n",
|
||
"L 1159 2956 \n",
|
||
"Q 1341 3275 1631 3429 \n",
|
||
"Q 1922 3584 2338 3584 \n",
|
||
"Q 2397 3584 2469 3576 \n",
|
||
"Q 2541 3569 2628 3553 \n",
|
||
"L 2631 2963 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-51\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-75\" x=\"78.710938\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-65\" x=\"142.089844\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-72\" x=\"203.613281\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-79\" x=\"244.726562\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-20\" x=\"303.90625\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-70\" x=\"335.693359\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6f\" x=\"399.169922\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"460.351562\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-69\" x=\"512.451172\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-74\" x=\"540.234375\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-69\" x=\"579.443359\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6f\" x=\"607.226562\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6e\" x=\"668.408203\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"731.787109\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_3\">\n",
|
||
" <path d=\"M 34.240625 143.1 \n",
|
||
"L 34.240625 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_4\">\n",
|
||
" <path d=\"M 124.840625 143.1 \n",
|
||
"L 124.840625 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_5\">\n",
|
||
" <path d=\"M 34.240625 143.1 \n",
|
||
"L 124.840625 143.1 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_6\">\n",
|
||
" <path d=\"M 34.240625 7.2 \n",
|
||
"L 124.840625 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"axes_2\">\n",
|
||
" <g id=\"patch_7\">\n",
|
||
" <path d=\"M 131.815625 115.92 \n",
|
||
"L 135.892625 115.92 \n",
|
||
"L 135.892625 34.38 \n",
|
||
"L 131.815625 34.38 \n",
|
||
"z\n",
|
||
"\" style=\"fill: #ffffff\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_8\">\n",
|
||
" <path clip-path=\"url(#p391a869b7e)\" style=\"fill: #ffffff; stroke: #ffffff; stroke-width: 0.01; stroke-linejoin: miter\"/>\n",
|
||
" </g>\n",
|
||
" <image xlink:href=\"data:image/png;base64,\n",
|
||
"iVBORw0KGgoAAAANSUhEUgAAAAQAAABRCAYAAAD1sgc6AAAAmUlEQVR4nJWRuwrDQBADL5D//1RXLsLtK20yMsh2KaSZPfya85j1871X13JBm8ZUWsZj6GpCE8HkZiMQFAMyLhoI9HHhtXK6vcM3VPucoYfZIBx01MJf6S2WMZvaQKOTEwk4mWzTaNFaS+UgKFiy2OAkeJgwZJLSoCXlsKG22VhkcCIBrJcNai00CJXJ5ulbG2B8LOOO5R/6BZQHybv/JYPcAAAAAElFTkSuQmCC\" id=\"imagec03581dfd3\" transform=\"scale(1 -1)translate(0 -81)\" x=\"132\" y=\"-34\" width=\"4\" height=\"81\"/>\n",
|
||
" <g id=\"matplotlib.axis_3\">\n",
|
||
" <g id=\"ytick_7\">\n",
|
||
" <g id=\"line2d_9\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"m4b81b12274\" d=\"M 0 0 \n",
|
||
"L 3.5 0 \n",
|
||
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </defs>\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m4b81b12274\" x=\"135.892625\" y=\"86.683959\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_11\">\n",
|
||
" <!-- 0.2 -->\n",
|
||
" <g transform=\"translate(142.892625 90.483177)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
|
||
"L 1344 794 \n",
|
||
"L 1344 0 \n",
|
||
"L 684 0 \n",
|
||
"L 684 794 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\" x=\"95.410156\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_8\">\n",
|
||
" <g id=\"line2d_10\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m4b81b12274\" x=\"135.892625\" y=\"53.745597\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_12\">\n",
|
||
" <!-- 0.4 -->\n",
|
||
" <g transform=\"translate(142.892625 57.544816)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-34\" x=\"95.410156\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"LineCollection_1\"/>\n",
|
||
" <g id=\"patch_9\">\n",
|
||
" <path d=\"M 131.815625 115.92 \n",
|
||
"L 133.854125 115.92 \n",
|
||
"L 135.892625 115.92 \n",
|
||
"L 135.892625 34.38 \n",
|
||
"L 133.854125 34.38 \n",
|
||
"L 131.815625 34.38 \n",
|
||
"L 131.815625 115.92 \n",
|
||
"z\n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <defs>\n",
|
||
" <clipPath id=\"p9f0cc4a14d\">\n",
|
||
" <rect x=\"34.240625\" y=\"7.2\" width=\"90.6\" height=\"135.9\"/>\n",
|
||
" </clipPath>\n",
|
||
" <clipPath id=\"p391a869b7e\">\n",
|
||
" <rect x=\"131.815625\" y=\"34.38\" width=\"4.077\" height=\"81.54\"/>\n",
|
||
" </clipPath>\n",
|
||
" </defs>\n",
|
||
"</svg>\n"
|
||
],
|
||
"text/plain": [
|
||
"<Figure size 180x180 with 2 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 加上一个包含序列结束词元\n",
|
||
"d2l.show_heatmaps(\n",
|
||
" attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),\n",
|
||
" xlabel='Key positions', ylabel='Query positions')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "30eb6f85",
|
||
"metadata": {
|
||
"origin_pos": 27
|
||
},
|
||
"source": [
|
||
"## 小结\n",
|
||
"\n",
|
||
"* 在预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。\n",
|
||
"* 在循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步的解码器隐状态视为查询,在所有时间步的编码器隐状态同时视为键和值。\n",
|
||
"\n",
|
||
"## 练习\n",
|
||
"\n",
|
||
"1. 在实验中用LSTM替换GRU。\n",
|
||
"1. 修改实验以将加性注意力打分函数替换为缩放点积注意力,它如何影响训练效率?\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "4d7906bc",
|
||
"metadata": {
|
||
"origin_pos": 29,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"source": [
|
||
"[Discussions](https://discuss.d2l.ai/t/5754)\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"language_info": {
|
||
"name": "python"
|
||
},
|
||
"required_libs": []
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
} |