{ "cells": [ { "cell_type": "markdown", "id": "ad81766b", "metadata": { "origin_pos": 0 }, "source": [ "# Transformer\n", ":label:`sec_transformer`\n", "\n", " :numref:`subsec_cnn-rnn-self-attention`中比较了卷积神经网络(CNN)、循环神经网络(RNN)和自注意力(self-attention)。值得注意的是,自注意力同时具有并行计算和最短的最大路径长度这两个优势。因此,使用自注意力来设计深度架构是很有吸引力的。对比之前仍然依赖循环神经网络实现输入表示的自注意力模型 :cite:`Cheng.Dong.Lapata.2016,Lin.Feng.Santos.ea.2017,Paulus.Xiong.Socher.2017`,Transformer模型完全基于注意力机制,没有任何卷积层或循环神经网络层 :cite:`Vaswani.Shazeer.Parmar.ea.2017`。尽管Transformer最初是应用于在文本数据上的序列到序列学习,但现在已经推广到各种现代的深度学习中,例如语言、视觉、语音和强化学习领域。\n", "\n", "## 模型\n", "\n", "Transformer作为编码器-解码器架构的一个实例,其整体架构图在 :numref:`fig_transformer`中展示。正如所见到的,Transformer是由编码器和解码器组成的。与 :numref:`fig_s2s_attention_details`中基于Bahdanau注意力实现的序列到序列的学习相比,Transformer的编码器和解码器是基于自注意力的模块叠加而成的,源(输入)序列和目标(输出)序列的*嵌入*(embedding)表示将加上*位置编码*(positional encoding),再分别输入到编码器和解码器中。\n", "\n", "![transformer架构](../img/transformer.svg)\n", ":width:`500px`\n", ":label:`fig_transformer`\n", "\n", "图 :numref:`fig_transformer`中概述了Transformer的架构。从宏观角度来看,Transformer的编码器是由多个相同的层叠加而成的,每个层都有两个子层(子层表示为$\\mathrm{sublayer}$)。第一个子层是*多头自注意力*(multi-head self-attention)汇聚;第二个子层是*基于位置的前馈网络*(positionwise feed-forward network)。具体来说,在计算编码器的自注意力时,查询、键和值都来自前一个编码器层的输出。受 :numref:`sec_resnet`中残差网络的启发,每个子层都采用了*残差连接*(residual connection)。在Transformer中,对于序列中任何位置的任何输入$\\mathbf{x} \\in \\mathbb{R}^d$,都要求满足$\\mathrm{sublayer}(\\mathbf{x}) \\in \\mathbb{R}^d$,以便残差连接满足$\\mathbf{x} + \\mathrm{sublayer}(\\mathbf{x}) \\in \\mathbb{R}^d$。在残差连接的加法计算之后,紧接着应用*层规范化*(layer normalization) :cite:`Ba.Kiros.Hinton.2016`。因此,输入序列对应的每个位置,Transformer编码器都将输出一个$d$维表示向量。\n", "\n", "Transformer解码器也是由多个相同的层叠加而成的,并且层中使用了残差连接和层规范化。除了编码器中描述的两个子层之外,解码器还在这两个子层之间插入了第三个子层,称为*编码器-解码器注意力*(encoder-decoder attention)层。在编码器-解码器注意力中,查询来自前一个解码器层的输出,而键和值来自整个编码器的输出。在解码器自注意力中,查询、键和值都来自上一个解码器层的输出。但是,解码器中的每个位置只能考虑该位置之前的所有位置。这种*掩蔽*(masked)注意力保留了*自回归*(auto-regressive)属性,确保预测仅依赖于已生成的输出词元。\n", "\n", "在此之前已经描述并实现了基于缩放点积多头注意力 :numref:`sec_multihead-attention`和位置编码 :numref:`subsec_positional-encoding`。接下来将实现Transformer模型的剩余部分。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "9309c9bd", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:02.524217Z", "iopub.status.busy": "2023-08-18T07:21:02.523886Z", "iopub.status.idle": "2023-08-18T07:21:05.481998Z", "shell.execute_reply": "2023-08-18T07:21:05.481087Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import math\n", "import pandas as pd\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "660c2057", "metadata": { "origin_pos": 5 }, "source": [ "## [**基于位置的前馈网络**]\n", "\n", "基于位置的前馈网络对序列中的所有位置的表示进行变换时使用的是同一个多层感知机(MLP),这就是称前馈网络是*基于位置的*(positionwise)的原因。在下面的实现中,输入`X`的形状(批量大小,时间步数或序列长度,隐单元数或特征维度)将被一个两层的感知机转换成形状为(批量大小,时间步数,`ffn_num_outputs`)的输出张量。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "5a030ddc", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.487583Z", "iopub.status.busy": "2023-08-18T07:21:05.486906Z", "iopub.status.idle": "2023-08-18T07:21:05.492826Z", "shell.execute_reply": "2023-08-18T07:21:05.492063Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "class PositionWiseFFN(nn.Module):\n", " \"\"\"基于位置的前馈网络\"\"\"\n", " def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,\n", " **kwargs):\n", " super(PositionWiseFFN, self).__init__(**kwargs)\n", " self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)\n", " self.relu = nn.ReLU()\n", " self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)\n", "\n", " def forward(self, X):\n", " return self.dense2(self.relu(self.dense1(X)))" ] }, { "cell_type": "markdown", "id": "d5d02e37", "metadata": { "origin_pos": 10 }, "source": [ "下面的例子显示,[**改变张量的最里层维度的尺寸**],会改变成基于位置的前馈网络的输出尺寸。因为用同一个多层感知机对所有位置上的输入进行变换,所以当所有这些位置的输入相同时,它们的输出也是相同的。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "be51755b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.498028Z", "iopub.status.busy": "2023-08-18T07:21:05.497721Z", "iopub.status.idle": "2023-08-18T07:21:05.527774Z", "shell.execute_reply": "2023-08-18T07:21:05.526996Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.8290, 1.0067, 0.3619, 0.3594, -0.5328, 0.2712, 0.7394, 0.0747],\n", " [-0.8290, 1.0067, 0.3619, 0.3594, -0.5328, 0.2712, 0.7394, 0.0747],\n", " [-0.8290, 1.0067, 0.3619, 0.3594, -0.5328, 0.2712, 0.7394, 0.0747]],\n", " grad_fn=)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ffn = PositionWiseFFN(4, 4, 8)\n", "ffn.eval()\n", "ffn(torch.ones((2, 3, 4)))[0]" ] }, { "cell_type": "markdown", "id": "3b4b7f94", "metadata": { "origin_pos": 15 }, "source": [ "## 残差连接和层规范化\n", "\n", "现在让我们关注 :numref:`fig_transformer`中的*加法和规范化*(add&norm)组件。正如在本节开头所述,这是由残差连接和紧随其后的层规范化组成的。两者都是构建有效的深度架构的关键。\n", "\n", " :numref:`sec_batch_norm`中解释了在一个小批量的样本内基于批量规范化对数据进行重新中心化和重新缩放的调整。层规范化和批量规范化的目标相同,但层规范化是基于特征维度进行规范化。尽管批量规范化在计算机视觉中被广泛应用,但在自然语言处理任务中(输入通常是变长序列)批量规范化通常不如层规范化的效果好。\n", "\n", "以下代码[**对比不同维度的层规范化和批量规范化的效果**]。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "1da6d2d6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.531626Z", "iopub.status.busy": "2023-08-18T07:21:05.531338Z", "iopub.status.idle": "2023-08-18T07:21:05.541716Z", "shell.execute_reply": "2023-08-18T07:21:05.540694Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "layer norm: tensor([[-1.0000, 1.0000],\n", " [-1.0000, 1.0000]], grad_fn=) \n", "batch norm: tensor([[-1.0000, -1.0000],\n", " [ 1.0000, 1.0000]], grad_fn=)\n" ] } ], "source": [ "ln = nn.LayerNorm(2)\n", "bn = nn.BatchNorm1d(2)\n", "X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)\n", "# 在训练模式下计算X的均值和方差\n", "print('layer norm:', ln(X), '\\nbatch norm:', bn(X))" ] }, { "cell_type": "markdown", "id": "a7438b9a", "metadata": { "origin_pos": 20 }, "source": [ "现在可以[**使用残差连接和层规范化**]来实现`AddNorm`类。暂退法也被作为正则化方法使用。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "41dc88c3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.547230Z", "iopub.status.busy": "2023-08-18T07:21:05.546508Z", "iopub.status.idle": "2023-08-18T07:21:05.553654Z", "shell.execute_reply": "2023-08-18T07:21:05.552629Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "class AddNorm(nn.Module):\n", " \"\"\"残差连接后进行层规范化\"\"\"\n", " def __init__(self, normalized_shape, dropout, **kwargs):\n", " super(AddNorm, self).__init__(**kwargs)\n", " self.dropout = nn.Dropout(dropout)\n", " self.ln = nn.LayerNorm(normalized_shape)\n", "\n", " def forward(self, X, Y):\n", " return self.ln(self.dropout(Y) + X)" ] }, { "cell_type": "markdown", "id": "17336160", "metadata": { "origin_pos": 25 }, "source": [ "残差连接要求两个输入的形状相同,以便[**加法操作后输出张量的形状相同**]。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "9ddbf972", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.559356Z", "iopub.status.busy": "2023-08-18T07:21:05.558397Z", "iopub.status.idle": "2023-08-18T07:21:05.568412Z", "shell.execute_reply": "2023-08-18T07:21:05.567252Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3, 4])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "add_norm = AddNorm([3, 4], 0.5)\n", "add_norm.eval()\n", "add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape" ] }, { "cell_type": "markdown", "id": "af623091", "metadata": { "origin_pos": 29 }, "source": [ "## 编码器\n", "\n", "有了组成Transformer编码器的基础组件,现在可以先[**实现编码器中的一个层**]。下面的`EncoderBlock`类包含两个子层:多头自注意力和基于位置的前馈网络,这两个子层都使用了残差连接和紧随的层规范化。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "3a4bae18", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.572991Z", "iopub.status.busy": "2023-08-18T07:21:05.572138Z", "iopub.status.idle": "2023-08-18T07:21:05.581937Z", "shell.execute_reply": "2023-08-18T07:21:05.580861Z" }, "origin_pos": 31, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "class EncoderBlock(nn.Module):\n", " \"\"\"Transformer编码器块\"\"\"\n", " def __init__(self, key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n", " dropout, use_bias=False, **kwargs):\n", " super(EncoderBlock, self).__init__(**kwargs)\n", " self.attention = d2l.MultiHeadAttention(\n", " key_size, query_size, value_size, num_hiddens, num_heads, dropout,\n", " use_bias)\n", " self.addnorm1 = AddNorm(norm_shape, dropout)\n", " self.ffn = PositionWiseFFN(\n", " ffn_num_input, ffn_num_hiddens, num_hiddens)\n", " self.addnorm2 = AddNorm(norm_shape, dropout)\n", "\n", " def forward(self, X, valid_lens):\n", " Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))\n", " return self.addnorm2(Y, self.ffn(Y))" ] }, { "cell_type": "markdown", "id": "683dbd77", "metadata": { "origin_pos": 34 }, "source": [ "正如从代码中所看到的,[**Transformer编码器中的任何层都不会改变其输入的形状**]。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "f55eefee", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.586090Z", "iopub.status.busy": "2023-08-18T07:21:05.585452Z", "iopub.status.idle": "2023-08-18T07:21:05.601167Z", "shell.execute_reply": "2023-08-18T07:21:05.599824Z" }, "origin_pos": 36, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 100, 24])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.ones((2, 100, 24))\n", "valid_lens = torch.tensor([3, 2])\n", "encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)\n", "encoder_blk.eval()\n", "encoder_blk(X, valid_lens).shape" ] }, { "cell_type": "markdown", "id": "cde53f8e", "metadata": { "origin_pos": 38 }, "source": [ "下面实现的[**Transformer编码器**]的代码中,堆叠了`num_layers`个`EncoderBlock`类的实例。由于这里使用的是值范围在$-1$和$1$之间的固定位置编码,因此通过学习得到的输入的嵌入表示的值需要先乘以嵌入维度的平方根进行重新缩放,然后再与位置编码相加。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "94c5fb5d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.605447Z", "iopub.status.busy": "2023-08-18T07:21:05.604941Z", "iopub.status.idle": "2023-08-18T07:21:05.618289Z", "shell.execute_reply": "2023-08-18T07:21:05.616871Z" }, "origin_pos": 40, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "class TransformerEncoder(d2l.Encoder):\n", " \"\"\"Transformer编码器\"\"\"\n", " def __init__(self, vocab_size, key_size, query_size, value_size,\n", " num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,\n", " num_heads, num_layers, dropout, use_bias=False, **kwargs):\n", " super(TransformerEncoder, self).__init__(**kwargs)\n", " self.num_hiddens = num_hiddens\n", " self.embedding = nn.Embedding(vocab_size, num_hiddens)\n", " self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)\n", " self.blks = nn.Sequential()\n", " for i in range(num_layers):\n", " self.blks.add_module(\"block\"+str(i),\n", " EncoderBlock(key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens,\n", " num_heads, dropout, use_bias))\n", "\n", " def forward(self, X, valid_lens, *args):\n", " # 因为位置编码值在-1和1之间,\n", " # 因此嵌入值乘以嵌入维度的平方根进行缩放,\n", " # 然后再与位置编码相加。\n", " X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))\n", " self.attention_weights = [None] * len(self.blks)\n", " for i, blk in enumerate(self.blks):\n", " X = blk(X, valid_lens)\n", " self.attention_weights[\n", " i] = blk.attention.attention.attention_weights\n", " return X" ] }, { "cell_type": "markdown", "id": "a0025bec", "metadata": { "origin_pos": 43 }, "source": [ "下面我们指定了超参数来[**创建一个两层的Transformer编码器**]。\n", "Transformer编码器输出的形状是(批量大小,时间步数目,`num_hiddens`)。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "b16ed63f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.624985Z", "iopub.status.busy": "2023-08-18T07:21:05.623935Z", "iopub.status.idle": "2023-08-18T07:21:05.643729Z", "shell.execute_reply": "2023-08-18T07:21:05.642680Z" }, "origin_pos": 45, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 100, 24])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder = TransformerEncoder(\n", " 200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)\n", "encoder.eval()\n", "encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape" ] }, { "cell_type": "markdown", "id": "db537625", "metadata": { "origin_pos": 48 }, "source": [ "## 解码器\n", "\n", "如 :numref:`fig_transformer`所示,[**Transformer解码器也是由多个相同的层组成**]。在`DecoderBlock`类中实现的每个层包含了三个子层:解码器自注意力、“编码器-解码器”注意力和基于位置的前馈网络。这些子层也都被残差连接和紧随的层规范化围绕。\n", "\n", "正如在本节前面所述,在掩蔽多头解码器自注意力层(第一个子层)中,查询、键和值都来自上一个解码器层的输出。关于*序列到序列模型*(sequence-to-sequence model),在训练阶段,其输出序列的所有位置(时间步)的词元都是已知的;然而,在预测阶段,其输出序列的词元是逐个生成的。因此,在任何解码器时间步中,只有生成的词元才能用于解码器的自注意力计算中。为了在解码器中保留自回归的属性,其掩蔽自注意力设定了参数`dec_valid_lens`,以便任何查询都只会与解码器中所有已经生成词元的位置(即直到该查询位置为止)进行注意力计算。\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "b97dbb3a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.648104Z", "iopub.status.busy": "2023-08-18T07:21:05.647725Z", "iopub.status.idle": "2023-08-18T07:21:05.662950Z", "shell.execute_reply": "2023-08-18T07:21:05.661933Z" }, "origin_pos": 50, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class DecoderBlock(nn.Module):\n", " \"\"\"解码器中第i个块\"\"\"\n", " def __init__(self, key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n", " dropout, i, **kwargs):\n", " super(DecoderBlock, self).__init__(**kwargs)\n", " self.i = i\n", " self.attention1 = d2l.MultiHeadAttention(\n", " key_size, query_size, value_size, num_hiddens, num_heads, dropout)\n", " self.addnorm1 = AddNorm(norm_shape, dropout)\n", " self.attention2 = d2l.MultiHeadAttention(\n", " key_size, query_size, value_size, num_hiddens, num_heads, dropout)\n", " self.addnorm2 = AddNorm(norm_shape, dropout)\n", " self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,\n", " num_hiddens)\n", " self.addnorm3 = AddNorm(norm_shape, dropout)\n", "\n", " def forward(self, X, state):\n", " enc_outputs, enc_valid_lens = state[0], state[1]\n", " # 训练阶段,输出序列的所有词元都在同一时间处理,\n", " # 因此state[2][self.i]初始化为None。\n", " # 预测阶段,输出序列是通过词元一个接着一个解码的,\n", " # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示\n", " if state[2][self.i] is None:\n", " key_values = X\n", " else:\n", " key_values = torch.cat((state[2][self.i], X), axis=1)\n", " state[2][self.i] = key_values\n", " if self.training:\n", " batch_size, num_steps, _ = X.shape\n", " # dec_valid_lens的开头:(batch_size,num_steps),\n", " # 其中每一行是[1,2,...,num_steps]\n", " dec_valid_lens = torch.arange(\n", " 1, num_steps + 1, device=X.device).repeat(batch_size, 1)\n", " else:\n", " dec_valid_lens = None\n", "\n", " # 自注意力\n", " X2 = self.attention1(X, key_values, key_values, dec_valid_lens)\n", " Y = self.addnorm1(X, X2)\n", " # 编码器-解码器注意力。\n", " # enc_outputs的开头:(batch_size,num_steps,num_hiddens)\n", " Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)\n", " Z = self.addnorm2(Y, Y2)\n", " return self.addnorm3(Z, self.ffn(Z)), state" ] }, { "cell_type": "markdown", "id": "860e7d0c", "metadata": { "origin_pos": 53 }, "source": [ "为了便于在“编码器-解码器”注意力中进行缩放点积计算和残差连接中进行加法计算,[**编码器和解码器的特征维度都是`num_hiddens`。**]\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "92119d80", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.666987Z", "iopub.status.busy": "2023-08-18T07:21:05.666595Z", "iopub.status.idle": "2023-08-18T07:21:05.685799Z", "shell.execute_reply": "2023-08-18T07:21:05.684460Z" }, "origin_pos": 55, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 100, 24])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0)\n", "decoder_blk.eval()\n", "X = torch.ones((2, 100, 24))\n", "state = [encoder_blk(X, valid_lens), valid_lens, [None]]\n", "decoder_blk(X, state)[0].shape" ] }, { "cell_type": "markdown", "id": "7b3b57e3", "metadata": { "origin_pos": 57 }, "source": [ "现在我们构建了由`num_layers`个`DecoderBlock`实例组成的完整的[**Transformer解码器**]。最后,通过一个全连接层计算所有`vocab_size`个可能的输出词元的预测值。解码器的自注意力权重和编码器解码器注意力权重都被存储下来,方便日后可视化的需要。\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "1fa8eade", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.691062Z", "iopub.status.busy": "2023-08-18T07:21:05.690657Z", "iopub.status.idle": "2023-08-18T07:21:05.706405Z", "shell.execute_reply": "2023-08-18T07:21:05.704976Z" }, "origin_pos": 59, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class TransformerDecoder(d2l.AttentionDecoder):\n", " def __init__(self, vocab_size, key_size, query_size, value_size,\n", " num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,\n", " num_heads, num_layers, dropout, **kwargs):\n", " super(TransformerDecoder, self).__init__(**kwargs)\n", " self.num_hiddens = num_hiddens\n", " self.num_layers = num_layers\n", " self.embedding = nn.Embedding(vocab_size, num_hiddens)\n", " self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)\n", " self.blks = nn.Sequential()\n", " for i in range(num_layers):\n", " self.blks.add_module(\"block\"+str(i),\n", " DecoderBlock(key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens,\n", " num_heads, dropout, i))\n", " self.dense = nn.Linear(num_hiddens, vocab_size)\n", "\n", " def init_state(self, enc_outputs, enc_valid_lens, *args):\n", " return [enc_outputs, enc_valid_lens, [None] * self.num_layers]\n", "\n", " def forward(self, X, state):\n", " X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))\n", " self._attention_weights = [[None] * len(self.blks) for _ in range (2)]\n", " for i, blk in enumerate(self.blks):\n", " X, state = blk(X, state)\n", " # 解码器自注意力权重\n", " self._attention_weights[0][\n", " i] = blk.attention1.attention.attention_weights\n", " # “编码器-解码器”自注意力权重\n", " self._attention_weights[1][\n", " i] = blk.attention2.attention.attention_weights\n", " return self.dense(X), state\n", "\n", " @property\n", " def attention_weights(self):\n", " return self._attention_weights" ] }, { "cell_type": "markdown", "id": "3adc5f9d", "metadata": { "origin_pos": 62 }, "source": [ "## [**训练**]\n", "\n", "依照Transformer架构来实例化编码器-解码器模型。在这里,指定Transformer的编码器和解码器都是2层,都使用4头注意力。与 :numref:`sec_seq2seq_training`类似,为了进行序列到序列的学习,下面在“英语-法语”机器翻译数据集上训练Transformer模型。\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "1adc9e94", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:21:05.711753Z", "iopub.status.busy": "2023-08-18T07:21:05.710942Z", "iopub.status.idle": "2023-08-18T07:22:31.086151Z", "shell.execute_reply": "2023-08-18T07:22:31.084931Z" }, "origin_pos": 64, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.030, 5202.9 tokens/sec on cuda:0\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:22:31.056238\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10\n", "lr, num_epochs, device = 0.005, 200, d2l.try_gpu()\n", "ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4\n", "key_size, query_size, value_size = 32, 32, 32\n", "norm_shape = [32]\n", "\n", "train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)\n", "\n", "encoder = TransformerEncoder(\n", " len(src_vocab), key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n", " num_layers, dropout)\n", "decoder = TransformerDecoder(\n", " len(tgt_vocab), key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n", " num_layers, dropout)\n", "net = d2l.EncoderDecoder(encoder, decoder)\n", "d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)" ] }, { "cell_type": "markdown", "id": "feb9d093", "metadata": { "origin_pos": 67 }, "source": [ "训练结束后,使用Transformer模型[**将一些英语句子翻译成法语**],并且计算它们的BLEU分数。\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "a21c8779", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:31.089687Z", "iopub.status.busy": "2023-08-18T07:22:31.089409Z", "iopub.status.idle": "2023-08-18T07:22:31.190689Z", "shell.execute_reply": "2023-08-18T07:22:31.189851Z" }, "origin_pos": 68, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go . => va !, bleu 1.000\n", "i lost . => j'ai perdu ., bleu 1.000\n", "he's calm . => il est calme ., bleu 1.000\n", "i'm home . => je suis chez moi ., bleu 1.000\n" ] } ], "source": [ "engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .']\n", "fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n", "for eng, fra in zip(engs, fras):\n", " translation, dec_attention_weight_seq = d2l.predict_seq2seq(\n", " net, eng, src_vocab, tgt_vocab, num_steps, device, True)\n", " print(f'{eng} => {translation}, ',\n", " f'bleu {d2l.bleu(translation, fra, k=2):.3f}')" ] }, { "cell_type": "markdown", "id": "a9d3170e", "metadata": { "origin_pos": 70 }, "source": [ "当进行最后一个英语到法语的句子翻译工作时,让我们[**可视化Transformer的注意力权重**]。编码器自注意力权重的形状为(编码器层数,注意力头数,`num_steps`或查询的数目,`num_steps`或“键-值”对的数目)。\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "937e4ab4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:31.194154Z", "iopub.status.busy": "2023-08-18T07:22:31.193869Z", "iopub.status.idle": "2023-08-18T07:22:31.200327Z", "shell.execute_reply": "2023-08-18T07:22:31.199543Z" }, "origin_pos": 71, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 4, 10, 10])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads,\n", " -1, num_steps))\n", "enc_attention_weights.shape" ] }, { "cell_type": "markdown", "id": "ca2a5bf2", "metadata": { "origin_pos": 72 }, "source": [ "在编码器的自注意力中,查询和键都来自相同的输入序列。因为填充词元是不携带信息的,因此通过指定输入序列的有效长度可以避免查询与使用填充词元的位置计算注意力。接下来,将逐行呈现两层多头注意力的权重。每个注意力头都根据查询、键和值的不同的表示子空间来表示不同的注意力。\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "0b408905", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:31.204505Z", "iopub.status.busy": "2023-08-18T07:22:31.203733Z", "iopub.status.idle": "2023-08-18T07:22:31.970964Z", "shell.execute_reply": "2023-08-18T07:22:31.969901Z" }, "origin_pos": 74, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:22:31.736938\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.show_heatmaps(\n", " enc_attention_weights.cpu(), xlabel='Key positions',\n", " ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],\n", " figsize=(7, 3.5))" ] }, { "cell_type": "markdown", "id": "5943a056", "metadata": { "origin_pos": 75 }, "source": [ "[**为了可视化解码器的自注意力权重和“编码器-解码器”的注意力权重,我们需要完成更多的数据操作工作。**]例如用零填充被掩蔽住的注意力权重。值得注意的是,解码器的自注意力权重和“编码器-解码器”的注意力权重都有相同的查询:即以*序列开始词元*(beginning-of-sequence,BOS)打头,再与后续输出的词元共同组成序列。\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "b4898564", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:31.975096Z", "iopub.status.busy": "2023-08-18T07:22:31.974177Z", "iopub.status.idle": "2023-08-18T07:22:31.987757Z", "shell.execute_reply": "2023-08-18T07:22:31.986824Z" }, "origin_pos": 77, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([2, 4, 6, 10]), torch.Size([2, 4, 6, 10]))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dec_attention_weights_2d = [head[0].tolist()\n", " for step in dec_attention_weight_seq\n", " for attn in step for blk in attn for head in blk]\n", "dec_attention_weights_filled = torch.tensor(\n", " pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)\n", "dec_attention_weights = dec_attention_weights_filled.reshape((-1, 2, num_layers, num_heads, num_steps))\n", "dec_self_attention_weights, dec_inter_attention_weights = \\\n", " dec_attention_weights.permute(1, 2, 3, 0, 4)\n", "dec_self_attention_weights.shape, dec_inter_attention_weights.shape" ] }, { "cell_type": "markdown", "id": "3e4e5d7a", "metadata": { "origin_pos": 80 }, "source": [ "由于解码器自注意力的自回归属性,查询不会对当前位置之后的“键-值”对进行注意力计算。\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "8f162beb", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:31.991865Z", "iopub.status.busy": "2023-08-18T07:22:31.990973Z", "iopub.status.idle": "2023-08-18T07:22:32.664530Z", "shell.execute_reply": "2023-08-18T07:22:32.663454Z" }, "origin_pos": 81, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:22:32.516045\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plusonetoincludethebeginning-of-sequencetoken\n", "d2l.show_heatmaps(\n", " dec_self_attention_weights[:, :, :, :len(translation.split()) + 1],\n", " xlabel='Key positions', ylabel='Query positions',\n", " titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))" ] }, { "cell_type": "markdown", "id": "ddb7c476", "metadata": { "origin_pos": 82 }, "source": [ "与编码器的自注意力的情况类似,通过指定输入序列的有效长度,[**输出序列的查询不会与输入序列中填充位置的词元进行注意力计算**]。\n" ] }, { "cell_type": "code", "execution_count": 20, "id": "910d950f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:32.668822Z", "iopub.status.busy": "2023-08-18T07:22:32.667894Z", "iopub.status.idle": "2023-08-18T07:22:33.326167Z", "shell.execute_reply": "2023-08-18T07:22:33.325372Z" }, "origin_pos": 83, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:22:33.186096\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.show_heatmaps(\n", " dec_inter_attention_weights, xlabel='Key positions',\n", " ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],\n", " figsize=(7, 3.5))" ] }, { "cell_type": "markdown", "id": "771a0d76", "metadata": { "origin_pos": 84 }, "source": [ "尽管Transformer架构是为了*序列到序列*的学习而提出的,但正如本书后面将提及的那样,Transformer编码器或Transformer解码器通常被单独用于不同的深度学习任务中。\n", "\n", "## 小结\n", "\n", "* Transformer是编码器-解码器架构的一个实践,尽管在实际情况中编码器或解码器可以单独使用。\n", "* 在Transformer中,多头自注意力用于表示输入序列和输出序列,不过解码器必须通过掩蔽机制来保留自回归属性。\n", "* Transformer中的残差连接和层规范化是训练非常深度模型的重要工具。\n", "* Transformer模型中基于位置的前馈网络使用同一个多层感知机,作用是对所有序列位置的表示进行转换。\n", "\n", "## 练习\n", "\n", "1. 在实验中训练更深的Transformer将如何影响训练速度和翻译效果?\n", "1. 在Transformer中使用加性注意力取代缩放点积注意力是不是个好办法?为什么?\n", "1. 对于语言模型,应该使用Transformer的编码器还是解码器,或者两者都用?如何设计?\n", "1. 如果输入序列很长,Transformer会面临什么挑战?为什么?\n", "1. 如何提高Transformer的计算速度和内存使用效率?提示:可以参考论文 :cite:`Tay.Dehghani.Bahri.ea.2020`。\n", "1. 如果不使用卷积神经网络,如何设计基于Transformer模型的图像分类任务?提示:可以参考Vision Transformer :cite:`Dosovitskiy.Beyer.Kolesnikov.ea.2021`。\n" ] }, { "cell_type": "markdown", "id": "b36a7e4f", "metadata": { "origin_pos": 86, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/5756)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }