{ "cells": [ { "cell_type": "markdown", "id": "e4c46021", "metadata": { "origin_pos": 0 }, "source": [ "# 循环神经网络的简洁实现\n", ":label:`sec_rnn-concise`\n", "\n", "虽然 :numref:`sec_rnn_scratch`\n", "对了解循环神经网络的实现方式具有指导意义,但并不方便。\n", "本节将展示如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。\n", "我们仍然从读取时光机器数据集开始。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "e38d82e3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:38.029441Z", "iopub.status.busy": "2023-08-18T07:22:38.028754Z", "iopub.status.idle": "2023-08-18T07:22:42.082845Z", "shell.execute_reply": "2023-08-18T07:22:42.081933Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from d2l import torch as d2l\n", "\n", "batch_size, num_steps = 32, 35\n", "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)" ] }, { "cell_type": "markdown", "id": "bc8deb24", "metadata": { "origin_pos": 5 }, "source": [ "## [**定义模型**]\n", "\n", "高级API提供了循环神经网络的实现。\n", "我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层`rnn_layer`。\n", "事实上,我们还没有讨论多层循环神经网络的意义(这将在 :numref:`sec_deep_rnn`中介绍)。\n", "现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "37dae103", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:42.087317Z", "iopub.status.busy": "2023-08-18T07:22:42.086622Z", "iopub.status.idle": "2023-08-18T07:22:42.117225Z", "shell.execute_reply": "2023-08-18T07:22:42.116310Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "num_hiddens = 256\n", "rnn_layer = nn.RNN(len(vocab), num_hiddens)" ] }, { "cell_type": "markdown", "id": "08e36240", "metadata": { "origin_pos": 11, "tab": [ "pytorch" ] }, "source": [ "我们(**使用张量来初始化隐状态**),它的形状是(隐藏层数,批量大小,隐藏单元数)。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "1922fe18", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:42.122208Z", "iopub.status.busy": "2023-08-18T07:22:42.121722Z", "iopub.status.idle": "2023-08-18T07:22:42.128343Z", "shell.execute_reply": "2023-08-18T07:22:42.127617Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 32, 256])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "state = torch.zeros((1, batch_size, num_hiddens))\n", "state.shape" ] }, { "cell_type": "markdown", "id": "170be239", "metadata": { "origin_pos": 16 }, "source": [ "[**通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。**]\n", "需要强调的是,`rnn_layer`的“输出”(`Y`)不涉及输出层的计算:\n", "它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "cecfe4c4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:42.132231Z", "iopub.status.busy": "2023-08-18T07:22:42.131762Z", "iopub.status.idle": "2023-08-18T07:22:42.149849Z", "shell.execute_reply": "2023-08-18T07:22:42.148795Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.rand(size=(num_steps, batch_size, len(vocab)))\n", "Y, state_new = rnn_layer(X, state)\n", "Y.shape, state_new.shape" ] }, { "cell_type": "markdown", "id": "7c919b48", "metadata": { "origin_pos": 22 }, "source": [ "与 :numref:`sec_rnn_scratch`类似,\n", "[**我们为一个完整的循环神经网络模型定义了一个`RNNModel`类**]。\n", "注意,`rnn_layer`只包含隐藏的循环层,我们还需要创建一个单独的输出层。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "300de81f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:42.155516Z", "iopub.status.busy": "2023-08-18T07:22:42.154604Z", "iopub.status.idle": "2023-08-18T07:22:42.171447Z", "shell.execute_reply": "2023-08-18T07:22:42.170040Z" }, "origin_pos": 24, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "class RNNModel(nn.Module):\n", " \"\"\"循环神经网络模型\"\"\"\n", " def __init__(self, rnn_layer, vocab_size, **kwargs):\n", " super(RNNModel, self).__init__(**kwargs)\n", " self.rnn = rnn_layer\n", " self.vocab_size = vocab_size\n", " self.num_hiddens = self.rnn.hidden_size\n", " # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1\n", " if not self.rnn.bidirectional:\n", " self.num_directions = 1\n", " self.linear = nn.Linear(self.num_hiddens, self.vocab_size)\n", " else:\n", " self.num_directions = 2\n", " self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)\n", "\n", " def forward(self, inputs, state):\n", " X = F.one_hot(inputs.T.long(), self.vocab_size)\n", " X = X.to(torch.float32)\n", " Y, state = self.rnn(X, state)\n", " # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)\n", " # 它的输出形状是(时间步数*批量大小,词表大小)。\n", " output = self.linear(Y.reshape((-1, Y.shape[-1])))\n", " return output, state\n", "\n", " def begin_state(self, device, batch_size=1):\n", " if not isinstance(self.rnn, nn.LSTM):\n", " # nn.GRU以张量作为隐状态\n", " return torch.zeros((self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens),\n", " device=device)\n", " else:\n", " # nn.LSTM以元组作为隐状态\n", " return (torch.zeros((\n", " self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens), device=device),\n", " torch.zeros((\n", " self.num_directions * self.rnn.num_layers,\n", " batch_size, self.num_hiddens), device=device))" ] }, { "cell_type": "markdown", "id": "573178db", "metadata": { "origin_pos": 27 }, "source": [ "## 训练与预测\n", "\n", "在训练模型之前,让我们[**基于一个具有随机权重的模型进行预测**]。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "93c01655", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:42.176888Z", "iopub.status.busy": "2023-08-18T07:22:42.175834Z", "iopub.status.idle": "2023-08-18T07:22:45.380419Z", "shell.execute_reply": "2023-08-18T07:22:45.379583Z" }, "origin_pos": 29, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'time travellerbbabbkabyg'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = d2l.try_gpu()\n", "net = RNNModel(rnn_layer, vocab_size=len(vocab))\n", "net = net.to(device)\n", "d2l.predict_ch8('time traveller', 10, net, vocab, device)" ] }, { "cell_type": "markdown", "id": "240f0d45", "metadata": { "origin_pos": 32 }, "source": [ "很明显,这种模型根本不能输出好的结果。\n", "接下来,我们使用 :numref:`sec_rnn_scratch`中\n", "定义的超参数调用`train_ch8`,并且[**使用高级API训练模型**]。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "97fc8534", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:22:45.384672Z", "iopub.status.busy": "2023-08-18T07:22:45.383787Z", "iopub.status.idle": "2023-08-18T07:23:06.055571Z", "shell.execute_reply": "2023-08-18T07:23:06.054355Z" }, "origin_pos": 34, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "perplexity 1.3, 404413.8 tokens/sec on cuda:0\n", "time travellerit would be remarkably convenient for the historia\n", "travellery of il the hise fupt might and st was it loflers \n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:23:06.012401\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs, lr = 500, 1\n", "d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)" ] }, { "cell_type": "markdown", "id": "409494cb", "metadata": { "origin_pos": 37 }, "source": [ "与上一节相比,由于深度学习框架的高级API对代码进行了更多的优化,\n", "该模型在较短的时间内达到了较低的困惑度。\n", "\n", "## 小结\n", "\n", "* 深度学习框架的高级API提供了循环神经网络层的实现。\n", "* 高级API的循环神经网络层返回一个输出和一个更新后的隐状态,我们还需要计算整个模型的输出层。\n", "* 相比从零开始实现的循环神经网络,使用高级API实现可以加速训练。\n", "\n", "## 练习\n", "\n", "1. 尝试使用高级API,能使循环神经网络模型过拟合吗?\n", "1. 如果在循环神经网络模型中增加隐藏层的数量会发生什么?能使模型正常工作吗?\n", "1. 尝试使用循环神经网络实现 :numref:`sec_sequence`的自回归模型。\n" ] }, { "cell_type": "markdown", "id": "d84e3e86", "metadata": { "origin_pos": 39, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/2106)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }