{ "cells": [ { "cell_type": "markdown", "id": "82f9eacb", "metadata": { "origin_pos": 0 }, "source": [ "# 用于预训练词嵌入的数据集\n", ":label:`sec_word2vec_data`\n", "\n", "现在我们已经了解了word2vec模型的技术细节和大致的训练方法,让我们来看看它们的实现。具体地说,我们将以 :numref:`sec_word2vec`的跳元模型和 :numref:`sec_approx_train`的负采样为例。本节从用于预训练词嵌入模型的数据集开始:数据的原始格式将被转换为可以在训练期间迭代的小批量。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "596ed133", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:38.933299Z", "iopub.status.busy": "2023-08-18T07:01:38.932361Z", "iopub.status.idle": "2023-08-18T07:01:41.929964Z", "shell.execute_reply": "2023-08-18T07:01:41.928691Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import math\n", "import os\n", "import random\n", "import torch\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "8286adf0", "metadata": { "origin_pos": 4 }, "source": [ "## 读取数据集\n", "\n", "我们在这里使用的数据集是[Penn Tree Bank(PTB)](https://catalog.ldc.upenn.edu/LDC99T42)。该语料库取自“华尔街日报”的文章,分为训练集、验证集和测试集。在原始格式中,文本文件的每一行表示由空格分隔的一句话。在这里,我们将每个单词视为一个词元。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "cc6c9b2e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:41.935897Z", "iopub.status.busy": "2023-08-18T07:01:41.934975Z", "iopub.status.idle": "2023-08-18T07:01:42.345380Z", "shell.execute_reply": "2023-08-18T07:01:42.344041Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading ../data/ptb.zip from http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip...\n" ] }, { "data": { "text/plain": [ "'# sentences数: 42069'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@save\n", "d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',\n", " '319d85e578af0cdc590547f26231e4e31cdf1e42')\n", "\n", "#@save\n", "def read_ptb():\n", " \"\"\"将PTB数据集加载到文本行的列表中\"\"\"\n", " data_dir = d2l.download_extract('ptb')\n", " # Readthetrainingset.\n", " with open(os.path.join(data_dir, 'ptb.train.txt')) as f:\n", " raw_text = f.read()\n", " return [line.split() for line in raw_text.split('\\n')]\n", "\n", "sentences = read_ptb()\n", "f'# sentences数: {len(sentences)}'" ] }, { "cell_type": "markdown", "id": "e7290de5", "metadata": { "origin_pos": 6 }, "source": [ "在读取训练集之后,我们为语料库构建了一个词表,其中出现次数少于10次的任何单词都将由“<unk>”词元替换。请注意,原始数据集还包含表示稀有(未知)单词的“<unk>”词元。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "04285c2d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:42.350103Z", "iopub.status.busy": "2023-08-18T07:01:42.349586Z", "iopub.status.idle": "2023-08-18T07:01:42.520737Z", "shell.execute_reply": "2023-08-18T07:01:42.519523Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'vocab size: 6719'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab = d2l.Vocab(sentences, min_freq=10)\n", "f'vocab size: {len(vocab)}'" ] }, { "cell_type": "markdown", "id": "0bba2291", "metadata": { "origin_pos": 8 }, "source": [ "## 下采样\n", "\n", "文本数据通常有“the”“a”和“in”等高频词:它们在非常大的语料库中甚至可能出现数十亿次。然而,这些词经常在上下文窗口中与许多不同的词共同出现,提供的有用信息很少。例如,考虑上下文窗口中的词“chip”:直观地说,它与低频单词“intel”的共现比与高频单词“a”的共现在训练中更有用。此外,大量(高频)单词的训练速度很慢。因此,当训练词嵌入模型时,可以对高频单词进行*下采样* :cite:`Mikolov.Sutskever.Chen.ea.2013`。具体地说,数据集中的每个词$w_i$将有概率地被丢弃\n", "\n", "$$ P(w_i) = \\max\\left(1 - \\sqrt{\\frac{t}{f(w_i)}}, 0\\right),$$\n", "\n", "其中$f(w_i)$是$w_i$的词数与数据集中的总词数的比率,常量$t$是超参数(在实验中为$10^{-4}$)。我们可以看到,只有当相对比率$f(w_i) > t$时,(高频)词$w_i$才能被丢弃,且该词的相对比率越高,被丢弃的概率就越大。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "88d0f9c2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:42.524901Z", "iopub.status.busy": "2023-08-18T07:01:42.524245Z", "iopub.status.idle": "2023-08-18T07:01:44.019122Z", "shell.execute_reply": "2023-08-18T07:01:44.017912Z" }, "origin_pos": 9, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def subsample(sentences, vocab):\n", " \"\"\"下采样高频词\"\"\"\n", " # 排除未知词元''\n", " sentences = [[token for token in line if vocab[token] != vocab.unk]\n", " for line in sentences]\n", " counter = d2l.count_corpus(sentences)\n", " num_tokens = sum(counter.values())\n", "\n", " # 如果在下采样期间保留词元,则返回True\n", " def keep(token):\n", " return(random.uniform(0, 1) <\n", " math.sqrt(1e-4 / counter[token] * num_tokens))\n", "\n", " return ([[token for token in line if keep(token)] for line in sentences],\n", " counter)\n", "\n", "subsampled, counter = subsample(sentences, vocab)" ] }, { "cell_type": "markdown", "id": "5c892ade", "metadata": { "origin_pos": 10 }, "source": [ "下面的代码片段绘制了下采样前后每句话的词元数量的直方图。正如预期的那样,下采样通过删除高频词来显著缩短句子,这将使训练加速。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "5dd0b4f6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:44.024294Z", "iopub.status.busy": "2023-08-18T07:01:44.023765Z", "iopub.status.idle": "2023-08-18T07:01:44.272889Z", "shell.execute_reply": "2023-08-18T07:01:44.271933Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:01:44.224820\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.show_list_len_pair_hist(\n", " ['origin', 'subsampled'], '# tokens per sentence',\n", " 'count', sentences, subsampled);" ] }, { "cell_type": "markdown", "id": "80da6e9d", "metadata": { "origin_pos": 12 }, "source": [ "对于单个词元,高频词“the”的采样率不到1/20。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "2ac63b1a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:44.277942Z", "iopub.status.busy": "2023-08-18T07:01:44.277661Z", "iopub.status.idle": "2023-08-18T07:01:44.319135Z", "shell.execute_reply": "2023-08-18T07:01:44.317982Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'\"the\"的数量:之前=50770, 之后=2056'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def compare_counts(token):\n", " return (f'\"{token}\"的数量:'\n", " f'之前={sum([l.count(token) for l in sentences])}, '\n", " f'之后={sum([l.count(token) for l in subsampled])}')\n", "\n", "compare_counts('the')" ] }, { "cell_type": "markdown", "id": "73ef69ef", "metadata": { "origin_pos": 14 }, "source": [ "相比之下,低频词“join”则被完全保留。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "9307cb04", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:44.324650Z", "iopub.status.busy": "2023-08-18T07:01:44.323726Z", "iopub.status.idle": "2023-08-18T07:01:44.366586Z", "shell.execute_reply": "2023-08-18T07:01:44.365449Z" }, "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'\"join\"的数量:之前=45, 之后=45'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compare_counts('join')" ] }, { "cell_type": "markdown", "id": "38762200", "metadata": { "origin_pos": 16 }, "source": [ "在下采样之后,我们将词元映射到它们在语料库中的索引。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "ed59e4d0", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:44.371681Z", "iopub.status.busy": "2023-08-18T07:01:44.370695Z", "iopub.status.idle": "2023-08-18T07:01:44.930927Z", "shell.execute_reply": "2023-08-18T07:01:44.929824Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "[[], [2115, 274, 406], [140, 3, 5277, 3054, 1580]]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "corpus = [vocab[line] for line in subsampled]\n", "corpus[:3]" ] }, { "cell_type": "markdown", "id": "fe5918fb", "metadata": { "origin_pos": 18 }, "source": [ "## 中心词和上下文词的提取\n", "\n", "下面的`get_centers_and_contexts`函数从`corpus`中提取所有中心词及其上下文词。它随机采样1到`max_window_size`之间的整数作为上下文窗口。对于任一中心词,与其距离不超过采样上下文窗口大小的词为其上下文词。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "d4a20ba3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:44.935833Z", "iopub.status.busy": "2023-08-18T07:01:44.935066Z", "iopub.status.idle": "2023-08-18T07:01:44.944963Z", "shell.execute_reply": "2023-08-18T07:01:44.943901Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def get_centers_and_contexts(corpus, max_window_size):\n", " \"\"\"返回跳元模型中的中心词和上下文词\"\"\"\n", " centers, contexts = [], []\n", " for line in corpus:\n", " # 要形成“中心词-上下文词”对,每个句子至少需要有2个词\n", " if len(line) < 2:\n", " continue\n", " centers += line\n", " for i in range(len(line)): # 上下文窗口中间i\n", " window_size = random.randint(1, max_window_size)\n", " indices = list(range(max(0, i - window_size),\n", " min(len(line), i + 1 + window_size)))\n", " # 从上下文词中排除中心词\n", " indices.remove(i)\n", " contexts.append([line[idx] for idx in indices])\n", " return centers, contexts" ] }, { "cell_type": "markdown", "id": "86fba895", "metadata": { "origin_pos": 20 }, "source": [ "接下来,我们创建一个人工数据集,分别包含7个和3个单词的两个句子。设置最大上下文窗口大小为2,并打印所有中心词及其上下文词。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "fae4771b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:44.948910Z", "iopub.status.busy": "2023-08-18T07:01:44.948190Z", "iopub.status.idle": "2023-08-18T07:01:44.955563Z", "shell.execute_reply": "2023-08-18T07:01:44.954488Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "数据集 [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]\n", "中心词 0 的上下文词是 [1]\n", "中心词 1 的上下文词是 [0, 2]\n", "中心词 2 的上下文词是 [0, 1, 3, 4]\n", "中心词 3 的上下文词是 [2, 4]\n", "中心词 4 的上下文词是 [3, 5]\n", "中心词 5 的上下文词是 [4, 6]\n", "中心词 6 的上下文词是 [5]\n", "中心词 7 的上下文词是 [8, 9]\n", "中心词 8 的上下文词是 [7, 9]\n", "中心词 9 的上下文词是 [7, 8]\n" ] } ], "source": [ "tiny_dataset = [list(range(7)), list(range(7, 10))]\n", "print('数据集', tiny_dataset)\n", "for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):\n", " print('中心词', center, '的上下文词是', context)" ] }, { "cell_type": "markdown", "id": "e21272fc", "metadata": { "origin_pos": 22 }, "source": [ "在PTB数据集上进行训练时,我们将最大上下文窗口大小设置为5。下面提取数据集中的所有中心词及其上下文词。\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "ec92f27e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:44.960145Z", "iopub.status.busy": "2023-08-18T07:01:44.959231Z", "iopub.status.idle": "2023-08-18T07:01:47.218796Z", "shell.execute_reply": "2023-08-18T07:01:47.217626Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'# “中心词-上下文词对”的数量: 1499984'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_centers, all_contexts = get_centers_and_contexts(corpus, 5)\n", "f'# “中心词-上下文词对”的数量: {sum([len(contexts) for contexts in all_contexts])}'" ] }, { "cell_type": "markdown", "id": "f48c535f", "metadata": { "origin_pos": 24 }, "source": [ "## 负采样\n", "\n", "我们使用负采样进行近似训练。为了根据预定义的分布对噪声词进行采样,我们定义以下`RandomGenerator`类,其中(可能未规范化的)采样分布通过变量`sampling_weights`传递。\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "365189a2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:47.223801Z", "iopub.status.busy": "2023-08-18T07:01:47.223354Z", "iopub.status.idle": "2023-08-18T07:01:47.232254Z", "shell.execute_reply": "2023-08-18T07:01:47.231166Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "class RandomGenerator:\n", " \"\"\"根据n个采样权重在{1,...,n}中随机抽取\"\"\"\n", " def __init__(self, sampling_weights):\n", " # Exclude\n", " self.population = list(range(1, len(sampling_weights) + 1))\n", " self.sampling_weights = sampling_weights\n", " self.candidates = []\n", " self.i = 0\n", "\n", " def draw(self):\n", " if self.i == len(self.candidates):\n", " # 缓存k个随机采样结果\n", " self.candidates = random.choices(\n", " self.population, self.sampling_weights, k=10000)\n", " self.i = 0\n", " self.i += 1\n", " return self.candidates[self.i - 1]" ] }, { "cell_type": "markdown", "id": "f886ada9", "metadata": { "origin_pos": 26 }, "source": [ "例如,我们可以在索引1、2和3中绘制10个随机变量$X$,采样概率为$P(X=1)=2/9, P(X=2)=3/9$和$P(X=3)=4/9$,如下所示。\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "f534865c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:47.237153Z", "iopub.status.busy": "2023-08-18T07:01:47.236381Z", "iopub.status.idle": "2023-08-18T07:01:47.251510Z", "shell.execute_reply": "2023-08-18T07:01:47.250435Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "[1, 2, 2, 3, 3, 3, 3, 2, 1, 2]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@save\n", "generator = RandomGenerator([2, 3, 4])\n", "[generator.draw() for _ in range(10)]" ] }, { "cell_type": "markdown", "id": "fe4049d4", "metadata": { "origin_pos": 28 }, "source": [ "对于一对中心词和上下文词,我们随机抽取了`K`个(实验中为5个)噪声词。根据word2vec论文中的建议,将噪声词$w$的采样概率$P(w)$设置为其在字典中的相对频率,其幂为0.75 :cite:`Mikolov.Sutskever.Chen.ea.2013`。\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "21950025", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:47.256344Z", "iopub.status.busy": "2023-08-18T07:01:47.255586Z", "iopub.status.idle": "2023-08-18T07:01:59.259799Z", "shell.execute_reply": "2023-08-18T07:01:59.258793Z" }, "origin_pos": 29, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def get_negatives(all_contexts, vocab, counter, K):\n", " \"\"\"返回负采样中的噪声词\"\"\"\n", " # 索引为1、2、...(索引0是词表中排除的未知标记)\n", " sampling_weights = [counter[vocab.to_tokens(i)]**0.75\n", " for i in range(1, len(vocab))]\n", " all_negatives, generator = [], RandomGenerator(sampling_weights)\n", " for contexts in all_contexts:\n", " negatives = []\n", " while len(negatives) < len(contexts) * K:\n", " neg = generator.draw()\n", " # 噪声词不能是上下文词\n", " if neg not in contexts:\n", " negatives.append(neg)\n", " all_negatives.append(negatives)\n", " return all_negatives\n", "\n", "all_negatives = get_negatives(all_contexts, vocab, counter, 5)" ] }, { "cell_type": "markdown", "id": "8aa17e2d", "metadata": { "origin_pos": 30 }, "source": [ "## 小批量加载训练实例\n", ":label:`subsec_word2vec-minibatch-loading`\n", "\n", "在提取所有中心词及其上下文词和采样噪声词后,将它们转换成小批量的样本,在训练过程中可以迭代加载。\n", "\n", "在小批量中,$i^\\mathrm{th}$个样本包括中心词及其$n_i$个上下文词和$m_i$个噪声词。由于上下文窗口大小不同,$n_i+m_i$对于不同的$i$是不同的。因此,对于每个样本,我们在`contexts_negatives`个变量中将其上下文词和噪声词连结起来,并填充零,直到连结长度达到$\\max_i n_i+m_i$(`max_len`)。为了在计算损失时排除填充,我们定义了掩码变量`masks`。在`masks`中的元素和`contexts_negatives`中的元素之间存在一一对应关系,其中`masks`中的0(否则为1)对应于`contexts_negatives`中的填充。\n", "\n", "为了区分正反例,我们在`contexts_negatives`中通过一个`labels`变量将上下文词与噪声词分开。类似于`masks`,在`labels`中的元素和`contexts_negatives`中的元素之间也存在一一对应关系,其中`labels`中的1(否则为0)对应于`contexts_negatives`中的上下文词的正例。\n", "\n", "上述思想在下面的`batchify`函数中实现。其输入`data`是长度等于批量大小的列表,其中每个元素是由中心词`center`、其上下文词`context`和其噪声词`negative`组成的样本。此函数返回一个可以在训练期间加载用于计算的小批量,例如包括掩码变量。\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "8e92a65e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:59.264970Z", "iopub.status.busy": "2023-08-18T07:01:59.264337Z", "iopub.status.idle": "2023-08-18T07:01:59.271417Z", "shell.execute_reply": "2023-08-18T07:01:59.270518Z" }, "origin_pos": 31, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def batchify(data):\n", " \"\"\"返回带有负采样的跳元模型的小批量样本\"\"\"\n", " max_len = max(len(c) + len(n) for _, c, n in data)\n", " centers, contexts_negatives, masks, labels = [], [], [], []\n", " for center, context, negative in data:\n", " cur_len = len(context) + len(negative)\n", " centers += [center]\n", " contexts_negatives += \\\n", " [context + negative + [0] * (max_len - cur_len)]\n", " masks += [[1] * cur_len + [0] * (max_len - cur_len)]\n", " labels += [[1] * len(context) + [0] * (max_len - len(context))]\n", " return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(\n", " contexts_negatives), torch.tensor(masks), torch.tensor(labels))" ] }, { "cell_type": "markdown", "id": "7aeb5c51", "metadata": { "origin_pos": 32 }, "source": [ "让我们使用一个小批量的两个样本来测试此函数。\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "e14e34ce", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:59.276193Z", "iopub.status.busy": "2023-08-18T07:01:59.275387Z", "iopub.status.idle": "2023-08-18T07:01:59.282832Z", "shell.execute_reply": "2023-08-18T07:01:59.281912Z" }, "origin_pos": 33, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "centers = tensor([[1],\n", " [1]])\n", "contexts_negatives = tensor([[2, 2, 3, 3, 3, 3],\n", " [2, 2, 2, 3, 3, 0]])\n", "masks = tensor([[1, 1, 1, 1, 1, 1],\n", " [1, 1, 1, 1, 1, 0]])\n", "labels = tensor([[1, 1, 0, 0, 0, 0],\n", " [1, 1, 1, 0, 0, 0]])\n" ] } ], "source": [ "x_1 = (1, [2, 2], [3, 3, 3, 3])\n", "x_2 = (1, [2, 2, 2], [3, 3])\n", "batch = batchify((x_1, x_2))\n", "\n", "names = ['centers', 'contexts_negatives', 'masks', 'labels']\n", "for name, data in zip(names, batch):\n", " print(name, '=', data)" ] }, { "cell_type": "markdown", "id": "e1eef3d8", "metadata": { "origin_pos": 34 }, "source": [ "## 整合代码\n", "\n", "最后,我们定义了读取PTB数据集并返回数据迭代器和词表的`load_data_ptb`函数。\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "8ddfb20d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:59.287587Z", "iopub.status.busy": "2023-08-18T07:01:59.286823Z", "iopub.status.idle": "2023-08-18T07:01:59.296040Z", "shell.execute_reply": "2023-08-18T07:01:59.294978Z" }, "origin_pos": 36, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def load_data_ptb(batch_size, max_window_size, num_noise_words):\n", " \"\"\"下载PTB数据集,然后将其加载到内存中\"\"\"\n", " num_workers = d2l.get_dataloader_workers()\n", " sentences = read_ptb()\n", " vocab = d2l.Vocab(sentences, min_freq=10)\n", " subsampled, counter = subsample(sentences, vocab)\n", " corpus = [vocab[line] for line in subsampled]\n", " all_centers, all_contexts = get_centers_and_contexts(\n", " corpus, max_window_size)\n", " all_negatives = get_negatives(\n", " all_contexts, vocab, counter, num_noise_words)\n", "\n", " class PTBDataset(torch.utils.data.Dataset):\n", " def __init__(self, centers, contexts, negatives):\n", " assert len(centers) == len(contexts) == len(negatives)\n", " self.centers = centers\n", " self.contexts = contexts\n", " self.negatives = negatives\n", "\n", " def __getitem__(self, index):\n", " return (self.centers[index], self.contexts[index],\n", " self.negatives[index])\n", "\n", " def __len__(self):\n", " return len(self.centers)\n", "\n", " dataset = PTBDataset(all_centers, all_contexts, all_negatives)\n", "\n", " data_iter = torch.utils.data.DataLoader(\n", " dataset, batch_size, shuffle=True,\n", " collate_fn=batchify, num_workers=num_workers)\n", " return data_iter, vocab" ] }, { "cell_type": "markdown", "id": "97991d10", "metadata": { "origin_pos": 38 }, "source": [ "让我们打印数据迭代器的第一个小批量。\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "5115b257", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:59.300574Z", "iopub.status.busy": "2023-08-18T07:01:59.299960Z", "iopub.status.idle": "2023-08-18T07:02:13.672095Z", "shell.execute_reply": "2023-08-18T07:02:13.671142Z" }, "origin_pos": 39, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "centers shape: torch.Size([512, 1])\n", "contexts_negatives shape: torch.Size([512, 60])\n", "masks shape: torch.Size([512, 60])\n", "labels shape: torch.Size([512, 60])\n" ] } ], "source": [ "data_iter, vocab = load_data_ptb(512, 5, 5)\n", "for batch in data_iter:\n", " for name, data in zip(names, batch):\n", " print(name, 'shape:', data.shape)\n", " break" ] }, { "cell_type": "markdown", "id": "cfc03f54", "metadata": { "origin_pos": 40 }, "source": [ "## 小结\n", "\n", "* 高频词在训练中可能不是那么有用。我们可以对他们进行下采样,以便在训练中加快速度。\n", "* 为了提高计算效率,我们以小批量方式加载样本。我们可以定义其他变量来区分填充标记和非填充标记,以及正例和负例。\n", "\n", "## 练习\n", "\n", "1. 如果不使用下采样,本节中代码的运行时间会发生什么变化?\n", "1. `RandomGenerator`类缓存`k`个随机采样结果。将`k`设置为其他值,看看它如何影响数据加载速度。\n", "1. 本节代码中的哪些其他超参数可能会影响数据加载速度?\n" ] }, { "cell_type": "markdown", "id": "9e415387", "metadata": { "origin_pos": 42, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/5735)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }