{ "cells": [ { "cell_type": "markdown", "id": "92354008", "metadata": { "origin_pos": 0 }, "source": [ "# softmax回归的从零开始实现\n", ":label:`sec_softmax_scratch`\n", "\n", "(**就像我们从零开始实现线性回归一样,**)\n", "我们认为softmax回归也是重要的基础,因此(**应该知道实现softmax回归的细节**)。\n", "本节我们将使用刚刚在 :numref:`sec_fashion_mnist`中引入的Fashion-MNIST数据集,\n", "并设置数据迭代器的批量大小为256。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "d5454103", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:34.568185Z", "iopub.status.busy": "2023-08-18T07:05:34.567550Z", "iopub.status.idle": "2023-08-18T07:05:36.481085Z", "shell.execute_reply": "2023-08-18T07:05:36.480189Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from IPython import display\n", "from d2l import torch as d2l" ] }, { "cell_type": "code", "execution_count": 2, "id": "b8bd138c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.486713Z", "iopub.status.busy": "2023-08-18T07:05:36.486051Z", "iopub.status.idle": "2023-08-18T07:05:36.589161Z", "shell.execute_reply": "2023-08-18T07:05:36.588107Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)" ] }, { "cell_type": "markdown", "id": "0c00d722", "metadata": { "origin_pos": 6 }, "source": [ "## 初始化模型参数\n", "\n", "和之前线性回归的例子一样,这里的每个样本都将用固定长度的向量表示。\n", "原始数据集中的每个样本都是$28 \\times 28$的图像。\n", "本节[**将展平每个图像,把它们看作长度为784的向量。**]\n", "在后面的章节中,我们将讨论能够利用图像空间结构的特征,\n", "但现在我们暂时只把每个像素位置看作一个特征。\n", "\n", "回想一下,在softmax回归中,我们的输出与类别一样多。\n", "(**因为我们的数据集有10个类别,所以网络输出维度为10**)。\n", "因此,权重将构成一个$784 \\times 10$的矩阵,\n", "偏置将构成一个$1 \\times 10$的行向量。\n", "与线性回归一样,我们将使用正态分布初始化我们的权重`W`,偏置初始化为0。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "4016fe6d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.594606Z", "iopub.status.busy": "2023-08-18T07:05:36.594134Z", "iopub.status.idle": "2023-08-18T07:05:36.599637Z", "shell.execute_reply": "2023-08-18T07:05:36.598552Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "num_inputs = 784\n", "num_outputs = 10\n", "\n", "W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)\n", "b = torch.zeros(num_outputs, requires_grad=True)" ] }, { "cell_type": "markdown", "id": "bcc89948", "metadata": { "origin_pos": 11 }, "source": [ "## 定义softmax操作\n", "\n", "在实现softmax回归模型之前,我们简要回顾一下`sum`运算符如何沿着张量中的特定维度工作。\n", "如 :numref:`subseq_lin-alg-reduction`和\n", " :numref:`subseq_lin-alg-non-reduction`所述,\n", " [**给定一个矩阵`X`,我们可以对所有元素求和**](默认情况下)。\n", " 也可以只求同一个轴上的元素,即同一列(轴0)或同一行(轴1)。\n", " 如果`X`是一个形状为`(2, 3)`的张量,我们对列进行求和,\n", " 则结果将是一个具有形状`(3,)`的向量。\n", " 当调用`sum`运算符时,我们可以指定保持在原始张量的轴数,而不折叠求和的维度。\n", " 这将产生一个具有形状`(1, 3)`的二维张量。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "c0df140e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.604982Z", "iopub.status.busy": "2023-08-18T07:05:36.604096Z", "iopub.status.idle": "2023-08-18T07:05:36.615513Z", "shell.execute_reply": "2023-08-18T07:05:36.614620Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(tensor([[5., 7., 9.]]),\n", " tensor([[ 6.],\n", " [15.]]))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n", "X.sum(0, keepdim=True), X.sum(1, keepdim=True)" ] }, { "cell_type": "markdown", "id": "3b78565c", "metadata": { "origin_pos": 14 }, "source": [ "回想一下,[**实现softmax**]由三个步骤组成:\n", "\n", "1. 对每个项求幂(使用`exp`);\n", "1. 对每一行求和(小批量中每个样本是一行),得到每个样本的规范化常数;\n", "1. 将每一行除以其规范化常数,确保结果的和为1。\n", "\n", "在查看代码之前,我们回顾一下这个表达式:\n", "\n", "(**\n", "$$\n", "\\mathrm{softmax}(\\mathbf{X})_{ij} = \\frac{\\exp(\\mathbf{X}_{ij})}{\\sum_k \\exp(\\mathbf{X}_{ik})}.\n", "$$\n", "**)\n", "\n", "分母或规范化常数,有时也称为*配分函数*(其对数称为对数-配分函数)。\n", "该名称来自[统计物理学](https://en.wikipedia.org/wiki/Partition_function_(statistical_mechanics))中一个模拟粒子群分布的方程。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "4c0be245", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.620749Z", "iopub.status.busy": "2023-08-18T07:05:36.620003Z", "iopub.status.idle": "2023-08-18T07:05:36.624603Z", "shell.execute_reply": "2023-08-18T07:05:36.623701Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def softmax(X):\n", " X_exp = torch.exp(X)\n", " partition = X_exp.sum(1, keepdim=True)\n", " return X_exp / partition # 这里应用了广播机制" ] }, { "cell_type": "markdown", "id": "b641b9eb", "metadata": { "origin_pos": 17 }, "source": [ "正如上述代码,对于任何随机输入,[**我们将每个元素变成一个非负数。\n", "此外,依据概率原理,每行总和为1**]。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "a357bb20", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.629240Z", "iopub.status.busy": "2023-08-18T07:05:36.628878Z", "iopub.status.idle": "2023-08-18T07:05:36.640613Z", "shell.execute_reply": "2023-08-18T07:05:36.639677Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(tensor([[0.1686, 0.4055, 0.0849, 0.1064, 0.2347],\n", " [0.0217, 0.2652, 0.6354, 0.0457, 0.0321]]),\n", " tensor([1.0000, 1.0000]))" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = torch.normal(0, 1, (2, 5))\n", "X_prob = softmax(X)\n", "X_prob, X_prob.sum(1)" ] }, { "cell_type": "markdown", "id": "b5943861", "metadata": { "origin_pos": 20 }, "source": [ "注意,虽然这在数学上看起来是正确的,但我们在代码实现中有点草率。\n", "矩阵中的非常大或非常小的元素可能造成数值上溢或下溢,但我们没有采取措施来防止这点。\n", "\n", "## 定义模型\n", "\n", "定义softmax操作后,我们可以[**实现softmax回归模型**]。\n", "下面的代码定义了输入如何通过网络映射到输出。\n", "注意,将数据传递到模型之前,我们使用`reshape`函数将每张原始图像展平为向量。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "098246b8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.644224Z", "iopub.status.busy": "2023-08-18T07:05:36.643949Z", "iopub.status.idle": "2023-08-18T07:05:36.648644Z", "shell.execute_reply": "2023-08-18T07:05:36.647745Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def net(X):\n", " return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)" ] }, { "cell_type": "markdown", "id": "e46f8133", "metadata": { "origin_pos": 22 }, "source": [ "## 定义损失函数\n", "\n", "接下来,我们实现 :numref:`sec_softmax`中引入的交叉熵损失函数。\n", "这可能是深度学习中最常见的损失函数,因为目前分类问题的数量远远超过回归问题的数量。\n", "\n", "回顾一下,交叉熵采用真实标签的预测概率的负对数似然。\n", "这里我们不使用Python的for循环迭代预测(这往往是低效的),\n", "而是通过一个运算符选择所有元素。\n", "下面,我们[**创建一个数据样本`y_hat`,其中包含2个样本在3个类别的预测概率,\n", "以及它们对应的标签`y`。**]\n", "有了`y`,我们知道在第一个样本中,第一类是正确的预测;\n", "而在第二个样本中,第三类是正确的预测。\n", "然后(**使用`y`作为`y_hat`中概率的索引**),\n", "我们选择第一个样本中第一个类的概率和第二个样本中第三个类的概率。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "d7196ba4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.652705Z", "iopub.status.busy": "2023-08-18T07:05:36.652434Z", "iopub.status.idle": "2023-08-18T07:05:36.660790Z", "shell.execute_reply": "2023-08-18T07:05:36.659617Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([0.1000, 0.5000])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = torch.tensor([0, 2])\n", "y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])\n", "y_hat[[0, 1], y]" ] }, { "cell_type": "markdown", "id": "f60bb6e4", "metadata": { "origin_pos": 25 }, "source": [ "现在我们只需一行代码就可以[**实现交叉熵损失函数**]。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "8a2ec204", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.665898Z", "iopub.status.busy": "2023-08-18T07:05:36.665109Z", "iopub.status.idle": "2023-08-18T07:05:36.672113Z", "shell.execute_reply": "2023-08-18T07:05:36.671215Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([2.3026, 0.6931])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def cross_entropy(y_hat, y):\n", " return - torch.log(y_hat[range(len(y_hat)), y])\n", "\n", "cross_entropy(y_hat, y)" ] }, { "cell_type": "markdown", "id": "889a4000", "metadata": { "origin_pos": 29 }, "source": [ "## 分类精度\n", "\n", "给定预测概率分布`y_hat`,当我们必须输出硬预测(hard prediction)时,\n", "我们通常选择预测概率最高的类。\n", "许多应用都要求我们做出选择。如Gmail必须将电子邮件分类为“Primary(主要邮件)”、\n", "“Social(社交邮件)”“Updates(更新邮件)”或“Forums(论坛邮件)”。\n", "Gmail做分类时可能在内部估计概率,但最终它必须在类中选择一个。\n", "\n", "当预测与标签分类`y`一致时,即是正确的。\n", "分类精度即正确预测数量与总预测数量之比。\n", "虽然直接优化精度可能很困难(因为精度的计算不可导),\n", "但精度通常是我们最关心的性能衡量标准,我们在训练分类器时几乎总会关注它。\n", "\n", "为了计算精度,我们执行以下操作。\n", "首先,如果`y_hat`是矩阵,那么假定第二个维度存储每个类的预测分数。\n", "我们使用`argmax`获得每行中最大元素的索引来获得预测类别。\n", "然后我们[**将预测类别与真实`y`元素进行比较**]。\n", "由于等式运算符“`==`”对数据类型很敏感,\n", "因此我们将`y_hat`的数据类型转换为与`y`的数据类型一致。\n", "结果是一个包含0(错)和1(对)的张量。\n", "最后,我们求和会得到正确预测的数量。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "2038b97d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.676633Z", "iopub.status.busy": "2023-08-18T07:05:36.676080Z", "iopub.status.idle": "2023-08-18T07:05:36.681962Z", "shell.execute_reply": "2023-08-18T07:05:36.680997Z" }, "origin_pos": 30, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def accuracy(y_hat, y): #@save\n", " \"\"\"计算预测正确的数量\"\"\"\n", " if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:\n", " y_hat = y_hat.argmax(axis=1)\n", " cmp = y_hat.type(y.dtype) == y\n", " return float(cmp.type(y.dtype).sum())" ] }, { "cell_type": "markdown", "id": "51a65a85", "metadata": { "origin_pos": 32 }, "source": [ "我们将继续使用之前定义的变量`y_hat`和`y`分别作为预测的概率分布和标签。\n", "可以看到,第一个样本的预测类别是2(该行的最大元素为0.6,索引为2),这与实际标签0不一致。\n", "第二个样本的预测类别是2(该行的最大元素为0.5,索引为2),这与实际标签2一致。\n", "因此,这两个样本的分类精度率为0.5。\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "6337adf4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.686076Z", "iopub.status.busy": "2023-08-18T07:05:36.685804Z", "iopub.status.idle": "2023-08-18T07:05:36.692192Z", "shell.execute_reply": "2023-08-18T07:05:36.691298Z" }, "origin_pos": 33, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "0.5" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy(y_hat, y) / len(y)" ] }, { "cell_type": "markdown", "id": "f553b37b", "metadata": { "origin_pos": 34 }, "source": [ "同样,对于任意数据迭代器`data_iter`可访问的数据集,\n", "[**我们可以评估在任意模型`net`的精度**]。\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "41ea8ca1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.696515Z", "iopub.status.busy": "2023-08-18T07:05:36.696074Z", "iopub.status.idle": "2023-08-18T07:05:36.702503Z", "shell.execute_reply": "2023-08-18T07:05:36.701545Z" }, "origin_pos": 36, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def evaluate_accuracy(net, data_iter): #@save\n", " \"\"\"计算在指定数据集上模型的精度\"\"\"\n", " if isinstance(net, torch.nn.Module):\n", " net.eval() # 将模型设置为评估模式\n", " metric = Accumulator(2) # 正确预测数、预测总数\n", " with torch.no_grad():\n", " for X, y in data_iter:\n", " metric.add(accuracy(net(X), y), y.numel())\n", " return metric[0] / metric[1]" ] }, { "cell_type": "markdown", "id": "eb10ad98", "metadata": { "origin_pos": 38 }, "source": [ "这里定义一个实用程序类`Accumulator`,用于对多个变量进行累加。\n", "在上面的`evaluate_accuracy`函数中,\n", "我们在(**`Accumulator`实例中创建了2个变量,\n", "分别用于存储正确预测的数量和预测的总数量**)。\n", "当我们遍历数据集时,两者都将随着时间的推移而累加。\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "381e6f11", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.707084Z", "iopub.status.busy": "2023-08-18T07:05:36.706353Z", "iopub.status.idle": "2023-08-18T07:05:36.712280Z", "shell.execute_reply": "2023-08-18T07:05:36.711359Z" }, "origin_pos": 39, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class Accumulator: #@save\n", " \"\"\"在n个变量上累加\"\"\"\n", " def __init__(self, n):\n", " self.data = [0.0] * n\n", "\n", " def add(self, *args):\n", " self.data = [a + float(b) for a, b in zip(self.data, args)]\n", "\n", " def reset(self):\n", " self.data = [0.0] * len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " return self.data[idx]" ] }, { "cell_type": "markdown", "id": "cd7411c0", "metadata": { "origin_pos": 40 }, "source": [ "由于我们使用随机权重初始化`net`模型,\n", "因此该模型的精度应接近于随机猜测。\n", "例如在有10个类别情况下的精度为0.1。\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "77706f95", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:36.716926Z", "iopub.status.busy": "2023-08-18T07:05:36.716179Z", "iopub.status.idle": "2023-08-18T07:05:37.338754Z", "shell.execute_reply": "2023-08-18T07:05:37.337496Z" }, "origin_pos": 41, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "0.0625" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluate_accuracy(net, test_iter)" ] }, { "cell_type": "markdown", "id": "7eba262d", "metadata": { "origin_pos": 42 }, "source": [ "## 训练\n", "\n", "在我们看过 :numref:`sec_linear_scratch`中的线性回归实现,\n", "[**softmax回归的训练**]过程代码应该看起来非常眼熟。\n", "在这里,我们重构训练过程的实现以使其可重复使用。\n", "首先,我们定义一个函数来训练一个迭代周期。\n", "请注意,`updater`是更新模型参数的常用函数,它接受批量大小作为参数。\n", "它可以是`d2l.sgd`函数,也可以是框架的内置优化函数。\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "a2e8f2ba", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:37.344329Z", "iopub.status.busy": "2023-08-18T07:05:37.343921Z", "iopub.status.idle": "2023-08-18T07:05:37.354464Z", "shell.execute_reply": "2023-08-18T07:05:37.353391Z" }, "origin_pos": 44, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def train_epoch_ch3(net, train_iter, loss, updater): #@save\n", " \"\"\"训练模型一个迭代周期(定义见第3章)\"\"\"\n", " # 将模型设置为训练模式\n", " if isinstance(net, torch.nn.Module):\n", " net.train()\n", " # 训练损失总和、训练准确度总和、样本数\n", " metric = Accumulator(3)\n", " for X, y in train_iter:\n", " # 计算梯度并更新参数\n", " y_hat = net(X)\n", " l = loss(y_hat, y)\n", " if isinstance(updater, torch.optim.Optimizer):\n", " # 使用PyTorch内置的优化器和损失函数\n", " updater.zero_grad()\n", " l.mean().backward()\n", " updater.step()\n", " else:\n", " # 使用定制的优化器和损失函数\n", " l.sum().backward()\n", " updater(X.shape[0])\n", " metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())\n", " # 返回训练损失和训练精度\n", " return metric[0] / metric[2], metric[1] / metric[2]" ] }, { "cell_type": "markdown", "id": "041a8166", "metadata": { "origin_pos": 47 }, "source": [ "在展示训练函数的实现之前,我们[**定义一个在动画中绘制数据的实用程序类**]`Animator`,\n", "它能够简化本书其余部分的代码。\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "9d3bab29", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:37.360211Z", "iopub.status.busy": "2023-08-18T07:05:37.359378Z", "iopub.status.idle": "2023-08-18T07:05:37.375759Z", "shell.execute_reply": "2023-08-18T07:05:37.374685Z" }, "origin_pos": 48, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class Animator: #@save\n", " \"\"\"在动画中绘制数据\"\"\"\n", " def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,\n", " ylim=None, xscale='linear', yscale='linear',\n", " fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,\n", " figsize=(3.5, 2.5)):\n", " # 增量地绘制多条线\n", " if legend is None:\n", " legend = []\n", " d2l.use_svg_display()\n", " self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)\n", " if nrows * ncols == 1:\n", " self.axes = [self.axes, ]\n", " # 使用lambda函数捕获参数\n", " self.config_axes = lambda: d2l.set_axes(\n", " self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n", " self.X, self.Y, self.fmts = None, None, fmts\n", "\n", " def add(self, x, y):\n", " # 向图表中添加多个数据点\n", " if not hasattr(y, \"__len__\"):\n", " y = [y]\n", " n = len(y)\n", " if not hasattr(x, \"__len__\"):\n", " x = [x] * n\n", " if not self.X:\n", " self.X = [[] for _ in range(n)]\n", " if not self.Y:\n", " self.Y = [[] for _ in range(n)]\n", " for i, (a, b) in enumerate(zip(x, y)):\n", " if a is not None and b is not None:\n", " self.X[i].append(a)\n", " self.Y[i].append(b)\n", " self.axes[0].cla()\n", " for x, y, fmt in zip(self.X, self.Y, self.fmts):\n", " self.axes[0].plot(x, y, fmt)\n", " self.config_axes()\n", " display.display(self.fig)\n", " display.clear_output(wait=True)" ] }, { "cell_type": "markdown", "id": "7c09ec7c", "metadata": { "origin_pos": 49 }, "source": [ "接下来我们实现一个[**训练函数**],\n", "它会在`train_iter`访问到的训练数据集上训练一个模型`net`。\n", "该训练函数将会运行多个迭代周期(由`num_epochs`指定)。\n", "在每个迭代周期结束时,利用`test_iter`访问到的测试数据集对模型进行评估。\n", "我们将利用`Animator`类来可视化训练进度。\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "7ff0a317", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:37.381304Z", "iopub.status.busy": "2023-08-18T07:05:37.380550Z", "iopub.status.idle": "2023-08-18T07:05:37.389072Z", "shell.execute_reply": "2023-08-18T07:05:37.387971Z" }, "origin_pos": 50, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save\n", " \"\"\"训练模型(定义见第3章)\"\"\"\n", " animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],\n", " legend=['train loss', 'train acc', 'test acc'])\n", " for epoch in range(num_epochs):\n", " train_metrics = train_epoch_ch3(net, train_iter, loss, updater)\n", " test_acc = evaluate_accuracy(net, test_iter)\n", " animator.add(epoch + 1, train_metrics + (test_acc,))\n", " train_loss, train_acc = train_metrics\n", " assert train_loss < 0.5, train_loss\n", " assert train_acc <= 1 and train_acc > 0.7, train_acc\n", " assert test_acc <= 1 and test_acc > 0.7, test_acc" ] }, { "cell_type": "markdown", "id": "a5add373", "metadata": { "origin_pos": 51 }, "source": [ "作为一个从零开始的实现,我们使用 :numref:`sec_linear_scratch`中定义的\n", "[**小批量随机梯度下降来优化模型的损失函数**],设置学习率为0.1。\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "44cfab15", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:37.393966Z", "iopub.status.busy": "2023-08-18T07:05:37.393127Z", "iopub.status.idle": "2023-08-18T07:05:37.398492Z", "shell.execute_reply": "2023-08-18T07:05:37.397420Z" }, "origin_pos": 52, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "lr = 0.1\n", "\n", "def updater(batch_size):\n", " return d2l.sgd([W, b], lr, batch_size)" ] }, { "cell_type": "markdown", "id": "0291691f", "metadata": { "origin_pos": 54 }, "source": [ "现在,我们[**训练模型10个迭代周期**]。\n", "请注意,迭代周期(`num_epochs`)和学习率(`lr`)都是可调节的超参数。\n", "通过更改它们的值,我们可以提高模型的分类精度。\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "fb9c12f8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:05:37.403054Z", "iopub.status.busy": "2023-08-18T07:05:37.402682Z", "iopub.status.idle": "2023-08-18T07:06:16.273679Z", "shell.execute_reply": "2023-08-18T07:06:16.272655Z" }, "origin_pos": 55, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:06:16.233360\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs = 10\n", "train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)" ] }, { "cell_type": "markdown", "id": "a8121f40", "metadata": { "origin_pos": 56 }, "source": [ "## 预测\n", "\n", "现在训练已经完成,我们的模型已经准备好[**对图像进行分类预测**]。\n", "给定一系列图像,我们将比较它们的实际标签(文本输出的第一行)和模型预测(文本输出的第二行)。\n" ] }, { "cell_type": "code", "execution_count": 20, "id": "74ba2d12", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:06:16.277808Z", "iopub.status.busy": "2023-08-18T07:06:16.277179Z", "iopub.status.idle": "2023-08-18T07:06:16.734243Z", "shell.execute_reply": "2023-08-18T07:06:16.733343Z" }, "origin_pos": 57, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:06:16.683770\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def predict_ch3(net, test_iter, n=6): #@save\n", " \"\"\"预测标签(定义见第3章)\"\"\"\n", " for X, y in test_iter:\n", " break\n", " trues = d2l.get_fashion_mnist_labels(y)\n", " preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))\n", " titles = [true +'\\n' + pred for true, pred in zip(trues, preds)]\n", " d2l.show_images(\n", " X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])\n", "\n", "predict_ch3(net, test_iter)" ] }, { "cell_type": "markdown", "id": "fd6f7fb0", "metadata": { "origin_pos": 58 }, "source": [ "## 小结\n", "\n", "* 借助softmax回归,我们可以训练多分类的模型。\n", "* 训练softmax回归循环模型与训练线性回归模型非常相似:先读取数据,再定义模型和损失函数,然后使用优化算法训练模型。大多数常见的深度学习模型都有类似的训练过程。\n", "\n", "## 练习\n", "\n", "1. 本节直接实现了基于数学定义softmax运算的`softmax`函数。这可能会导致什么问题?提示:尝试计算$\\exp(50)$的大小。\n", "1. 本节中的函数`cross_entropy`是根据交叉熵损失函数的定义实现的。它可能有什么问题?提示:考虑对数的定义域。\n", "1. 请想一个解决方案来解决上述两个问题。\n", "1. 返回概率最大的分类标签总是最优解吗?例如,医疗诊断场景下可以这样做吗?\n", "1. 假设我们使用softmax回归来预测下一个单词,可选取的单词数目过多可能会带来哪些问题?\n" ] }, { "cell_type": "markdown", "id": "ed8131a1", "metadata": { "origin_pos": 60, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1789)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }