{ "cells": [ { "cell_type": "markdown", "id": "099f849e", "metadata": { "origin_pos": 0 }, "source": [ "# 预训练BERT\n", ":label:`sec_bert-pretraining`\n", "\n", "利用 :numref:`sec_bert`中实现的BERT模型和 :numref:`sec_bert-dataset`中从WikiText-2数据集生成的预训练样本,我们将在本节中在WikiText-2数据集上对BERT进行预训练。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "8c0979b7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:26.170037Z", "iopub.status.busy": "2023-08-18T07:04:26.168910Z", "iopub.status.idle": "2023-08-18T07:04:28.547324Z", "shell.execute_reply": "2023-08-18T07:04:28.546158Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "898d6f91", "metadata": { "origin_pos": 4 }, "source": [ "首先,我们加载WikiText-2数据集作为小批量的预训练样本,用于遮蔽语言模型和下一句预测。批量大小是512,BERT输入序列的最大长度是64。注意,在原始BERT模型中,最大长度是512。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "95571e6a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:28.552742Z", "iopub.status.busy": "2023-08-18T07:04:28.552374Z", "iopub.status.idle": "2023-08-18T07:04:38.456343Z", "shell.execute_reply": "2023-08-18T07:04:38.455141Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size, max_len = 512, 64\n", "train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)" ] }, { "cell_type": "markdown", "id": "cfb22b86", "metadata": { "origin_pos": 7 }, "source": [ "## 预训练BERT\n", "\n", "原始BERT :cite:`Devlin.Chang.Lee.ea.2018`有两个不同模型尺寸的版本。基本模型($\\text{BERT}_{\\text{BASE}}$)使用12层(Transformer编码器块),768个隐藏单元(隐藏大小)和12个自注意头。大模型($\\text{BERT}_{\\text{LARGE}}$)使用24层,1024个隐藏单元和16个自注意头。值得注意的是,前者有1.1亿个参数,后者有3.4亿个参数。为了便于演示,我们定义了一个小的BERT,使用了2层、128个隐藏单元和2个自注意头。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "3cc34825", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:38.461166Z", "iopub.status.busy": "2023-08-18T07:04:38.460802Z", "iopub.status.idle": "2023-08-18T07:04:38.581653Z", "shell.execute_reply": "2023-08-18T07:04:38.580139Z" }, "origin_pos": 9, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "net = d2l.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],\n", " ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,\n", " num_layers=2, dropout=0.2, key_size=128, query_size=128,\n", " value_size=128, hid_in_features=128, mlm_in_features=128,\n", " nsp_in_features=128)\n", "devices = d2l.try_all_gpus()\n", "loss = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "id": "be063421", "metadata": { "origin_pos": 10 }, "source": [ "在定义训练代码实现之前,我们定义了一个辅助函数`_get_batch_loss_bert`。给定训练样本,该函数计算遮蔽语言模型和下一句子预测任务的损失。请注意,BERT预训练的最终损失是遮蔽语言模型损失和下一句预测损失的和。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "64b2c84b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:38.586837Z", "iopub.status.busy": "2023-08-18T07:04:38.585868Z", "iopub.status.idle": "2023-08-18T07:04:38.594572Z", "shell.execute_reply": "2023-08-18T07:04:38.593478Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,\n", " segments_X, valid_lens_x,\n", " pred_positions_X, mlm_weights_X,\n", " mlm_Y, nsp_y):\n", " # 前向传播\n", " _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,\n", " valid_lens_x.reshape(-1),\n", " pred_positions_X)\n", " # 计算遮蔽语言模型损失\n", " mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\\\n", " mlm_weights_X.reshape(-1, 1)\n", " mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)\n", " # 计算下一句子预测任务的损失\n", " nsp_l = loss(nsp_Y_hat, nsp_y)\n", " l = mlm_l + nsp_l\n", " return mlm_l, nsp_l, l" ] }, { "cell_type": "markdown", "id": "4e553304", "metadata": { "origin_pos": 14 }, "source": [ "通过调用上述两个辅助函数,下面的`train_bert`函数定义了在WikiText-2(`train_iter`)数据集上预训练BERT(`net`)的过程。训练BERT可能需要很长时间。以下函数的输入`num_steps`指定了训练的迭代步数,而不是像`train_ch13`函数那样指定训练的轮数(参见 :numref:`sec_image_augmentation`)。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "6cd43502", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:38.599431Z", "iopub.status.busy": "2023-08-18T07:04:38.598650Z", "iopub.status.idle": "2023-08-18T07:04:38.614756Z", "shell.execute_reply": "2023-08-18T07:04:38.613328Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):\n", " net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n", " trainer = torch.optim.Adam(net.parameters(), lr=0.01)\n", " step, timer = 0, d2l.Timer()\n", " animator = d2l.Animator(xlabel='step', ylabel='loss',\n", " xlim=[1, num_steps], legend=['mlm', 'nsp'])\n", " # 遮蔽语言模型损失的和,下一句预测任务损失的和,句子对的数量,计数\n", " metric = d2l.Accumulator(4)\n", " num_steps_reached = False\n", " while step < num_steps and not num_steps_reached:\n", " for tokens_X, segments_X, valid_lens_x, pred_positions_X,\\\n", " mlm_weights_X, mlm_Y, nsp_y in train_iter:\n", " tokens_X = tokens_X.to(devices[0])\n", " segments_X = segments_X.to(devices[0])\n", " valid_lens_x = valid_lens_x.to(devices[0])\n", " pred_positions_X = pred_positions_X.to(devices[0])\n", " mlm_weights_X = mlm_weights_X.to(devices[0])\n", " mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])\n", " trainer.zero_grad()\n", " timer.start()\n", " mlm_l, nsp_l, l = _get_batch_loss_bert(\n", " net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,\n", " pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)\n", " l.backward()\n", " trainer.step()\n", " metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)\n", " timer.stop()\n", " animator.add(step + 1,\n", " (metric[0] / metric[3], metric[1] / metric[3]))\n", " step += 1\n", " if step == num_steps:\n", " num_steps_reached = True\n", " break\n", "\n", " print(f'MLM loss {metric[0] / metric[3]:.3f}, '\n", " f'NSP loss {metric[1] / metric[3]:.3f}')\n", " print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '\n", " f'{str(devices)}')" ] }, { "cell_type": "markdown", "id": "08640bff", "metadata": { "origin_pos": 18 }, "source": [ "在预训练过程中,我们可以绘制出遮蔽语言模型损失和下一句预测损失。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "35e856a0", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:04:38.619952Z", "iopub.status.busy": "2023-08-18T07:04:38.619192Z", "iopub.status.idle": "2023-08-18T07:05:00.659514Z", "shell.execute_reply": "2023-08-18T07:05:00.658404Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLM loss 5.425, NSP loss 0.775\n", "3485.7 sentence pairs/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:05:00.617374\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train_bert(train_iter, net, loss, len(vocab), devices, 50)" ] }, { "cell_type": "markdown", "id": "ede604ea", "metadata": { "origin_pos": 21 }, "source": [ "## 用BERT表示文本\n", "\n", "在预训练BERT之后,我们可以用它来表示单个文本、文本对或其中的任何词元。下面的函数返回`tokens_a`和`tokens_b`中所有词元的BERT(`net`)表示。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "77f3b8e4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:00.663916Z", "iopub.status.busy": "2023-08-18T07:05:00.663281Z", "iopub.status.idle": "2023-08-18T07:05:00.669609Z", "shell.execute_reply": "2023-08-18T07:05:00.668549Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def get_bert_encoding(net, tokens_a, tokens_b=None):\n", " tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)\n", " token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)\n", " segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)\n", " valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)\n", " encoded_X, _, _ = net(token_ids, segments, valid_len)\n", " return encoded_X" ] }, { "cell_type": "markdown", "id": "25e0697e", "metadata": { "origin_pos": 25 }, "source": [ "考虑“a crane is flying”这句话。回想一下 :numref:`subsec_bert_input_rep`中讨论的BERT的输入表示。插入特殊标记“<cls>”(用于分类)和“<sep>”(用于分隔)后,BERT输入序列的长度为6。因为零是“<cls>”词元,`encoded_text[:, 0, :]`是整个输入语句的BERT表示。为了评估一词多义词元“crane”,我们还打印出了该词元的BERT表示的前三个元素。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "1081fda9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:00.673428Z", "iopub.status.busy": "2023-08-18T07:05:00.672675Z", "iopub.status.idle": "2023-08-18T07:05:00.690133Z", "shell.execute_reply": "2023-08-18T07:05:00.689347Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 6, 128]),\n", " torch.Size([1, 128]),\n", " tensor([-0.5007, -1.0034, 0.8718], device='cuda:0', grad_fn=))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens_a = ['a', 'crane', 'is', 'flying']\n", "encoded_text = get_bert_encoding(net, tokens_a)\n", "# 词元:'','a','crane','is','flying',''\n", "encoded_text_cls = encoded_text[:, 0, :]\n", "encoded_text_crane = encoded_text[:, 2, :]\n", "encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]" ] }, { "cell_type": "markdown", "id": "203ca198", "metadata": { "origin_pos": 27 }, "source": [ "现在考虑一个句子“a crane driver came”和“he just left”。类似地,`encoded_pair[:, 0, :]`是来自预训练BERT的整个句子对的编码结果。注意,多义词元“crane”的前三个元素与上下文不同时的元素不同。这支持了BERT表示是上下文敏感的。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "960c3aa2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:00.694637Z", "iopub.status.busy": "2023-08-18T07:05:00.694061Z", "iopub.status.idle": "2023-08-18T07:05:00.708881Z", "shell.execute_reply": "2023-08-18T07:05:00.707778Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 10, 128]),\n", " torch.Size([1, 128]),\n", " tensor([ 0.5101, -0.4041, -1.2749], device='cuda:0', grad_fn=))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']\n", "encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)\n", "# 词元:'','a','crane','driver','came','','he','just',\n", "# 'left',''\n", "encoded_pair_cls = encoded_pair[:, 0, :]\n", "encoded_pair_crane = encoded_pair[:, 2, :]\n", "encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]" ] }, { "cell_type": "markdown", "id": "d642486c", "metadata": { "origin_pos": 29 }, "source": [ "在 :numref:`chap_nlp_app`中,我们将为下游自然语言处理应用微调预训练的BERT模型。\n", "\n", "## 小结\n", "\n", "* 原始的BERT有两个版本,其中基本模型有1.1亿个参数,大模型有3.4亿个参数。\n", "* 在预训练BERT之后,我们可以用它来表示单个文本、文本对或其中的任何词元。\n", "* 在实验中,同一个词元在不同的上下文中具有不同的BERT表示。这支持BERT表示是上下文敏感的。\n", "\n", "## 练习\n", "\n", "1. 在实验中,我们可以看到遮蔽语言模型损失明显高于下一句预测损失。为什么?\n", "2. 将BERT输入序列的最大长度设置为512(与原始BERT模型相同)。使用原始BERT模型的配置,如$\\text{BERT}_{\\text{LARGE}}$。运行此部分时是否遇到错误?为什么?\n" ] }, { "cell_type": "markdown", "id": "9f6249ab", "metadata": { "origin_pos": 31, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/5743)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }