{ "cells": [ { "cell_type": "markdown", "id": "bec47e64", "metadata": { "origin_pos": 0 }, "source": [ "# 读写文件\n", "\n", "到目前为止,我们讨论了如何处理数据,\n", "以及如何构建、训练和测试深度学习模型。\n", "然而,有时我们希望保存训练的模型,\n", "以备将来在各种环境中使用(比如在部署中进行预测)。\n", "此外,当运行一个耗时较长的训练过程时,\n", "最佳的做法是定期保存中间结果,\n", "以确保在服务器电源被不小心断掉时,我们不会损失几天的计算结果。\n", "因此,现在是时候学习如何加载和存储权重向量和整个模型了。\n", "\n", "## (**加载和保存张量**)\n", "\n", "对于单个张量,我们可以直接调用`load`和`save`函数分别读写它们。\n", "这两个函数都要求我们提供一个名称,`save`要求将要保存的变量作为输入。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "9b319fd3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:42.668559Z", "iopub.status.busy": "2023-08-18T06:56:42.667248Z", "iopub.status.idle": "2023-08-18T06:56:43.728764Z", "shell.execute_reply": "2023-08-18T06:56:43.727885Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "\n", "x = torch.arange(4)\n", "torch.save(x, 'x-file')" ] }, { "cell_type": "markdown", "id": "e4f44ac7", "metadata": { "origin_pos": 5 }, "source": [ "我们现在可以将存储在文件中的数据读回内存。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "1ab53461", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:43.733002Z", "iopub.status.busy": "2023-08-18T06:56:43.732347Z", "iopub.status.idle": "2023-08-18T06:56:43.741208Z", "shell.execute_reply": "2023-08-18T06:56:43.740416Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([0, 1, 2, 3])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x2 = torch.load('x-file')\n", "x2" ] }, { "cell_type": "markdown", "id": "44d4a111", "metadata": { "origin_pos": 10 }, "source": [ "我们可以[**存储一个张量列表,然后把它们读回内存。**]\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "81027fe1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:43.744676Z", "iopub.status.busy": "2023-08-18T06:56:43.744140Z", "iopub.status.idle": "2023-08-18T06:56:43.751376Z", "shell.execute_reply": "2023-08-18T06:56:43.750630Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = torch.zeros(4)\n", "torch.save([x, y],'x-files')\n", "x2, y2 = torch.load('x-files')\n", "(x2, y2)" ] }, { "cell_type": "markdown", "id": "b060dd48", "metadata": { "origin_pos": 15 }, "source": [ "我们甚至可以(**写入或读取从字符串映射到张量的字典**)。\n", "当我们要读取或写入模型中的所有权重时,这很方便。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "fde1cb33", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:43.754777Z", "iopub.status.busy": "2023-08-18T06:56:43.754313Z", "iopub.status.idle": "2023-08-18T06:56:43.761150Z", "shell.execute_reply": "2023-08-18T06:56:43.760369Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mydict = {'x': x, 'y': y}\n", "torch.save(mydict, 'mydict')\n", "mydict2 = torch.load('mydict')\n", "mydict2" ] }, { "cell_type": "markdown", "id": "afa857bf", "metadata": { "origin_pos": 20 }, "source": [ "## [**加载和保存模型参数**]\n", "\n", "保存单个权重向量(或其他张量)确实有用,\n", "但是如果我们想保存整个模型,并在以后加载它们,\n", "单独保存每个向量则会变得很麻烦。\n", "毕竟,我们可能有数百个参数散布在各处。\n", "因此,深度学习框架提供了内置函数来保存和加载整个网络。\n", "需要注意的一个重要细节是,这将保存模型的参数而不是保存整个模型。\n", "例如,如果我们有一个3层多层感知机,我们需要单独指定架构。\n", "因为模型本身可以包含任意代码,所以模型本身难以序列化。\n", "因此,为了恢复模型,我们需要用代码生成架构,\n", "然后从磁盘加载参数。\n", "让我们从熟悉的多层感知机开始尝试一下。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "2672b5c2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:43.764609Z", "iopub.status.busy": "2023-08-18T06:56:43.764090Z", "iopub.status.idle": "2023-08-18T06:56:43.773070Z", "shell.execute_reply": "2023-08-18T06:56:43.772277Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class MLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.hidden = nn.Linear(20, 256)\n", " self.output = nn.Linear(256, 10)\n", "\n", " def forward(self, x):\n", " return self.output(F.relu(self.hidden(x)))\n", "\n", "net = MLP()\n", "X = torch.randn(size=(2, 20))\n", "Y = net(X)" ] }, { "cell_type": "markdown", "id": "697ceed0", "metadata": { "origin_pos": 25 }, "source": [ "接下来,我们[**将模型的参数存储在一个叫做“mlp.params”的文件中。**]\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "a53c1315", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:43.776452Z", "iopub.status.busy": "2023-08-18T06:56:43.775942Z", "iopub.status.idle": "2023-08-18T06:56:43.780387Z", "shell.execute_reply": "2023-08-18T06:56:43.779636Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "torch.save(net.state_dict(), 'mlp.params')" ] }, { "cell_type": "markdown", "id": "b6df754a", "metadata": { "origin_pos": 30 }, "source": [ "为了恢复模型,我们[**实例化了原始多层感知机模型的一个备份。**]\n", "这里我们不需要随机初始化模型参数,而是(**直接读取文件中存储的参数。**)\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "da5e1b3f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:43.783850Z", "iopub.status.busy": "2023-08-18T06:56:43.783240Z", "iopub.status.idle": "2023-08-18T06:56:43.789905Z", "shell.execute_reply": "2023-08-18T06:56:43.789164Z" }, "origin_pos": 32, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "MLP(\n", " (hidden): Linear(in_features=20, out_features=256, bias=True)\n", " (output): Linear(in_features=256, out_features=10, bias=True)\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clone = MLP()\n", "clone.load_state_dict(torch.load('mlp.params'))\n", "clone.eval()" ] }, { "cell_type": "markdown", "id": "65076662", "metadata": { "origin_pos": 35 }, "source": [ "由于两个实例具有相同的模型参数,在输入相同的`X`时,\n", "两个实例的计算结果应该相同。\n", "让我们来验证一下。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "a25ba1f1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:56:43.793400Z", "iopub.status.busy": "2023-08-18T06:56:43.792788Z", "iopub.status.idle": "2023-08-18T06:56:43.798329Z", "shell.execute_reply": "2023-08-18T06:56:43.797576Z" }, "origin_pos": 37, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[True, True, True, True, True, True, True, True, True, True],\n", " [True, True, True, True, True, True, True, True, True, True]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y_clone = clone(X)\n", "Y_clone == Y" ] }, { "cell_type": "markdown", "id": "7a65b1e2", "metadata": { "origin_pos": 39 }, "source": [ "## 小结\n", "\n", "* `save`和`load`函数可用于张量对象的文件读写。\n", "* 我们可以通过参数字典保存和加载网络的全部参数。\n", "* 保存架构必须在代码中完成,而不是在参数中完成。\n", "\n", "## 练习\n", "\n", "1. 即使不需要将经过训练的模型部署到不同的设备上,存储模型参数还有什么实际的好处?\n", "1. 假设我们只想复用网络的一部分,以将其合并到不同的网络架构中。比如想在一个新的网络中使用之前网络的前两层,该怎么做?\n", "1. 如何同时保存网络架构和参数?需要对架构加上什么限制?\n" ] }, { "cell_type": "markdown", "id": "d803f301", "metadata": { "origin_pos": 41, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1839)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }