{ "cells": [ { "cell_type": "markdown", "id": "7a776eb6", "metadata": { "origin_pos": 0 }, "source": [ "# 单发多框检测(SSD)\n", ":label:`sec_ssd`\n", "\n", "在 :numref:`sec_bbox`— :numref:`sec_object-detection-dataset`中,我们分别介绍了边界框、锚框、多尺度目标检测和用于目标检测的数据集。\n", "现在我们已经准备好使用这样的背景知识来设计一个目标检测模型:单发多框检测(SSD) :cite:`Liu.Anguelov.Erhan.ea.2016`。\n", "该模型简单、快速且被广泛使用。尽管这只是其中一种目标检测模型,但本节中的一些设计原则和实现细节也适用于其他模型。\n", "\n", "## 模型\n", "\n", " :numref:`fig_ssd`描述了单发多框检测模型的设计。\n", "此模型主要由基础网络组成,其后是几个多尺度特征块。\n", "基本网络用于从输入图像中提取特征,因此它可以使用深度卷积神经网络。\n", "单发多框检测论文中选用了在分类层之前截断的VGG :cite:`Liu.Anguelov.Erhan.ea.2016`,现在也常用ResNet替代。\n", "我们可以设计基础网络,使它输出的高和宽较大。\n", "这样一来,基于该特征图生成的锚框数量较多,可以用来检测尺寸较小的目标。\n", "接下来的每个多尺度特征块将上一层提供的特征图的高和宽缩小(如减半),并使特征图中每个单元在输入图像上的感受野变得更广阔。\n", "\n", "回想一下在 :numref:`sec_multiscale-object-detection`中,通过深度神经网络分层表示图像的多尺度目标检测的设计。\n", "由于接近 :numref:`fig_ssd`顶部的多尺度特征图较小,但具有较大的感受野,它们适合检测较少但较大的物体。\n", "简而言之,通过多尺度特征块,单发多框检测生成不同大小的锚框,并通过预测边界框的类别和偏移量来检测大小不同的目标,因此这是一个多尺度目标检测模型。\n", "\n", "![单发多框检测模型主要由一个基础网络块和若干多尺度特征块串联而成。](../img/ssd.svg)\n", ":label:`fig_ssd`\n", "\n", "在下面,我们将介绍 :numref:`fig_ssd`中不同块的实施细节。\n", "首先,我们将讨论如何实施类别和边界框预测。\n", "\n", "### [**类别预测层**]\n", "\n", "设目标类别的数量为$q$。这样一来,锚框有$q+1$个类别,其中0类是背景。\n", "在某个尺度下,设特征图的高和宽分别为$h$和$w$。\n", "如果以其中每个单元为中心生成$a$个锚框,那么我们需要对$hwa$个锚框进行分类。\n", "如果使用全连接层作为输出,很容易导致模型参数过多。\n", "回忆 :numref:`sec_nin`一节介绍的使用卷积层的通道来输出类别预测的方法,\n", "单发多框检测采用同样的方法来降低模型复杂度。\n", "\n", "具体来说,类别预测层使用一个保持输入高和宽的卷积层。\n", "这样一来,输出和输入在特征图宽和高上的空间坐标一一对应。\n", "考虑输出和输入同一空间坐标($x$、$y$):输出特征图上($x$、$y$)坐标的通道里包含了以输入特征图($x$、$y$)坐标为中心生成的所有锚框的类别预测。\n", "因此输出通道数为$a(q+1)$,其中索引为$i(q+1) + j$($0 \\leq j \\leq q$)的通道代表了索引为$i$的锚框有关类别索引为$j$的预测。\n", "\n", "在下面,我们定义了这样一个类别预测层,通过参数`num_anchors`和`num_classes`分别指定了$a$和$q$。\n", "该图层使用填充为1的$3\\times3$的卷积层。此卷积层的输入和输出的宽度和高度保持不变。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "563696ed", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:10.621444Z", "iopub.status.busy": "2023-08-18T07:17:10.620782Z", "iopub.status.idle": "2023-08-18T07:17:13.555579Z", "shell.execute_reply": "2023-08-18T07:17:13.554045Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "import torchvision\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from d2l import torch as d2l\n", "\n", "\n", "def cls_predictor(num_inputs, num_anchors, num_classes):\n", " return nn.Conv2d(num_inputs, num_anchors * (num_classes + 1),\n", " kernel_size=3, padding=1)" ] }, { "cell_type": "markdown", "id": "1e1e0dfc", "metadata": { "origin_pos": 4 }, "source": [ "### (**边界框预测层**)\n", "\n", "边界框预测层的设计与类别预测层的设计类似。\n", "唯一不同的是,这里需要为每个锚框预测4个偏移量,而不是$q+1$个类别。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "0e7f1560", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.561361Z", "iopub.status.busy": "2023-08-18T07:17:13.560758Z", "iopub.status.idle": "2023-08-18T07:17:13.566605Z", "shell.execute_reply": "2023-08-18T07:17:13.565544Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def bbox_predictor(num_inputs, num_anchors):\n", " return nn.Conv2d(num_inputs, num_anchors * 4, kernel_size=3, padding=1)" ] }, { "cell_type": "markdown", "id": "2292886e", "metadata": { "origin_pos": 8 }, "source": [ "### [**连结多尺度的预测**]\n", "\n", "正如我们所提到的,单发多框检测使用多尺度特征图来生成锚框并预测其类别和偏移量。\n", "在不同的尺度下,特征图的形状或以同一单元为中心的锚框的数量可能会有所不同。\n", "因此,不同尺度下预测输出的形状可能会有所不同。\n", "\n", "在以下示例中,我们为同一个小批量构建两个不同比例(`Y1`和`Y2`)的特征图,其中`Y2`的高度和宽度是`Y1`的一半。\n", "以类别预测为例,假设`Y1`和`Y2`的每个单元分别生成了$5$个和$3$个锚框。\n", "进一步假设目标类别的数量为$10$,对于特征图`Y1`和`Y2`,类别预测输出中的通道数分别为$5\\times(10+1)=55$和$3\\times(10+1)=33$,其中任一输出的形状是(批量大小,通道数,高度,宽度)。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "1978218f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.571902Z", "iopub.status.busy": "2023-08-18T07:17:13.571041Z", "iopub.status.idle": "2023-08-18T07:17:13.700224Z", "shell.execute_reply": "2023-08-18T07:17:13.699037Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([2, 55, 20, 20]), torch.Size([2, 33, 10, 10]))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def forward(x, block):\n", " return block(x)\n", "\n", "Y1 = forward(torch.zeros((2, 8, 20, 20)), cls_predictor(8, 5, 10))\n", "Y2 = forward(torch.zeros((2, 16, 10, 10)), cls_predictor(16, 3, 10))\n", "Y1.shape, Y2.shape" ] }, { "cell_type": "markdown", "id": "c9abaf3e", "metadata": { "origin_pos": 12 }, "source": [ "正如我们所看到的,除了批量大小这一维度外,其他三个维度都具有不同的尺寸。\n", "为了将这两个预测输出链接起来以提高计算效率,我们将把这些张量转换为更一致的格式。\n", "\n", "通道维包含中心相同的锚框的预测结果。我们首先将通道维移到最后一维。\n", "因为不同尺度下批量大小仍保持不变,我们可以将预测结果转成二维的(批量大小,高$\\times$宽$\\times$通道数)的格式,以方便之后在维度$1$上的连结。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "11d90c8d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.711602Z", "iopub.status.busy": "2023-08-18T07:17:13.706279Z", "iopub.status.idle": "2023-08-18T07:17:13.718777Z", "shell.execute_reply": "2023-08-18T07:17:13.717594Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def flatten_pred(pred):\n", " return torch.flatten(pred.permute(0, 2, 3, 1), start_dim=1)\n", "\n", "def concat_preds(preds):\n", " return torch.cat([flatten_pred(p) for p in preds], dim=1)" ] }, { "cell_type": "markdown", "id": "67254e16", "metadata": { "origin_pos": 16 }, "source": [ "这样一来,尽管`Y1`和`Y2`在通道数、高度和宽度方面具有不同的大小,我们仍然可以在同一个小批量的两个不同尺度上连接这两个预测输出。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "82a882b4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.724224Z", "iopub.status.busy": "2023-08-18T07:17:13.723406Z", "iopub.status.idle": "2023-08-18T07:17:13.731796Z", "shell.execute_reply": "2023-08-18T07:17:13.730728Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 25300])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "concat_preds([Y1, Y2]).shape" ] }, { "cell_type": "markdown", "id": "ba3e5c35", "metadata": { "origin_pos": 18 }, "source": [ "### [**高和宽减半块**]\n", "\n", "为了在多个尺度下检测目标,我们在下面定义了高和宽减半块`down_sample_blk`,该模块将输入特征图的高度和宽度减半。\n", "事实上,该块应用了在 :numref:`subsec_vgg-blocks`中的VGG模块设计。\n", "更具体地说,每个高和宽减半块由两个填充为$1$的$3\\times3$的卷积层、以及步幅为$2$的$2\\times2$最大汇聚层组成。\n", "我们知道,填充为$1$的$3\\times3$卷积层不改变特征图的形状。但是,其后的$2\\times2$的最大汇聚层将输入特征图的高度和宽度减少了一半。\n", "对于此高和宽减半块的输入和输出特征图,因为$1\\times 2+(3-1)+(3-1)=6$,所以输出中的每个单元在输入上都有一个$6\\times6$的感受野。因此,高和宽减半块会扩大每个单元在其输出特征图中的感受野。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "e9fd4f8b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.736409Z", "iopub.status.busy": "2023-08-18T07:17:13.735546Z", "iopub.status.idle": "2023-08-18T07:17:13.743144Z", "shell.execute_reply": "2023-08-18T07:17:13.742156Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def down_sample_blk(in_channels, out_channels):\n", " blk = []\n", " for _ in range(2):\n", " blk.append(nn.Conv2d(in_channels, out_channels,\n", " kernel_size=3, padding=1))\n", " blk.append(nn.BatchNorm2d(out_channels))\n", " blk.append(nn.ReLU())\n", " in_channels = out_channels\n", " blk.append(nn.MaxPool2d(2))\n", " return nn.Sequential(*blk)" ] }, { "cell_type": "markdown", "id": "2ad13aeb", "metadata": { "origin_pos": 22 }, "source": [ "在以下示例中,我们构建的高和宽减半块会更改输入通道的数量,并将输入特征图的高度和宽度减半。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "ead207aa", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.747496Z", "iopub.status.busy": "2023-08-18T07:17:13.746637Z", "iopub.status.idle": "2023-08-18T07:17:13.760305Z", "shell.execute_reply": "2023-08-18T07:17:13.759092Z" }, "origin_pos": 24, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 10, 10, 10])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forward(torch.zeros((2, 3, 20, 20)), down_sample_blk(3, 10)).shape" ] }, { "cell_type": "markdown", "id": "c39b0407", "metadata": { "origin_pos": 26 }, "source": [ "### [**基本网络块**]\n", "\n", "基本网络块用于从输入图像中抽取特征。\n", "为了计算简洁,我们构造了一个小的基础网络,该网络串联3个高和宽减半块,并逐步将通道数翻倍。\n", "给定输入图像的形状为$256\\times256$,此基本网络块输出的特征图形状为$32 \\times 32$($256/2^3=32$)。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "12d15b30", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.765884Z", "iopub.status.busy": "2023-08-18T07:17:13.764976Z", "iopub.status.idle": "2023-08-18T07:17:13.810888Z", "shell.execute_reply": "2023-08-18T07:17:13.809472Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 64, 32, 32])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def base_net():\n", " blk = []\n", " num_filters = [3, 16, 32, 64]\n", " for i in range(len(num_filters) - 1):\n", " blk.append(down_sample_blk(num_filters[i], num_filters[i+1]))\n", " return nn.Sequential(*blk)\n", "\n", "forward(torch.zeros((2, 3, 256, 256)), base_net()).shape" ] }, { "cell_type": "markdown", "id": "f2e2905e", "metadata": { "origin_pos": 30 }, "source": [ "### 完整的模型\n", "\n", "[**完整的单发多框检测模型由五个模块组成**]。每个块生成的特征图既用于生成锚框,又用于预测这些锚框的类别和偏移量。在这五个模块中,第一个是基本网络块,第二个到第四个是高和宽减半块,最后一个模块使用全局最大池将高度和宽度都降到1。从技术上讲,第二到第五个区块都是 :numref:`fig_ssd`中的多尺度特征块。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "a1299bce", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.816373Z", "iopub.status.busy": "2023-08-18T07:17:13.815882Z", "iopub.status.idle": "2023-08-18T07:17:13.823049Z", "shell.execute_reply": "2023-08-18T07:17:13.821892Z" }, "origin_pos": 32, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def get_blk(i):\n", " if i == 0:\n", " blk = base_net()\n", " elif i == 1:\n", " blk = down_sample_blk(64, 128)\n", " elif i == 4:\n", " blk = nn.AdaptiveMaxPool2d((1,1))\n", " else:\n", " blk = down_sample_blk(128, 128)\n", " return blk" ] }, { "cell_type": "markdown", "id": "ffe48b79", "metadata": { "origin_pos": 34 }, "source": [ "现在我们[**为每个块定义前向传播**]。与图像分类任务不同,此处的输出包括:CNN特征图`Y`;在当前尺度下根据`Y`生成的锚框;预测的这些锚框的类别和偏移量(基于`Y`)。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "a32c85fe", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.828048Z", "iopub.status.busy": "2023-08-18T07:17:13.827370Z", "iopub.status.idle": "2023-08-18T07:17:13.834162Z", "shell.execute_reply": "2023-08-18T07:17:13.833029Z" }, "origin_pos": 36, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def blk_forward(X, blk, size, ratio, cls_predictor, bbox_predictor):\n", " Y = blk(X)\n", " anchors = d2l.multibox_prior(Y, sizes=size, ratios=ratio)\n", " cls_preds = cls_predictor(Y)\n", " bbox_preds = bbox_predictor(Y)\n", " return (Y, anchors, cls_preds, bbox_preds)" ] }, { "cell_type": "markdown", "id": "e342b537", "metadata": { "origin_pos": 38 }, "source": [ "回想一下,在 :numref:`fig_ssd`中,一个较接近顶部的多尺度特征块是用于检测较大目标的,因此需要生成更大的锚框。\n", "在上面的前向传播中,在每个多尺度特征块上,我们通过调用的`multibox_prior`函数(见 :numref:`sec_anchor`)的`sizes`参数传递两个比例值的列表。\n", "在下面,0.2和1.05之间的区间被均匀分成五个部分,以确定五个模块的在不同尺度下的较小值:0.2、0.37、0.54、0.71和0.88。\n", "之后,他们较大的值由$\\sqrt{0.2 \\times 0.37} = 0.272$、$\\sqrt{0.37 \\times 0.54} = 0.447$等给出。\n", "\n", "[~~超参数~~]\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "c059209e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.839681Z", "iopub.status.busy": "2023-08-18T07:17:13.838916Z", "iopub.status.idle": "2023-08-18T07:17:13.845388Z", "shell.execute_reply": "2023-08-18T07:17:13.844279Z" }, "origin_pos": 39, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79],\n", " [0.88, 0.961]]\n", "ratios = [[1, 2, 0.5]] * 5\n", "num_anchors = len(sizes[0]) + len(ratios[0]) - 1" ] }, { "cell_type": "markdown", "id": "7eeb21d9", "metadata": { "origin_pos": 40 }, "source": [ "现在,我们就可以按如下方式[**定义完整的模型**]`TinySSD`了。\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "c872fa1d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.850809Z", "iopub.status.busy": "2023-08-18T07:17:13.850043Z", "iopub.status.idle": "2023-08-18T07:17:13.862966Z", "shell.execute_reply": "2023-08-18T07:17:13.861799Z" }, "origin_pos": 42, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class TinySSD(nn.Module):\n", " def __init__(self, num_classes, **kwargs):\n", " super(TinySSD, self).__init__(**kwargs)\n", " self.num_classes = num_classes\n", " idx_to_in_channels = [64, 128, 128, 128, 128]\n", " for i in range(5):\n", " # 即赋值语句self.blk_i=get_blk(i)\n", " setattr(self, f'blk_{i}', get_blk(i))\n", " setattr(self, f'cls_{i}', cls_predictor(idx_to_in_channels[i],\n", " num_anchors, num_classes))\n", " setattr(self, f'bbox_{i}', bbox_predictor(idx_to_in_channels[i],\n", " num_anchors))\n", "\n", " def forward(self, X):\n", " anchors, cls_preds, bbox_preds = [None] * 5, [None] * 5, [None] * 5\n", " for i in range(5):\n", " # getattr(self,'blk_%d'%i)即访问self.blk_i\n", " X, anchors[i], cls_preds[i], bbox_preds[i] = blk_forward(\n", " X, getattr(self, f'blk_{i}'), sizes[i], ratios[i],\n", " getattr(self, f'cls_{i}'), getattr(self, f'bbox_{i}'))\n", " anchors = torch.cat(anchors, dim=1)\n", " cls_preds = concat_preds(cls_preds)\n", " cls_preds = cls_preds.reshape(\n", " cls_preds.shape[0], -1, self.num_classes + 1)\n", " bbox_preds = concat_preds(bbox_preds)\n", " return anchors, cls_preds, bbox_preds" ] }, { "cell_type": "markdown", "id": "12faa05a", "metadata": { "origin_pos": 44 }, "source": [ "我们[**创建一个模型实例,然后使用它**]对一个$256 \\times 256$像素的小批量图像`X`(**执行前向传播**)。\n", "\n", "如本节前面部分所示,第一个模块输出特征图的形状为$32 \\times 32$。\n", "回想一下,第二到第四个模块为高和宽减半块,第五个模块为全局汇聚层。\n", "由于以特征图的每个单元为中心有$4$个锚框生成,因此在所有五个尺度下,每个图像总共生成$(32^2 + 16^2 + 8^2 + 4^2 + 1)\\times 4 = 5444$个锚框。\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "4f690e1e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:13.868497Z", "iopub.status.busy": "2023-08-18T07:17:13.867704Z", "iopub.status.idle": "2023-08-18T07:17:14.352257Z", "shell.execute_reply": "2023-08-18T07:17:14.351363Z" }, "origin_pos": 46, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output anchors: torch.Size([1, 5444, 4])\n", "output class preds: torch.Size([32, 5444, 2])\n", "output bbox preds: torch.Size([32, 21776])\n" ] } ], "source": [ "net = TinySSD(num_classes=1)\n", "X = torch.zeros((32, 3, 256, 256))\n", "anchors, cls_preds, bbox_preds = net(X)\n", "\n", "print('output anchors:', anchors.shape)\n", "print('output class preds:', cls_preds.shape)\n", "print('output bbox preds:', bbox_preds.shape)" ] }, { "cell_type": "markdown", "id": "88c255f7", "metadata": { "origin_pos": 48 }, "source": [ "## 训练模型\n", "\n", "现在,我们将描述如何训练用于目标检测的单发多框检测模型。\n", "\n", "### 读取数据集和初始化\n", "\n", "首先,让我们[**读取**] :numref:`sec_object-detection-dataset`中描述的(**香蕉检测数据集**)。\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "929c68e5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:14.357283Z", "iopub.status.busy": "2023-08-18T07:17:14.356878Z", "iopub.status.idle": "2023-08-18T07:17:19.113866Z", "shell.execute_reply": "2023-08-18T07:17:19.112868Z" }, "origin_pos": 49, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "read 1000 training examples\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "read 100 validation examples\n" ] } ], "source": [ "batch_size = 32\n", "train_iter, _ = d2l.load_data_bananas(batch_size)" ] }, { "cell_type": "markdown", "id": "c3b335ed", "metadata": { "origin_pos": 50 }, "source": [ "香蕉检测数据集中,目标的类别数为1。\n", "定义好模型后,我们需要(**初始化其参数并定义优化算法**)。\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "1b2d1579", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:19.117581Z", "iopub.status.busy": "2023-08-18T07:17:19.116990Z", "iopub.status.idle": "2023-08-18T07:17:19.138203Z", "shell.execute_reply": "2023-08-18T07:17:19.137289Z" }, "origin_pos": 52, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "device, net = d2l.try_gpu(), TinySSD(num_classes=1)\n", "trainer = torch.optim.SGD(net.parameters(), lr=0.2, weight_decay=5e-4)" ] }, { "cell_type": "markdown", "id": "ff7b1d34", "metadata": { "origin_pos": 54 }, "source": [ "### [**定义损失函数和评价函数**]\n", "\n", "目标检测有两种类型的损失。\n", "第一种有关锚框类别的损失:我们可以简单地复用之前图像分类问题里一直使用的交叉熵损失函数来计算;\n", "第二种有关正类锚框偏移量的损失:预测偏移量是一个回归问题。\n", "但是,对于这个回归问题,我们在这里不使用 :numref:`subsec_normal_distribution_and_squared_loss`中描述的平方损失,而是使用$L_1$范数损失,即预测值和真实值之差的绝对值。\n", "掩码变量`bbox_masks`令负类锚框和填充锚框不参与损失的计算。\n", "最后,我们将锚框类别和偏移量的损失相加,以获得模型的最终损失函数。\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "37ad81e9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:19.142054Z", "iopub.status.busy": "2023-08-18T07:17:19.141508Z", "iopub.status.idle": "2023-08-18T07:17:19.147193Z", "shell.execute_reply": "2023-08-18T07:17:19.146422Z" }, "origin_pos": 56, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "cls_loss = nn.CrossEntropyLoss(reduction='none')\n", "bbox_loss = nn.L1Loss(reduction='none')\n", "\n", "def calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks):\n", " batch_size, num_classes = cls_preds.shape[0], cls_preds.shape[2]\n", " cls = cls_loss(cls_preds.reshape(-1, num_classes),\n", " cls_labels.reshape(-1)).reshape(batch_size, -1).mean(dim=1)\n", " bbox = bbox_loss(bbox_preds * bbox_masks,\n", " bbox_labels * bbox_masks).mean(dim=1)\n", " return cls + bbox" ] }, { "cell_type": "markdown", "id": "e0d8d1ce", "metadata": { "origin_pos": 58 }, "source": [ "我们可以沿用准确率评价分类结果。\n", "由于偏移量使用了$L_1$范数损失,我们使用*平均绝对误差*来评价边界框的预测结果。这些预测结果是从生成的锚框及其预测偏移量中获得的。\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "df7b0fae", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:19.151032Z", "iopub.status.busy": "2023-08-18T07:17:19.150178Z", "iopub.status.idle": "2023-08-18T07:17:19.156061Z", "shell.execute_reply": "2023-08-18T07:17:19.154909Z" }, "origin_pos": 60, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def cls_eval(cls_preds, cls_labels):\n", " # 由于类别预测结果放在最后一维,argmax需要指定最后一维。\n", " return float((cls_preds.argmax(dim=-1).type(\n", " cls_labels.dtype) == cls_labels).sum())\n", "\n", "def bbox_eval(bbox_preds, bbox_labels, bbox_masks):\n", " return float((torch.abs((bbox_labels - bbox_preds) * bbox_masks)).sum())" ] }, { "cell_type": "markdown", "id": "c7253cca", "metadata": { "origin_pos": 62 }, "source": [ "### [**训练模型**]\n", "\n", "在训练模型时,我们需要在模型的前向传播过程中生成多尺度锚框(`anchors`),并预测其类别(`cls_preds`)和偏移量(`bbox_preds`)。\n", "然后,我们根据标签信息`Y`为生成的锚框标记类别(`cls_labels`)和偏移量(`bbox_labels`)。\n", "最后,我们根据类别和偏移量的预测和标注值计算损失函数。为了代码简洁,这里没有评价测试数据集。\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "6e08a9c2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:19.159726Z", "iopub.status.busy": "2023-08-18T07:17:19.159190Z", "iopub.status.idle": "2023-08-18T07:18:50.670936Z", "shell.execute_reply": "2023-08-18T07:18:50.670107Z" }, "origin_pos": 64, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "class err 3.22e-03, bbox mae 3.16e-03\n", "3353.9 examples/sec on cuda:0\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:18:50.635153\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs, timer = 20, d2l.Timer()\n", "animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", " legend=['class error', 'bbox mae'])\n", "net = net.to(device)\n", "for epoch in range(num_epochs):\n", " # 训练精确度的和,训练精确度的和中的示例数\n", " # 绝对误差的和,绝对误差的和中的示例数\n", " metric = d2l.Accumulator(4)\n", " net.train()\n", " for features, target in train_iter:\n", " timer.start()\n", " trainer.zero_grad()\n", " X, Y = features.to(device), target.to(device)\n", " # 生成多尺度的锚框,为每个锚框预测类别和偏移量\n", " anchors, cls_preds, bbox_preds = net(X)\n", " # 为每个锚框标注类别和偏移量\n", " bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors, Y)\n", " # 根据类别和偏移量的预测和标注值计算损失函数\n", " l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", " bbox_masks)\n", " l.mean().backward()\n", " trainer.step()\n", " metric.add(cls_eval(cls_preds, cls_labels), cls_labels.numel(),\n", " bbox_eval(bbox_preds, bbox_labels, bbox_masks),\n", " bbox_labels.numel())\n", " cls_err, bbox_mae = 1 - metric[0] / metric[1], metric[2] / metric[3]\n", " animator.add(epoch + 1, (cls_err, bbox_mae))\n", "print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')\n", "print(f'{len(train_iter.dataset) / timer.stop():.1f} examples/sec on '\n", " f'{str(device)}')" ] }, { "cell_type": "markdown", "id": "9e6b06a4", "metadata": { "origin_pos": 66 }, "source": [ "## [**预测目标**]\n", "\n", "在预测阶段,我们希望能把图像里面所有我们感兴趣的目标检测出来。在下面,我们读取并调整测试图像的大小,然后将其转成卷积层需要的四维格式。\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "290676d2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:18:50.674365Z", "iopub.status.busy": "2023-08-18T07:18:50.674082Z", "iopub.status.idle": "2023-08-18T07:18:50.679714Z", "shell.execute_reply": "2023-08-18T07:18:50.678773Z" }, "origin_pos": 68, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "X = torchvision.io.read_image('../img/banana.jpg').unsqueeze(0).float()\n", "img = X.squeeze(0).permute(1, 2, 0).long()" ] }, { "cell_type": "markdown", "id": "fb32477b", "metadata": { "origin_pos": 70 }, "source": [ "使用下面的`multibox_detection`函数,我们可以根据锚框及其预测偏移量得到预测边界框。然后,通过非极大值抑制来移除相似的预测边界框。\n" ] }, { "cell_type": "code", "execution_count": 20, "id": "1da1d7a3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:18:50.683817Z", "iopub.status.busy": "2023-08-18T07:18:50.683541Z", "iopub.status.idle": "2023-08-18T07:18:51.223526Z", "shell.execute_reply": "2023-08-18T07:18:51.222573Z" }, "origin_pos": 72, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def predict(X):\n", " net.eval()\n", " anchors, cls_preds, bbox_preds = net(X.to(device))\n", " cls_probs = F.softmax(cls_preds, dim=2).permute(0, 2, 1)\n", " output = d2l.multibox_detection(cls_probs, bbox_preds, anchors)\n", " idx = [i for i, row in enumerate(output[0]) if row[0] != -1]\n", " return output[0, idx]\n", "\n", "output = predict(X)" ] }, { "cell_type": "markdown", "id": "8f3c20db", "metadata": { "origin_pos": 74 }, "source": [ "最后,我们[**筛选所有置信度不低于0.9的边界框,做为最终输出**]。\n" ] }, { "cell_type": "code", "execution_count": 21, "id": "ddfce357", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:18:51.227738Z", "iopub.status.busy": "2023-08-18T07:18:51.227134Z", "iopub.status.idle": "2023-08-18T07:18:51.460658Z", "shell.execute_reply": "2023-08-18T07:18:51.459830Z" }, "origin_pos": 76, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:18:51.364050\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def display(img, output, threshold):\n", " d2l.set_figsize((5, 5))\n", " fig = d2l.plt.imshow(img)\n", " for row in output:\n", " score = float(row[1])\n", " if score < threshold:\n", " continue\n", " h, w = img.shape[0:2]\n", " bbox = [row[2:6] * torch.tensor((w, h, w, h), device=row.device)]\n", " d2l.show_bboxes(fig.axes, bbox, '%.2f' % score, 'w')\n", "\n", "display(img, output.cpu(), threshold=0.9)" ] }, { "cell_type": "markdown", "id": "133fb05f", "metadata": { "origin_pos": 78 }, "source": [ "## 小结\n", "\n", "* 单发多框检测是一种多尺度目标检测模型。基于基础网络块和各个多尺度特征块,单发多框检测生成不同数量和不同大小的锚框,并通过预测这些锚框的类别和偏移量检测不同大小的目标。\n", "* 在训练单发多框检测模型时,损失函数是根据锚框的类别和偏移量的预测及标注值计算得出的。\n", "\n", "## 练习\n", "\n", "1. 能通过改进损失函数来改进单发多框检测吗?例如,将预测偏移量用到的$L_1$范数损失替换为平滑$L_1$范数损失。它在零点附近使用平方函数从而更加平滑,这是通过一个超参数$\\sigma$来控制平滑区域的:\n", "\n", "$$\n", "f(x) =\n", " \\begin{cases}\n", " (\\sigma x)^2/2,& \\text{if }|x| < 1/\\sigma^2\\\\\n", " |x|-0.5/\\sigma^2,& \\text{otherwise}\n", " \\end{cases}\n", "$$\n", "\n", "当$\\sigma$非常大时,这种损失类似于$L_1$范数损失。当它的值较小时,损失函数较平滑。\n" ] }, { "cell_type": "code", "execution_count": 22, "id": "5ff25a9d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:18:51.466346Z", "iopub.status.busy": "2023-08-18T07:18:51.465756Z", "iopub.status.idle": "2023-08-18T07:18:51.636211Z", "shell.execute_reply": "2023-08-18T07:18:51.635398Z" }, "origin_pos": 80, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:18:51.596999\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def smooth_l1(data, scalar):\n", " out = []\n", " for i in data:\n", " if abs(i) < 1 / (scalar ** 2):\n", " out.append(((scalar * i) ** 2) / 2)\n", " else:\n", " out.append(abs(i) - 0.5 / (scalar ** 2))\n", " return torch.tensor(out)\n", "\n", "sigmas = [10, 1, 0.5]\n", "lines = ['-', '--', '-.']\n", "x = torch.arange(-2, 2, 0.1)\n", "d2l.set_figsize()\n", "\n", "for l, s in zip(lines, sigmas):\n", " y = smooth_l1(x, scalar=s)\n", " d2l.plt.plot(x, y, l, label='sigma=%.1f' % s)\n", "d2l.plt.legend();" ] }, { "cell_type": "markdown", "id": "96f72911", "metadata": { "origin_pos": 82 }, "source": [ "此外,在类别预测时,实验中使用了交叉熵损失:设真实类别$j$的预测概率是$p_j$,交叉熵损失为$-\\log p_j$。我们还可以使用焦点损失 :cite:`Lin.Goyal.Girshick.ea.2017`。给定超参数$\\gamma > 0$和$\\alpha > 0$,此损失的定义为:\n", "\n", "$$ - \\alpha (1-p_j)^{\\gamma} \\log p_j.$$\n", "\n", "可以看到,增大$\\gamma$可以有效地减少正类预测概率较大时(例如$p_j > 0.5$)的相对损失,因此训练可以更集中在那些错误分类的困难示例上。\n" ] }, { "cell_type": "code", "execution_count": 23, "id": "3a5a6fea", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:18:51.639973Z", "iopub.status.busy": "2023-08-18T07:18:51.639373Z", "iopub.status.idle": "2023-08-18T07:18:51.882393Z", "shell.execute_reply": "2023-08-18T07:18:51.881561Z" }, "origin_pos": 84, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:18:51.844535\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def focal_loss(gamma, x):\n", " return -(1 - x) ** gamma * torch.log(x)\n", "\n", "x = torch.arange(0.01, 1, 0.01)\n", "for l, gamma in zip(lines, [0, 1, 5]):\n", " y = d2l.plt.plot(x, focal_loss(gamma, x), l, label='gamma=%.1f' % gamma)\n", "d2l.plt.legend();" ] }, { "cell_type": "markdown", "id": "fd10bf15", "metadata": { "origin_pos": 86 }, "source": [ "2. 由于篇幅限制,我们在本节中省略了单发多框检测模型的一些实现细节。能否从以下几个方面进一步改进模型:\n", " 1. 当目标比图像小得多时,模型可以将输入图像调大;\n", " 1. 通常会存在大量的负锚框。为了使类别分布更加平衡,我们可以将负锚框的高和宽减半;\n", " 1. 在损失函数中,给类别损失和偏移损失设置不同比重的超参数;\n", " 1. 使用其他方法评估目标检测模型,例如单发多框检测论文 :cite:`Liu.Anguelov.Erhan.ea.2016`中的方法。\n" ] }, { "cell_type": "markdown", "id": "03e00148", "metadata": { "origin_pos": 88, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/3204)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }