Files
2025-12-16 09:23:53 +08:00

408 lines
10 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "bec47e64",
"metadata": {
"origin_pos": 0
},
"source": [
"# 读写文件\n",
"\n",
"到目前为止,我们讨论了如何处理数据,\n",
"以及如何构建、训练和测试深度学习模型。\n",
"然而,有时我们希望保存训练的模型,\n",
"以备将来在各种环境中使用(比如在部署中进行预测)。\n",
"此外,当运行一个耗时较长的训练过程时,\n",
"最佳的做法是定期保存中间结果,\n",
"以确保在服务器电源被不小心断掉时,我们不会损失几天的计算结果。\n",
"因此,现在是时候学习如何加载和存储权重向量和整个模型了。\n",
"\n",
"## (**加载和保存张量**)\n",
"\n",
"对于单个张量,我们可以直接调用`load`和`save`函数分别读写它们。\n",
"这两个函数都要求我们提供一个名称,`save`要求将要保存的变量作为输入。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9b319fd3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:56:42.668559Z",
"iopub.status.busy": "2023-08-18T06:56:42.667248Z",
"iopub.status.idle": "2023-08-18T06:56:43.728764Z",
"shell.execute_reply": "2023-08-18T06:56:43.727885Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"\n",
"x = torch.arange(4)\n",
"torch.save(x, 'x-file')"
]
},
{
"cell_type": "markdown",
"id": "e4f44ac7",
"metadata": {
"origin_pos": 5
},
"source": [
"我们现在可以将存储在文件中的数据读回内存。\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1ab53461",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:56:43.733002Z",
"iopub.status.busy": "2023-08-18T06:56:43.732347Z",
"iopub.status.idle": "2023-08-18T06:56:43.741208Z",
"shell.execute_reply": "2023-08-18T06:56:43.740416Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 1, 2, 3])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x2 = torch.load('x-file')\n",
"x2"
]
},
{
"cell_type": "markdown",
"id": "44d4a111",
"metadata": {
"origin_pos": 10
},
"source": [
"我们可以[**存储一个张量列表,然后把它们读回内存。**]\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "81027fe1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:56:43.744676Z",
"iopub.status.busy": "2023-08-18T06:56:43.744140Z",
"iopub.status.idle": "2023-08-18T06:56:43.751376Z",
"shell.execute_reply": "2023-08-18T06:56:43.750630Z"
},
"origin_pos": 12,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = torch.zeros(4)\n",
"torch.save([x, y],'x-files')\n",
"x2, y2 = torch.load('x-files')\n",
"(x2, y2)"
]
},
{
"cell_type": "markdown",
"id": "b060dd48",
"metadata": {
"origin_pos": 15
},
"source": [
"我们甚至可以(**写入或读取从字符串映射到张量的字典**)。\n",
"当我们要读取或写入模型中的所有权重时,这很方便。\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fde1cb33",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:56:43.754777Z",
"iopub.status.busy": "2023-08-18T06:56:43.754313Z",
"iopub.status.idle": "2023-08-18T06:56:43.761150Z",
"shell.execute_reply": "2023-08-18T06:56:43.760369Z"
},
"origin_pos": 17,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mydict = {'x': x, 'y': y}\n",
"torch.save(mydict, 'mydict')\n",
"mydict2 = torch.load('mydict')\n",
"mydict2"
]
},
{
"cell_type": "markdown",
"id": "afa857bf",
"metadata": {
"origin_pos": 20
},
"source": [
"## [**加载和保存模型参数**]\n",
"\n",
"保存单个权重向量(或其他张量)确实有用,\n",
"但是如果我们想保存整个模型,并在以后加载它们,\n",
"单独保存每个向量则会变得很麻烦。\n",
"毕竟,我们可能有数百个参数散布在各处。\n",
"因此,深度学习框架提供了内置函数来保存和加载整个网络。\n",
"需要注意的一个重要细节是,这将保存模型的参数而不是保存整个模型。\n",
"例如,如果我们有一个3层多层感知机,我们需要单独指定架构。\n",
"因为模型本身可以包含任意代码,所以模型本身难以序列化。\n",
"因此,为了恢复模型,我们需要用代码生成架构,\n",
"然后从磁盘加载参数。\n",
"让我们从熟悉的多层感知机开始尝试一下。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2672b5c2",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:56:43.764609Z",
"iopub.status.busy": "2023-08-18T06:56:43.764090Z",
"iopub.status.idle": "2023-08-18T06:56:43.773070Z",
"shell.execute_reply": "2023-08-18T06:56:43.772277Z"
},
"origin_pos": 22,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.hidden = nn.Linear(20, 256)\n",
" self.output = nn.Linear(256, 10)\n",
"\n",
" def forward(self, x):\n",
" return self.output(F.relu(self.hidden(x)))\n",
"\n",
"net = MLP()\n",
"X = torch.randn(size=(2, 20))\n",
"Y = net(X)"
]
},
{
"cell_type": "markdown",
"id": "697ceed0",
"metadata": {
"origin_pos": 25
},
"source": [
"接下来,我们[**将模型的参数存储在一个叫做“mlp.params”的文件中。**]\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a53c1315",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:56:43.776452Z",
"iopub.status.busy": "2023-08-18T06:56:43.775942Z",
"iopub.status.idle": "2023-08-18T06:56:43.780387Z",
"shell.execute_reply": "2023-08-18T06:56:43.779636Z"
},
"origin_pos": 27,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"torch.save(net.state_dict(), 'mlp.params')"
]
},
{
"cell_type": "markdown",
"id": "b6df754a",
"metadata": {
"origin_pos": 30
},
"source": [
"为了恢复模型,我们[**实例化了原始多层感知机模型的一个备份。**]\n",
"这里我们不需要随机初始化模型参数,而是(**直接读取文件中存储的参数。**)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "da5e1b3f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:56:43.783850Z",
"iopub.status.busy": "2023-08-18T06:56:43.783240Z",
"iopub.status.idle": "2023-08-18T06:56:43.789905Z",
"shell.execute_reply": "2023-08-18T06:56:43.789164Z"
},
"origin_pos": 32,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"MLP(\n",
" (hidden): Linear(in_features=20, out_features=256, bias=True)\n",
" (output): Linear(in_features=256, out_features=10, bias=True)\n",
")"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clone = MLP()\n",
"clone.load_state_dict(torch.load('mlp.params'))\n",
"clone.eval()"
]
},
{
"cell_type": "markdown",
"id": "65076662",
"metadata": {
"origin_pos": 35
},
"source": [
"由于两个实例具有相同的模型参数,在输入相同的`X`时,\n",
"两个实例的计算结果应该相同。\n",
"让我们来验证一下。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a25ba1f1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:56:43.793400Z",
"iopub.status.busy": "2023-08-18T06:56:43.792788Z",
"iopub.status.idle": "2023-08-18T06:56:43.798329Z",
"shell.execute_reply": "2023-08-18T06:56:43.797576Z"
},
"origin_pos": 37,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True, True, True, True, True, True, True, True, True, True],\n",
" [True, True, True, True, True, True, True, True, True, True]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y_clone = clone(X)\n",
"Y_clone == Y"
]
},
{
"cell_type": "markdown",
"id": "7a65b1e2",
"metadata": {
"origin_pos": 39
},
"source": [
"## 小结\n",
"\n",
"* `save`和`load`函数可用于张量对象的文件读写。\n",
"* 我们可以通过参数字典保存和加载网络的全部参数。\n",
"* 保存架构必须在代码中完成,而不是在参数中完成。\n",
"\n",
"## 练习\n",
"\n",
"1. 即使不需要将经过训练的模型部署到不同的设备上,存储模型参数还有什么实际的好处?\n",
"1. 假设我们只想复用网络的一部分,以将其合并到不同的网络架构中。比如想在一个新的网络中使用之前网络的前两层,该怎么做?\n",
"1. 如何同时保存网络架构和参数?需要对架构加上什么限制?\n"
]
},
{
"cell_type": "markdown",
"id": "d803f301",
"metadata": {
"origin_pos": 41,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/1839)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}