{ "cells": [ { "cell_type": "markdown", "id": "1464cc27", "metadata": { "origin_pos": 0 }, "source": [ "# softmax回归的简洁实现\n", ":label:`sec_softmax_concise`\n", "\n", "在 :numref:`sec_linear_concise`中,\n", "我们发现(**通过深度学习框架的高级API能够使实现**)\n", "(~~softmax~~)\n", "线性(**回归变得更加容易**)。\n", "同样,通过深度学习框架的高级API也能更方便地实现softmax回归模型。\n", "本节如在 :numref:`sec_softmax_scratch`中一样,\n", "继续使用Fashion-MNIST数据集,并保持批量大小为256。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "7f81001f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:57:16.212083Z", "iopub.status.busy": "2023-08-18T06:57:16.211369Z", "iopub.status.idle": "2023-08-18T06:57:18.227520Z", "shell.execute_reply": "2023-08-18T06:57:18.226314Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "code", "execution_count": 2, "id": "92e395a8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:57:18.231806Z", "iopub.status.busy": "2023-08-18T06:57:18.230933Z", "iopub.status.idle": "2023-08-18T06:57:18.337514Z", "shell.execute_reply": "2023-08-18T06:57:18.336238Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size = 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)" ] }, { "cell_type": "markdown", "id": "fb2b9199", "metadata": { "origin_pos": 6 }, "source": [ "## 初始化模型参数\n", "\n", "如我们在 :numref:`sec_softmax`所述,\n", "[**softmax回归的输出层是一个全连接层**]。\n", "因此,为了实现我们的模型,\n", "我们只需在`Sequential`中添加一个带有10个输出的全连接层。\n", "同样,在这里`Sequential`并不是必要的,\n", "但它是实现深度模型的基础。\n", "我们仍然以均值0和标准差0.01随机初始化权重。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "ebf37311", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:57:18.342288Z", "iopub.status.busy": "2023-08-18T06:57:18.342007Z", "iopub.status.idle": "2023-08-18T06:57:18.349431Z", "shell.execute_reply": "2023-08-18T06:57:18.348277Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "# PyTorch不会隐式地调整输入的形状。因此,\n", "# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状\n", "net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))\n", "\n", "def init_weights(m):\n", " if type(m) == nn.Linear:\n", " nn.init.normal_(m.weight, std=0.01)\n", "\n", "net.apply(init_weights);" ] }, { "cell_type": "markdown", "id": "30a6c3c5", "metadata": { "origin_pos": 11 }, "source": [ "## 重新审视Softmax的实现\n", ":label:`subsec_softmax-implementation-revisited`\n", "\n", "在前面 :numref:`sec_softmax_scratch`的例子中,\n", "我们计算了模型的输出,然后将此输出送入交叉熵损失。\n", "从数学上讲,这是一件完全合理的事情。\n", "然而,从计算角度来看,指数可能会造成数值稳定性问题。\n", "\n", "回想一下,softmax函数$\\hat y_j = \\frac{\\exp(o_j)}{\\sum_k \\exp(o_k)}$,\n", "其中$\\hat y_j$是预测的概率分布。\n", "$o_j$是未规范化的预测$\\mathbf{o}$的第$j$个元素。\n", "如果$o_k$中的一些数值非常大,\n", "那么$\\exp(o_k)$可能大于数据类型容许的最大数字,即*上溢*(overflow)。\n", "这将使分母或分子变为`inf`(无穷大),\n", "最后得到的是0、`inf`或`nan`(不是数字)的$\\hat y_j$。\n", "在这些情况下,我们无法得到一个明确定义的交叉熵值。\n", "\n", "解决这个问题的一个技巧是:\n", "在继续softmax计算之前,先从所有$o_k$中减去$\\max(o_k)$。\n", "这里可以看到每个$o_k$按常数进行的移动不会改变softmax的返回值:\n", "\n", "$$\n", "\\begin{aligned}\n", "\\hat y_j & = \\frac{\\exp(o_j - \\max(o_k))\\exp(\\max(o_k))}{\\sum_k \\exp(o_k - \\max(o_k))\\exp(\\max(o_k))} \\\\\n", "& = \\frac{\\exp(o_j - \\max(o_k))}{\\sum_k \\exp(o_k - \\max(o_k))}.\n", "\\end{aligned}\n", "$$\n", "\n", "\n", "在减法和规范化步骤之后,可能有些$o_j - \\max(o_k)$具有较大的负值。\n", "由于精度受限,$\\exp(o_j - \\max(o_k))$将有接近零的值,即*下溢*(underflow)。\n", "这些值可能会四舍五入为零,使$\\hat y_j$为零,\n", "并且使得$\\log(\\hat y_j)$的值为`-inf`。\n", "反向传播几步后,我们可能会发现自己面对一屏幕可怕的`nan`结果。\n", "\n", "尽管我们要计算指数函数,但我们最终在计算交叉熵损失时会取它们的对数。\n", "通过将softmax和交叉熵结合在一起,可以避免反向传播过程中可能会困扰我们的数值稳定性问题。\n", "如下面的等式所示,我们避免计算$\\exp(o_j - \\max(o_k))$,\n", "而可以直接使用$o_j - \\max(o_k)$,因为$\\log(\\exp(\\cdot))$被抵消了。\n", "\n", "$$\n", "\\begin{aligned}\n", "\\log{(\\hat y_j)} & = \\log\\left( \\frac{\\exp(o_j - \\max(o_k))}{\\sum_k \\exp(o_k - \\max(o_k))}\\right) \\\\\n", "& = \\log{(\\exp(o_j - \\max(o_k)))}-\\log{\\left( \\sum_k \\exp(o_k - \\max(o_k)) \\right)} \\\\\n", "& = o_j - \\max(o_k) -\\log{\\left( \\sum_k \\exp(o_k - \\max(o_k)) \\right)}.\n", "\\end{aligned}\n", "$$\n", "\n", "我们也希望保留传统的softmax函数,以备我们需要评估通过模型输出的概率。\n", "但是,我们没有将softmax概率传递到损失函数中,\n", "而是[**在交叉熵损失函数中传递未规范化的预测,并同时计算softmax及其对数**],\n", "这是一种类似[\"LogSumExp技巧\"](https://en.wikipedia.org/wiki/LogSumExp)的聪明方式。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "91c3ac45", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:57:18.353684Z", "iopub.status.busy": "2023-08-18T06:57:18.353410Z", "iopub.status.idle": "2023-08-18T06:57:18.358187Z", "shell.execute_reply": "2023-08-18T06:57:18.357079Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "loss = nn.CrossEntropyLoss(reduction='none')" ] }, { "cell_type": "markdown", "id": "c347cec2", "metadata": { "origin_pos": 15 }, "source": [ "## 优化算法\n", "\n", "在这里,我们(**使用学习率为0.1的小批量随机梯度下降作为优化算法**)。\n", "这与我们在线性回归例子中的相同,这说明了优化器的普适性。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "c4849ef8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:57:18.362274Z", "iopub.status.busy": "2023-08-18T06:57:18.361998Z", "iopub.status.idle": "2023-08-18T06:57:18.366991Z", "shell.execute_reply": "2023-08-18T06:57:18.365798Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "trainer = torch.optim.SGD(net.parameters(), lr=0.1)" ] }, { "cell_type": "markdown", "id": "c2cf8941", "metadata": { "origin_pos": 20 }, "source": [ "## 训练\n", "\n", "接下来我们[**调用**] :numref:`sec_softmax_scratch`中(~~之前~~)\n", "(**定义的训练函数来训练模型**)。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "5acea90d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:57:18.371133Z", "iopub.status.busy": "2023-08-18T06:57:18.370849Z", "iopub.status.idle": "2023-08-18T06:58:00.716532Z", "shell.execute_reply": "2023-08-18T06:58:00.715223Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T06:58:00.674658\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs = 10\n", "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)" ] }, { "cell_type": "markdown", "id": "e09d7d1c", "metadata": { "origin_pos": 22 }, "source": [ "和以前一样,这个算法使结果收敛到一个相当高的精度,而且这次的代码比之前更精简了。\n", "\n", "## 小结\n", "\n", "* 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。\n", "* 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。\n", "\n", "## 练习\n", "\n", "1. 尝试调整超参数,例如批量大小、迭代周期数和学习率,并查看结果。\n", "1. 增加迭代周期的数量。为什么测试精度会在一段时间后降低?我们怎么解决这个问题?\n" ] }, { "cell_type": "markdown", "id": "e81d17a2", "metadata": { "origin_pos": 24, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1793)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }