408 lines
10 KiB
Plaintext
408 lines
10 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "bec47e64",
|
|
"metadata": {
|
|
"origin_pos": 0
|
|
},
|
|
"source": [
|
|
"# 读写文件\n",
|
|
"\n",
|
|
"到目前为止,我们讨论了如何处理数据,\n",
|
|
"以及如何构建、训练和测试深度学习模型。\n",
|
|
"然而,有时我们希望保存训练的模型,\n",
|
|
"以备将来在各种环境中使用(比如在部署中进行预测)。\n",
|
|
"此外,当运行一个耗时较长的训练过程时,\n",
|
|
"最佳的做法是定期保存中间结果,\n",
|
|
"以确保在服务器电源被不小心断掉时,我们不会损失几天的计算结果。\n",
|
|
"因此,现在是时候学习如何加载和存储权重向量和整个模型了。\n",
|
|
"\n",
|
|
"## (**加载和保存张量**)\n",
|
|
"\n",
|
|
"对于单个张量,我们可以直接调用`load`和`save`函数分别读写它们。\n",
|
|
"这两个函数都要求我们提供一个名称,`save`要求将要保存的变量作为输入。\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "9b319fd3",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-08-18T06:56:42.668559Z",
|
|
"iopub.status.busy": "2023-08-18T06:56:42.667248Z",
|
|
"iopub.status.idle": "2023-08-18T06:56:43.728764Z",
|
|
"shell.execute_reply": "2023-08-18T06:56:43.727885Z"
|
|
},
|
|
"origin_pos": 2,
|
|
"tab": [
|
|
"pytorch"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch import nn\n",
|
|
"from torch.nn import functional as F\n",
|
|
"\n",
|
|
"x = torch.arange(4)\n",
|
|
"torch.save(x, 'x-file')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e4f44ac7",
|
|
"metadata": {
|
|
"origin_pos": 5
|
|
},
|
|
"source": [
|
|
"我们现在可以将存储在文件中的数据读回内存。\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "1ab53461",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-08-18T06:56:43.733002Z",
|
|
"iopub.status.busy": "2023-08-18T06:56:43.732347Z",
|
|
"iopub.status.idle": "2023-08-18T06:56:43.741208Z",
|
|
"shell.execute_reply": "2023-08-18T06:56:43.740416Z"
|
|
},
|
|
"origin_pos": 7,
|
|
"tab": [
|
|
"pytorch"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([0, 1, 2, 3])"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"x2 = torch.load('x-file')\n",
|
|
"x2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "44d4a111",
|
|
"metadata": {
|
|
"origin_pos": 10
|
|
},
|
|
"source": [
|
|
"我们可以[**存储一个张量列表,然后把它们读回内存。**]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "81027fe1",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-08-18T06:56:43.744676Z",
|
|
"iopub.status.busy": "2023-08-18T06:56:43.744140Z",
|
|
"iopub.status.idle": "2023-08-18T06:56:43.751376Z",
|
|
"shell.execute_reply": "2023-08-18T06:56:43.750630Z"
|
|
},
|
|
"origin_pos": 12,
|
|
"tab": [
|
|
"pytorch"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"y = torch.zeros(4)\n",
|
|
"torch.save([x, y],'x-files')\n",
|
|
"x2, y2 = torch.load('x-files')\n",
|
|
"(x2, y2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b060dd48",
|
|
"metadata": {
|
|
"origin_pos": 15
|
|
},
|
|
"source": [
|
|
"我们甚至可以(**写入或读取从字符串映射到张量的字典**)。\n",
|
|
"当我们要读取或写入模型中的所有权重时,这很方便。\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "fde1cb33",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-08-18T06:56:43.754777Z",
|
|
"iopub.status.busy": "2023-08-18T06:56:43.754313Z",
|
|
"iopub.status.idle": "2023-08-18T06:56:43.761150Z",
|
|
"shell.execute_reply": "2023-08-18T06:56:43.760369Z"
|
|
},
|
|
"origin_pos": 17,
|
|
"tab": [
|
|
"pytorch"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"mydict = {'x': x, 'y': y}\n",
|
|
"torch.save(mydict, 'mydict')\n",
|
|
"mydict2 = torch.load('mydict')\n",
|
|
"mydict2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "afa857bf",
|
|
"metadata": {
|
|
"origin_pos": 20
|
|
},
|
|
"source": [
|
|
"## [**加载和保存模型参数**]\n",
|
|
"\n",
|
|
"保存单个权重向量(或其他张量)确实有用,\n",
|
|
"但是如果我们想保存整个模型,并在以后加载它们,\n",
|
|
"单独保存每个向量则会变得很麻烦。\n",
|
|
"毕竟,我们可能有数百个参数散布在各处。\n",
|
|
"因此,深度学习框架提供了内置函数来保存和加载整个网络。\n",
|
|
"需要注意的一个重要细节是,这将保存模型的参数而不是保存整个模型。\n",
|
|
"例如,如果我们有一个3层多层感知机,我们需要单独指定架构。\n",
|
|
"因为模型本身可以包含任意代码,所以模型本身难以序列化。\n",
|
|
"因此,为了恢复模型,我们需要用代码生成架构,\n",
|
|
"然后从磁盘加载参数。\n",
|
|
"让我们从熟悉的多层感知机开始尝试一下。\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "2672b5c2",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-08-18T06:56:43.764609Z",
|
|
"iopub.status.busy": "2023-08-18T06:56:43.764090Z",
|
|
"iopub.status.idle": "2023-08-18T06:56:43.773070Z",
|
|
"shell.execute_reply": "2023-08-18T06:56:43.772277Z"
|
|
},
|
|
"origin_pos": 22,
|
|
"tab": [
|
|
"pytorch"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class MLP(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__()\n",
|
|
" self.hidden = nn.Linear(20, 256)\n",
|
|
" self.output = nn.Linear(256, 10)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" return self.output(F.relu(self.hidden(x)))\n",
|
|
"\n",
|
|
"net = MLP()\n",
|
|
"X = torch.randn(size=(2, 20))\n",
|
|
"Y = net(X)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "697ceed0",
|
|
"metadata": {
|
|
"origin_pos": 25
|
|
},
|
|
"source": [
|
|
"接下来,我们[**将模型的参数存储在一个叫做“mlp.params”的文件中。**]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "a53c1315",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-08-18T06:56:43.776452Z",
|
|
"iopub.status.busy": "2023-08-18T06:56:43.775942Z",
|
|
"iopub.status.idle": "2023-08-18T06:56:43.780387Z",
|
|
"shell.execute_reply": "2023-08-18T06:56:43.779636Z"
|
|
},
|
|
"origin_pos": 27,
|
|
"tab": [
|
|
"pytorch"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"torch.save(net.state_dict(), 'mlp.params')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b6df754a",
|
|
"metadata": {
|
|
"origin_pos": 30
|
|
},
|
|
"source": [
|
|
"为了恢复模型,我们[**实例化了原始多层感知机模型的一个备份。**]\n",
|
|
"这里我们不需要随机初始化模型参数,而是(**直接读取文件中存储的参数。**)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "da5e1b3f",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-08-18T06:56:43.783850Z",
|
|
"iopub.status.busy": "2023-08-18T06:56:43.783240Z",
|
|
"iopub.status.idle": "2023-08-18T06:56:43.789905Z",
|
|
"shell.execute_reply": "2023-08-18T06:56:43.789164Z"
|
|
},
|
|
"origin_pos": 32,
|
|
"tab": [
|
|
"pytorch"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"MLP(\n",
|
|
" (hidden): Linear(in_features=20, out_features=256, bias=True)\n",
|
|
" (output): Linear(in_features=256, out_features=10, bias=True)\n",
|
|
")"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"clone = MLP()\n",
|
|
"clone.load_state_dict(torch.load('mlp.params'))\n",
|
|
"clone.eval()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "65076662",
|
|
"metadata": {
|
|
"origin_pos": 35
|
|
},
|
|
"source": [
|
|
"由于两个实例具有相同的模型参数,在输入相同的`X`时,\n",
|
|
"两个实例的计算结果应该相同。\n",
|
|
"让我们来验证一下。\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "a25ba1f1",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2023-08-18T06:56:43.793400Z",
|
|
"iopub.status.busy": "2023-08-18T06:56:43.792788Z",
|
|
"iopub.status.idle": "2023-08-18T06:56:43.798329Z",
|
|
"shell.execute_reply": "2023-08-18T06:56:43.797576Z"
|
|
},
|
|
"origin_pos": 37,
|
|
"tab": [
|
|
"pytorch"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[True, True, True, True, True, True, True, True, True, True],\n",
|
|
" [True, True, True, True, True, True, True, True, True, True]])"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"Y_clone = clone(X)\n",
|
|
"Y_clone == Y"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7a65b1e2",
|
|
"metadata": {
|
|
"origin_pos": 39
|
|
},
|
|
"source": [
|
|
"## 小结\n",
|
|
"\n",
|
|
"* `save`和`load`函数可用于张量对象的文件读写。\n",
|
|
"* 我们可以通过参数字典保存和加载网络的全部参数。\n",
|
|
"* 保存架构必须在代码中完成,而不是在参数中完成。\n",
|
|
"\n",
|
|
"## 练习\n",
|
|
"\n",
|
|
"1. 即使不需要将经过训练的模型部署到不同的设备上,存储模型参数还有什么实际的好处?\n",
|
|
"1. 假设我们只想复用网络的一部分,以将其合并到不同的网络架构中。比如想在一个新的网络中使用之前网络的前两层,该怎么做?\n",
|
|
"1. 如何同时保存网络架构和参数?需要对架构加上什么限制?\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d803f301",
|
|
"metadata": {
|
|
"origin_pos": 41,
|
|
"tab": [
|
|
"pytorch"
|
|
]
|
|
},
|
|
"source": [
|
|
"[Discussions](https://discuss.d2l.ai/t/1839)\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"language_info": {
|
|
"name": "python"
|
|
},
|
|
"required_libs": []
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |