更新
This commit is contained in:
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "962e28eb",
|
||||
"metadata": {
|
||||
"origin_pos": 0
|
||||
},
|
||||
"source": [
|
||||
"# 编码器-解码器架构\n",
|
||||
":label:`sec_encoder-decoder`\n",
|
||||
"\n",
|
||||
"正如我们在 :numref:`sec_machine_translation`中所讨论的,\n",
|
||||
"机器翻译是序列转换模型的一个核心问题,\n",
|
||||
"其输入和输出都是长度可变的序列。\n",
|
||||
"为了处理这种类型的输入和输出,\n",
|
||||
"我们可以设计一个包含两个主要组件的架构:\n",
|
||||
"第一个组件是一个*编码器*(encoder):\n",
|
||||
"它接受一个长度可变的序列作为输入,\n",
|
||||
"并将其转换为具有固定形状的编码状态。\n",
|
||||
"第二个组件是*解码器*(decoder):\n",
|
||||
"它将固定形状的编码状态映射到长度可变的序列。\n",
|
||||
"这被称为*编码器-解码器*(encoder-decoder)架构,\n",
|
||||
"如 :numref:`fig_encoder_decoder` 所示。\n",
|
||||
"\n",
|
||||
"\n",
|
||||
":label:`fig_encoder_decoder`\n",
|
||||
"\n",
|
||||
"我们以英语到法语的机器翻译为例:\n",
|
||||
"给定一个英文的输入序列:“They”“are”“watching”“.”。\n",
|
||||
"首先,这种“编码器-解码器”架构将长度可变的输入序列编码成一个“状态”,\n",
|
||||
"然后对该状态进行解码,\n",
|
||||
"一个词元接着一个词元地生成翻译后的序列作为输出:\n",
|
||||
"“Ils”“regordent”“.”。\n",
|
||||
"由于“编码器-解码器”架构是形成后续章节中不同序列转换模型的基础,\n",
|
||||
"因此本节将把这个架构转换为接口方便后面的代码实现。\n",
|
||||
"\n",
|
||||
"## (**编码器**)\n",
|
||||
"\n",
|
||||
"在编码器接口中,我们只指定长度可变的序列作为编码器的输入`X`。\n",
|
||||
"任何继承这个`Encoder`基类的模型将完成代码实现。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "17f77c60",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:05:48.406295Z",
|
||||
"iopub.status.busy": "2023-08-18T07:05:48.405469Z",
|
||||
"iopub.status.idle": "2023-08-18T07:05:49.653322Z",
|
||||
"shell.execute_reply": "2023-08-18T07:05:49.651979Z"
|
||||
},
|
||||
"origin_pos": 2,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"#@save\n",
|
||||
"class Encoder(nn.Module):\n",
|
||||
" \"\"\"编码器-解码器架构的基本编码器接口\"\"\"\n",
|
||||
" def __init__(self, **kwargs):\n",
|
||||
" super(Encoder, self).__init__(**kwargs)\n",
|
||||
"\n",
|
||||
" def forward(self, X, *args):\n",
|
||||
" raise NotImplementedError"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "de7f0caf",
|
||||
"metadata": {
|
||||
"origin_pos": 5
|
||||
},
|
||||
"source": [
|
||||
"## [**解码器**]\n",
|
||||
"\n",
|
||||
"在下面的解码器接口中,我们新增一个`init_state`函数,\n",
|
||||
"用于将编码器的输出(`enc_outputs`)转换为编码后的状态。\n",
|
||||
"注意,此步骤可能需要额外的输入,例如:输入序列的有效长度,\n",
|
||||
"这在 :numref:`subsec_mt_data_loading`中进行了解释。\n",
|
||||
"为了逐个地生成长度可变的词元序列,\n",
|
||||
"解码器在每个时间步都会将输入\n",
|
||||
"(例如:在前一时间步生成的词元)和编码后的状态\n",
|
||||
"映射成当前时间步的输出词元。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "5c7a6471",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:05:49.659889Z",
|
||||
"iopub.status.busy": "2023-08-18T07:05:49.659020Z",
|
||||
"iopub.status.idle": "2023-08-18T07:05:49.666360Z",
|
||||
"shell.execute_reply": "2023-08-18T07:05:49.665230Z"
|
||||
},
|
||||
"origin_pos": 7,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@save\n",
|
||||
"class Decoder(nn.Module):\n",
|
||||
" \"\"\"编码器-解码器架构的基本解码器接口\"\"\"\n",
|
||||
" def __init__(self, **kwargs):\n",
|
||||
" super(Decoder, self).__init__(**kwargs)\n",
|
||||
"\n",
|
||||
" def init_state(self, enc_outputs, *args):\n",
|
||||
" raise NotImplementedError\n",
|
||||
"\n",
|
||||
" def forward(self, X, state):\n",
|
||||
" raise NotImplementedError"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6e0548de",
|
||||
"metadata": {
|
||||
"origin_pos": 10
|
||||
},
|
||||
"source": [
|
||||
"## [**合并编码器和解码器**]\n",
|
||||
"\n",
|
||||
"总而言之,“编码器-解码器”架构包含了一个编码器和一个解码器,\n",
|
||||
"并且还拥有可选的额外的参数。\n",
|
||||
"在前向传播中,编码器的输出用于生成编码状态,\n",
|
||||
"这个状态又被解码器作为其输入的一部分。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "53fb0929",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:05:49.671685Z",
|
||||
"iopub.status.busy": "2023-08-18T07:05:49.670944Z",
|
||||
"iopub.status.idle": "2023-08-18T07:05:49.678831Z",
|
||||
"shell.execute_reply": "2023-08-18T07:05:49.677718Z"
|
||||
},
|
||||
"origin_pos": 12,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@save\n",
|
||||
"class EncoderDecoder(nn.Module):\n",
|
||||
" \"\"\"编码器-解码器架构的基类\"\"\"\n",
|
||||
" def __init__(self, encoder, decoder, **kwargs):\n",
|
||||
" super(EncoderDecoder, self).__init__(**kwargs)\n",
|
||||
" self.encoder = encoder\n",
|
||||
" self.decoder = decoder\n",
|
||||
"\n",
|
||||
" def forward(self, enc_X, dec_X, *args):\n",
|
||||
" enc_outputs = self.encoder(enc_X, *args)\n",
|
||||
" dec_state = self.decoder.init_state(enc_outputs, *args)\n",
|
||||
" return self.decoder(dec_X, dec_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dce5eb8e",
|
||||
"metadata": {
|
||||
"origin_pos": 15
|
||||
},
|
||||
"source": [
|
||||
"“编码器-解码器”体系架构中的术语*状态*\n",
|
||||
"会启发人们使用具有状态的神经网络来实现该架构。\n",
|
||||
"在下一节中,我们将学习如何应用循环神经网络,\n",
|
||||
"来设计基于“编码器-解码器”架构的序列转换模型。\n",
|
||||
"\n",
|
||||
"## 小结\n",
|
||||
"\n",
|
||||
"* “编码器-解码器”架构可以将长度可变的序列作为输入和输出,因此适用于机器翻译等序列转换问题。\n",
|
||||
"* 编码器将长度可变的序列作为输入,并将其转换为具有固定形状的编码状态。\n",
|
||||
"* 解码器将具有固定形状的编码状态映射为长度可变的序列。\n",
|
||||
"\n",
|
||||
"## 练习\n",
|
||||
"\n",
|
||||
"1. 假设我们使用神经网络来实现“编码器-解码器”架构,那么编码器和解码器必须是同一类型的神经网络吗?\n",
|
||||
"1. 除了机器翻译,还有其它可以适用于”编码器-解码器“架构的应用吗?\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "99846b42",
|
||||
"metadata": {
|
||||
"origin_pos": 17,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"source": [
|
||||
"[Discussions](https://discuss.d2l.ai/t/2779)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"required_libs": []
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user