1645 lines
61 KiB
Plaintext
1645 lines
61 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "74f92c25",
|
||
"metadata": {
|
||
"origin_pos": 0
|
||
},
|
||
"source": [
|
||
"# 序列到序列学习(seq2seq)\n",
|
||
":label:`sec_seq2seq`\n",
|
||
"\n",
|
||
"正如我们在 :numref:`sec_machine_translation`中看到的,\n",
|
||
"机器翻译中的输入序列和输出序列都是长度可变的。\n",
|
||
"为了解决这类问题,我们在 :numref:`sec_encoder-decoder`中\n",
|
||
"设计了一个通用的”编码器-解码器“架构。\n",
|
||
"本节,我们将使用两个循环神经网络的编码器和解码器,\n",
|
||
"并将其应用于*序列到序列*(sequence to sequence,seq2seq)类的学习任务\n",
|
||
" :cite:`Sutskever.Vinyals.Le.2014,Cho.Van-Merrienboer.Gulcehre.ea.2014`。\n",
|
||
"\n",
|
||
"遵循编码器-解码器架构的设计原则,\n",
|
||
"循环神经网络编码器使用长度可变的序列作为输入,\n",
|
||
"将其转换为固定形状的隐状态。\n",
|
||
"换言之,输入序列的信息被*编码*到循环神经网络编码器的隐状态中。\n",
|
||
"为了连续生成输出序列的词元,\n",
|
||
"独立的循环神经网络解码器是基于输入序列的编码信息\n",
|
||
"和输出序列已经看见的或者生成的词元来预测下一个词元。\n",
|
||
" :numref:`fig_seq2seq`演示了\n",
|
||
"如何在机器翻译中使用两个循环神经网络进行序列到序列学习。\n",
|
||
"\n",
|
||
"\n",
|
||
":label:`fig_seq2seq`\n",
|
||
"\n",
|
||
"在 :numref:`fig_seq2seq`中,\n",
|
||
"特定的“<eos>”表示序列结束词元。\n",
|
||
"一旦输出序列生成此词元,模型就会停止预测。\n",
|
||
"在循环神经网络解码器的初始化时间步,有两个特定的设计决定:\n",
|
||
"首先,特定的“<bos>”表示序列开始词元,它是解码器的输入序列的第一个词元。\n",
|
||
"其次,使用循环神经网络编码器最终的隐状态来初始化解码器的隐状态。\n",
|
||
"例如,在 :cite:`Sutskever.Vinyals.Le.2014`的设计中,\n",
|
||
"正是基于这种设计将输入序列的编码信息送入到解码器中来生成输出序列的。\n",
|
||
"在其他一些设计中 :cite:`Cho.Van-Merrienboer.Gulcehre.ea.2014`,\n",
|
||
"如 :numref:`fig_seq2seq`所示,\n",
|
||
"编码器最终的隐状态在每一个时间步都作为解码器的输入序列的一部分。\n",
|
||
"类似于 :numref:`sec_language_model`中语言模型的训练,\n",
|
||
"可以允许标签成为原始的输出序列,\n",
|
||
"从源序列词元“<bos>”“Ils”“regardent”“.”\n",
|
||
"到新序列词元\n",
|
||
"“Ils”“regardent”“.”“<eos>”来移动预测的位置。\n",
|
||
"\n",
|
||
"下面,我们动手构建 :numref:`fig_seq2seq`的设计,\n",
|
||
"并将基于 :numref:`sec_machine_translation`中\n",
|
||
"介绍的“英-法”数据集来训练这个机器翻译模型。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "bc9aa4b5",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:31.967521Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:31.966534Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:33.959337Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:33.958486Z"
|
||
},
|
||
"origin_pos": 2,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import collections\n",
|
||
"import math\n",
|
||
"import torch\n",
|
||
"from torch import nn\n",
|
||
"from d2l import torch as d2l"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "cb6cae12",
|
||
"metadata": {
|
||
"origin_pos": 5
|
||
},
|
||
"source": [
|
||
"## 编码器\n",
|
||
"\n",
|
||
"从技术上讲,编码器将长度可变的输入序列转换成\n",
|
||
"形状固定的上下文变量$\\mathbf{c}$,\n",
|
||
"并且将输入序列的信息在该上下文变量中进行编码。\n",
|
||
"如 :numref:`fig_seq2seq`所示,可以使用循环神经网络来设计编码器。\n",
|
||
"\n",
|
||
"考虑由一个序列组成的样本(批量大小是$1$)。\n",
|
||
"假设输入序列是$x_1, \\ldots, x_T$,\n",
|
||
"其中$x_t$是输入文本序列中的第$t$个词元。\n",
|
||
"在时间步$t$,循环神经网络将词元$x_t$的输入特征向量\n",
|
||
"$\\mathbf{x}_t$和$\\mathbf{h} _{t-1}$(即上一时间步的隐状态)\n",
|
||
"转换为$\\mathbf{h}_t$(即当前步的隐状态)。\n",
|
||
"使用一个函数$f$来描述循环神经网络的循环层所做的变换:\n",
|
||
"\n",
|
||
"$$\\mathbf{h}_t = f(\\mathbf{x}_t, \\mathbf{h}_{t-1}). $$\n",
|
||
"\n",
|
||
"总之,编码器通过选定的函数$q$,\n",
|
||
"将所有时间步的隐状态转换为上下文变量:\n",
|
||
"\n",
|
||
"$$\\mathbf{c} = q(\\mathbf{h}_1, \\ldots, \\mathbf{h}_T).$$\n",
|
||
"\n",
|
||
"比如,当选择$q(\\mathbf{h}_1, \\ldots, \\mathbf{h}_T) = \\mathbf{h}_T$时\n",
|
||
"(就像 :numref:`fig_seq2seq`中一样),\n",
|
||
"上下文变量仅仅是输入序列在最后时间步的隐状态$\\mathbf{h}_T$。\n",
|
||
"\n",
|
||
"到目前为止,我们使用的是一个单向循环神经网络来设计编码器,\n",
|
||
"其中隐状态只依赖于输入子序列,\n",
|
||
"这个子序列是由输入序列的开始位置到隐状态所在的时间步的位置\n",
|
||
"(包括隐状态所在的时间步)组成。\n",
|
||
"我们也可以使用双向循环神经网络构造编码器,\n",
|
||
"其中隐状态依赖于两个输入子序列,\n",
|
||
"两个子序列是由隐状态所在的时间步的位置之前的序列和之后的序列\n",
|
||
"(包括隐状态所在的时间步),\n",
|
||
"因此隐状态对整个序列的信息都进行了编码。\n",
|
||
"\n",
|
||
"现在,让我们[**实现循环神经网络编码器**]。\n",
|
||
"注意,我们使用了*嵌入层*(embedding layer)\n",
|
||
"来获得输入序列中每个词元的特征向量。\n",
|
||
"嵌入层的权重是一个矩阵,\n",
|
||
"其行数等于输入词表的大小(`vocab_size`),\n",
|
||
"其列数等于特征向量的维度(`embed_size`)。\n",
|
||
"对于任意输入词元的索引$i$,\n",
|
||
"嵌入层获取权重矩阵的第$i$行(从$0$开始)以返回其特征向量。\n",
|
||
"另外,本文选择了一个多层门控循环单元来实现编码器。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "3dbfb3ed",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:33.963601Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:33.962917Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:33.969272Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:33.968489Z"
|
||
},
|
||
"origin_pos": 7,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@save\n",
|
||
"class Seq2SeqEncoder(d2l.Encoder):\n",
|
||
" \"\"\"用于序列到序列学习的循环神经网络编码器\"\"\"\n",
|
||
" def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n",
|
||
" dropout=0, **kwargs):\n",
|
||
" super(Seq2SeqEncoder, self).__init__(**kwargs)\n",
|
||
" # 嵌入层\n",
|
||
" self.embedding = nn.Embedding(vocab_size, embed_size)\n",
|
||
" self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,\n",
|
||
" dropout=dropout)\n",
|
||
"\n",
|
||
" def forward(self, X, *args):\n",
|
||
" # 输出'X'的形状:(batch_size,num_steps,embed_size)\n",
|
||
" X = self.embedding(X)\n",
|
||
" # 在循环神经网络模型中,第一个轴对应于时间步\n",
|
||
" X = X.permute(1, 0, 2)\n",
|
||
" # 如果未提及状态,则默认为0\n",
|
||
" output, state = self.rnn(X)\n",
|
||
" # output的形状:(num_steps,batch_size,num_hiddens)\n",
|
||
" # state的形状:(num_layers,batch_size,num_hiddens)\n",
|
||
" return output, state"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9ba9f2c0",
|
||
"metadata": {
|
||
"origin_pos": 10
|
||
},
|
||
"source": [
|
||
"循环层返回变量的说明可以参考 :numref:`sec_rnn-concise`。\n",
|
||
"\n",
|
||
"下面,我们实例化[**上述编码器的实现**]:\n",
|
||
"我们使用一个两层门控循环单元编码器,其隐藏单元数为$16$。\n",
|
||
"给定一小批量的输入序列`X`(批量大小为$4$,时间步为$7$)。\n",
|
||
"在完成所有时间步后,\n",
|
||
"最后一层的隐状态的输出是一个张量(`output`由编码器的循环层返回),\n",
|
||
"其形状为(时间步数,批量大小,隐藏单元数)。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "1780ca82",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:33.972667Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:33.972142Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:34.003637Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:34.002907Z"
|
||
},
|
||
"origin_pos": 12,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"torch.Size([7, 4, 16])"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n",
|
||
" num_layers=2)\n",
|
||
"encoder.eval()\n",
|
||
"X = torch.zeros((4, 7), dtype=torch.long)\n",
|
||
"output, state = encoder(X)\n",
|
||
"output.shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b6eea9d1",
|
||
"metadata": {
|
||
"origin_pos": 15
|
||
},
|
||
"source": [
|
||
"由于这里使用的是门控循环单元,\n",
|
||
"所以在最后一个时间步的多层隐状态的形状是\n",
|
||
"(隐藏层的数量,批量大小,隐藏单元的数量)。\n",
|
||
"如果使用长短期记忆网络,`state`中还将包含记忆单元信息。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "32a2c1d8",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:34.007123Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:34.006595Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:34.011456Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:34.010716Z"
|
||
},
|
||
"origin_pos": 17,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"torch.Size([2, 4, 16])"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"state.shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "131eaa96",
|
||
"metadata": {
|
||
"origin_pos": 19
|
||
},
|
||
"source": [
|
||
"## [**解码器**]\n",
|
||
":label:`sec_seq2seq_decoder`\n",
|
||
"\n",
|
||
"正如上文提到的,编码器输出的上下文变量$\\mathbf{c}$\n",
|
||
"对整个输入序列$x_1, \\ldots, x_T$进行编码。\n",
|
||
"来自训练数据集的输出序列$y_1, y_2, \\ldots, y_{T'}$,\n",
|
||
"对于每个时间步$t'$(与输入序列或编码器的时间步$t$不同),\n",
|
||
"解码器输出$y_{t'}$的概率取决于先前的输出子序列\n",
|
||
"$y_1, \\ldots, y_{t'-1}$和上下文变量$\\mathbf{c}$,\n",
|
||
"即$P(y_{t'} \\mid y_1, \\ldots, y_{t'-1}, \\mathbf{c})$。\n",
|
||
"\n",
|
||
"为了在序列上模型化这种条件概率,\n",
|
||
"我们可以使用另一个循环神经网络作为解码器。\n",
|
||
"在输出序列上的任意时间步$t^\\prime$,\n",
|
||
"循环神经网络将来自上一时间步的输出$y_{t^\\prime-1}$\n",
|
||
"和上下文变量$\\mathbf{c}$作为其输入,\n",
|
||
"然后在当前时间步将它们和上一隐状态\n",
|
||
"$\\mathbf{s}_{t^\\prime-1}$转换为\n",
|
||
"隐状态$\\mathbf{s}_{t^\\prime}$。\n",
|
||
"因此,可以使用函数$g$来表示解码器的隐藏层的变换:\n",
|
||
"\n",
|
||
"$$\\mathbf{s}_{t^\\prime} = g(y_{t^\\prime-1}, \\mathbf{c}, \\mathbf{s}_{t^\\prime-1}).$$\n",
|
||
":eqlabel:`eq_seq2seq_s_t`\n",
|
||
"\n",
|
||
"在获得解码器的隐状态之后,\n",
|
||
"我们可以使用输出层和softmax操作\n",
|
||
"来计算在时间步$t^\\prime$时输出$y_{t^\\prime}$的条件概率分布\n",
|
||
"$P(y_{t^\\prime} \\mid y_1, \\ldots, y_{t^\\prime-1}, \\mathbf{c})$。\n",
|
||
"\n",
|
||
"根据 :numref:`fig_seq2seq`,当实现解码器时,\n",
|
||
"我们直接使用编码器最后一个时间步的隐状态来初始化解码器的隐状态。\n",
|
||
"这就要求使用循环神经网络实现的编码器和解码器具有相同数量的层和隐藏单元。\n",
|
||
"为了进一步包含经过编码的输入序列的信息,\n",
|
||
"上下文变量在所有的时间步与解码器的输入进行拼接(concatenate)。\n",
|
||
"为了预测输出词元的概率分布,\n",
|
||
"在循环神经网络解码器的最后一层使用全连接层来变换隐状态。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "09143bb3",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:34.014841Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:34.014327Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:34.021372Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:34.020591Z"
|
||
},
|
||
"origin_pos": 21,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class Seq2SeqDecoder(d2l.Decoder):\n",
|
||
" \"\"\"用于序列到序列学习的循环神经网络解码器\"\"\"\n",
|
||
" def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n",
|
||
" dropout=0, **kwargs):\n",
|
||
" super(Seq2SeqDecoder, self).__init__(**kwargs)\n",
|
||
" self.embedding = nn.Embedding(vocab_size, embed_size)\n",
|
||
" self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,\n",
|
||
" dropout=dropout)\n",
|
||
" self.dense = nn.Linear(num_hiddens, vocab_size)\n",
|
||
"\n",
|
||
" def init_state(self, enc_outputs, *args):\n",
|
||
" return enc_outputs[1]\n",
|
||
"\n",
|
||
" def forward(self, X, state):\n",
|
||
" # 输出'X'的形状:(batch_size,num_steps,embed_size)\n",
|
||
" X = self.embedding(X).permute(1, 0, 2)\n",
|
||
" # 广播context,使其具有与X相同的num_steps\n",
|
||
" context = state[-1].repeat(X.shape[0], 1, 1)\n",
|
||
" X_and_context = torch.cat((X, context), 2)\n",
|
||
" output, state = self.rnn(X_and_context, state)\n",
|
||
" output = self.dense(output).permute(1, 0, 2)\n",
|
||
" # output的形状:(batch_size,num_steps,vocab_size)\n",
|
||
" # state的形状:(num_layers,batch_size,num_hiddens)\n",
|
||
" return output, state"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "7eb99248",
|
||
"metadata": {
|
||
"origin_pos": 24
|
||
},
|
||
"source": [
|
||
"下面,我们用与前面提到的编码器中相同的超参数来[**实例化解码器**]。\n",
|
||
"如我们所见,解码器的输出形状变为(批量大小,时间步数,词表大小),\n",
|
||
"其中张量的最后一个维度存储预测的词元分布。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "ad17a24d",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:34.024844Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:34.024212Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:34.034277Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:34.033517Z"
|
||
},
|
||
"origin_pos": 26,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n",
|
||
" num_layers=2)\n",
|
||
"decoder.eval()\n",
|
||
"state = decoder.init_state(encoder(X))\n",
|
||
"output, state = decoder(X, state)\n",
|
||
"output.shape, state.shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "eb13a51a",
|
||
"metadata": {
|
||
"origin_pos": 29
|
||
},
|
||
"source": [
|
||
"总之,上述循环神经网络“编码器-解码器”模型中的各层如\n",
|
||
" :numref:`fig_seq2seq_details`所示。\n",
|
||
"\n",
|
||
"\n",
|
||
":label:`fig_seq2seq_details`\n",
|
||
"\n",
|
||
"## 损失函数\n",
|
||
"\n",
|
||
"在每个时间步,解码器预测了输出词元的概率分布。\n",
|
||
"类似于语言模型,可以使用softmax来获得分布,\n",
|
||
"并通过计算交叉熵损失函数来进行优化。\n",
|
||
"回想一下 :numref:`sec_machine_translation`中,\n",
|
||
"特定的填充词元被添加到序列的末尾,\n",
|
||
"因此不同长度的序列可以以相同形状的小批量加载。\n",
|
||
"但是,我们应该将填充词元的预测排除在损失函数的计算之外。\n",
|
||
"\n",
|
||
"为此,我们可以使用下面的`sequence_mask`函数\n",
|
||
"[**通过零值化屏蔽不相关的项**],\n",
|
||
"以便后面任何不相关预测的计算都是与零的乘积,结果都等于零。\n",
|
||
"例如,如果两个序列的有效长度(不包括填充词元)分别为$1$和$2$,\n",
|
||
"则第一个序列的第一项和第二个序列的前两项之后的剩余项将被清除为零。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "57c5a5f4",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:34.037911Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:34.037256Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:34.044866Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:34.044120Z"
|
||
},
|
||
"origin_pos": 31,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[1, 0, 0],\n",
|
||
" [4, 5, 0]])"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"#@save\n",
|
||
"def sequence_mask(X, valid_len, value=0):\n",
|
||
" \"\"\"在序列中屏蔽不相关的项\"\"\"\n",
|
||
" maxlen = X.size(1)\n",
|
||
" mask = torch.arange((maxlen), dtype=torch.float32,\n",
|
||
" device=X.device)[None, :] < valid_len[:, None]\n",
|
||
" X[~mask] = value\n",
|
||
" return X\n",
|
||
"\n",
|
||
"X = torch.tensor([[1, 2, 3], [4, 5, 6]])\n",
|
||
"sequence_mask(X, torch.tensor([1, 2]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f6b25824",
|
||
"metadata": {
|
||
"origin_pos": 34
|
||
},
|
||
"source": [
|
||
"(**我们还可以使用此函数屏蔽最后几个轴上的所有项。**)如果愿意,也可以使用指定的非零值来替换这些项。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "fbb003c2",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:34.048373Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:34.047745Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:34.054283Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:34.053539Z"
|
||
},
|
||
"origin_pos": 36,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[[ 1., 1., 1., 1.],\n",
|
||
" [-1., -1., -1., -1.],\n",
|
||
" [-1., -1., -1., -1.]],\n",
|
||
"\n",
|
||
" [[ 1., 1., 1., 1.],\n",
|
||
" [ 1., 1., 1., 1.],\n",
|
||
" [-1., -1., -1., -1.]]])"
|
||
]
|
||
},
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"X = torch.ones(2, 3, 4)\n",
|
||
"sequence_mask(X, torch.tensor([1, 2]), value=-1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ecf893af",
|
||
"metadata": {
|
||
"origin_pos": 39
|
||
},
|
||
"source": [
|
||
"现在,我们可以[**通过扩展softmax交叉熵损失函数来遮蔽不相关的预测**]。\n",
|
||
"最初,所有预测词元的掩码都设置为1。\n",
|
||
"一旦给定了有效长度,与填充词元对应的掩码将被设置为0。\n",
|
||
"最后,将所有词元的损失乘以掩码,以过滤掉损失中填充词元产生的不相关预测。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "0da33ae4",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:34.057946Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:34.057267Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:34.062428Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:34.061664Z"
|
||
},
|
||
"origin_pos": 41,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@save\n",
|
||
"class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):\n",
|
||
" \"\"\"带遮蔽的softmax交叉熵损失函数\"\"\"\n",
|
||
" # pred的形状:(batch_size,num_steps,vocab_size)\n",
|
||
" # label的形状:(batch_size,num_steps)\n",
|
||
" # valid_len的形状:(batch_size,)\n",
|
||
" def forward(self, pred, label, valid_len):\n",
|
||
" weights = torch.ones_like(label)\n",
|
||
" weights = sequence_mask(weights, valid_len)\n",
|
||
" self.reduction='none'\n",
|
||
" unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(\n",
|
||
" pred.permute(0, 2, 1), label)\n",
|
||
" weighted_loss = (unweighted_loss * weights).mean(dim=1)\n",
|
||
" return weighted_loss"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f6b7600c",
|
||
"metadata": {
|
||
"origin_pos": 44
|
||
},
|
||
"source": [
|
||
"我们可以创建三个相同的序列来进行[**代码健全性检查**],\n",
|
||
"然后分别指定这些序列的有效长度为$4$、$2$和$0$。\n",
|
||
"结果就是,第一个序列的损失应为第二个序列的两倍,而第三个序列的损失应为零。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"id": "65239ee5",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:34.065956Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:34.065339Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:34.073758Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:34.072755Z"
|
||
},
|
||
"origin_pos": 46,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([2.3026, 1.1513, 0.0000])"
|
||
]
|
||
},
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"loss = MaskedSoftmaxCELoss()\n",
|
||
"loss(torch.ones(3, 4, 10), torch.ones((3, 4), dtype=torch.long),\n",
|
||
" torch.tensor([4, 2, 0]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "6bdc4e96",
|
||
"metadata": {
|
||
"origin_pos": 49
|
||
},
|
||
"source": [
|
||
"## [**训练**]\n",
|
||
":label:`sec_seq2seq_training`\n",
|
||
"\n",
|
||
"在下面的循环训练过程中,如 :numref:`fig_seq2seq`所示,\n",
|
||
"特定的序列开始词元(“<bos>”)和\n",
|
||
"原始的输出序列(不包括序列结束词元“<eos>”)\n",
|
||
"拼接在一起作为解码器的输入。\n",
|
||
"这被称为*强制教学*(teacher forcing),\n",
|
||
"因为原始的输出序列(词元的标签)被送入解码器。\n",
|
||
"或者,将来自上一个时间步的*预测*得到的词元作为解码器的当前输入。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"id": "9d7b922e",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:34.077404Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:34.076756Z",
|
||
"iopub.status.idle": "2023-08-18T07:15:34.087405Z",
|
||
"shell.execute_reply": "2023-08-18T07:15:34.086461Z"
|
||
},
|
||
"origin_pos": 51,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@save\n",
|
||
"def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):\n",
|
||
" \"\"\"训练序列到序列模型\"\"\"\n",
|
||
" def xavier_init_weights(m):\n",
|
||
" if type(m) == nn.Linear:\n",
|
||
" nn.init.xavier_uniform_(m.weight)\n",
|
||
" if type(m) == nn.GRU:\n",
|
||
" for param in m._flat_weights_names:\n",
|
||
" if \"weight\" in param:\n",
|
||
" nn.init.xavier_uniform_(m._parameters[param])\n",
|
||
"\n",
|
||
" net.apply(xavier_init_weights)\n",
|
||
" net.to(device)\n",
|
||
" optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
|
||
" loss = MaskedSoftmaxCELoss()\n",
|
||
" net.train()\n",
|
||
" animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n",
|
||
" xlim=[10, num_epochs])\n",
|
||
" for epoch in range(num_epochs):\n",
|
||
" timer = d2l.Timer()\n",
|
||
" metric = d2l.Accumulator(2) # 训练损失总和,词元数量\n",
|
||
" for batch in data_iter:\n",
|
||
" optimizer.zero_grad()\n",
|
||
" X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n",
|
||
" bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],\n",
|
||
" device=device).reshape(-1, 1)\n",
|
||
" dec_input = torch.cat([bos, Y[:, :-1]], 1) # 强制教学\n",
|
||
" Y_hat, _ = net(X, dec_input, X_valid_len)\n",
|
||
" l = loss(Y_hat, Y, Y_valid_len)\n",
|
||
" l.sum().backward()\t# 损失函数的标量进行“反向传播”\n",
|
||
" d2l.grad_clipping(net, 1)\n",
|
||
" num_tokens = Y_valid_len.sum()\n",
|
||
" optimizer.step()\n",
|
||
" with torch.no_grad():\n",
|
||
" metric.add(l.sum(), num_tokens)\n",
|
||
" if (epoch + 1) % 10 == 0:\n",
|
||
" animator.add(epoch + 1, (metric[0] / metric[1],))\n",
|
||
" print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '\n",
|
||
" f'tokens/sec on {str(device)}')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "fe583c85",
|
||
"metadata": {
|
||
"origin_pos": 54
|
||
},
|
||
"source": [
|
||
"现在,在机器翻译数据集上,我们可以\n",
|
||
"[**创建和训练一个循环神经网络“编码器-解码器”模型**]用于序列到序列的学习。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "79f585d6",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:15:34.091791Z",
|
||
"iopub.status.busy": "2023-08-18T07:15:34.090975Z",
|
||
"iopub.status.idle": "2023-08-18T07:16:11.767145Z",
|
||
"shell.execute_reply": "2023-08-18T07:16:11.765998Z"
|
||
},
|
||
"origin_pos": 55,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"loss 0.019, 12745.1 tokens/sec on cuda:0\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/svg+xml": [
|
||
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
|
||
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"262.1875pt\" height=\"180.65625pt\" viewBox=\"0 0 262.1875 180.65625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
|
||
" <metadata>\n",
|
||
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
|
||
" <cc:Work>\n",
|
||
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
|
||
" <dc:date>2023-08-18T07:16:11.733851</dc:date>\n",
|
||
" <dc:format>image/svg+xml</dc:format>\n",
|
||
" <dc:creator>\n",
|
||
" <cc:Agent>\n",
|
||
" <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
|
||
" </cc:Agent>\n",
|
||
" </dc:creator>\n",
|
||
" </cc:Work>\n",
|
||
" </rdf:RDF>\n",
|
||
" </metadata>\n",
|
||
" <defs>\n",
|
||
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
|
||
" </defs>\n",
|
||
" <g id=\"figure_1\">\n",
|
||
" <g id=\"patch_1\">\n",
|
||
" <path d=\"M 0 180.65625 \n",
|
||
"L 262.1875 180.65625 \n",
|
||
"L 262.1875 0 \n",
|
||
"L 0 0 \n",
|
||
"L 0 180.65625 \n",
|
||
"z\n",
|
||
"\" style=\"fill: none\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"axes_1\">\n",
|
||
" <g id=\"patch_2\">\n",
|
||
" <path d=\"M 50.14375 143.1 \n",
|
||
"L 245.44375 143.1 \n",
|
||
"L 245.44375 7.2 \n",
|
||
"L 50.14375 7.2 \n",
|
||
"z\n",
|
||
"\" style=\"fill: #ffffff\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"matplotlib.axis_1\">\n",
|
||
" <g id=\"xtick_1\">\n",
|
||
" <g id=\"line2d_1\">\n",
|
||
" <path d=\"M 77.081681 143.1 \n",
|
||
"L 77.081681 7.2 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_2\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"m011d5a03dd\" d=\"M 0 0 \n",
|
||
"L 0 3.5 \n",
|
||
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </defs>\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m011d5a03dd\" x=\"77.081681\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_1\">\n",
|
||
" <!-- 50 -->\n",
|
||
" <g transform=\"translate(70.719181 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
|
||
"L 3169 4666 \n",
|
||
"L 3169 4134 \n",
|
||
"L 1269 4134 \n",
|
||
"L 1269 2991 \n",
|
||
"Q 1406 3038 1543 3061 \n",
|
||
"Q 1681 3084 1819 3084 \n",
|
||
"Q 2600 3084 3056 2656 \n",
|
||
"Q 3513 2228 3513 1497 \n",
|
||
"Q 3513 744 3044 326 \n",
|
||
"Q 2575 -91 1722 -91 \n",
|
||
"Q 1428 -91 1123 -41 \n",
|
||
"Q 819 9 494 109 \n",
|
||
"L 494 744 \n",
|
||
"Q 775 591 1075 516 \n",
|
||
"Q 1375 441 1709 441 \n",
|
||
"Q 2250 441 2565 725 \n",
|
||
"Q 2881 1009 2881 1497 \n",
|
||
"Q 2881 1984 2565 2268 \n",
|
||
"Q 2250 2553 1709 2553 \n",
|
||
"Q 1456 2553 1204 2497 \n",
|
||
"Q 953 2441 691 2322 \n",
|
||
"L 691 4666 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
|
||
"Q 1547 4250 1301 3770 \n",
|
||
"Q 1056 3291 1056 2328 \n",
|
||
"Q 1056 1369 1301 889 \n",
|
||
"Q 1547 409 2034 409 \n",
|
||
"Q 2525 409 2770 889 \n",
|
||
"Q 3016 1369 3016 2328 \n",
|
||
"Q 3016 3291 2770 3770 \n",
|
||
"Q 2525 4250 2034 4250 \n",
|
||
"z\n",
|
||
"M 2034 4750 \n",
|
||
"Q 2819 4750 3233 4129 \n",
|
||
"Q 3647 3509 3647 2328 \n",
|
||
"Q 3647 1150 3233 529 \n",
|
||
"Q 2819 -91 2034 -91 \n",
|
||
"Q 1250 -91 836 529 \n",
|
||
"Q 422 1150 422 2328 \n",
|
||
"Q 422 3509 836 4129 \n",
|
||
"Q 1250 4750 2034 4750 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_2\">\n",
|
||
" <g id=\"line2d_3\">\n",
|
||
" <path d=\"M 110.754095 143.1 \n",
|
||
"L 110.754095 7.2 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_4\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m011d5a03dd\" x=\"110.754095\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_2\">\n",
|
||
" <!-- 100 -->\n",
|
||
" <g transform=\"translate(101.210345 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
|
||
"L 1825 531 \n",
|
||
"L 1825 4091 \n",
|
||
"L 703 3866 \n",
|
||
"L 703 4441 \n",
|
||
"L 1819 4666 \n",
|
||
"L 2450 4666 \n",
|
||
"L 2450 531 \n",
|
||
"L 3481 531 \n",
|
||
"L 3481 0 \n",
|
||
"L 794 0 \n",
|
||
"L 794 531 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_3\">\n",
|
||
" <g id=\"line2d_5\">\n",
|
||
" <path d=\"M 144.426509 143.1 \n",
|
||
"L 144.426509 7.2 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_6\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m011d5a03dd\" x=\"144.426509\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_3\">\n",
|
||
" <!-- 150 -->\n",
|
||
" <g transform=\"translate(134.882759 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_4\">\n",
|
||
" <g id=\"line2d_7\">\n",
|
||
" <path d=\"M 178.098922 143.1 \n",
|
||
"L 178.098922 7.2 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_8\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m011d5a03dd\" x=\"178.098922\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_4\">\n",
|
||
" <!-- 200 -->\n",
|
||
" <g transform=\"translate(168.555172 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
|
||
"L 3431 531 \n",
|
||
"L 3431 0 \n",
|
||
"L 469 0 \n",
|
||
"L 469 531 \n",
|
||
"Q 828 903 1448 1529 \n",
|
||
"Q 2069 2156 2228 2338 \n",
|
||
"Q 2531 2678 2651 2914 \n",
|
||
"Q 2772 3150 2772 3378 \n",
|
||
"Q 2772 3750 2511 3984 \n",
|
||
"Q 2250 4219 1831 4219 \n",
|
||
"Q 1534 4219 1204 4116 \n",
|
||
"Q 875 4013 500 3803 \n",
|
||
"L 500 4441 \n",
|
||
"Q 881 4594 1212 4672 \n",
|
||
"Q 1544 4750 1819 4750 \n",
|
||
"Q 2544 4750 2975 4387 \n",
|
||
"Q 3406 4025 3406 3419 \n",
|
||
"Q 3406 3131 3298 2873 \n",
|
||
"Q 3191 2616 2906 2266 \n",
|
||
"Q 2828 2175 2409 1742 \n",
|
||
"Q 1991 1309 1228 531 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_5\">\n",
|
||
" <g id=\"line2d_9\">\n",
|
||
" <path d=\"M 211.771336 143.1 \n",
|
||
"L 211.771336 7.2 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_10\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m011d5a03dd\" x=\"211.771336\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_5\">\n",
|
||
" <!-- 250 -->\n",
|
||
" <g transform=\"translate(202.227586 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_6\">\n",
|
||
" <g id=\"line2d_11\">\n",
|
||
" <path d=\"M 245.44375 143.1 \n",
|
||
"L 245.44375 7.2 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_12\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m011d5a03dd\" x=\"245.44375\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_6\">\n",
|
||
" <!-- 300 -->\n",
|
||
" <g transform=\"translate(235.9 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
|
||
"Q 3050 2419 3304 2112 \n",
|
||
"Q 3559 1806 3559 1356 \n",
|
||
"Q 3559 666 3084 287 \n",
|
||
"Q 2609 -91 1734 -91 \n",
|
||
"Q 1441 -91 1130 -33 \n",
|
||
"Q 819 25 488 141 \n",
|
||
"L 488 750 \n",
|
||
"Q 750 597 1062 519 \n",
|
||
"Q 1375 441 1716 441 \n",
|
||
"Q 2309 441 2620 675 \n",
|
||
"Q 2931 909 2931 1356 \n",
|
||
"Q 2931 1769 2642 2001 \n",
|
||
"Q 2353 2234 1838 2234 \n",
|
||
"L 1294 2234 \n",
|
||
"L 1294 2753 \n",
|
||
"L 1863 2753 \n",
|
||
"Q 2328 2753 2575 2939 \n",
|
||
"Q 2822 3125 2822 3475 \n",
|
||
"Q 2822 3834 2567 4026 \n",
|
||
"Q 2313 4219 1838 4219 \n",
|
||
"Q 1578 4219 1281 4162 \n",
|
||
"Q 984 4106 628 3988 \n",
|
||
"L 628 4550 \n",
|
||
"Q 988 4650 1302 4700 \n",
|
||
"Q 1616 4750 1894 4750 \n",
|
||
"Q 2613 4750 3031 4423 \n",
|
||
"Q 3450 4097 3450 3541 \n",
|
||
"Q 3450 3153 3228 2886 \n",
|
||
"Q 3006 2619 2597 2516 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-33\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_7\">\n",
|
||
" <!-- epoch -->\n",
|
||
" <g transform=\"translate(132.565625 171.376563)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
|
||
"L 3597 1613 \n",
|
||
"L 953 1613 \n",
|
||
"Q 991 1019 1311 708 \n",
|
||
"Q 1631 397 2203 397 \n",
|
||
"Q 2534 397 2845 478 \n",
|
||
"Q 3156 559 3463 722 \n",
|
||
"L 3463 178 \n",
|
||
"Q 3153 47 2828 -22 \n",
|
||
"Q 2503 -91 2169 -91 \n",
|
||
"Q 1331 -91 842 396 \n",
|
||
"Q 353 884 353 1716 \n",
|
||
"Q 353 2575 817 3079 \n",
|
||
"Q 1281 3584 2069 3584 \n",
|
||
"Q 2775 3584 3186 3129 \n",
|
||
"Q 3597 2675 3597 1894 \n",
|
||
"z\n",
|
||
"M 3022 2063 \n",
|
||
"Q 3016 2534 2758 2815 \n",
|
||
"Q 2500 3097 2075 3097 \n",
|
||
"Q 1594 3097 1305 2825 \n",
|
||
"Q 1016 2553 972 2059 \n",
|
||
"L 3022 2063 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
|
||
"L 1159 -1331 \n",
|
||
"L 581 -1331 \n",
|
||
"L 581 3500 \n",
|
||
"L 1159 3500 \n",
|
||
"L 1159 2969 \n",
|
||
"Q 1341 3281 1617 3432 \n",
|
||
"Q 1894 3584 2278 3584 \n",
|
||
"Q 2916 3584 3314 3078 \n",
|
||
"Q 3713 2572 3713 1747 \n",
|
||
"Q 3713 922 3314 415 \n",
|
||
"Q 2916 -91 2278 -91 \n",
|
||
"Q 1894 -91 1617 61 \n",
|
||
"Q 1341 213 1159 525 \n",
|
||
"z\n",
|
||
"M 3116 1747 \n",
|
||
"Q 3116 2381 2855 2742 \n",
|
||
"Q 2594 3103 2138 3103 \n",
|
||
"Q 1681 3103 1420 2742 \n",
|
||
"Q 1159 2381 1159 1747 \n",
|
||
"Q 1159 1113 1420 752 \n",
|
||
"Q 1681 391 2138 391 \n",
|
||
"Q 2594 391 2855 752 \n",
|
||
"Q 3116 1113 3116 1747 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
|
||
"Q 1497 3097 1228 2736 \n",
|
||
"Q 959 2375 959 1747 \n",
|
||
"Q 959 1119 1226 758 \n",
|
||
"Q 1494 397 1959 397 \n",
|
||
"Q 2419 397 2687 759 \n",
|
||
"Q 2956 1122 2956 1747 \n",
|
||
"Q 2956 2369 2687 2733 \n",
|
||
"Q 2419 3097 1959 3097 \n",
|
||
"z\n",
|
||
"M 1959 3584 \n",
|
||
"Q 2709 3584 3137 3096 \n",
|
||
"Q 3566 2609 3566 1747 \n",
|
||
"Q 3566 888 3137 398 \n",
|
||
"Q 2709 -91 1959 -91 \n",
|
||
"Q 1206 -91 779 398 \n",
|
||
"Q 353 888 353 1747 \n",
|
||
"Q 353 2609 779 3096 \n",
|
||
"Q 1206 3584 1959 3584 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-63\" d=\"M 3122 3366 \n",
|
||
"L 3122 2828 \n",
|
||
"Q 2878 2963 2633 3030 \n",
|
||
"Q 2388 3097 2138 3097 \n",
|
||
"Q 1578 3097 1268 2742 \n",
|
||
"Q 959 2388 959 1747 \n",
|
||
"Q 959 1106 1268 751 \n",
|
||
"Q 1578 397 2138 397 \n",
|
||
"Q 2388 397 2633 464 \n",
|
||
"Q 2878 531 3122 666 \n",
|
||
"L 3122 134 \n",
|
||
"Q 2881 22 2623 -34 \n",
|
||
"Q 2366 -91 2075 -91 \n",
|
||
"Q 1284 -91 818 406 \n",
|
||
"Q 353 903 353 1747 \n",
|
||
"Q 353 2603 823 3093 \n",
|
||
"Q 1294 3584 2113 3584 \n",
|
||
"Q 2378 3584 2631 3529 \n",
|
||
"Q 2884 3475 3122 3366 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-68\" d=\"M 3513 2113 \n",
|
||
"L 3513 0 \n",
|
||
"L 2938 0 \n",
|
||
"L 2938 2094 \n",
|
||
"Q 2938 2591 2744 2837 \n",
|
||
"Q 2550 3084 2163 3084 \n",
|
||
"Q 1697 3084 1428 2787 \n",
|
||
"Q 1159 2491 1159 1978 \n",
|
||
"L 1159 0 \n",
|
||
"L 581 0 \n",
|
||
"L 581 4863 \n",
|
||
"L 1159 4863 \n",
|
||
"L 1159 2956 \n",
|
||
"Q 1366 3272 1645 3428 \n",
|
||
"Q 1925 3584 2291 3584 \n",
|
||
"Q 2894 3584 3203 3211 \n",
|
||
"Q 3513 2838 3513 2113 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-65\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-70\" x=\"61.523438\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6f\" x=\"125\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-63\" x=\"186.181641\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-68\" x=\"241.162109\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"matplotlib.axis_2\">\n",
|
||
" <g id=\"ytick_1\">\n",
|
||
" <g id=\"line2d_13\">\n",
|
||
" <path d=\"M 50.14375 116.334147 \n",
|
||
"L 245.44375 116.334147 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_14\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"mb09ffc4a21\" d=\"M 0 0 \n",
|
||
"L -3.5 0 \n",
|
||
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </defs>\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#mb09ffc4a21\" x=\"50.14375\" y=\"116.334147\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_8\">\n",
|
||
" <!-- 0.05 -->\n",
|
||
" <g transform=\"translate(20.878125 120.133366)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
|
||
"L 1344 794 \n",
|
||
"L 1344 0 \n",
|
||
"L 684 0 \n",
|
||
"L 684 794 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\" x=\"159.033203\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_2\">\n",
|
||
" <g id=\"line2d_15\">\n",
|
||
" <path d=\"M 50.14375 83.069068 \n",
|
||
"L 245.44375 83.069068 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_16\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#mb09ffc4a21\" x=\"50.14375\" y=\"83.069068\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_9\">\n",
|
||
" <!-- 0.10 -->\n",
|
||
" <g transform=\"translate(20.878125 86.868287)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\" x=\"95.410156\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"159.033203\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_3\">\n",
|
||
" <g id=\"line2d_17\">\n",
|
||
" <path d=\"M 50.14375 49.803989 \n",
|
||
"L 245.44375 49.803989 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_18\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#mb09ffc4a21\" x=\"50.14375\" y=\"49.803989\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_10\">\n",
|
||
" <!-- 0.15 -->\n",
|
||
" <g transform=\"translate(20.878125 53.603208)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\" x=\"95.410156\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\" x=\"159.033203\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_4\">\n",
|
||
" <g id=\"line2d_19\">\n",
|
||
" <path d=\"M 50.14375 16.53891 \n",
|
||
"L 245.44375 16.53891 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_20\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#mb09ffc4a21\" x=\"50.14375\" y=\"16.53891\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_11\">\n",
|
||
" <!-- 0.20 -->\n",
|
||
" <g transform=\"translate(20.878125 20.338128)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\" x=\"95.410156\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"159.033203\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_12\">\n",
|
||
" <!-- loss -->\n",
|
||
" <g transform=\"translate(14.798437 84.807812)rotate(-90)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
|
||
"L 1178 4863 \n",
|
||
"L 1178 0 \n",
|
||
"L 603 0 \n",
|
||
"L 603 4863 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
|
||
"L 2834 2853 \n",
|
||
"Q 2591 2978 2328 3040 \n",
|
||
"Q 2066 3103 1784 3103 \n",
|
||
"Q 1356 3103 1142 2972 \n",
|
||
"Q 928 2841 928 2578 \n",
|
||
"Q 928 2378 1081 2264 \n",
|
||
"Q 1234 2150 1697 2047 \n",
|
||
"L 1894 2003 \n",
|
||
"Q 2506 1872 2764 1633 \n",
|
||
"Q 3022 1394 3022 966 \n",
|
||
"Q 3022 478 2636 193 \n",
|
||
"Q 2250 -91 1575 -91 \n",
|
||
"Q 1294 -91 989 -36 \n",
|
||
"Q 684 19 347 128 \n",
|
||
"L 347 722 \n",
|
||
"Q 666 556 975 473 \n",
|
||
"Q 1284 391 1588 391 \n",
|
||
"Q 1994 391 2212 530 \n",
|
||
"Q 2431 669 2431 922 \n",
|
||
"Q 2431 1156 2273 1281 \n",
|
||
"Q 2116 1406 1581 1522 \n",
|
||
"L 1381 1569 \n",
|
||
"Q 847 1681 609 1914 \n",
|
||
"Q 372 2147 372 2553 \n",
|
||
"Q 372 3047 722 3315 \n",
|
||
"Q 1072 3584 1716 3584 \n",
|
||
"Q 2034 3584 2315 3537 \n",
|
||
"Q 2597 3491 2834 3397 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"88.964844\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"141.064453\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_21\">\n",
|
||
" <path d=\"M 50.14375 13.377273 \n",
|
||
"L 56.878233 53.801248 \n",
|
||
"L 63.612716 76.94721 \n",
|
||
"L 70.347198 93.478629 \n",
|
||
"L 77.081681 105.328161 \n",
|
||
"L 83.816164 113.311737 \n",
|
||
"L 90.550647 118.804818 \n",
|
||
"L 97.285129 122.329577 \n",
|
||
"L 104.019612 125.712693 \n",
|
||
"L 110.754095 127.561094 \n",
|
||
"L 117.488578 129.463581 \n",
|
||
"L 124.22306 130.93093 \n",
|
||
"L 130.957543 131.820335 \n",
|
||
"L 137.692026 132.84396 \n",
|
||
"L 144.426509 133.235857 \n",
|
||
"L 151.160991 133.933191 \n",
|
||
"L 157.895474 134.564336 \n",
|
||
"L 164.629957 134.964614 \n",
|
||
"L 171.36444 135.063178 \n",
|
||
"L 178.098922 135.718194 \n",
|
||
"L 184.833405 136.190944 \n",
|
||
"L 191.567888 136.099791 \n",
|
||
"L 198.302371 136.272226 \n",
|
||
"L 205.036853 136.148948 \n",
|
||
"L 211.771336 135.969755 \n",
|
||
"L 218.505819 136.56661 \n",
|
||
"L 225.240302 136.514543 \n",
|
||
"L 231.974784 136.690111 \n",
|
||
"L 238.709267 136.84549 \n",
|
||
"L 245.44375 136.922727 \n",
|
||
"\" clip-path=\"url(#p97e5e7769b)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_3\">\n",
|
||
" <path d=\"M 50.14375 143.1 \n",
|
||
"L 50.14375 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_4\">\n",
|
||
" <path d=\"M 245.44375 143.1 \n",
|
||
"L 245.44375 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_5\">\n",
|
||
" <path d=\"M 50.14375 143.1 \n",
|
||
"L 245.44375 143.1 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_6\">\n",
|
||
" <path d=\"M 50.14375 7.2 \n",
|
||
"L 245.44375 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <defs>\n",
|
||
" <clipPath id=\"p97e5e7769b\">\n",
|
||
" <rect x=\"50.14375\" y=\"7.2\" width=\"195.3\" height=\"135.9\"/>\n",
|
||
" </clipPath>\n",
|
||
" </defs>\n",
|
||
"</svg>\n"
|
||
],
|
||
"text/plain": [
|
||
"<Figure size 252x180 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n",
|
||
"batch_size, num_steps = 64, 10\n",
|
||
"lr, num_epochs, device = 0.005, 300, d2l.try_gpu()\n",
|
||
"\n",
|
||
"train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)\n",
|
||
"encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,\n",
|
||
" dropout)\n",
|
||
"decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,\n",
|
||
" dropout)\n",
|
||
"net = d2l.EncoderDecoder(encoder, decoder)\n",
|
||
"train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b24c26df",
|
||
"metadata": {
|
||
"origin_pos": 56
|
||
},
|
||
"source": [
|
||
"## [**预测**]\n",
|
||
"\n",
|
||
"为了采用一个接着一个词元的方式预测输出序列,\n",
|
||
"每个解码器当前时间步的输入都将来自于前一时间步的预测词元。\n",
|
||
"与训练类似,序列开始词元(“<bos>”)\n",
|
||
"在初始时间步被输入到解码器中。\n",
|
||
"该预测过程如 :numref:`fig_seq2seq_predict`所示,\n",
|
||
"当输出序列的预测遇到序列结束词元(“<eos>”)时,预测就结束了。\n",
|
||
"\n",
|
||
"\n",
|
||
":label:`fig_seq2seq_predict`\n",
|
||
"\n",
|
||
"我们将在 :numref:`sec_beam-search`中介绍不同的序列生成策略。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"id": "7510bee7",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:16:11.771151Z",
|
||
"iopub.status.busy": "2023-08-18T07:16:11.770496Z",
|
||
"iopub.status.idle": "2023-08-18T07:16:11.779631Z",
|
||
"shell.execute_reply": "2023-08-18T07:16:11.778678Z"
|
||
},
|
||
"origin_pos": 58,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@save\n",
|
||
"def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,\n",
|
||
" device, save_attention_weights=False):\n",
|
||
" \"\"\"序列到序列模型的预测\"\"\"\n",
|
||
" # 在预测时将net设置为评估模式\n",
|
||
" net.eval()\n",
|
||
" src_tokens = src_vocab[src_sentence.lower().split(' ')] + [\n",
|
||
" src_vocab['<eos>']]\n",
|
||
" enc_valid_len = torch.tensor([len(src_tokens)], device=device)\n",
|
||
" src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])\n",
|
||
" # 添加批量轴\n",
|
||
" enc_X = torch.unsqueeze(\n",
|
||
" torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)\n",
|
||
" enc_outputs = net.encoder(enc_X, enc_valid_len)\n",
|
||
" dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)\n",
|
||
" # 添加批量轴\n",
|
||
" dec_X = torch.unsqueeze(torch.tensor(\n",
|
||
" [tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)\n",
|
||
" output_seq, attention_weight_seq = [], []\n",
|
||
" for _ in range(num_steps):\n",
|
||
" Y, dec_state = net.decoder(dec_X, dec_state)\n",
|
||
" # 我们使用具有预测最高可能性的词元,作为解码器在下一时间步的输入\n",
|
||
" dec_X = Y.argmax(dim=2)\n",
|
||
" pred = dec_X.squeeze(dim=0).type(torch.int32).item()\n",
|
||
" # 保存注意力权重(稍后讨论)\n",
|
||
" if save_attention_weights:\n",
|
||
" attention_weight_seq.append(net.decoder.attention_weights)\n",
|
||
" # 一旦序列结束词元被预测,输出序列的生成就完成了\n",
|
||
" if pred == tgt_vocab['<eos>']:\n",
|
||
" break\n",
|
||
" output_seq.append(pred)\n",
|
||
" return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "71773ad1",
|
||
"metadata": {
|
||
"origin_pos": 61
|
||
},
|
||
"source": [
|
||
"## 预测序列的评估\n",
|
||
"\n",
|
||
"我们可以通过与真实的标签序列进行比较来评估预测序列。\n",
|
||
"虽然 :cite:`Papineni.Roukos.Ward.ea.2002`\n",
|
||
"提出的BLEU(bilingual evaluation understudy)\n",
|
||
"最先是用于评估机器翻译的结果,\n",
|
||
"但现在它已经被广泛用于测量许多应用的输出序列的质量。\n",
|
||
"原则上说,对于预测序列中的任意$n$元语法(n-grams),\n",
|
||
"BLEU的评估都是这个$n$元语法是否出现在标签序列中。\n",
|
||
"\n",
|
||
"我们将BLEU定义为:\n",
|
||
"\n",
|
||
"$$ \\exp\\left(\\min\\left(0, 1 - \\frac{\\mathrm{len}_{\\text{label}}}{\\mathrm{len}_{\\text{pred}}}\\right)\\right) \\prod_{n=1}^k p_n^{1/2^n},$$\n",
|
||
":eqlabel:`eq_bleu`\n",
|
||
"\n",
|
||
"其中$\\mathrm{len}_{\\text{label}}$表示标签序列中的词元数和\n",
|
||
"$\\mathrm{len}_{\\text{pred}}$表示预测序列中的词元数,\n",
|
||
"$k$是用于匹配的最长的$n$元语法。\n",
|
||
"另外,用$p_n$表示$n$元语法的精确度,它是两个数量的比值:\n",
|
||
"第一个是预测序列与标签序列中匹配的$n$元语法的数量,\n",
|
||
"第二个是预测序列中$n$元语法的数量的比率。\n",
|
||
"具体地说,给定标签序列$A$、$B$、$C$、$D$、$E$、$F$\n",
|
||
"和预测序列$A$、$B$、$B$、$C$、$D$,\n",
|
||
"我们有$p_1 = 4/5$、$p_2 = 3/4$、$p_3 = 1/3$和$p_4 = 0$。\n",
|
||
"\n",
|
||
"根据 :eqref:`eq_bleu`中BLEU的定义,\n",
|
||
"当预测序列与标签序列完全相同时,BLEU为$1$。\n",
|
||
"此外,由于$n$元语法越长则匹配难度越大,\n",
|
||
"所以BLEU为更长的$n$元语法的精确度分配更大的权重。\n",
|
||
"具体来说,当$p_n$固定时,$p_n^{1/2^n}$\n",
|
||
"会随着$n$的增长而增加(原始论文使用$p_n^{1/n}$)。\n",
|
||
"而且,由于预测的序列越短获得的$p_n$值越高,\n",
|
||
"所以 :eqref:`eq_bleu`中乘法项之前的系数用于惩罚较短的预测序列。\n",
|
||
"例如,当$k=2$时,给定标签序列$A$、$B$、$C$、$D$、$E$、$F$\n",
|
||
"和预测序列$A$、$B$,尽管$p_1 = p_2 = 1$,\n",
|
||
"惩罚因子$\\exp(1-6/2) \\approx 0.14$会降低BLEU。\n",
|
||
"\n",
|
||
"[**BLEU的代码实现**]如下。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "9135ade0",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:16:11.784109Z",
|
||
"iopub.status.busy": "2023-08-18T07:16:11.783827Z",
|
||
"iopub.status.idle": "2023-08-18T07:16:11.791568Z",
|
||
"shell.execute_reply": "2023-08-18T07:16:11.790396Z"
|
||
},
|
||
"origin_pos": 62,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def bleu(pred_seq, label_seq, k): #@save\n",
|
||
" \"\"\"计算BLEU\"\"\"\n",
|
||
" pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')\n",
|
||
" len_pred, len_label = len(pred_tokens), len(label_tokens)\n",
|
||
" score = math.exp(min(0, 1 - len_label / len_pred))\n",
|
||
" for n in range(1, k + 1):\n",
|
||
" num_matches, label_subs = 0, collections.defaultdict(int)\n",
|
||
" for i in range(len_label - n + 1):\n",
|
||
" label_subs[' '.join(label_tokens[i: i + n])] += 1\n",
|
||
" for i in range(len_pred - n + 1):\n",
|
||
" if label_subs[' '.join(pred_tokens[i: i + n])] > 0:\n",
|
||
" num_matches += 1\n",
|
||
" label_subs[' '.join(pred_tokens[i: i + n])] -= 1\n",
|
||
" score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))\n",
|
||
" return score"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "16c57898",
|
||
"metadata": {
|
||
"origin_pos": 63
|
||
},
|
||
"source": [
|
||
"最后,利用训练好的循环神经网络“编码器-解码器”模型,\n",
|
||
"[**将几个英语句子翻译成法语**],并计算BLEU的最终结果。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"id": "653f0dd4",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:16:11.796025Z",
|
||
"iopub.status.busy": "2023-08-18T07:16:11.795107Z",
|
||
"iopub.status.idle": "2023-08-18T07:16:11.818936Z",
|
||
"shell.execute_reply": "2023-08-18T07:16:11.817788Z"
|
||
},
|
||
"origin_pos": 64,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"go . => va !, bleu 1.000\n",
|
||
"i lost . => j'ai perdu ., bleu 1.000\n",
|
||
"he's calm . => il est riche ., bleu 0.658\n",
|
||
"i'm home . => je suis en retard ?, bleu 0.447\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .']\n",
|
||
"fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n",
|
||
"for eng, fra in zip(engs, fras):\n",
|
||
" translation, attention_weight_seq = predict_seq2seq(\n",
|
||
" net, eng, src_vocab, tgt_vocab, num_steps, device)\n",
|
||
" print(f'{eng} => {translation}, bleu {bleu(translation, fra, k=2):.3f}')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "2448f426",
|
||
"metadata": {
|
||
"origin_pos": 66
|
||
},
|
||
"source": [
|
||
"## 小结\n",
|
||
"\n",
|
||
"* 根据“编码器-解码器”架构的设计,\n",
|
||
" 我们可以使用两个循环神经网络来设计一个序列到序列学习的模型。\n",
|
||
"* 在实现编码器和解码器时,我们可以使用多层循环神经网络。\n",
|
||
"* 我们可以使用遮蔽来过滤不相关的计算,例如在计算损失时。\n",
|
||
"* 在“编码器-解码器”训练中,强制教学方法将原始输出序列(而非预测结果)输入解码器。\n",
|
||
"* BLEU是一种常用的评估方法,它通过测量预测序列和标签序列之间的$n$元语法的匹配度来评估预测。\n",
|
||
"\n",
|
||
"## 练习\n",
|
||
"\n",
|
||
"1. 试着通过调整超参数来改善翻译效果。\n",
|
||
"1. 重新运行实验并在计算损失时不使用遮蔽,可以观察到什么结果?为什么会有这个结果?\n",
|
||
"1. 如果编码器和解码器的层数或者隐藏单元数不同,那么如何初始化解码器的隐状态?\n",
|
||
"1. 在训练中,如果用前一时间步的预测输入到解码器来代替强制教学,对性能有何影响?\n",
|
||
"1. 用长短期记忆网络替换门控循环单元重新运行实验。\n",
|
||
"1. 有没有其他方法来设计解码器的输出层?\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "89b706ef",
|
||
"metadata": {
|
||
"origin_pos": 68,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"source": [
|
||
"[Discussions](https://discuss.d2l.ai/t/2782)\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"language_info": {
|
||
"name": "python"
|
||
},
|
||
"required_libs": []
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
} |