更新
This commit is contained in:
+88
@@ -0,0 +1,88 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2731ad59",
|
||||
"metadata": {
|
||||
"origin_pos": 0
|
||||
},
|
||||
"source": [
|
||||
"# 针对序列级和词元级应用微调BERT\n",
|
||||
":label:`sec_finetuning-bert`\n",
|
||||
"\n",
|
||||
"在本章的前几节中,我们为自然语言处理应用设计了不同的模型,例如基于循环神经网络、卷积神经网络、注意力和多层感知机。这些模型在有空间或时间限制的情况下是有帮助的,但是,为每个自然语言处理任务精心设计一个特定的模型实际上是不可行的。在 :numref:`sec_bert`中,我们介绍了一个名为BERT的预训练模型,该模型可以对广泛的自然语言处理任务进行最少的架构更改。一方面,在提出时,BERT改进了各种自然语言处理任务的技术水平。另一方面,正如在 :numref:`sec_bert-pretraining`中指出的那样,原始BERT模型的两个版本分别带有1.1亿和3.4亿个参数。因此,当有足够的计算资源时,我们可以考虑为下游自然语言处理应用微调BERT。\n",
|
||||
"\n",
|
||||
"下面,我们将自然语言处理应用的子集概括为序列级和词元级。在序列层次上,介绍了在单文本分类任务和文本对分类(或回归)任务中,如何将文本输入的BERT表示转换为输出标签。在词元级别,我们将简要介绍新的应用,如文本标注和问答,并说明BERT如何表示它们的输入并转换为输出标签。在微调期间,不同应用之间的BERT所需的“最小架构更改”是额外的全连接层。在下游应用的监督学习期间,额外层的参数是从零开始学习的,而预训练BERT模型中的所有参数都是微调的。\n",
|
||||
"\n",
|
||||
"## 单文本分类\n",
|
||||
"\n",
|
||||
"*单文本分类*将单个文本序列作为输入,并输出其分类结果。\n",
|
||||
"除了我们在这一章中探讨的情感分析之外,语言可接受性语料库(Corpus of Linguistic Acceptability,COLA)也是一个单文本分类的数据集,它的要求判断给定的句子在语法上是否可以接受。 :cite:`Warstadt.Singh.Bowman.2019`。例如,“I should study.”是可以接受的,但是“I should studying.”不是可以接受的。\n",
|
||||
"\n",
|
||||
"\n",
|
||||
":label:`fig_bert-one-seq`\n",
|
||||
"\n",
|
||||
" :numref:`sec_bert`描述了BERT的输入表示。BERT输入序列明确地表示单个文本和文本对,其中特殊分类标记“<cls>”用于序列分类,而特殊分类标记“<sep>”标记单个文本的结束或分隔成对文本。如 :numref:`fig_bert-one-seq`所示,在单文本分类应用中,特殊分类标记“<cls>”的BERT表示对整个输入文本序列的信息进行编码。作为输入单个文本的表示,它将被送入到由全连接(稠密)层组成的小多层感知机中,以输出所有离散标签值的分布。\n",
|
||||
"\n",
|
||||
"## 文本对分类或回归\n",
|
||||
"\n",
|
||||
"在本章中,我们还研究了自然语言推断。它属于*文本对分类*,这是一种对文本进行分类的应用类型。\n",
|
||||
"\n",
|
||||
"以一对文本作为输入但输出连续值,*语义文本相似度*是一个流行的“文本对回归”任务。\n",
|
||||
"这项任务评估句子的语义相似度。例如,在语义文本相似度基准数据集(Semantic Textual Similarity Benchmark)中,句子对的相似度得分是从0(无语义重叠)到5(语义等价)的分数区间 :cite:`Cer.Diab.Agirre.ea.2017`。我们的目标是预测这些分数。来自语义文本相似性基准数据集的样本包括(句子1,句子2,相似性得分):\n",
|
||||
"\n",
|
||||
"* \"A plane is taking off.\"(“一架飞机正在起飞。”),\"An air plane is taking off.\"(“一架飞机正在起飞。”),5.000分;\n",
|
||||
"* \"A woman is eating something.\"(“一个女人在吃东西。”),\"A woman is eating meat.\"(“一个女人在吃肉。”),3.000分;\n",
|
||||
"* \"A woman is dancing.\"(一个女人在跳舞。),\"A man is talking.\"(“一个人在说话。”),0.000分。\n",
|
||||
"\n",
|
||||
"\n",
|
||||
":label:`fig_bert-two-seqs`\n",
|
||||
"\n",
|
||||
"与 :numref:`fig_bert-one-seq`中的单文本分类相比, :numref:`fig_bert-two-seqs`中的文本对分类的BERT微调在输入表示上有所不同。对于文本对回归任务(如语义文本相似性),可以应用细微的更改,例如输出连续的标签值和使用均方损失:它们在回归中很常见。\n",
|
||||
"\n",
|
||||
"## 文本标注\n",
|
||||
"\n",
|
||||
"现在让我们考虑词元级任务,比如*文本标注*(text tagging),其中每个词元都被分配了一个标签。在文本标注任务中,*词性标注*为每个单词分配词性标记(例如,形容词和限定词)。\n",
|
||||
"根据单词在句子中的作用。如,在Penn树库II标注集中,句子“John Smith‘s car is new”应该被标记为“NNP(名词,专有单数)NNP POS(所有格结尾)NN(名词,单数或质量)VB(动词,基本形式)JJ(形容词)”。\n",
|
||||
"\n",
|
||||
"\n",
|
||||
":label:`fig_bert-tagging`\n",
|
||||
"\n",
|
||||
" :numref:`fig_bert-tagging`中说明了文本标记应用的BERT微调。与 :numref:`fig_bert-one-seq`相比,唯一的区别在于,在文本标注中,输入文本的*每个词元*的BERT表示被送到相同的额外全连接层中,以输出词元的标签,例如词性标签。\n",
|
||||
"\n",
|
||||
"## 问答\n",
|
||||
"\n",
|
||||
"作为另一个词元级应用,*问答*反映阅读理解能力。\n",
|
||||
"例如,斯坦福问答数据集(Stanford Question Answering Dataset,SQuAD v1.1)由阅读段落和问题组成,其中每个问题的答案只是段落中的一段文本(文本片段) :cite:`Rajpurkar.Zhang.Lopyrev.ea.2016`。举个例子,考虑一段话:“Some experts report that a mask's efficacy is inconclusive.However,mask makers insist that their products,such as N95 respirator masks,can guard against the virus.”(“一些专家报告说面罩的功效是不确定的。然而,口罩制造商坚持他们的产品,如N95口罩,可以预防病毒。”)还有一个问题“Who say that N95 respirator masks can guard against the virus?”(“谁说N95口罩可以预防病毒?”)。答案应该是文章中的文本片段“mask makers”(“口罩制造商”)。因此,SQuAD v1.1的目标是在给定问题和段落的情况下预测段落中文本片段的开始和结束。\n",
|
||||
"\n",
|
||||
"\n",
|
||||
":label:`fig_bert-qa`\n",
|
||||
"\n",
|
||||
"为了微调BERT进行问答,在BERT的输入中,将问题和段落分别作为第一个和第二个文本序列。为了预测文本片段开始的位置,相同的额外的全连接层将把来自位置$i$的任何词元的BERT表示转换成标量分数$s_i$。文章中所有词元的分数还通过softmax转换成概率分布,从而为文章中的每个词元位置$i$分配作为文本片段开始的概率$p_i$。预测文本片段的结束与上面相同,只是其额外的全连接层中的参数与用于预测开始位置的参数无关。当预测结束时,位置$i$的词元由相同的全连接层变换成标量分数$e_i$。 :numref:`fig_bert-qa`描述了用于问答的微调BERT。\n",
|
||||
"\n",
|
||||
"对于问答,监督学习的训练目标就像最大化真实值的开始和结束位置的对数似然一样简单。当预测片段时,我们可以计算从位置$i$到位置$j$的有效片段的分数$s_i + e_j$($i \\leq j$),并输出分数最高的跨度。\n",
|
||||
"\n",
|
||||
"## 小结\n",
|
||||
"\n",
|
||||
"* 对于序列级和词元级自然语言处理应用,BERT只需要最小的架构改变(额外的全连接层),如单个文本分类(例如,情感分析和测试语言可接受性)、文本对分类或回归(例如,自然语言推断和语义文本相似性)、文本标记(例如,词性标记)和问答。\n",
|
||||
"* 在下游应用的监督学习期间,额外层的参数是从零开始学习的,而预训练BERT模型中的所有参数都是微调的。\n",
|
||||
"\n",
|
||||
"## 练习\n",
|
||||
"\n",
|
||||
"1. 让我们为新闻文章设计一个搜索引擎算法。当系统接收到查询(例如,“冠状病毒爆发期间的石油行业”)时,它应该返回与该查询最相关的新闻文章的排序列表。假设我们有一个巨大的新闻文章池和大量的查询。为了简化问题,假设为每个查询标记了最相关的文章。如何在算法设计中应用负采样(见 :numref:`subsec_negative-sampling`)和BERT?\n",
|
||||
"1. 我们如何利用BERT来训练语言模型?\n",
|
||||
"1. 我们能在机器翻译中利用BERT吗?\n",
|
||||
"\n",
|
||||
"[Discussions](https://discuss.d2l.ai/t/5729)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"required_libs": []
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cd1572d4",
|
||||
"metadata": {
|
||||
"origin_pos": 0
|
||||
},
|
||||
"source": [
|
||||
"# 自然语言处理:应用\n",
|
||||
":label:`chap_nlp_app`\n",
|
||||
"\n",
|
||||
"前面我们学习了如何在文本序列中表示词元,\n",
|
||||
"并在 :numref:`chap_nlp_pretrain`中训练了词元的表示。\n",
|
||||
"这样的预训练文本表示可以通过不同模型架构,放入不同的下游自然语言处理任务。\n",
|
||||
"\n",
|
||||
"前一章我们提及到一些自然语言处理应用,这些应用没有预训练,只是为了解释深度学习架构。\n",
|
||||
"例如,在 :numref:`chap_rnn`中,\n",
|
||||
"我们依赖循环神经网络设计语言模型来生成类似中篇小说的文本。\n",
|
||||
"在 :numref:`chap_modern_rnn`和 :numref:`chap_attention`中,\n",
|
||||
"我们还设计了基于循环神经网络和注意力机制的机器翻译模型。\n",
|
||||
"\n",
|
||||
"然而,本书并不打算全面涵盖所有此类应用。\n",
|
||||
"相反,我们的重点是*如何应用深度语言表征学习来解决自然语言处理问题*。\n",
|
||||
"在给定预训练的文本表示的情况下,\n",
|
||||
"本章将探讨两种流行且具有代表性的下游自然语言处理任务:\n",
|
||||
"情感分析和自然语言推断,它们分别分析单个文本和文本对之间的关系。\n",
|
||||
"\n",
|
||||
"\n",
|
||||
":label:`fig_nlp-map-app`\n",
|
||||
"\n",
|
||||
"如 :numref:`fig_nlp-map-app`所述,\n",
|
||||
"本章将重点描述然后使用不同类型的深度学习架构\n",
|
||||
"(如多层感知机、卷积神经网络、循环神经网络和注意力)\n",
|
||||
"设计自然语言处理模型。\n",
|
||||
"尽管在 :numref:`fig_nlp-map-app`中,\n",
|
||||
"可以将任何预训练的文本表示与任何应用的架构相结合,\n",
|
||||
"但我们选择了一些具有代表性的组合。\n",
|
||||
"具体来说,我们将探索基于循环神经网络和卷积神经网络的流行架构进行情感分析。\n",
|
||||
"对于自然语言推断,我们选择注意力和多层感知机来演示如何分析文本对。\n",
|
||||
"最后,我们介绍了如何为广泛的自然语言处理应用,\n",
|
||||
"如在序列级(单文本分类和文本对分类)和词元级(文本标注和问答)上\n",
|
||||
"对预训练BERT模型进行微调。\n",
|
||||
"作为一个具体的经验案例,我们将针对自然语言推断对BERT进行微调。\n",
|
||||
"\n",
|
||||
"正如我们在 :numref:`sec_bert`中介绍的那样,\n",
|
||||
"对于广泛的自然语言处理应用,BERT只需要最少的架构更改。\n",
|
||||
"然而,这一好处是以微调下游应用的大量BERT参数为代价的。\n",
|
||||
"当空间或时间有限时,基于多层感知机、卷积神经网络、循环神经网络\n",
|
||||
"和注意力的精心构建的模型更具可行性。\n",
|
||||
"下面,我们从情感分析应用开始,分别解读基于循环神经网络和卷积神经网络的模型设计。\n",
|
||||
"\n",
|
||||
":begin_tab:toc\n",
|
||||
" - [sentiment-analysis-and-dataset](sentiment-analysis-and-dataset.ipynb)\n",
|
||||
" - [sentiment-analysis-rnn](sentiment-analysis-rnn.ipynb)\n",
|
||||
" - [sentiment-analysis-cnn](sentiment-analysis-cnn.ipynb)\n",
|
||||
" - [natural-language-inference-and-dataset](natural-language-inference-and-dataset.ipynb)\n",
|
||||
" - [natural-language-inference-attention](natural-language-inference-attention.ipynb)\n",
|
||||
" - [finetuning-bert](finetuning-bert.ipynb)\n",
|
||||
" - [natural-language-inference-bert](natural-language-inference-bert.ipynb)\n",
|
||||
":end_tab:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"required_libs": []
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
+479
@@ -0,0 +1,479 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "15c5cd33",
|
||||
"metadata": {
|
||||
"origin_pos": 0
|
||||
},
|
||||
"source": [
|
||||
"# 自然语言推断与数据集\n",
|
||||
":label:`sec_natural-language-inference-and-dataset`\n",
|
||||
"\n",
|
||||
"在 :numref:`sec_sentiment`中,我们讨论了情感分析问题。这个任务的目的是将单个文本序列分类到预定义的类别中,例如一组情感极性中。然而,当需要决定一个句子是否可以从另一个句子推断出来,或者需要通过识别语义等价的句子来消除句子间冗余时,知道如何对一个文本序列进行分类是不够的。相反,我们需要能够对成对的文本序列进行推断。\n",
|
||||
"\n",
|
||||
"## 自然语言推断\n",
|
||||
"\n",
|
||||
"*自然语言推断*(natural language inference)主要研究\n",
|
||||
"*假设*(hypothesis)是否可以从*前提*(premise)中推断出来,\n",
|
||||
"其中两者都是文本序列。\n",
|
||||
"换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:\n",
|
||||
"\n",
|
||||
"* *蕴涵*(entailment):假设可以从前提中推断出来。\n",
|
||||
"* *矛盾*(contradiction):假设的否定可以从前提中推断出来。\n",
|
||||
"* *中性*(neutral):所有其他情况。\n",
|
||||
"\n",
|
||||
"自然语言推断也被称为识别文本蕴涵任务。\n",
|
||||
"例如,下面的一个文本对将被贴上“蕴涵”的标签,因为假设中的“表白”可以从前提中的“拥抱”中推断出来。\n",
|
||||
"\n",
|
||||
">前提:两个女人拥抱在一起。\n",
|
||||
"\n",
|
||||
">假设:两个女人在示爱。\n",
|
||||
"\n",
|
||||
"下面是一个“矛盾”的例子,因为“运行编码示例”表示“不睡觉”,而不是“睡觉”。\n",
|
||||
"\n",
|
||||
">前提:一名男子正在运行Dive Into Deep Learning的编码示例。\n",
|
||||
"\n",
|
||||
">假设:该男子正在睡觉。\n",
|
||||
"\n",
|
||||
"第三个例子显示了一种“中性”关系,因为“正在为我们表演”这一事实无法推断出“出名”或“不出名”。\n",
|
||||
"\n",
|
||||
">前提:音乐家们正在为我们表演。\n",
|
||||
"\n",
|
||||
">假设:音乐家很有名。\n",
|
||||
"\n",
|
||||
"自然语言推断一直是理解自然语言的中心话题。它有着广泛的应用,从信息检索到开放领域的问答。为了研究这个问题,我们将首先研究一个流行的自然语言推断基准数据集。\n",
|
||||
"\n",
|
||||
"## 斯坦福自然语言推断(SNLI)数据集\n",
|
||||
"\n",
|
||||
"[**斯坦福自然语言推断语料库(Stanford Natural Language Inference,SNLI)**]是由500000多个带标签的英语句子对组成的集合 :cite:`Bowman.Angeli.Potts.ea.2015`。我们在路径`../data/snli_1.0`中下载并存储提取的SNLI数据集。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "85ccbfd4",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:06:00.201212Z",
|
||||
"iopub.status.busy": "2023-08-18T07:06:00.200144Z",
|
||||
"iopub.status.idle": "2023-08-18T07:06:09.370822Z",
|
||||
"shell.execute_reply": "2023-08-18T07:06:09.368591Z"
|
||||
},
|
||||
"origin_pos": 2,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"from d2l import torch as d2l\n",
|
||||
"\n",
|
||||
"#@save\n",
|
||||
"d2l.DATA_HUB['SNLI'] = (\n",
|
||||
" 'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',\n",
|
||||
" '9fcde07509c7e87ec61c640c1b2753d9041758e4')\n",
|
||||
"\n",
|
||||
"data_dir = d2l.download_extract('SNLI')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5e647396",
|
||||
"metadata": {
|
||||
"origin_pos": 4
|
||||
},
|
||||
"source": [
|
||||
"### [**读取数据集**]\n",
|
||||
"\n",
|
||||
"原始的SNLI数据集包含的信息比我们在实验中真正需要的信息丰富得多。因此,我们定义函数`read_snli`以仅提取数据集的一部分,然后返回前提、假设及其标签的列表。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "fa839f80",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:06:09.377922Z",
|
||||
"iopub.status.busy": "2023-08-18T07:06:09.377380Z",
|
||||
"iopub.status.idle": "2023-08-18T07:06:09.392203Z",
|
||||
"shell.execute_reply": "2023-08-18T07:06:09.390984Z"
|
||||
},
|
||||
"origin_pos": 5,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@save\n",
|
||||
"def read_snli(data_dir, is_train):\n",
|
||||
" \"\"\"将SNLI数据集解析为前提、假设和标签\"\"\"\n",
|
||||
" def extract_text(s):\n",
|
||||
" # 删除我们不会使用的信息\n",
|
||||
" s = re.sub('\\\\(', '', s)\n",
|
||||
" s = re.sub('\\\\)', '', s)\n",
|
||||
" # 用一个空格替换两个或多个连续的空格\n",
|
||||
" s = re.sub('\\\\s{2,}', ' ', s)\n",
|
||||
" return s.strip()\n",
|
||||
" label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}\n",
|
||||
" file_name = os.path.join(data_dir, 'snli_1.0_train.txt'\n",
|
||||
" if is_train else 'snli_1.0_test.txt')\n",
|
||||
" with open(file_name, 'r') as f:\n",
|
||||
" rows = [row.split('\\t') for row in f.readlines()[1:]]\n",
|
||||
" premises = [extract_text(row[1]) for row in rows if row[0] in label_set]\n",
|
||||
" hypotheses = [extract_text(row[2]) for row in rows if row[0] \\\n",
|
||||
" in label_set]\n",
|
||||
" labels = [label_set[row[0]] for row in rows if row[0] in label_set]\n",
|
||||
" return premises, hypotheses, labels"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "607a64fd",
|
||||
"metadata": {
|
||||
"origin_pos": 6
|
||||
},
|
||||
"source": [
|
||||
"现在让我们[**打印前3对**]前提和假设,以及它们的标签(“0”“1”和“2”分别对应于“蕴涵”“矛盾”和“中性”)。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "19101f9e",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:06:09.397297Z",
|
||||
"iopub.status.busy": "2023-08-18T07:06:09.396407Z",
|
||||
"iopub.status.idle": "2023-08-18T07:06:23.206512Z",
|
||||
"shell.execute_reply": "2023-08-18T07:06:23.205574Z"
|
||||
},
|
||||
"origin_pos": 7,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"前提: A person on a horse jumps over a broken down airplane .\n",
|
||||
"假设: A person is training his horse for a competition .\n",
|
||||
"标签: 2\n",
|
||||
"前提: A person on a horse jumps over a broken down airplane .\n",
|
||||
"假设: A person is at a diner , ordering an omelette .\n",
|
||||
"标签: 1\n",
|
||||
"前提: A person on a horse jumps over a broken down airplane .\n",
|
||||
"假设: A person is outdoors , on a horse .\n",
|
||||
"标签: 0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train_data = read_snli(data_dir, is_train=True)\n",
|
||||
"for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):\n",
|
||||
" print('前提:', x0)\n",
|
||||
" print('假设:', x1)\n",
|
||||
" print('标签:', y)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f09b2cf4",
|
||||
"metadata": {
|
||||
"origin_pos": 8
|
||||
},
|
||||
"source": [
|
||||
"训练集约有550000对,测试集约有10000对。下面显示了训练集和测试集中的三个[**标签“蕴涵”“矛盾”和“中性”是平衡的**]。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "972ca3d1",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:06:23.210300Z",
|
||||
"iopub.status.busy": "2023-08-18T07:06:23.209728Z",
|
||||
"iopub.status.idle": "2023-08-18T07:06:23.531128Z",
|
||||
"shell.execute_reply": "2023-08-18T07:06:23.530246Z"
|
||||
},
|
||||
"origin_pos": 9,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[183416, 183187, 182764]\n",
|
||||
"[3368, 3237, 3219]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"test_data = read_snli(data_dir, is_train=False)\n",
|
||||
"for data in [train_data, test_data]:\n",
|
||||
" print([[row for row in data[2]].count(i) for i in range(3)])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e7ab2708",
|
||||
"metadata": {
|
||||
"origin_pos": 10
|
||||
},
|
||||
"source": [
|
||||
"### [**定义用于加载数据集的类**]\n",
|
||||
"\n",
|
||||
"下面我们来定义一个用于加载SNLI数据集的类。类构造函数中的变量`num_steps`指定文本序列的长度,使得每个小批量序列将具有相同的形状。换句话说,在较长序列中的前`num_steps`个标记之后的标记被截断,而特殊标记“<pad>”将被附加到较短的序列后,直到它们的长度变为`num_steps`。通过实现`__getitem__`功能,我们可以任意访问带有索引`idx`的前提、假设和标签。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "b8b15f65",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:06:23.534933Z",
|
||||
"iopub.status.busy": "2023-08-18T07:06:23.534365Z",
|
||||
"iopub.status.idle": "2023-08-18T07:06:23.542550Z",
|
||||
"shell.execute_reply": "2023-08-18T07:06:23.541714Z"
|
||||
},
|
||||
"origin_pos": 12,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@save\n",
|
||||
"class SNLIDataset(torch.utils.data.Dataset):\n",
|
||||
" \"\"\"用于加载SNLI数据集的自定义数据集\"\"\"\n",
|
||||
" def __init__(self, dataset, num_steps, vocab=None):\n",
|
||||
" self.num_steps = num_steps\n",
|
||||
" all_premise_tokens = d2l.tokenize(dataset[0])\n",
|
||||
" all_hypothesis_tokens = d2l.tokenize(dataset[1])\n",
|
||||
" if vocab is None:\n",
|
||||
" self.vocab = d2l.Vocab(all_premise_tokens + \\\n",
|
||||
" all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])\n",
|
||||
" else:\n",
|
||||
" self.vocab = vocab\n",
|
||||
" self.premises = self._pad(all_premise_tokens)\n",
|
||||
" self.hypotheses = self._pad(all_hypothesis_tokens)\n",
|
||||
" self.labels = torch.tensor(dataset[2])\n",
|
||||
" print('read ' + str(len(self.premises)) + ' examples')\n",
|
||||
"\n",
|
||||
" def _pad(self, lines):\n",
|
||||
" return torch.tensor([d2l.truncate_pad(\n",
|
||||
" self.vocab[line], self.num_steps, self.vocab['<pad>'])\n",
|
||||
" for line in lines])\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.premises)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f5efd5df",
|
||||
"metadata": {
|
||||
"origin_pos": 14
|
||||
},
|
||||
"source": [
|
||||
"### [**整合代码**]\n",
|
||||
"\n",
|
||||
"现在,我们可以调用`read_snli`函数和`SNLIDataset`类来下载SNLI数据集,并返回训练集和测试集的`DataLoader`实例,以及训练集的词表。值得注意的是,我们必须使用从训练集构造的词表作为测试集的词表。因此,在训练集中训练的模型将不知道来自测试集的任何新词元。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "96c46f53",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:06:23.546033Z",
|
||||
"iopub.status.busy": "2023-08-18T07:06:23.545509Z",
|
||||
"iopub.status.idle": "2023-08-18T07:06:23.551107Z",
|
||||
"shell.execute_reply": "2023-08-18T07:06:23.550286Z"
|
||||
},
|
||||
"origin_pos": 16,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@save\n",
|
||||
"def load_data_snli(batch_size, num_steps=50):\n",
|
||||
" \"\"\"下载SNLI数据集并返回数据迭代器和词表\"\"\"\n",
|
||||
" num_workers = d2l.get_dataloader_workers()\n",
|
||||
" data_dir = d2l.download_extract('SNLI')\n",
|
||||
" train_data = read_snli(data_dir, True)\n",
|
||||
" test_data = read_snli(data_dir, False)\n",
|
||||
" train_set = SNLIDataset(train_data, num_steps)\n",
|
||||
" test_set = SNLIDataset(test_data, num_steps, train_set.vocab)\n",
|
||||
" train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=num_workers)\n",
|
||||
" test_iter = torch.utils.data.DataLoader(test_set, batch_size,\n",
|
||||
" shuffle=False,\n",
|
||||
" num_workers=num_workers)\n",
|
||||
" return train_iter, test_iter, train_set.vocab"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "16d0cddb",
|
||||
"metadata": {
|
||||
"origin_pos": 18
|
||||
},
|
||||
"source": [
|
||||
"在这里,我们将批量大小设置为128时,将序列长度设置为50,并调用`load_data_snli`函数来获取数据迭代器和词表。然后我们打印词表大小。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "08d0c755",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:06:23.554839Z",
|
||||
"iopub.status.busy": "2023-08-18T07:06:23.554288Z",
|
||||
"iopub.status.idle": "2023-08-18T07:07:02.488484Z",
|
||||
"shell.execute_reply": "2023-08-18T07:07:02.487658Z"
|
||||
},
|
||||
"origin_pos": 19,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"read 549367 examples\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"read 9824 examples\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"18678"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train_iter, test_iter, vocab = load_data_snli(128, 50)\n",
|
||||
"len(vocab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "783f8d2d",
|
||||
"metadata": {
|
||||
"origin_pos": 20
|
||||
},
|
||||
"source": [
|
||||
"现在我们打印第一个小批量的形状。与情感分析相反,我们有分别代表前提和假设的两个输入`X[0]`和`X[1]`。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "d7411a33",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-08-18T07:07:02.492220Z",
|
||||
"iopub.status.busy": "2023-08-18T07:07:02.491909Z",
|
||||
"iopub.status.idle": "2023-08-18T07:07:02.966465Z",
|
||||
"shell.execute_reply": "2023-08-18T07:07:02.965137Z"
|
||||
},
|
||||
"origin_pos": 21,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([128, 50])\n",
|
||||
"torch.Size([128, 50])\n",
|
||||
"torch.Size([128])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for X, Y in train_iter:\n",
|
||||
" print(X[0].shape)\n",
|
||||
" print(X[1].shape)\n",
|
||||
" print(Y.shape)\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2cdcfd40",
|
||||
"metadata": {
|
||||
"origin_pos": 22
|
||||
},
|
||||
"source": [
|
||||
"## 小结\n",
|
||||
"\n",
|
||||
"* 自然语言推断研究“假设”是否可以从“前提”推断出来,其中两者都是文本序列。\n",
|
||||
"* 在自然语言推断中,前提和假设之间的关系包括蕴涵关系、矛盾关系和中性关系。\n",
|
||||
"* 斯坦福自然语言推断(SNLI)语料库是一个比较流行的自然语言推断基准数据集。\n",
|
||||
"\n",
|
||||
"## 练习\n",
|
||||
"\n",
|
||||
"1. 机器翻译长期以来一直是基于翻译输出和翻译真实值之间的表面$n$元语法匹配来进行评估的。可以设计一种用自然语言推断来评价机器翻译结果的方法吗?\n",
|
||||
"1. 我们如何更改超参数以减小词表大小?\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d452fb1d",
|
||||
"metadata": {
|
||||
"origin_pos": 24,
|
||||
"tab": [
|
||||
"pytorch"
|
||||
]
|
||||
},
|
||||
"source": [
|
||||
"[Discussions](https://discuss.d2l.ai/t/5722)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"required_libs": []
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
+1585
File diff suppressed because it is too large
Load Diff
+1418
File diff suppressed because it is too large
Load Diff
+1249
File diff suppressed because it is too large
Load Diff
+1482
File diff suppressed because it is too large
Load Diff
+1417
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user