{ "cells": [ { "cell_type": "markdown", "id": "15c5cd33", "metadata": { "origin_pos": 0 }, "source": [ "# 自然语言推断与数据集\n", ":label:`sec_natural-language-inference-and-dataset`\n", "\n", "在 :numref:`sec_sentiment`中,我们讨论了情感分析问题。这个任务的目的是将单个文本序列分类到预定义的类别中,例如一组情感极性中。然而,当需要决定一个句子是否可以从另一个句子推断出来,或者需要通过识别语义等价的句子来消除句子间冗余时,知道如何对一个文本序列进行分类是不够的。相反,我们需要能够对成对的文本序列进行推断。\n", "\n", "## 自然语言推断\n", "\n", "*自然语言推断*(natural language inference)主要研究\n", "*假设*(hypothesis)是否可以从*前提*(premise)中推断出来,\n", "其中两者都是文本序列。\n", "换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:\n", "\n", "* *蕴涵*(entailment):假设可以从前提中推断出来。\n", "* *矛盾*(contradiction):假设的否定可以从前提中推断出来。\n", "* *中性*(neutral):所有其他情况。\n", "\n", "自然语言推断也被称为识别文本蕴涵任务。\n", "例如,下面的一个文本对将被贴上“蕴涵”的标签,因为假设中的“表白”可以从前提中的“拥抱”中推断出来。\n", "\n", ">前提:两个女人拥抱在一起。\n", "\n", ">假设:两个女人在示爱。\n", "\n", "下面是一个“矛盾”的例子,因为“运行编码示例”表示“不睡觉”,而不是“睡觉”。\n", "\n", ">前提:一名男子正在运行Dive Into Deep Learning的编码示例。\n", "\n", ">假设:该男子正在睡觉。\n", "\n", "第三个例子显示了一种“中性”关系,因为“正在为我们表演”这一事实无法推断出“出名”或“不出名”。\n", "\n", ">前提:音乐家们正在为我们表演。\n", "\n", ">假设:音乐家很有名。\n", "\n", "自然语言推断一直是理解自然语言的中心话题。它有着广泛的应用,从信息检索到开放领域的问答。为了研究这个问题,我们将首先研究一个流行的自然语言推断基准数据集。\n", "\n", "## 斯坦福自然语言推断(SNLI)数据集\n", "\n", "[**斯坦福自然语言推断语料库(Stanford Natural Language Inference,SNLI)**]是由500000多个带标签的英语句子对组成的集合 :cite:`Bowman.Angeli.Potts.ea.2015`。我们在路径`../data/snli_1.0`中下载并存储提取的SNLI数据集。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "85ccbfd4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:06:00.201212Z", "iopub.status.busy": "2023-08-18T07:06:00.200144Z", "iopub.status.idle": "2023-08-18T07:06:09.370822Z", "shell.execute_reply": "2023-08-18T07:06:09.368591Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import os\n", "import re\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "#@save\n", "d2l.DATA_HUB['SNLI'] = (\n", " 'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',\n", " '9fcde07509c7e87ec61c640c1b2753d9041758e4')\n", "\n", "data_dir = d2l.download_extract('SNLI')" ] }, { "cell_type": "markdown", "id": "5e647396", "metadata": { "origin_pos": 4 }, "source": [ "### [**读取数据集**]\n", "\n", "原始的SNLI数据集包含的信息比我们在实验中真正需要的信息丰富得多。因此,我们定义函数`read_snli`以仅提取数据集的一部分,然后返回前提、假设及其标签的列表。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "fa839f80", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:06:09.377922Z", "iopub.status.busy": "2023-08-18T07:06:09.377380Z", "iopub.status.idle": "2023-08-18T07:06:09.392203Z", "shell.execute_reply": "2023-08-18T07:06:09.390984Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def read_snli(data_dir, is_train):\n", " \"\"\"将SNLI数据集解析为前提、假设和标签\"\"\"\n", " def extract_text(s):\n", " # 删除我们不会使用的信息\n", " s = re.sub('\\\\(', '', s)\n", " s = re.sub('\\\\)', '', s)\n", " # 用一个空格替换两个或多个连续的空格\n", " s = re.sub('\\\\s{2,}', ' ', s)\n", " return s.strip()\n", " label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}\n", " file_name = os.path.join(data_dir, 'snli_1.0_train.txt'\n", " if is_train else 'snli_1.0_test.txt')\n", " with open(file_name, 'r') as f:\n", " rows = [row.split('\\t') for row in f.readlines()[1:]]\n", " premises = [extract_text(row[1]) for row in rows if row[0] in label_set]\n", " hypotheses = [extract_text(row[2]) for row in rows if row[0] \\\n", " in label_set]\n", " labels = [label_set[row[0]] for row in rows if row[0] in label_set]\n", " return premises, hypotheses, labels" ] }, { "cell_type": "markdown", "id": "607a64fd", "metadata": { "origin_pos": 6 }, "source": [ "现在让我们[**打印前3对**]前提和假设,以及它们的标签(“0”“1”和“2”分别对应于“蕴涵”“矛盾”和“中性”)。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "19101f9e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:06:09.397297Z", "iopub.status.busy": "2023-08-18T07:06:09.396407Z", "iopub.status.idle": "2023-08-18T07:06:23.206512Z", "shell.execute_reply": "2023-08-18T07:06:23.205574Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "前提: A person on a horse jumps over a broken down airplane .\n", "假设: A person is training his horse for a competition .\n", "标签: 2\n", "前提: A person on a horse jumps over a broken down airplane .\n", "假设: A person is at a diner , ordering an omelette .\n", "标签: 1\n", "前提: A person on a horse jumps over a broken down airplane .\n", "假设: A person is outdoors , on a horse .\n", "标签: 0\n" ] } ], "source": [ "train_data = read_snli(data_dir, is_train=True)\n", "for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):\n", " print('前提:', x0)\n", " print('假设:', x1)\n", " print('标签:', y)" ] }, { "cell_type": "markdown", "id": "f09b2cf4", "metadata": { "origin_pos": 8 }, "source": [ "训练集约有550000对,测试集约有10000对。下面显示了训练集和测试集中的三个[**标签“蕴涵”“矛盾”和“中性”是平衡的**]。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "972ca3d1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:06:23.210300Z", "iopub.status.busy": "2023-08-18T07:06:23.209728Z", "iopub.status.idle": "2023-08-18T07:06:23.531128Z", "shell.execute_reply": "2023-08-18T07:06:23.530246Z" }, "origin_pos": 9, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[183416, 183187, 182764]\n", "[3368, 3237, 3219]\n" ] } ], "source": [ "test_data = read_snli(data_dir, is_train=False)\n", "for data in [train_data, test_data]:\n", " print([[row for row in data[2]].count(i) for i in range(3)])" ] }, { "cell_type": "markdown", "id": "e7ab2708", "metadata": { "origin_pos": 10 }, "source": [ "### [**定义用于加载数据集的类**]\n", "\n", "下面我们来定义一个用于加载SNLI数据集的类。类构造函数中的变量`num_steps`指定文本序列的长度,使得每个小批量序列将具有相同的形状。换句话说,在较长序列中的前`num_steps`个标记之后的标记被截断,而特殊标记“<pad>”将被附加到较短的序列后,直到它们的长度变为`num_steps`。通过实现`__getitem__`功能,我们可以任意访问带有索引`idx`的前提、假设和标签。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "b8b15f65", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:06:23.534933Z", "iopub.status.busy": "2023-08-18T07:06:23.534365Z", "iopub.status.idle": "2023-08-18T07:06:23.542550Z", "shell.execute_reply": "2023-08-18T07:06:23.541714Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "class SNLIDataset(torch.utils.data.Dataset):\n", " \"\"\"用于加载SNLI数据集的自定义数据集\"\"\"\n", " def __init__(self, dataset, num_steps, vocab=None):\n", " self.num_steps = num_steps\n", " all_premise_tokens = d2l.tokenize(dataset[0])\n", " all_hypothesis_tokens = d2l.tokenize(dataset[1])\n", " if vocab is None:\n", " self.vocab = d2l.Vocab(all_premise_tokens + \\\n", " all_hypothesis_tokens, min_freq=5, reserved_tokens=[''])\n", " else:\n", " self.vocab = vocab\n", " self.premises = self._pad(all_premise_tokens)\n", " self.hypotheses = self._pad(all_hypothesis_tokens)\n", " self.labels = torch.tensor(dataset[2])\n", " print('read ' + str(len(self.premises)) + ' examples')\n", "\n", " def _pad(self, lines):\n", " return torch.tensor([d2l.truncate_pad(\n", " self.vocab[line], self.num_steps, self.vocab[''])\n", " for line in lines])\n", "\n", " def __getitem__(self, idx):\n", " return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]\n", "\n", " def __len__(self):\n", " return len(self.premises)" ] }, { "cell_type": "markdown", "id": "f5efd5df", "metadata": { "origin_pos": 14 }, "source": [ "### [**整合代码**]\n", "\n", "现在,我们可以调用`read_snli`函数和`SNLIDataset`类来下载SNLI数据集,并返回训练集和测试集的`DataLoader`实例,以及训练集的词表。值得注意的是,我们必须使用从训练集构造的词表作为测试集的词表。因此,在训练集中训练的模型将不知道来自测试集的任何新词元。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "96c46f53", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:06:23.546033Z", "iopub.status.busy": "2023-08-18T07:06:23.545509Z", "iopub.status.idle": "2023-08-18T07:06:23.551107Z", "shell.execute_reply": "2023-08-18T07:06:23.550286Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def load_data_snli(batch_size, num_steps=50):\n", " \"\"\"下载SNLI数据集并返回数据迭代器和词表\"\"\"\n", " num_workers = d2l.get_dataloader_workers()\n", " data_dir = d2l.download_extract('SNLI')\n", " train_data = read_snli(data_dir, True)\n", " test_data = read_snli(data_dir, False)\n", " train_set = SNLIDataset(train_data, num_steps)\n", " test_set = SNLIDataset(test_data, num_steps, train_set.vocab)\n", " train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n", " shuffle=True,\n", " num_workers=num_workers)\n", " test_iter = torch.utils.data.DataLoader(test_set, batch_size,\n", " shuffle=False,\n", " num_workers=num_workers)\n", " return train_iter, test_iter, train_set.vocab" ] }, { "cell_type": "markdown", "id": "16d0cddb", "metadata": { "origin_pos": 18 }, "source": [ "在这里,我们将批量大小设置为128时,将序列长度设置为50,并调用`load_data_snli`函数来获取数据迭代器和词表。然后我们打印词表大小。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "08d0c755", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:06:23.554839Z", "iopub.status.busy": "2023-08-18T07:06:23.554288Z", "iopub.status.idle": "2023-08-18T07:07:02.488484Z", "shell.execute_reply": "2023-08-18T07:07:02.487658Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "read 549367 examples\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "read 9824 examples\n" ] }, { "data": { "text/plain": [ "18678" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_iter, test_iter, vocab = load_data_snli(128, 50)\n", "len(vocab)" ] }, { "cell_type": "markdown", "id": "783f8d2d", "metadata": { "origin_pos": 20 }, "source": [ "现在我们打印第一个小批量的形状。与情感分析相反,我们有分别代表前提和假设的两个输入`X[0]`和`X[1]`。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "d7411a33", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:07:02.492220Z", "iopub.status.busy": "2023-08-18T07:07:02.491909Z", "iopub.status.idle": "2023-08-18T07:07:02.966465Z", "shell.execute_reply": "2023-08-18T07:07:02.965137Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([128, 50])\n", "torch.Size([128, 50])\n", "torch.Size([128])\n" ] } ], "source": [ "for X, Y in train_iter:\n", " print(X[0].shape)\n", " print(X[1].shape)\n", " print(Y.shape)\n", " break" ] }, { "cell_type": "markdown", "id": "2cdcfd40", "metadata": { "origin_pos": 22 }, "source": [ "## 小结\n", "\n", "* 自然语言推断研究“假设”是否可以从“前提”推断出来,其中两者都是文本序列。\n", "* 在自然语言推断中,前提和假设之间的关系包括蕴涵关系、矛盾关系和中性关系。\n", "* 斯坦福自然语言推断(SNLI)语料库是一个比较流行的自然语言推断基准数据集。\n", "\n", "## 练习\n", "\n", "1. 机器翻译长期以来一直是基于翻译输出和翻译真实值之间的表面$n$元语法匹配来进行评估的。可以设计一种用自然语言推断来评价机器翻译结果的方法吗?\n", "1. 我们如何更改超参数以减小词表大小?\n" ] }, { "cell_type": "markdown", "id": "d452fb1d", "metadata": { "origin_pos": 24, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/5722)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }