Files
Python/d2l/d2l-zh/pytorch/chapter_natural-language-processing-applications/natural-language-inference-and-dataset.ipynb
T
2025-12-16 09:23:53 +08:00

479 lines
16 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "15c5cd33",
"metadata": {
"origin_pos": 0
},
"source": [
"# 自然语言推断与数据集\n",
":label:`sec_natural-language-inference-and-dataset`\n",
"\n",
"在 :numref:`sec_sentiment`中,我们讨论了情感分析问题。这个任务的目的是将单个文本序列分类到预定义的类别中,例如一组情感极性中。然而,当需要决定一个句子是否可以从另一个句子推断出来,或者需要通过识别语义等价的句子来消除句子间冗余时,知道如何对一个文本序列进行分类是不够的。相反,我们需要能够对成对的文本序列进行推断。\n",
"\n",
"## 自然语言推断\n",
"\n",
"*自然语言推断*natural language inference)主要研究\n",
"*假设*hypothesis)是否可以从*前提*premise)中推断出来,\n",
"其中两者都是文本序列。\n",
"换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:\n",
"\n",
"* *蕴涵*entailment):假设可以从前提中推断出来。\n",
"* *矛盾*contradiction):假设的否定可以从前提中推断出来。\n",
"* *中性*neutral):所有其他情况。\n",
"\n",
"自然语言推断也被称为识别文本蕴涵任务。\n",
"例如,下面的一个文本对将被贴上“蕴涵”的标签,因为假设中的“表白”可以从前提中的“拥抱”中推断出来。\n",
"\n",
">前提:两个女人拥抱在一起。\n",
"\n",
">假设:两个女人在示爱。\n",
"\n",
"下面是一个“矛盾”的例子,因为“运行编码示例”表示“不睡觉”,而不是“睡觉”。\n",
"\n",
">前提:一名男子正在运行Dive Into Deep Learning的编码示例。\n",
"\n",
">假设:该男子正在睡觉。\n",
"\n",
"第三个例子显示了一种“中性”关系,因为“正在为我们表演”这一事实无法推断出“出名”或“不出名”。\n",
"\n",
">前提:音乐家们正在为我们表演。\n",
"\n",
">假设:音乐家很有名。\n",
"\n",
"自然语言推断一直是理解自然语言的中心话题。它有着广泛的应用,从信息检索到开放领域的问答。为了研究这个问题,我们将首先研究一个流行的自然语言推断基准数据集。\n",
"\n",
"## 斯坦福自然语言推断(SNLI)数据集\n",
"\n",
"[**斯坦福自然语言推断语料库(Stanford Natural Language InferenceSNLI**]是由500000多个带标签的英语句子对组成的集合 :cite:`Bowman.Angeli.Potts.ea.2015`。我们在路径`../data/snli_1.0`中下载并存储提取的SNLI数据集。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "85ccbfd4",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:00.201212Z",
"iopub.status.busy": "2023-08-18T07:06:00.200144Z",
"iopub.status.idle": "2023-08-18T07:06:09.370822Z",
"shell.execute_reply": "2023-08-18T07:06:09.368591Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import os\n",
"import re\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l\n",
"\n",
"#@save\n",
"d2l.DATA_HUB['SNLI'] = (\n",
" 'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',\n",
" '9fcde07509c7e87ec61c640c1b2753d9041758e4')\n",
"\n",
"data_dir = d2l.download_extract('SNLI')"
]
},
{
"cell_type": "markdown",
"id": "5e647396",
"metadata": {
"origin_pos": 4
},
"source": [
"### [**读取数据集**]\n",
"\n",
"原始的SNLI数据集包含的信息比我们在实验中真正需要的信息丰富得多。因此,我们定义函数`read_snli`以仅提取数据集的一部分,然后返回前提、假设及其标签的列表。\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fa839f80",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:09.377922Z",
"iopub.status.busy": "2023-08-18T07:06:09.377380Z",
"iopub.status.idle": "2023-08-18T07:06:09.392203Z",
"shell.execute_reply": "2023-08-18T07:06:09.390984Z"
},
"origin_pos": 5,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"#@save\n",
"def read_snli(data_dir, is_train):\n",
" \"\"\"将SNLI数据集解析为前提、假设和标签\"\"\"\n",
" def extract_text(s):\n",
" # 删除我们不会使用的信息\n",
" s = re.sub('\\\\(', '', s)\n",
" s = re.sub('\\\\)', '', s)\n",
" # 用一个空格替换两个或多个连续的空格\n",
" s = re.sub('\\\\s{2,}', ' ', s)\n",
" return s.strip()\n",
" label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}\n",
" file_name = os.path.join(data_dir, 'snli_1.0_train.txt'\n",
" if is_train else 'snli_1.0_test.txt')\n",
" with open(file_name, 'r') as f:\n",
" rows = [row.split('\\t') for row in f.readlines()[1:]]\n",
" premises = [extract_text(row[1]) for row in rows if row[0] in label_set]\n",
" hypotheses = [extract_text(row[2]) for row in rows if row[0] \\\n",
" in label_set]\n",
" labels = [label_set[row[0]] for row in rows if row[0] in label_set]\n",
" return premises, hypotheses, labels"
]
},
{
"cell_type": "markdown",
"id": "607a64fd",
"metadata": {
"origin_pos": 6
},
"source": [
"现在让我们[**打印前3对**]前提和假设,以及它们的标签(“0”“1”和“2”分别对应于“蕴涵”“矛盾”和“中性”)。\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "19101f9e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:09.397297Z",
"iopub.status.busy": "2023-08-18T07:06:09.396407Z",
"iopub.status.idle": "2023-08-18T07:06:23.206512Z",
"shell.execute_reply": "2023-08-18T07:06:23.205574Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"前提: A person on a horse jumps over a broken down airplane .\n",
"假设: A person is training his horse for a competition .\n",
"标签: 2\n",
"前提: A person on a horse jumps over a broken down airplane .\n",
"假设: A person is at a diner , ordering an omelette .\n",
"标签: 1\n",
"前提: A person on a horse jumps over a broken down airplane .\n",
"假设: A person is outdoors , on a horse .\n",
"标签: 0\n"
]
}
],
"source": [
"train_data = read_snli(data_dir, is_train=True)\n",
"for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):\n",
" print('前提:', x0)\n",
" print('假设:', x1)\n",
" print('标签:', y)"
]
},
{
"cell_type": "markdown",
"id": "f09b2cf4",
"metadata": {
"origin_pos": 8
},
"source": [
"训练集约有550000对,测试集约有10000对。下面显示了训练集和测试集中的三个[**标签“蕴涵”“矛盾”和“中性”是平衡的**]。\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "972ca3d1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:23.210300Z",
"iopub.status.busy": "2023-08-18T07:06:23.209728Z",
"iopub.status.idle": "2023-08-18T07:06:23.531128Z",
"shell.execute_reply": "2023-08-18T07:06:23.530246Z"
},
"origin_pos": 9,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[183416, 183187, 182764]\n",
"[3368, 3237, 3219]\n"
]
}
],
"source": [
"test_data = read_snli(data_dir, is_train=False)\n",
"for data in [train_data, test_data]:\n",
" print([[row for row in data[2]].count(i) for i in range(3)])"
]
},
{
"cell_type": "markdown",
"id": "e7ab2708",
"metadata": {
"origin_pos": 10
},
"source": [
"### [**定义用于加载数据集的类**]\n",
"\n",
"下面我们来定义一个用于加载SNLI数据集的类。类构造函数中的变量`num_steps`指定文本序列的长度,使得每个小批量序列将具有相同的形状。换句话说,在较长序列中的前`num_steps`个标记之后的标记被截断,而特殊标记“<pad>”将被附加到较短的序列后,直到它们的长度变为`num_steps`。通过实现`__getitem__`功能,我们可以任意访问带有索引`idx`的前提、假设和标签。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b8b15f65",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:23.534933Z",
"iopub.status.busy": "2023-08-18T07:06:23.534365Z",
"iopub.status.idle": "2023-08-18T07:06:23.542550Z",
"shell.execute_reply": "2023-08-18T07:06:23.541714Z"
},
"origin_pos": 12,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"#@save\n",
"class SNLIDataset(torch.utils.data.Dataset):\n",
" \"\"\"用于加载SNLI数据集的自定义数据集\"\"\"\n",
" def __init__(self, dataset, num_steps, vocab=None):\n",
" self.num_steps = num_steps\n",
" all_premise_tokens = d2l.tokenize(dataset[0])\n",
" all_hypothesis_tokens = d2l.tokenize(dataset[1])\n",
" if vocab is None:\n",
" self.vocab = d2l.Vocab(all_premise_tokens + \\\n",
" all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])\n",
" else:\n",
" self.vocab = vocab\n",
" self.premises = self._pad(all_premise_tokens)\n",
" self.hypotheses = self._pad(all_hypothesis_tokens)\n",
" self.labels = torch.tensor(dataset[2])\n",
" print('read ' + str(len(self.premises)) + ' examples')\n",
"\n",
" def _pad(self, lines):\n",
" return torch.tensor([d2l.truncate_pad(\n",
" self.vocab[line], self.num_steps, self.vocab['<pad>'])\n",
" for line in lines])\n",
"\n",
" def __getitem__(self, idx):\n",
" return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]\n",
"\n",
" def __len__(self):\n",
" return len(self.premises)"
]
},
{
"cell_type": "markdown",
"id": "f5efd5df",
"metadata": {
"origin_pos": 14
},
"source": [
"### [**整合代码**]\n",
"\n",
"现在,我们可以调用`read_snli`函数和`SNLIDataset`类来下载SNLI数据集,并返回训练集和测试集的`DataLoader`实例,以及训练集的词表。值得注意的是,我们必须使用从训练集构造的词表作为测试集的词表。因此,在训练集中训练的模型将不知道来自测试集的任何新词元。\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "96c46f53",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:23.546033Z",
"iopub.status.busy": "2023-08-18T07:06:23.545509Z",
"iopub.status.idle": "2023-08-18T07:06:23.551107Z",
"shell.execute_reply": "2023-08-18T07:06:23.550286Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"#@save\n",
"def load_data_snli(batch_size, num_steps=50):\n",
" \"\"\"下载SNLI数据集并返回数据迭代器和词表\"\"\"\n",
" num_workers = d2l.get_dataloader_workers()\n",
" data_dir = d2l.download_extract('SNLI')\n",
" train_data = read_snli(data_dir, True)\n",
" test_data = read_snli(data_dir, False)\n",
" train_set = SNLIDataset(train_data, num_steps)\n",
" test_set = SNLIDataset(test_data, num_steps, train_set.vocab)\n",
" train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n",
" shuffle=True,\n",
" num_workers=num_workers)\n",
" test_iter = torch.utils.data.DataLoader(test_set, batch_size,\n",
" shuffle=False,\n",
" num_workers=num_workers)\n",
" return train_iter, test_iter, train_set.vocab"
]
},
{
"cell_type": "markdown",
"id": "16d0cddb",
"metadata": {
"origin_pos": 18
},
"source": [
"在这里,我们将批量大小设置为128时,将序列长度设置为50,并调用`load_data_snli`函数来获取数据迭代器和词表。然后我们打印词表大小。\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "08d0c755",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:23.554839Z",
"iopub.status.busy": "2023-08-18T07:06:23.554288Z",
"iopub.status.idle": "2023-08-18T07:07:02.488484Z",
"shell.execute_reply": "2023-08-18T07:07:02.487658Z"
},
"origin_pos": 19,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"read 549367 examples\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"read 9824 examples\n"
]
},
{
"data": {
"text/plain": [
"18678"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_iter, test_iter, vocab = load_data_snli(128, 50)\n",
"len(vocab)"
]
},
{
"cell_type": "markdown",
"id": "783f8d2d",
"metadata": {
"origin_pos": 20
},
"source": [
"现在我们打印第一个小批量的形状。与情感分析相反,我们有分别代表前提和假设的两个输入`X[0]`和`X[1]`。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d7411a33",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:07:02.492220Z",
"iopub.status.busy": "2023-08-18T07:07:02.491909Z",
"iopub.status.idle": "2023-08-18T07:07:02.966465Z",
"shell.execute_reply": "2023-08-18T07:07:02.965137Z"
},
"origin_pos": 21,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([128, 50])\n",
"torch.Size([128, 50])\n",
"torch.Size([128])\n"
]
}
],
"source": [
"for X, Y in train_iter:\n",
" print(X[0].shape)\n",
" print(X[1].shape)\n",
" print(Y.shape)\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "2cdcfd40",
"metadata": {
"origin_pos": 22
},
"source": [
"## 小结\n",
"\n",
"* 自然语言推断研究“假设”是否可以从“前提”推断出来,其中两者都是文本序列。\n",
"* 在自然语言推断中,前提和假设之间的关系包括蕴涵关系、矛盾关系和中性关系。\n",
"* 斯坦福自然语言推断(SNLI)语料库是一个比较流行的自然语言推断基准数据集。\n",
"\n",
"## 练习\n",
"\n",
"1. 机器翻译长期以来一直是基于翻译输出和翻译真实值之间的表面$n$元语法匹配来进行评估的。可以设计一种用自然语言推断来评价机器翻译结果的方法吗?\n",
"1. 我们如何更改超参数以减小词表大小?\n"
]
},
{
"cell_type": "markdown",
"id": "d452fb1d",
"metadata": {
"origin_pos": 24,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/5722)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}