1300 lines
46 KiB
Plaintext
1300 lines
46 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "099f849e",
|
||
"metadata": {
|
||
"origin_pos": 0
|
||
},
|
||
"source": [
|
||
"# 预训练BERT\n",
|
||
":label:`sec_bert-pretraining`\n",
|
||
"\n",
|
||
"利用 :numref:`sec_bert`中实现的BERT模型和 :numref:`sec_bert-dataset`中从WikiText-2数据集生成的预训练样本,我们将在本节中在WikiText-2数据集上对BERT进行预训练。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "8c0979b7",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:04:26.170037Z",
|
||
"iopub.status.busy": "2023-08-18T07:04:26.168910Z",
|
||
"iopub.status.idle": "2023-08-18T07:04:28.547324Z",
|
||
"shell.execute_reply": "2023-08-18T07:04:28.546158Z"
|
||
},
|
||
"origin_pos": 2,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"from torch import nn\n",
|
||
"from d2l import torch as d2l"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "898d6f91",
|
||
"metadata": {
|
||
"origin_pos": 4
|
||
},
|
||
"source": [
|
||
"首先,我们加载WikiText-2数据集作为小批量的预训练样本,用于遮蔽语言模型和下一句预测。批量大小是512,BERT输入序列的最大长度是64。注意,在原始BERT模型中,最大长度是512。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "95571e6a",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:04:28.552742Z",
|
||
"iopub.status.busy": "2023-08-18T07:04:28.552374Z",
|
||
"iopub.status.idle": "2023-08-18T07:04:38.456343Z",
|
||
"shell.execute_reply": "2023-08-18T07:04:38.455141Z"
|
||
},
|
||
"origin_pos": 5,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"batch_size, max_len = 512, 64\n",
|
||
"train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "cfb22b86",
|
||
"metadata": {
|
||
"origin_pos": 7
|
||
},
|
||
"source": [
|
||
"## 预训练BERT\n",
|
||
"\n",
|
||
"原始BERT :cite:`Devlin.Chang.Lee.ea.2018`有两个不同模型尺寸的版本。基本模型($\\text{BERT}_{\\text{BASE}}$)使用12层(Transformer编码器块),768个隐藏单元(隐藏大小)和12个自注意头。大模型($\\text{BERT}_{\\text{LARGE}}$)使用24层,1024个隐藏单元和16个自注意头。值得注意的是,前者有1.1亿个参数,后者有3.4亿个参数。为了便于演示,我们定义了一个小的BERT,使用了2层、128个隐藏单元和2个自注意头。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "3cc34825",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:04:38.461166Z",
|
||
"iopub.status.busy": "2023-08-18T07:04:38.460802Z",
|
||
"iopub.status.idle": "2023-08-18T07:04:38.581653Z",
|
||
"shell.execute_reply": "2023-08-18T07:04:38.580139Z"
|
||
},
|
||
"origin_pos": 9,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"net = d2l.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],\n",
|
||
" ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,\n",
|
||
" num_layers=2, dropout=0.2, key_size=128, query_size=128,\n",
|
||
" value_size=128, hid_in_features=128, mlm_in_features=128,\n",
|
||
" nsp_in_features=128)\n",
|
||
"devices = d2l.try_all_gpus()\n",
|
||
"loss = nn.CrossEntropyLoss()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "be063421",
|
||
"metadata": {
|
||
"origin_pos": 10
|
||
},
|
||
"source": [
|
||
"在定义训练代码实现之前,我们定义了一个辅助函数`_get_batch_loss_bert`。给定训练样本,该函数计算遮蔽语言模型和下一句子预测任务的损失。请注意,BERT预训练的最终损失是遮蔽语言模型损失和下一句预测损失的和。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "64b2c84b",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:04:38.586837Z",
|
||
"iopub.status.busy": "2023-08-18T07:04:38.585868Z",
|
||
"iopub.status.idle": "2023-08-18T07:04:38.594572Z",
|
||
"shell.execute_reply": "2023-08-18T07:04:38.593478Z"
|
||
},
|
||
"origin_pos": 12,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@save\n",
|
||
"def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,\n",
|
||
" segments_X, valid_lens_x,\n",
|
||
" pred_positions_X, mlm_weights_X,\n",
|
||
" mlm_Y, nsp_y):\n",
|
||
" # 前向传播\n",
|
||
" _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,\n",
|
||
" valid_lens_x.reshape(-1),\n",
|
||
" pred_positions_X)\n",
|
||
" # 计算遮蔽语言模型损失\n",
|
||
" mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\\\n",
|
||
" mlm_weights_X.reshape(-1, 1)\n",
|
||
" mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)\n",
|
||
" # 计算下一句子预测任务的损失\n",
|
||
" nsp_l = loss(nsp_Y_hat, nsp_y)\n",
|
||
" l = mlm_l + nsp_l\n",
|
||
" return mlm_l, nsp_l, l"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "4e553304",
|
||
"metadata": {
|
||
"origin_pos": 14
|
||
},
|
||
"source": [
|
||
"通过调用上述两个辅助函数,下面的`train_bert`函数定义了在WikiText-2(`train_iter`)数据集上预训练BERT(`net`)的过程。训练BERT可能需要很长时间。以下函数的输入`num_steps`指定了训练的迭代步数,而不是像`train_ch13`函数那样指定训练的轮数(参见 :numref:`sec_image_augmentation`)。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "6cd43502",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:04:38.599431Z",
|
||
"iopub.status.busy": "2023-08-18T07:04:38.598650Z",
|
||
"iopub.status.idle": "2023-08-18T07:04:38.614756Z",
|
||
"shell.execute_reply": "2023-08-18T07:04:38.613328Z"
|
||
},
|
||
"origin_pos": 16,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):\n",
|
||
" net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n",
|
||
" trainer = torch.optim.Adam(net.parameters(), lr=0.01)\n",
|
||
" step, timer = 0, d2l.Timer()\n",
|
||
" animator = d2l.Animator(xlabel='step', ylabel='loss',\n",
|
||
" xlim=[1, num_steps], legend=['mlm', 'nsp'])\n",
|
||
" # 遮蔽语言模型损失的和,下一句预测任务损失的和,句子对的数量,计数\n",
|
||
" metric = d2l.Accumulator(4)\n",
|
||
" num_steps_reached = False\n",
|
||
" while step < num_steps and not num_steps_reached:\n",
|
||
" for tokens_X, segments_X, valid_lens_x, pred_positions_X,\\\n",
|
||
" mlm_weights_X, mlm_Y, nsp_y in train_iter:\n",
|
||
" tokens_X = tokens_X.to(devices[0])\n",
|
||
" segments_X = segments_X.to(devices[0])\n",
|
||
" valid_lens_x = valid_lens_x.to(devices[0])\n",
|
||
" pred_positions_X = pred_positions_X.to(devices[0])\n",
|
||
" mlm_weights_X = mlm_weights_X.to(devices[0])\n",
|
||
" mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])\n",
|
||
" trainer.zero_grad()\n",
|
||
" timer.start()\n",
|
||
" mlm_l, nsp_l, l = _get_batch_loss_bert(\n",
|
||
" net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,\n",
|
||
" pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)\n",
|
||
" l.backward()\n",
|
||
" trainer.step()\n",
|
||
" metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)\n",
|
||
" timer.stop()\n",
|
||
" animator.add(step + 1,\n",
|
||
" (metric[0] / metric[3], metric[1] / metric[3]))\n",
|
||
" step += 1\n",
|
||
" if step == num_steps:\n",
|
||
" num_steps_reached = True\n",
|
||
" break\n",
|
||
"\n",
|
||
" print(f'MLM loss {metric[0] / metric[3]:.3f}, '\n",
|
||
" f'NSP loss {metric[1] / metric[3]:.3f}')\n",
|
||
" print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '\n",
|
||
" f'{str(devices)}')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "08640bff",
|
||
"metadata": {
|
||
"origin_pos": 18
|
||
},
|
||
"source": [
|
||
"在预训练过程中,我们可以绘制出遮蔽语言模型损失和下一句预测损失。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "35e856a0",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:04:38.619952Z",
|
||
"iopub.status.busy": "2023-08-18T07:04:38.619192Z",
|
||
"iopub.status.idle": "2023-08-18T07:05:00.659514Z",
|
||
"shell.execute_reply": "2023-08-18T07:05:00.658404Z"
|
||
},
|
||
"origin_pos": 19,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"MLM loss 5.425, NSP loss 0.775\n",
|
||
"3485.7 sentence pairs/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/svg+xml": [
|
||
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
|
||
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"249.465625pt\" height=\"180.65625pt\" viewBox=\"0 0 249.465625 180.65625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
|
||
" <metadata>\n",
|
||
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
|
||
" <cc:Work>\n",
|
||
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
|
||
" <dc:date>2023-08-18T07:05:00.617374</dc:date>\n",
|
||
" <dc:format>image/svg+xml</dc:format>\n",
|
||
" <dc:creator>\n",
|
||
" <cc:Agent>\n",
|
||
" <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
|
||
" </cc:Agent>\n",
|
||
" </dc:creator>\n",
|
||
" </cc:Work>\n",
|
||
" </rdf:RDF>\n",
|
||
" </metadata>\n",
|
||
" <defs>\n",
|
||
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
|
||
" </defs>\n",
|
||
" <g id=\"figure_1\">\n",
|
||
" <g id=\"patch_1\">\n",
|
||
" <path d=\"M 0 180.65625 \n",
|
||
"L 249.465625 180.65625 \n",
|
||
"L 249.465625 0 \n",
|
||
"L 0 0 \n",
|
||
"L 0 180.65625 \n",
|
||
"z\n",
|
||
"\" style=\"fill: none\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"axes_1\">\n",
|
||
" <g id=\"patch_2\">\n",
|
||
" <path d=\"M 40.603125 143.1 \n",
|
||
"L 235.903125 143.1 \n",
|
||
"L 235.903125 7.2 \n",
|
||
"L 40.603125 7.2 \n",
|
||
"z\n",
|
||
"\" style=\"fill: #ffffff\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"matplotlib.axis_1\">\n",
|
||
" <g id=\"xtick_1\">\n",
|
||
" <g id=\"line2d_1\">\n",
|
||
" <path d=\"M 76.474554 143.1 \n",
|
||
"L 76.474554 7.2 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_2\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"ma84ba932a6\" d=\"M 0 0 \n",
|
||
"L 0 3.5 \n",
|
||
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </defs>\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#ma84ba932a6\" x=\"76.474554\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_1\">\n",
|
||
" <!-- 10 -->\n",
|
||
" <g transform=\"translate(70.112054 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
|
||
"L 1825 531 \n",
|
||
"L 1825 4091 \n",
|
||
"L 703 3866 \n",
|
||
"L 703 4441 \n",
|
||
"L 1819 4666 \n",
|
||
"L 2450 4666 \n",
|
||
"L 2450 531 \n",
|
||
"L 3481 531 \n",
|
||
"L 3481 0 \n",
|
||
"L 794 0 \n",
|
||
"L 794 531 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
|
||
"Q 1547 4250 1301 3770 \n",
|
||
"Q 1056 3291 1056 2328 \n",
|
||
"Q 1056 1369 1301 889 \n",
|
||
"Q 1547 409 2034 409 \n",
|
||
"Q 2525 409 2770 889 \n",
|
||
"Q 3016 1369 3016 2328 \n",
|
||
"Q 3016 3291 2770 3770 \n",
|
||
"Q 2525 4250 2034 4250 \n",
|
||
"z\n",
|
||
"M 2034 4750 \n",
|
||
"Q 2819 4750 3233 4129 \n",
|
||
"Q 3647 3509 3647 2328 \n",
|
||
"Q 3647 1150 3233 529 \n",
|
||
"Q 2819 -91 2034 -91 \n",
|
||
"Q 1250 -91 836 529 \n",
|
||
"Q 422 1150 422 2328 \n",
|
||
"Q 422 3509 836 4129 \n",
|
||
"Q 1250 4750 2034 4750 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_2\">\n",
|
||
" <g id=\"line2d_3\">\n",
|
||
" <path d=\"M 116.331696 143.1 \n",
|
||
"L 116.331696 7.2 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_4\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#ma84ba932a6\" x=\"116.331696\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_2\">\n",
|
||
" <!-- 20 -->\n",
|
||
" <g transform=\"translate(109.969196 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
|
||
"L 3431 531 \n",
|
||
"L 3431 0 \n",
|
||
"L 469 0 \n",
|
||
"L 469 531 \n",
|
||
"Q 828 903 1448 1529 \n",
|
||
"Q 2069 2156 2228 2338 \n",
|
||
"Q 2531 2678 2651 2914 \n",
|
||
"Q 2772 3150 2772 3378 \n",
|
||
"Q 2772 3750 2511 3984 \n",
|
||
"Q 2250 4219 1831 4219 \n",
|
||
"Q 1534 4219 1204 4116 \n",
|
||
"Q 875 4013 500 3803 \n",
|
||
"L 500 4441 \n",
|
||
"Q 881 4594 1212 4672 \n",
|
||
"Q 1544 4750 1819 4750 \n",
|
||
"Q 2544 4750 2975 4387 \n",
|
||
"Q 3406 4025 3406 3419 \n",
|
||
"Q 3406 3131 3298 2873 \n",
|
||
"Q 3191 2616 2906 2266 \n",
|
||
"Q 2828 2175 2409 1742 \n",
|
||
"Q 1991 1309 1228 531 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_3\">\n",
|
||
" <g id=\"line2d_5\">\n",
|
||
" <path d=\"M 156.188839 143.1 \n",
|
||
"L 156.188839 7.2 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_6\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#ma84ba932a6\" x=\"156.188839\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_3\">\n",
|
||
" <!-- 30 -->\n",
|
||
" <g transform=\"translate(149.826339 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
|
||
"Q 3050 2419 3304 2112 \n",
|
||
"Q 3559 1806 3559 1356 \n",
|
||
"Q 3559 666 3084 287 \n",
|
||
"Q 2609 -91 1734 -91 \n",
|
||
"Q 1441 -91 1130 -33 \n",
|
||
"Q 819 25 488 141 \n",
|
||
"L 488 750 \n",
|
||
"Q 750 597 1062 519 \n",
|
||
"Q 1375 441 1716 441 \n",
|
||
"Q 2309 441 2620 675 \n",
|
||
"Q 2931 909 2931 1356 \n",
|
||
"Q 2931 1769 2642 2001 \n",
|
||
"Q 2353 2234 1838 2234 \n",
|
||
"L 1294 2234 \n",
|
||
"L 1294 2753 \n",
|
||
"L 1863 2753 \n",
|
||
"Q 2328 2753 2575 2939 \n",
|
||
"Q 2822 3125 2822 3475 \n",
|
||
"Q 2822 3834 2567 4026 \n",
|
||
"Q 2313 4219 1838 4219 \n",
|
||
"Q 1578 4219 1281 4162 \n",
|
||
"Q 984 4106 628 3988 \n",
|
||
"L 628 4550 \n",
|
||
"Q 988 4650 1302 4700 \n",
|
||
"Q 1616 4750 1894 4750 \n",
|
||
"Q 2613 4750 3031 4423 \n",
|
||
"Q 3450 4097 3450 3541 \n",
|
||
"Q 3450 3153 3228 2886 \n",
|
||
"Q 3006 2619 2597 2516 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-33\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_4\">\n",
|
||
" <g id=\"line2d_7\">\n",
|
||
" <path d=\"M 196.045982 143.1 \n",
|
||
"L 196.045982 7.2 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_8\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#ma84ba932a6\" x=\"196.045982\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_4\">\n",
|
||
" <!-- 40 -->\n",
|
||
" <g transform=\"translate(189.683482 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
|
||
"L 825 1625 \n",
|
||
"L 2419 1625 \n",
|
||
"L 2419 4116 \n",
|
||
"z\n",
|
||
"M 2253 4666 \n",
|
||
"L 3047 4666 \n",
|
||
"L 3047 1625 \n",
|
||
"L 3713 1625 \n",
|
||
"L 3713 1100 \n",
|
||
"L 3047 1100 \n",
|
||
"L 3047 0 \n",
|
||
"L 2419 0 \n",
|
||
"L 2419 1100 \n",
|
||
"L 313 1100 \n",
|
||
"L 313 1709 \n",
|
||
"L 2253 4666 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"xtick_5\">\n",
|
||
" <g id=\"line2d_9\">\n",
|
||
" <path d=\"M 235.903125 143.1 \n",
|
||
"L 235.903125 7.2 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_10\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#ma84ba932a6\" x=\"235.903125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_5\">\n",
|
||
" <!-- 50 -->\n",
|
||
" <g transform=\"translate(229.540625 157.698438)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
|
||
"L 3169 4666 \n",
|
||
"L 3169 4134 \n",
|
||
"L 1269 4134 \n",
|
||
"L 1269 2991 \n",
|
||
"Q 1406 3038 1543 3061 \n",
|
||
"Q 1681 3084 1819 3084 \n",
|
||
"Q 2600 3084 3056 2656 \n",
|
||
"Q 3513 2228 3513 1497 \n",
|
||
"Q 3513 744 3044 326 \n",
|
||
"Q 2575 -91 1722 -91 \n",
|
||
"Q 1428 -91 1123 -41 \n",
|
||
"Q 819 9 494 109 \n",
|
||
"L 494 744 \n",
|
||
"Q 775 591 1075 516 \n",
|
||
"Q 1375 441 1709 441 \n",
|
||
"Q 2250 441 2565 725 \n",
|
||
"Q 2881 1009 2881 1497 \n",
|
||
"Q 2881 1984 2565 2268 \n",
|
||
"Q 2250 2553 1709 2553 \n",
|
||
"Q 1456 2553 1204 2497 \n",
|
||
"Q 953 2441 691 2322 \n",
|
||
"L 691 4666 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-35\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_6\">\n",
|
||
" <!-- step -->\n",
|
||
" <g transform=\"translate(127.4375 171.376563)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
|
||
"L 2834 2853 \n",
|
||
"Q 2591 2978 2328 3040 \n",
|
||
"Q 2066 3103 1784 3103 \n",
|
||
"Q 1356 3103 1142 2972 \n",
|
||
"Q 928 2841 928 2578 \n",
|
||
"Q 928 2378 1081 2264 \n",
|
||
"Q 1234 2150 1697 2047 \n",
|
||
"L 1894 2003 \n",
|
||
"Q 2506 1872 2764 1633 \n",
|
||
"Q 3022 1394 3022 966 \n",
|
||
"Q 3022 478 2636 193 \n",
|
||
"Q 2250 -91 1575 -91 \n",
|
||
"Q 1294 -91 989 -36 \n",
|
||
"Q 684 19 347 128 \n",
|
||
"L 347 722 \n",
|
||
"Q 666 556 975 473 \n",
|
||
"Q 1284 391 1588 391 \n",
|
||
"Q 1994 391 2212 530 \n",
|
||
"Q 2431 669 2431 922 \n",
|
||
"Q 2431 1156 2273 1281 \n",
|
||
"Q 2116 1406 1581 1522 \n",
|
||
"L 1381 1569 \n",
|
||
"Q 847 1681 609 1914 \n",
|
||
"Q 372 2147 372 2553 \n",
|
||
"Q 372 3047 722 3315 \n",
|
||
"Q 1072 3584 1716 3584 \n",
|
||
"Q 2034 3584 2315 3537 \n",
|
||
"Q 2597 3491 2834 3397 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-74\" d=\"M 1172 4494 \n",
|
||
"L 1172 3500 \n",
|
||
"L 2356 3500 \n",
|
||
"L 2356 3053 \n",
|
||
"L 1172 3053 \n",
|
||
"L 1172 1153 \n",
|
||
"Q 1172 725 1289 603 \n",
|
||
"Q 1406 481 1766 481 \n",
|
||
"L 2356 481 \n",
|
||
"L 2356 0 \n",
|
||
"L 1766 0 \n",
|
||
"Q 1100 0 847 248 \n",
|
||
"Q 594 497 594 1153 \n",
|
||
"L 594 3053 \n",
|
||
"L 172 3053 \n",
|
||
"L 172 3500 \n",
|
||
"L 594 3500 \n",
|
||
"L 594 4494 \n",
|
||
"L 1172 4494 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
|
||
"L 3597 1613 \n",
|
||
"L 953 1613 \n",
|
||
"Q 991 1019 1311 708 \n",
|
||
"Q 1631 397 2203 397 \n",
|
||
"Q 2534 397 2845 478 \n",
|
||
"Q 3156 559 3463 722 \n",
|
||
"L 3463 178 \n",
|
||
"Q 3153 47 2828 -22 \n",
|
||
"Q 2503 -91 2169 -91 \n",
|
||
"Q 1331 -91 842 396 \n",
|
||
"Q 353 884 353 1716 \n",
|
||
"Q 353 2575 817 3079 \n",
|
||
"Q 1281 3584 2069 3584 \n",
|
||
"Q 2775 3584 3186 3129 \n",
|
||
"Q 3597 2675 3597 1894 \n",
|
||
"z\n",
|
||
"M 3022 2063 \n",
|
||
"Q 3016 2534 2758 2815 \n",
|
||
"Q 2500 3097 2075 3097 \n",
|
||
"Q 1594 3097 1305 2825 \n",
|
||
"Q 1016 2553 972 2059 \n",
|
||
"L 3022 2063 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
|
||
"L 1159 -1331 \n",
|
||
"L 581 -1331 \n",
|
||
"L 581 3500 \n",
|
||
"L 1159 3500 \n",
|
||
"L 1159 2969 \n",
|
||
"Q 1341 3281 1617 3432 \n",
|
||
"Q 1894 3584 2278 3584 \n",
|
||
"Q 2916 3584 3314 3078 \n",
|
||
"Q 3713 2572 3713 1747 \n",
|
||
"Q 3713 922 3314 415 \n",
|
||
"Q 2916 -91 2278 -91 \n",
|
||
"Q 1894 -91 1617 61 \n",
|
||
"Q 1341 213 1159 525 \n",
|
||
"z\n",
|
||
"M 3116 1747 \n",
|
||
"Q 3116 2381 2855 2742 \n",
|
||
"Q 2594 3103 2138 3103 \n",
|
||
"Q 1681 3103 1420 2742 \n",
|
||
"Q 1159 2381 1159 1747 \n",
|
||
"Q 1159 1113 1420 752 \n",
|
||
"Q 1681 391 2138 391 \n",
|
||
"Q 2594 391 2855 752 \n",
|
||
"Q 3116 1113 3116 1747 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-74\" x=\"52.099609\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-65\" x=\"91.308594\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-70\" x=\"152.832031\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"matplotlib.axis_2\">\n",
|
||
" <g id=\"ytick_1\">\n",
|
||
" <g id=\"line2d_11\">\n",
|
||
" <path d=\"M 40.603125 120.23593 \n",
|
||
"L 235.903125 120.23593 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_12\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"m4be01386ce\" d=\"M 0 0 \n",
|
||
"L -3.5 0 \n",
|
||
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </defs>\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m4be01386ce\" x=\"40.603125\" y=\"120.23593\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_7\">\n",
|
||
" <!-- 2 -->\n",
|
||
" <g transform=\"translate(27.240625 124.035149)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_2\">\n",
|
||
" <g id=\"line2d_13\">\n",
|
||
" <path d=\"M 40.603125 93.951993 \n",
|
||
"L 235.903125 93.951993 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_14\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m4be01386ce\" x=\"40.603125\" y=\"93.951993\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_8\">\n",
|
||
" <!-- 4 -->\n",
|
||
" <g transform=\"translate(27.240625 97.751212)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_3\">\n",
|
||
" <g id=\"line2d_15\">\n",
|
||
" <path d=\"M 40.603125 67.668056 \n",
|
||
"L 235.903125 67.668056 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_16\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m4be01386ce\" x=\"40.603125\" y=\"67.668056\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_9\">\n",
|
||
" <!-- 6 -->\n",
|
||
" <g transform=\"translate(27.240625 71.467274)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \n",
|
||
"Q 1688 2584 1439 2293 \n",
|
||
"Q 1191 2003 1191 1497 \n",
|
||
"Q 1191 994 1439 701 \n",
|
||
"Q 1688 409 2113 409 \n",
|
||
"Q 2538 409 2786 701 \n",
|
||
"Q 3034 994 3034 1497 \n",
|
||
"Q 3034 2003 2786 2293 \n",
|
||
"Q 2538 2584 2113 2584 \n",
|
||
"z\n",
|
||
"M 3366 4563 \n",
|
||
"L 3366 3988 \n",
|
||
"Q 3128 4100 2886 4159 \n",
|
||
"Q 2644 4219 2406 4219 \n",
|
||
"Q 1781 4219 1451 3797 \n",
|
||
"Q 1122 3375 1075 2522 \n",
|
||
"Q 1259 2794 1537 2939 \n",
|
||
"Q 1816 3084 2150 3084 \n",
|
||
"Q 2853 3084 3261 2657 \n",
|
||
"Q 3669 2231 3669 1497 \n",
|
||
"Q 3669 778 3244 343 \n",
|
||
"Q 2819 -91 2113 -91 \n",
|
||
"Q 1303 -91 875 529 \n",
|
||
"Q 447 1150 447 2328 \n",
|
||
"Q 447 3434 972 4092 \n",
|
||
"Q 1497 4750 2381 4750 \n",
|
||
"Q 2619 4750 2861 4703 \n",
|
||
"Q 3103 4656 3366 4563 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-36\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_4\">\n",
|
||
" <g id=\"line2d_17\">\n",
|
||
" <path d=\"M 40.603125 41.384118 \n",
|
||
"L 235.903125 41.384118 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_18\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m4be01386ce\" x=\"40.603125\" y=\"41.384118\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_10\">\n",
|
||
" <!-- 8 -->\n",
|
||
" <g transform=\"translate(27.240625 45.183337)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \n",
|
||
"Q 1584 2216 1326 1975 \n",
|
||
"Q 1069 1734 1069 1313 \n",
|
||
"Q 1069 891 1326 650 \n",
|
||
"Q 1584 409 2034 409 \n",
|
||
"Q 2484 409 2743 651 \n",
|
||
"Q 3003 894 3003 1313 \n",
|
||
"Q 3003 1734 2745 1975 \n",
|
||
"Q 2488 2216 2034 2216 \n",
|
||
"z\n",
|
||
"M 1403 2484 \n",
|
||
"Q 997 2584 770 2862 \n",
|
||
"Q 544 3141 544 3541 \n",
|
||
"Q 544 4100 942 4425 \n",
|
||
"Q 1341 4750 2034 4750 \n",
|
||
"Q 2731 4750 3128 4425 \n",
|
||
"Q 3525 4100 3525 3541 \n",
|
||
"Q 3525 3141 3298 2862 \n",
|
||
"Q 3072 2584 2669 2484 \n",
|
||
"Q 3125 2378 3379 2068 \n",
|
||
"Q 3634 1759 3634 1313 \n",
|
||
"Q 3634 634 3220 271 \n",
|
||
"Q 2806 -91 2034 -91 \n",
|
||
"Q 1263 -91 848 271 \n",
|
||
"Q 434 634 434 1313 \n",
|
||
"Q 434 1759 690 2068 \n",
|
||
"Q 947 2378 1403 2484 \n",
|
||
"z\n",
|
||
"M 1172 3481 \n",
|
||
"Q 1172 3119 1398 2916 \n",
|
||
"Q 1625 2713 2034 2713 \n",
|
||
"Q 2441 2713 2670 2916 \n",
|
||
"Q 2900 3119 2900 3481 \n",
|
||
"Q 2900 3844 2670 4047 \n",
|
||
"Q 2441 4250 2034 4250 \n",
|
||
"Q 1625 4250 1398 4047 \n",
|
||
"Q 1172 3844 1172 3481 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-38\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"ytick_5\">\n",
|
||
" <g id=\"line2d_19\">\n",
|
||
" <path d=\"M 40.603125 15.100181 \n",
|
||
"L 235.903125 15.100181 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_20\">\n",
|
||
" <g>\n",
|
||
" <use xlink:href=\"#m4be01386ce\" x=\"40.603125\" y=\"15.100181\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_11\">\n",
|
||
" <!-- 10 -->\n",
|
||
" <g transform=\"translate(20.878125 18.8994)scale(0.1 -0.1)\">\n",
|
||
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_12\">\n",
|
||
" <!-- loss -->\n",
|
||
" <g transform=\"translate(14.798437 84.807812)rotate(-90)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
|
||
"L 1178 4863 \n",
|
||
"L 1178 0 \n",
|
||
"L 603 0 \n",
|
||
"L 603 4863 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
|
||
"Q 1497 3097 1228 2736 \n",
|
||
"Q 959 2375 959 1747 \n",
|
||
"Q 959 1119 1226 758 \n",
|
||
"Q 1494 397 1959 397 \n",
|
||
"Q 2419 397 2687 759 \n",
|
||
"Q 2956 1122 2956 1747 \n",
|
||
"Q 2956 2369 2687 2733 \n",
|
||
"Q 2419 3097 1959 3097 \n",
|
||
"z\n",
|
||
"M 1959 3584 \n",
|
||
"Q 2709 3584 3137 3096 \n",
|
||
"Q 3566 2609 3566 1747 \n",
|
||
"Q 3566 888 3137 398 \n",
|
||
"Q 2709 -91 1959 -91 \n",
|
||
"Q 1206 -91 779 398 \n",
|
||
"Q 353 888 353 1747 \n",
|
||
"Q 353 2609 779 3096 \n",
|
||
"Q 1206 3584 1959 3584 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"88.964844\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"141.064453\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_21\">\n",
|
||
" <path d=\"M 40.603125 13.377273 \n",
|
||
"L 44.588839 24.224345 \n",
|
||
"L 48.574554 30.80551 \n",
|
||
"L 52.560268 38.316224 \n",
|
||
"L 56.545982 44.516487 \n",
|
||
"L 60.531696 48.937694 \n",
|
||
"L 64.517411 52.484771 \n",
|
||
"L 68.503125 54.778538 \n",
|
||
"L 72.488839 56.364095 \n",
|
||
"L 76.474554 58.041318 \n",
|
||
"L 80.460268 59.35448 \n",
|
||
"L 84.445982 60.441956 \n",
|
||
"L 88.431696 61.379153 \n",
|
||
"L 92.417411 62.229215 \n",
|
||
"L 96.403125 63.065199 \n",
|
||
"L 100.388839 63.813791 \n",
|
||
"L 104.374554 64.5639 \n",
|
||
"L 108.360268 65.2638 \n",
|
||
"L 112.345982 65.952618 \n",
|
||
"L 116.331696 66.525922 \n",
|
||
"L 120.317411 67.195551 \n",
|
||
"L 124.303125 67.80039 \n",
|
||
"L 128.288839 68.37399 \n",
|
||
"L 132.274554 68.872608 \n",
|
||
"L 136.260268 69.328405 \n",
|
||
"L 140.245982 69.706822 \n",
|
||
"L 144.231696 70.093209 \n",
|
||
"L 148.217411 70.458095 \n",
|
||
"L 152.203125 70.844132 \n",
|
||
"L 156.188839 71.180242 \n",
|
||
"L 160.174554 71.452899 \n",
|
||
"L 164.160268 71.713348 \n",
|
||
"L 168.145982 71.965907 \n",
|
||
"L 172.131696 72.21415 \n",
|
||
"L 176.117411 72.456634 \n",
|
||
"L 180.103125 72.653849 \n",
|
||
"L 184.088839 72.885751 \n",
|
||
"L 188.074554 73.104599 \n",
|
||
"L 192.060268 73.272813 \n",
|
||
"L 196.045982 73.533779 \n",
|
||
"L 200.031696 73.744679 \n",
|
||
"L 204.017411 73.939854 \n",
|
||
"L 208.003125 74.09068 \n",
|
||
"L 211.988839 74.286236 \n",
|
||
"L 215.974554 74.459222 \n",
|
||
"L 219.960268 74.639756 \n",
|
||
"L 223.945982 74.783027 \n",
|
||
"L 227.931696 74.93494 \n",
|
||
"L 231.917411 75.082847 \n",
|
||
"L 235.903125 75.22682 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_22\">\n",
|
||
" <path d=\"M 40.603125 136.922727 \n",
|
||
"L 44.588839 133.148485 \n",
|
||
"L 48.574554 126.506679 \n",
|
||
"L 52.560268 127.774556 \n",
|
||
"L 56.545982 129.41543 \n",
|
||
"L 60.531696 130.676493 \n",
|
||
"L 64.517411 131.612863 \n",
|
||
"L 68.503125 132.177482 \n",
|
||
"L 72.488839 132.745921 \n",
|
||
"L 76.474554 132.900444 \n",
|
||
"L 80.460268 133.16383 \n",
|
||
"L 84.445982 133.474624 \n",
|
||
"L 88.431696 133.724839 \n",
|
||
"L 92.417411 133.954133 \n",
|
||
"L 96.403125 134.167357 \n",
|
||
"L 100.388839 134.368487 \n",
|
||
"L 104.374554 134.520358 \n",
|
||
"L 108.360268 134.656883 \n",
|
||
"L 112.345982 134.793594 \n",
|
||
"L 116.331696 134.925427 \n",
|
||
"L 120.317411 135.037142 \n",
|
||
"L 124.303125 135.138294 \n",
|
||
"L 128.288839 135.227811 \n",
|
||
"L 132.274554 135.316948 \n",
|
||
"L 136.260268 135.398903 \n",
|
||
"L 140.245982 135.47614 \n",
|
||
"L 144.231696 135.537726 \n",
|
||
"L 148.217411 135.596796 \n",
|
||
"L 152.203125 135.657255 \n",
|
||
"L 156.188839 135.712856 \n",
|
||
"L 160.174554 135.764527 \n",
|
||
"L 164.160268 135.810959 \n",
|
||
"L 168.145982 135.859114 \n",
|
||
"L 172.131696 135.903682 \n",
|
||
"L 176.117411 135.947898 \n",
|
||
"L 180.103125 135.989192 \n",
|
||
"L 184.088839 136.023708 \n",
|
||
"L 188.074554 136.058565 \n",
|
||
"L 192.060268 136.091532 \n",
|
||
"L 196.045982 136.121414 \n",
|
||
"L 200.031696 136.149649 \n",
|
||
"L 204.017411 136.176424 \n",
|
||
"L 208.003125 136.205532 \n",
|
||
"L 211.988839 136.23183 \n",
|
||
"L 215.974554 136.249744 \n",
|
||
"L 219.960268 136.267683 \n",
|
||
"L 223.945982 136.291637 \n",
|
||
"L 227.931696 136.312131 \n",
|
||
"L 231.917411 136.322669 \n",
|
||
"L 235.903125 136.334804 \n",
|
||
"\" clip-path=\"url(#p1b68231c3f)\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_3\">\n",
|
||
" <path d=\"M 40.603125 143.1 \n",
|
||
"L 40.603125 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_4\">\n",
|
||
" <path d=\"M 235.903125 143.1 \n",
|
||
"L 235.903125 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_5\">\n",
|
||
" <path d=\"M 40.603125 143.1 \n",
|
||
"L 235.903125 143.1 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"patch_6\">\n",
|
||
" <path d=\"M 40.603125 7.2 \n",
|
||
"L 235.903125 7.2 \n",
|
||
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"legend_1\">\n",
|
||
" <g id=\"patch_7\">\n",
|
||
" <path d=\"M 174.64375 44.55625 \n",
|
||
"L 228.903125 44.55625 \n",
|
||
"Q 230.903125 44.55625 230.903125 42.55625 \n",
|
||
"L 230.903125 14.2 \n",
|
||
"Q 230.903125 12.2 228.903125 12.2 \n",
|
||
"L 174.64375 12.2 \n",
|
||
"Q 172.64375 12.2 172.64375 14.2 \n",
|
||
"L 172.64375 42.55625 \n",
|
||
"Q 172.64375 44.55625 174.64375 44.55625 \n",
|
||
"z\n",
|
||
"\" style=\"fill: #ffffff; opacity: 0.8; stroke: #cccccc; stroke-linejoin: miter\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_23\">\n",
|
||
" <path d=\"M 176.64375 20.298437 \n",
|
||
"L 186.64375 20.298437 \n",
|
||
"L 196.64375 20.298437 \n",
|
||
"\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_13\">\n",
|
||
" <!-- mlm -->\n",
|
||
" <g transform=\"translate(204.64375 23.798437)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-6d\" d=\"M 3328 2828 \n",
|
||
"Q 3544 3216 3844 3400 \n",
|
||
"Q 4144 3584 4550 3584 \n",
|
||
"Q 5097 3584 5394 3201 \n",
|
||
"Q 5691 2819 5691 2113 \n",
|
||
"L 5691 0 \n",
|
||
"L 5113 0 \n",
|
||
"L 5113 2094 \n",
|
||
"Q 5113 2597 4934 2840 \n",
|
||
"Q 4756 3084 4391 3084 \n",
|
||
"Q 3944 3084 3684 2787 \n",
|
||
"Q 3425 2491 3425 1978 \n",
|
||
"L 3425 0 \n",
|
||
"L 2847 0 \n",
|
||
"L 2847 2094 \n",
|
||
"Q 2847 2600 2669 2842 \n",
|
||
"Q 2491 3084 2119 3084 \n",
|
||
"Q 1678 3084 1418 2786 \n",
|
||
"Q 1159 2488 1159 1978 \n",
|
||
"L 1159 0 \n",
|
||
"L 581 0 \n",
|
||
"L 581 3500 \n",
|
||
"L 1159 3500 \n",
|
||
"L 1159 2956 \n",
|
||
"Q 1356 3278 1631 3431 \n",
|
||
"Q 1906 3584 2284 3584 \n",
|
||
"Q 2666 3584 2933 3390 \n",
|
||
"Q 3200 3197 3328 2828 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6d\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6c\" x=\"97.412109\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6d\" x=\"125.195312\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <g id=\"line2d_24\">\n",
|
||
" <path d=\"M 176.64375 34.976562 \n",
|
||
"L 186.64375 34.976562 \n",
|
||
"L 196.64375 34.976562 \n",
|
||
"\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n",
|
||
" </g>\n",
|
||
" <g id=\"text_14\">\n",
|
||
" <!-- nsp -->\n",
|
||
" <g transform=\"translate(204.64375 38.476562)scale(0.1 -0.1)\">\n",
|
||
" <defs>\n",
|
||
" <path id=\"DejaVuSans-6e\" d=\"M 3513 2113 \n",
|
||
"L 3513 0 \n",
|
||
"L 2938 0 \n",
|
||
"L 2938 2094 \n",
|
||
"Q 2938 2591 2744 2837 \n",
|
||
"Q 2550 3084 2163 3084 \n",
|
||
"Q 1697 3084 1428 2787 \n",
|
||
"Q 1159 2491 1159 1978 \n",
|
||
"L 1159 0 \n",
|
||
"L 581 0 \n",
|
||
"L 581 3500 \n",
|
||
"L 1159 3500 \n",
|
||
"L 1159 2956 \n",
|
||
"Q 1366 3272 1645 3428 \n",
|
||
"Q 1925 3584 2291 3584 \n",
|
||
"Q 2894 3584 3203 3211 \n",
|
||
"Q 3513 2838 3513 2113 \n",
|
||
"z\n",
|
||
"\" transform=\"scale(0.015625)\"/>\n",
|
||
" </defs>\n",
|
||
" <use xlink:href=\"#DejaVuSans-6e\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-73\" x=\"63.378906\"/>\n",
|
||
" <use xlink:href=\"#DejaVuSans-70\" x=\"115.478516\"/>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" </g>\n",
|
||
" <defs>\n",
|
||
" <clipPath id=\"p1b68231c3f\">\n",
|
||
" <rect x=\"40.603125\" y=\"7.2\" width=\"195.3\" height=\"135.9\"/>\n",
|
||
" </clipPath>\n",
|
||
" </defs>\n",
|
||
"</svg>\n"
|
||
],
|
||
"text/plain": [
|
||
"<Figure size 252x180 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"train_bert(train_iter, net, loss, len(vocab), devices, 50)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ede604ea",
|
||
"metadata": {
|
||
"origin_pos": 21
|
||
},
|
||
"source": [
|
||
"## 用BERT表示文本\n",
|
||
"\n",
|
||
"在预训练BERT之后,我们可以用它来表示单个文本、文本对或其中的任何词元。下面的函数返回`tokens_a`和`tokens_b`中所有词元的BERT(`net`)表示。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "77f3b8e4",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:05:00.663916Z",
|
||
"iopub.status.busy": "2023-08-18T07:05:00.663281Z",
|
||
"iopub.status.idle": "2023-08-18T07:05:00.669609Z",
|
||
"shell.execute_reply": "2023-08-18T07:05:00.668549Z"
|
||
},
|
||
"origin_pos": 23,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get_bert_encoding(net, tokens_a, tokens_b=None):\n",
|
||
" tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)\n",
|
||
" token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)\n",
|
||
" segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)\n",
|
||
" valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)\n",
|
||
" encoded_X, _, _ = net(token_ids, segments, valid_len)\n",
|
||
" return encoded_X"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "25e0697e",
|
||
"metadata": {
|
||
"origin_pos": 25
|
||
},
|
||
"source": [
|
||
"考虑“a crane is flying”这句话。回想一下 :numref:`subsec_bert_input_rep`中讨论的BERT的输入表示。插入特殊标记“<cls>”(用于分类)和“<sep>”(用于分隔)后,BERT输入序列的长度为6。因为零是“<cls>”词元,`encoded_text[:, 0, :]`是整个输入语句的BERT表示。为了评估一词多义词元“crane”,我们还打印出了该词元的BERT表示的前三个元素。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "1081fda9",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:05:00.673428Z",
|
||
"iopub.status.busy": "2023-08-18T07:05:00.672675Z",
|
||
"iopub.status.idle": "2023-08-18T07:05:00.690133Z",
|
||
"shell.execute_reply": "2023-08-18T07:05:00.689347Z"
|
||
},
|
||
"origin_pos": 26,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(torch.Size([1, 6, 128]),\n",
|
||
" torch.Size([1, 128]),\n",
|
||
" tensor([-0.5007, -1.0034, 0.8718], device='cuda:0', grad_fn=<SliceBackward0>))"
|
||
]
|
||
},
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"tokens_a = ['a', 'crane', 'is', 'flying']\n",
|
||
"encoded_text = get_bert_encoding(net, tokens_a)\n",
|
||
"# 词元:'<cls>','a','crane','is','flying','<sep>'\n",
|
||
"encoded_text_cls = encoded_text[:, 0, :]\n",
|
||
"encoded_text_crane = encoded_text[:, 2, :]\n",
|
||
"encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "203ca198",
|
||
"metadata": {
|
||
"origin_pos": 27
|
||
},
|
||
"source": [
|
||
"现在考虑一个句子“a crane driver came”和“he just left”。类似地,`encoded_pair[:, 0, :]`是来自预训练BERT的整个句子对的编码结果。注意,多义词元“crane”的前三个元素与上下文不同时的元素不同。这支持了BERT表示是上下文敏感的。\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "960c3aa2",
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2023-08-18T07:05:00.694637Z",
|
||
"iopub.status.busy": "2023-08-18T07:05:00.694061Z",
|
||
"iopub.status.idle": "2023-08-18T07:05:00.708881Z",
|
||
"shell.execute_reply": "2023-08-18T07:05:00.707778Z"
|
||
},
|
||
"origin_pos": 28,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(torch.Size([1, 10, 128]),\n",
|
||
" torch.Size([1, 128]),\n",
|
||
" tensor([ 0.5101, -0.4041, -1.2749], device='cuda:0', grad_fn=<SliceBackward0>))"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']\n",
|
||
"encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)\n",
|
||
"# 词元:'<cls>','a','crane','driver','came','<sep>','he','just',\n",
|
||
"# 'left','<sep>'\n",
|
||
"encoded_pair_cls = encoded_pair[:, 0, :]\n",
|
||
"encoded_pair_crane = encoded_pair[:, 2, :]\n",
|
||
"encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d642486c",
|
||
"metadata": {
|
||
"origin_pos": 29
|
||
},
|
||
"source": [
|
||
"在 :numref:`chap_nlp_app`中,我们将为下游自然语言处理应用微调预训练的BERT模型。\n",
|
||
"\n",
|
||
"## 小结\n",
|
||
"\n",
|
||
"* 原始的BERT有两个版本,其中基本模型有1.1亿个参数,大模型有3.4亿个参数。\n",
|
||
"* 在预训练BERT之后,我们可以用它来表示单个文本、文本对或其中的任何词元。\n",
|
||
"* 在实验中,同一个词元在不同的上下文中具有不同的BERT表示。这支持BERT表示是上下文敏感的。\n",
|
||
"\n",
|
||
"## 练习\n",
|
||
"\n",
|
||
"1. 在实验中,我们可以看到遮蔽语言模型损失明显高于下一句预测损失。为什么?\n",
|
||
"2. 将BERT输入序列的最大长度设置为512(与原始BERT模型相同)。使用原始BERT模型的配置,如$\\text{BERT}_{\\text{LARGE}}$。运行此部分时是否遇到错误?为什么?\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9f6249ab",
|
||
"metadata": {
|
||
"origin_pos": 31,
|
||
"tab": [
|
||
"pytorch"
|
||
]
|
||
},
|
||
"source": [
|
||
"[Discussions](https://discuss.d2l.ai/t/5743)\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"language_info": {
|
||
"name": "python"
|
||
},
|
||
"required_libs": []
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
} |