Files
Python/d2l/d2l-zh/pytorch/chapter_natural-language-processing-pretraining/similarity-analogy.ipynb
T
2025-12-16 09:23:53 +08:00

718 lines
19 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "0160c8de",
"metadata": {
"origin_pos": 0
},
"source": [
"# 词的相似性和类比任务\n",
":label:`sec_synonyms`\n",
"\n",
"在 :numref:`sec_word2vec_pretraining`中,我们在一个小的数据集上训练了一个word2vec模型,并使用它为一个输入词寻找语义相似的词。实际上,在大型语料库上预先训练的词向量可以应用于下游的自然语言处理任务,这将在后面的 :numref:`chap_nlp_app`中讨论。为了直观地演示大型语料库中预训练词向量的语义,让我们将预训练词向量应用到词的相似性和类比任务中。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f23dc33a",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:41.256400Z",
"iopub.status.busy": "2023-08-18T07:06:41.255749Z",
"iopub.status.idle": "2023-08-18T07:06:43.288113Z",
"shell.execute_reply": "2023-08-18T07:06:43.287240Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "ce6db3d6",
"metadata": {
"origin_pos": 4
},
"source": [
"## 加载预训练词向量\n",
"\n",
"以下列出维度为50、100和300的预训练GloVe嵌入,可从[GloVe网站](https://nlp.stanford.edu/projects/glove/)下载。预训练的fastText嵌入有多种语言。这里我们使用可以从[fastText网站](https://fasttext.cc/)下载300维度的英文版本(“wiki.en”)。\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "89f705ca",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:43.292543Z",
"iopub.status.busy": "2023-08-18T07:06:43.291837Z",
"iopub.status.idle": "2023-08-18T07:06:43.297097Z",
"shell.execute_reply": "2023-08-18T07:06:43.296299Z"
},
"origin_pos": 5,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"#@save\n",
"d2l.DATA_HUB['glove.6b.50d'] = (d2l.DATA_URL + 'glove.6B.50d.zip',\n",
" '0b8703943ccdb6eb788e6f091b8946e82231bc4d')\n",
"\n",
"#@save\n",
"d2l.DATA_HUB['glove.6b.100d'] = (d2l.DATA_URL + 'glove.6B.100d.zip',\n",
" 'cd43bfb07e44e6f27cbcc7bc9ae3d80284fdaf5a')\n",
"\n",
"#@save\n",
"d2l.DATA_HUB['glove.42b.300d'] = (d2l.DATA_URL + 'glove.42B.300d.zip',\n",
" 'b5116e234e9eb9076672cfeabf5469f3eec904fa')\n",
"\n",
"#@save\n",
"d2l.DATA_HUB['wiki.en'] = (d2l.DATA_URL + 'wiki.en.zip',\n",
" 'c1816da3821ae9f43899be655002f6c723e91b88')"
]
},
{
"cell_type": "markdown",
"id": "8368bbae",
"metadata": {
"origin_pos": 6
},
"source": [
"为了加载这些预训练的GloVe和fastText嵌入,我们定义了以下`TokenEmbedding`类。\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cd54118c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:43.300883Z",
"iopub.status.busy": "2023-08-18T07:06:43.300205Z",
"iopub.status.idle": "2023-08-18T07:06:43.309328Z",
"shell.execute_reply": "2023-08-18T07:06:43.308481Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"#@save\n",
"class TokenEmbedding:\n",
" \"\"\"GloVe嵌入\"\"\"\n",
" def __init__(self, embedding_name):\n",
" self.idx_to_token, self.idx_to_vec = self._load_embedding(\n",
" embedding_name)\n",
" self.unknown_idx = 0\n",
" self.token_to_idx = {token: idx for idx, token in\n",
" enumerate(self.idx_to_token)}\n",
"\n",
" def _load_embedding(self, embedding_name):\n",
" idx_to_token, idx_to_vec = ['<unk>'], []\n",
" data_dir = d2l.download_extract(embedding_name)\n",
" # GloVe网站:https://nlp.stanford.edu/projects/glove/\n",
" # fastText网站:https://fasttext.cc/\n",
" with open(os.path.join(data_dir, 'vec.txt'), 'r') as f:\n",
" for line in f:\n",
" elems = line.rstrip().split(' ')\n",
" token, elems = elems[0], [float(elem) for elem in elems[1:]]\n",
" # 跳过标题信息,例如fastText中的首行\n",
" if len(elems) > 1:\n",
" idx_to_token.append(token)\n",
" idx_to_vec.append(elems)\n",
" idx_to_vec = [[0] * len(idx_to_vec[0])] + idx_to_vec\n",
" return idx_to_token, torch.tensor(idx_to_vec)\n",
"\n",
" def __getitem__(self, tokens):\n",
" indices = [self.token_to_idx.get(token, self.unknown_idx)\n",
" for token in tokens]\n",
" vecs = self.idx_to_vec[torch.tensor(indices)]\n",
" return vecs\n",
"\n",
" def __len__(self):\n",
" return len(self.idx_to_token)"
]
},
{
"cell_type": "markdown",
"id": "6375fd2e",
"metadata": {
"origin_pos": 8
},
"source": [
"下面我们加载50维GloVe嵌入(在维基百科的子集上预训练)。创建`TokenEmbedding`实例时,如果尚未下载指定的嵌入文件,则必须下载该文件。\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ac49581b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:43.312986Z",
"iopub.status.busy": "2023-08-18T07:06:43.312409Z",
"iopub.status.idle": "2023-08-18T07:06:54.396038Z",
"shell.execute_reply": "2023-08-18T07:06:54.395176Z"
},
"origin_pos": 9,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading ../data/glove.6B.50d.zip from http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.50d.zip...\n"
]
}
],
"source": [
"glove_6b50d = TokenEmbedding('glove.6b.50d')"
]
},
{
"cell_type": "markdown",
"id": "57f30d4e",
"metadata": {
"origin_pos": 10
},
"source": [
"输出词表大小。词表包含400000个词(词元)和一个特殊的未知词元。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5d91a982",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.400162Z",
"iopub.status.busy": "2023-08-18T07:06:54.399579Z",
"iopub.status.idle": "2023-08-18T07:06:54.405466Z",
"shell.execute_reply": "2023-08-18T07:06:54.404676Z"
},
"origin_pos": 11,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"400001"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(glove_6b50d)"
]
},
{
"cell_type": "markdown",
"id": "867f2106",
"metadata": {
"origin_pos": 12
},
"source": [
"我们可以得到词表中一个单词的索引,反之亦然。\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6e10f262",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.408746Z",
"iopub.status.busy": "2023-08-18T07:06:54.408294Z",
"iopub.status.idle": "2023-08-18T07:06:54.413468Z",
"shell.execute_reply": "2023-08-18T07:06:54.412687Z"
},
"origin_pos": 13,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"(3367, 'beautiful')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"glove_6b50d.token_to_idx['beautiful'], glove_6b50d.idx_to_token[3367]"
]
},
{
"cell_type": "markdown",
"id": "92b6c303",
"metadata": {
"origin_pos": 14
},
"source": [
"## 应用预训练词向量\n",
"\n",
"使用加载的GloVe向量,我们将通过下面的词相似性和类比任务中来展示词向量的语义。\n",
"\n",
"### 词相似度\n",
"\n",
"与 :numref:`subsec_apply-word-embed`类似,为了根据词向量之间的余弦相似性为输入词查找语义相似的词,我们实现了以下`knn`($k$近邻)函数。\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2da78732",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.416901Z",
"iopub.status.busy": "2023-08-18T07:06:54.416268Z",
"iopub.status.idle": "2023-08-18T07:06:54.421648Z",
"shell.execute_reply": "2023-08-18T07:06:54.420466Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def knn(W, x, k):\n",
" # 增加1e-9以获得数值稳定性\n",
" cos = torch.mv(W, x.reshape(-1,)) / (\n",
" torch.sqrt(torch.sum(W * W, axis=1) + 1e-9) *\n",
" torch.sqrt((x * x).sum()))\n",
" _, topk = torch.topk(cos, k=k)\n",
" return topk, [cos[int(i)] for i in topk]"
]
},
{
"cell_type": "markdown",
"id": "644a758d",
"metadata": {
"origin_pos": 18
},
"source": [
"然后,我们使用`TokenEmbedding`的实例`embed`中预训练好的词向量来搜索相似的词。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7b1da561",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.425376Z",
"iopub.status.busy": "2023-08-18T07:06:54.424618Z",
"iopub.status.idle": "2023-08-18T07:06:54.430025Z",
"shell.execute_reply": "2023-08-18T07:06:54.428981Z"
},
"origin_pos": 19,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def get_similar_tokens(query_token, k, embed):\n",
" topk, cos = knn(embed.idx_to_vec, embed[[query_token]], k + 1)\n",
" for i, c in zip(topk[1:], cos[1:]): # 排除输入词\n",
" print(f'{embed.idx_to_token[int(i)]}cosine相似度={float(c):.3f}')"
]
},
{
"cell_type": "markdown",
"id": "6ba6f5c8",
"metadata": {
"origin_pos": 20
},
"source": [
"`glove_6b50d`中预训练词向量的词表包含400000个词和一个特殊的未知词元。排除输入词和未知词元后,我们在词表中找到与“chip”一词语义最相似的三个词。\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "623bc4a9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.433258Z",
"iopub.status.busy": "2023-08-18T07:06:54.432943Z",
"iopub.status.idle": "2023-08-18T07:06:54.481827Z",
"shell.execute_reply": "2023-08-18T07:06:54.480628Z"
},
"origin_pos": 21,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"chipscosine相似度=0.856\n",
"intelcosine相似度=0.749\n",
"electronicscosine相似度=0.749\n"
]
}
],
"source": [
"get_similar_tokens('chip', 3, glove_6b50d)"
]
},
{
"cell_type": "markdown",
"id": "c18fa17a",
"metadata": {
"origin_pos": 22
},
"source": [
"下面输出与“baby”和“beautiful”相似的词。\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d2fd5e8f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.486458Z",
"iopub.status.busy": "2023-08-18T07:06:54.485962Z",
"iopub.status.idle": "2023-08-18T07:06:54.508991Z",
"shell.execute_reply": "2023-08-18T07:06:54.507938Z"
},
"origin_pos": 23,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"babiescosine相似度=0.839\n",
"boycosine相似度=0.800\n",
"girlcosine相似度=0.792\n"
]
}
],
"source": [
"get_similar_tokens('baby', 3, glove_6b50d)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "faa9e2e2",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.513356Z",
"iopub.status.busy": "2023-08-18T07:06:54.512976Z",
"iopub.status.idle": "2023-08-18T07:06:54.534489Z",
"shell.execute_reply": "2023-08-18T07:06:54.533425Z"
},
"origin_pos": 24,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lovelycosine相似度=0.921\n",
"gorgeouscosine相似度=0.893\n",
"wonderfulcosine相似度=0.830\n"
]
}
],
"source": [
"get_similar_tokens('beautiful', 3, glove_6b50d)"
]
},
{
"cell_type": "markdown",
"id": "5cc0553d",
"metadata": {
"origin_pos": 25
},
"source": [
"### 词类比\n",
"\n",
"除了找到相似的词,我们还可以将词向量应用到词类比任务中。\n",
"例如,“man” : “woman” :: “son” : “daughter”是一个词的类比。\n",
"“man”是对“woman”的类比,“son”是对“daughter”的类比。\n",
"具体来说,词类比任务可以定义为:\n",
"对于单词类比$a : b :: c : d$,给出前三个词$a$、$b$和$c$,找到$d$。\n",
"用$\\text{vec}(w)$表示词$w$的向量,\n",
"为了完成这个类比,我们将找到一个词,\n",
"其向量与$\\text{vec}(c)+\\text{vec}(b)-\\text{vec}(a)$的结果最相似。\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e5340469",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.539108Z",
"iopub.status.busy": "2023-08-18T07:06:54.538593Z",
"iopub.status.idle": "2023-08-18T07:06:54.544150Z",
"shell.execute_reply": "2023-08-18T07:06:54.543191Z"
},
"origin_pos": 26,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def get_analogy(token_a, token_b, token_c, embed):\n",
" vecs = embed[[token_a, token_b, token_c]]\n",
" x = vecs[1] - vecs[0] + vecs[2]\n",
" topk, cos = knn(embed.idx_to_vec, x, 1)\n",
" return embed.idx_to_token[int(topk[0])] # 删除未知词"
]
},
{
"cell_type": "markdown",
"id": "df8f2721",
"metadata": {
"origin_pos": 27
},
"source": [
"让我们使用加载的词向量来验证“male-female”类比。\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e91de1ce",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.548236Z",
"iopub.status.busy": "2023-08-18T07:06:54.547963Z",
"iopub.status.idle": "2023-08-18T07:06:54.569097Z",
"shell.execute_reply": "2023-08-18T07:06:54.568018Z"
},
"origin_pos": 28,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"'daughter'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_analogy('man', 'woman', 'son', glove_6b50d)"
]
},
{
"cell_type": "markdown",
"id": "d9b1ce80",
"metadata": {
"origin_pos": 29
},
"source": [
"下面完成一个“首都-国家”的类比:\n",
"“beijing” : “china” :: “tokyo” : “japan”。\n",
"这说明了预训练词向量中的语义。\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "16eb56d3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.573551Z",
"iopub.status.busy": "2023-08-18T07:06:54.573270Z",
"iopub.status.idle": "2023-08-18T07:06:54.595104Z",
"shell.execute_reply": "2023-08-18T07:06:54.594092Z"
},
"origin_pos": 30,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"'japan'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_analogy('beijing', 'china', 'tokyo', glove_6b50d)"
]
},
{
"cell_type": "markdown",
"id": "595634f2",
"metadata": {
"origin_pos": 31
},
"source": [
"另外,对于“bad” : “worst” :: “big” : “biggest”等“形容词-形容词最高级”的比喻,预训练词向量可以捕捉到句法信息。\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "b8d6395b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.599698Z",
"iopub.status.busy": "2023-08-18T07:06:54.599313Z",
"iopub.status.idle": "2023-08-18T07:06:54.621533Z",
"shell.execute_reply": "2023-08-18T07:06:54.620486Z"
},
"origin_pos": 32,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"'biggest'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_analogy('bad', 'worst', 'big', glove_6b50d)"
]
},
{
"cell_type": "markdown",
"id": "a6555f30",
"metadata": {
"origin_pos": 33
},
"source": [
"为了演示在预训练词向量中捕捉到的过去式概念,我们可以使用“现在式-过去式”的类比来测试句法:“do” : “did” :: “go” : “went”。\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "986fa401",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:06:54.626086Z",
"iopub.status.busy": "2023-08-18T07:06:54.625554Z",
"iopub.status.idle": "2023-08-18T07:06:54.647570Z",
"shell.execute_reply": "2023-08-18T07:06:54.646604Z"
},
"origin_pos": 34,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"'went'"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_analogy('do', 'did', 'go', glove_6b50d)"
]
},
{
"cell_type": "markdown",
"id": "61371af5",
"metadata": {
"origin_pos": 35
},
"source": [
"## 小结\n",
"\n",
"* 在实践中,在大型语料库上预先练的词向量可以应用于下游的自然语言处理任务。\n",
"* 预训练的词向量可以应用于词的相似性和类比任务。\n",
"\n",
"## 练习\n",
"\n",
"1. 使用`TokenEmbedding('wiki.en')`测试fastText结果。\n",
"1. 当词表非常大时,我们怎样才能更快地找到相似的词或完成一个词的类比呢?\n"
]
},
{
"cell_type": "markdown",
"id": "bfebf384",
"metadata": {
"origin_pos": 37,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/5746)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}