{ "cells": [ { "cell_type": "markdown", "id": "403799f7", "metadata": { "origin_pos": 0 }, "source": [ "# 锚框\n", ":label:`sec_anchor`\n", "\n", "目标检测算法通常会在输入图像中采样大量的区域,然后判断这些区域中是否包含我们感兴趣的目标,并调整区域边界从而更准确地预测目标的*真实边界框*(ground-truth bounding box)。\n", "不同的模型使用的区域采样方法可能不同。\n", "这里我们介绍其中的一种方法:以每个像素为中心,生成多个缩放比和宽高比(aspect ratio)不同的边界框。\n", "这些边界框被称为*锚框*(anchor box)我们将在 :numref:`sec_ssd`中设计一个基于锚框的目标检测模型。\n", "\n", "首先,让我们修改输出精度,以获得更简洁的输出。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "e079962e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:42.677769Z", "iopub.status.busy": "2023-08-18T07:00:42.676695Z", "iopub.status.idle": "2023-08-18T07:00:45.106116Z", "shell.execute_reply": "2023-08-18T07:00:45.104773Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "from d2l import torch as d2l\n", "\n", "torch.set_printoptions(2) # 精简输出精度" ] }, { "cell_type": "markdown", "id": "dc5b92b8", "metadata": { "origin_pos": 4 }, "source": [ "## 生成多个锚框\n", "\n", "假设输入图像的高度为$h$,宽度为$w$。\n", "我们以图像的每个像素为中心生成不同形状的锚框:*缩放比*为$s\\in (0, 1]$,*宽高比*为$r > 0$。\n", "那么[**锚框的宽度和高度分别是$hs\\sqrt{r}$和$hs/\\sqrt{r}$。**]\n", "请注意,当中心位置给定时,已知宽和高的锚框是确定的。\n", "\n", "要生成多个不同形状的锚框,让我们设置许多缩放比(scale)取值$s_1,\\ldots, s_n$和许多宽高比(aspect ratio)取值$r_1,\\ldots, r_m$。\n", "当使用这些比例和长宽比的所有组合以每个像素为中心时,输入图像将总共有$whnm$个锚框。\n", "尽管这些锚框可能会覆盖所有真实边界框,但计算复杂性很容易过高。\n", "在实践中,(**我们只考虑**)包含$s_1$或$r_1$的(**组合:**)\n", "\n", "(**\n", "$$(s_1, r_1), (s_1, r_2), \\ldots, (s_1, r_m), (s_2, r_1), (s_3, r_1), \\ldots, (s_n, r_1).$$\n", "**)\n", "\n", "也就是说,以同一像素为中心的锚框的数量是$n+m-1$。\n", "对于整个输入图像,将共生成$wh(n+m-1)$个锚框。\n", "\n", "上述生成锚框的方法在下面的`multibox_prior`函数中实现。\n", "我们指定输入图像、尺寸列表和宽高比列表,然后此函数将返回所有的锚框。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "4c5fb635", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.112186Z", "iopub.status.busy": "2023-08-18T07:00:45.111657Z", "iopub.status.idle": "2023-08-18T07:00:45.126939Z", "shell.execute_reply": "2023-08-18T07:00:45.125859Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def multibox_prior(data, sizes, ratios):\n", " \"\"\"生成以每个像素为中心具有不同形状的锚框\"\"\"\n", " in_height, in_width = data.shape[-2:]\n", " device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)\n", " boxes_per_pixel = (num_sizes + num_ratios - 1)\n", " size_tensor = torch.tensor(sizes, device=device)\n", " ratio_tensor = torch.tensor(ratios, device=device)\n", "\n", " # 为了将锚点移动到像素的中心,需要设置偏移量。\n", " # 因为一个像素的高为1且宽为1,我们选择偏移我们的中心0.5\n", " offset_h, offset_w = 0.5, 0.5\n", " steps_h = 1.0 / in_height # 在y轴上缩放步长\n", " steps_w = 1.0 / in_width # 在x轴上缩放步长\n", "\n", " # 生成锚框的所有中心点\n", " center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h\n", " center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w\n", " shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')\n", " shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)\n", "\n", " # 生成“boxes_per_pixel”个高和宽,\n", " # 之后用于创建锚框的四角坐标(xmin,xmax,ymin,ymax)\n", " w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),\n", " sizes[0] * torch.sqrt(ratio_tensor[1:])))\\\n", " * in_height / in_width # 处理矩形输入\n", " h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),\n", " sizes[0] / torch.sqrt(ratio_tensor[1:])))\n", " # 除以2来获得半高和半宽\n", " anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(\n", " in_height * in_width, 1) / 2\n", "\n", " # 每个中心点都将有“boxes_per_pixel”个锚框,\n", " # 所以生成含所有锚框中心的网格,重复了“boxes_per_pixel”次\n", " out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],\n", " dim=1).repeat_interleave(boxes_per_pixel, dim=0)\n", " output = out_grid + anchor_manipulations\n", " return output.unsqueeze(0)" ] }, { "cell_type": "markdown", "id": "8f03ef54", "metadata": { "origin_pos": 8 }, "source": [ "可以看到[**返回的锚框变量`Y`的形状**]是(批量大小,锚框的数量,4)。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "f411d4af", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.131714Z", "iopub.status.busy": "2023-08-18T07:00:45.131003Z", "iopub.status.idle": "2023-08-18T07:00:45.238891Z", "shell.execute_reply": "2023-08-18T07:00:45.237843Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "561 728\n" ] }, { "data": { "text/plain": [ "torch.Size([1, 2042040, 4])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = d2l.plt.imread('../img/catdog.jpg')\n", "h, w = img.shape[:2]\n", "\n", "print(h, w)\n", "X = torch.rand(size=(1, 3, h, w))\n", "Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5])\n", "Y.shape" ] }, { "cell_type": "markdown", "id": "a368b5c5", "metadata": { "origin_pos": 12 }, "source": [ "将锚框变量`Y`的形状更改为(图像高度,图像宽度,以同一像素为中心的锚框的数量,4)后,我们可以获得以指定像素的位置为中心的所有锚框。\n", "在接下来的内容中,我们[**访问以(250,250)为中心的第一个锚框**]。\n", "它有四个元素:锚框左上角的$(x, y)$轴坐标和右下角的$(x, y)$轴坐标。\n", "输出中两个轴的坐标各分别除以了图像的宽度和高度。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "a7b7cfa3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.244522Z", "iopub.status.busy": "2023-08-18T07:00:45.243982Z", "iopub.status.idle": "2023-08-18T07:00:45.252916Z", "shell.execute_reply": "2023-08-18T07:00:45.251985Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([0.06, 0.07, 0.63, 0.82])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "boxes = Y.reshape(h, w, 5, 4)\n", "boxes[250, 250, 0, :]" ] }, { "cell_type": "markdown", "id": "eb941166", "metadata": { "origin_pos": 15 }, "source": [ "为了[**显示以图像中以某个像素为中心的所有锚框**],定义下面的`show_bboxes`函数来在图像上绘制多个边界框。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "b2dc5400", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.257994Z", "iopub.status.busy": "2023-08-18T07:00:45.257297Z", "iopub.status.idle": "2023-08-18T07:00:45.267572Z", "shell.execute_reply": "2023-08-18T07:00:45.266638Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def show_bboxes(axes, bboxes, labels=None, colors=None):\n", " \"\"\"显示所有边界框\"\"\"\n", " def _make_list(obj, default_values=None):\n", " if obj is None:\n", " obj = default_values\n", " elif not isinstance(obj, (list, tuple)):\n", " obj = [obj]\n", " return obj\n", "\n", " labels = _make_list(labels)\n", " colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c'])\n", " for i, bbox in enumerate(bboxes):\n", " color = colors[i % len(colors)]\n", " rect = d2l.bbox_to_rect(bbox.detach().numpy(), color)\n", " axes.add_patch(rect)\n", " if labels and len(labels) > i:\n", " text_color = 'k' if color == 'w' else 'w'\n", " axes.text(rect.xy[0], rect.xy[1], labels[i],\n", " va='center', ha='center', fontsize=9, color=text_color,\n", " bbox=dict(facecolor=color, lw=0))" ] }, { "cell_type": "markdown", "id": "0aefc961", "metadata": { "origin_pos": 17 }, "source": [ "正如从上面代码中所看到的,变量`boxes`中$x$轴和$y$轴的坐标值已分别除以图像的宽度和高度。\n", "绘制锚框时,我们需要恢复它们原始的坐标值。\n", "因此,在下面定义了变量`bbox_scale`。\n", "现在可以绘制出图像中所有以(250,250)为中心的锚框了。\n", "如下所示,缩放比为0.75且宽高比为1的蓝色锚框很好地围绕着图像中的狗。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "c199e557", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.272415Z", "iopub.status.busy": "2023-08-18T07:00:45.271753Z", "iopub.status.idle": "2023-08-18T07:00:45.634073Z", "shell.execute_reply": "2023-08-18T07:00:45.632866Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:45.532795\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.set_figsize()\n", "bbox_scale = torch.tensor((w, h, w, h))\n", "fig = d2l.plt.imshow(img)\n", "show_bboxes(fig.axes, boxes[250, 250, :, :] * bbox_scale,\n", " ['s=0.75, r=1', 's=0.5, r=1', 's=0.25, r=1', 's=0.75, r=2',\n", " 's=0.75, r=0.5'])" ] }, { "cell_type": "markdown", "id": "da87f940", "metadata": { "origin_pos": 19 }, "source": [ "## [**交并比(IoU)**]\n", "\n", "我们刚刚提到某个锚框“较好地”覆盖了图像中的狗。\n", "如果已知目标的真实边界框,那么这里的“好”该如何如何量化呢?\n", "直观地说,可以衡量锚框和真实边界框之间的相似性。\n", "*杰卡德系数*(Jaccard)可以衡量两组之间的相似性。\n", "给定集合$\\mathcal{A}$和$\\mathcal{B}$,他们的杰卡德系数是他们交集的大小除以他们并集的大小:\n", "\n", "$$J(\\mathcal{A},\\mathcal{B}) = \\frac{\\left|\\mathcal{A} \\cap \\mathcal{B}\\right|}{\\left| \\mathcal{A} \\cup \\mathcal{B}\\right|}.$$\n", "\n", "事实上,我们可以将任何边界框的像素区域视为一组像素。通\n", "过这种方式,我们可以通过其像素集的杰卡德系数来测量两个边界框的相似性。\n", "对于两个边界框,它们的杰卡德系数通常称为*交并比*(intersection over union,IoU),即两个边界框相交面积与相并面积之比,如 :numref:`fig_iou`所示。\n", "交并比的取值范围在0和1之间:0表示两个边界框无重合像素,1表示两个边界框完全重合。\n", "\n", "![交并比是两个边界框相交面积与相并面积之比。](../img/iou.svg)\n", ":label:`fig_iou`\n", "\n", "接下来部分将使用交并比来衡量锚框和真实边界框之间、以及不同锚框之间的相似度。\n", "给定两个锚框或边界框的列表,以下`box_iou`函数将在这两个列表中计算它们成对的交并比。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "feab924c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.639984Z", "iopub.status.busy": "2023-08-18T07:00:45.639243Z", "iopub.status.idle": "2023-08-18T07:00:45.649165Z", "shell.execute_reply": "2023-08-18T07:00:45.648025Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def box_iou(boxes1, boxes2):\n", " \"\"\"计算两个锚框或边界框列表中成对的交并比\"\"\"\n", " box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *\n", " (boxes[:, 3] - boxes[:, 1]))\n", " # boxes1,boxes2,areas1,areas2的形状:\n", " # boxes1:(boxes1的数量,4),\n", " # boxes2:(boxes2的数量,4),\n", " # areas1:(boxes1的数量,),\n", " # areas2:(boxes2的数量,)\n", " areas1 = box_area(boxes1)\n", " areas2 = box_area(boxes2)\n", " # inter_upperlefts,inter_lowerrights,inters的形状:\n", " # (boxes1的数量,boxes2的数量,2)\n", " inter_upperlefts = torch.max(boxes1[:, None, :2], boxes2[:, :2])\n", " inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])\n", " inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)\n", " # inter_areasandunion_areas的形状:(boxes1的数量,boxes2的数量)\n", " inter_areas = inters[:, :, 0] * inters[:, :, 1]\n", " union_areas = areas1[:, None] + areas2 - inter_areas\n", " return inter_areas / union_areas" ] }, { "cell_type": "markdown", "id": "aa0f6d1e", "metadata": { "origin_pos": 23 }, "source": [ "## 在训练数据中标注锚框\n", ":label:`subsec_labeling-anchor-boxes`\n", "\n", "在训练集中,我们将每个锚框视为一个训练样本。\n", "为了训练目标检测模型,我们需要每个锚框的*类别*(class)和*偏移量*(offset)标签,其中前者是与锚框相关的对象的类别,后者是真实边界框相对于锚框的偏移量。\n", "在预测时,我们为每个图像生成多个锚框,预测所有锚框的类别和偏移量,根据预测的偏移量调整它们的位置以获得预测的边界框,最后只输出符合特定条件的预测边界框。\n", "\n", "目标检测训练集带有*真实边界框*的位置及其包围物体类别的标签。\n", "要标记任何生成的锚框,我们可以参考分配到的最接近此锚框的真实边界框的位置和类别标签。\n", "下文将介绍一个算法,它能够把最接近的真实边界框分配给锚框。\n", "\n", "### [**将真实边界框分配给锚框**]\n", "\n", "给定图像,假设锚框是$A_1, A_2, \\ldots, A_{n_a}$,真实边界框是$B_1, B_2, \\ldots, B_{n_b}$,其中$n_a \\geq n_b$。\n", "让我们定义一个矩阵$\\mathbf{X} \\in \\mathbb{R}^{n_a \\times n_b}$,其中第$i$行、第$j$列的元素$x_{ij}$是锚框$A_i$和真实边界框$B_j$的IoU。\n", "该算法包含以下步骤。\n", "\n", "1. 在矩阵$\\mathbf{X}$中找到最大的元素,并将它的行索引和列索引分别表示为$i_1$和$j_1$。然后将真实边界框$B_{j_1}$分配给锚框$A_{i_1}$。这很直观,因为$A_{i_1}$和$B_{j_1}$是所有锚框和真实边界框配对中最相近的。在第一个分配完成后,丢弃矩阵中${i_1}^\\mathrm{th}$行和${j_1}^\\mathrm{th}$列中的所有元素。\n", "1. 在矩阵$\\mathbf{X}$中找到剩余元素中最大的元素,并将它的行索引和列索引分别表示为$i_2$和$j_2$。我们将真实边界框$B_{j_2}$分配给锚框$A_{i_2}$,并丢弃矩阵中${i_2}^\\mathrm{th}$行和${j_2}^\\mathrm{th}$列中的所有元素。\n", "1. 此时,矩阵$\\mathbf{X}$中两行和两列中的元素已被丢弃。我们继续,直到丢弃掉矩阵$\\mathbf{X}$中$n_b$列中的所有元素。此时已经为这$n_b$个锚框各自分配了一个真实边界框。\n", "1. 只遍历剩下的$n_a - n_b$个锚框。例如,给定任何锚框$A_i$,在矩阵$\\mathbf{X}$的第$i^\\mathrm{th}$行中找到与$A_i$的IoU最大的真实边界框$B_j$,只有当此IoU大于预定义的阈值时,才将$B_j$分配给$A_i$。\n", "\n", "下面用一个具体的例子来说明上述算法。\n", "如 :numref:`fig_anchor_label`(左)所示,假设矩阵$\\mathbf{X}$中的最大值为$x_{23}$,我们将真实边界框$B_3$分配给锚框$A_2$。\n", "然后,我们丢弃矩阵第2行和第3列中的所有元素,在剩余元素(阴影区域)中找到最大的$x_{71}$,然后将真实边界框$B_1$分配给锚框$A_7$。\n", "接下来,如 :numref:`fig_anchor_label`(中)所示,丢弃矩阵第7行和第1列中的所有元素,在剩余元素(阴影区域)中找到最大的$x_{54}$,然后将真实边界框$B_4$分配给锚框$A_5$。\n", "最后,如 :numref:`fig_anchor_label`(右)所示,丢弃矩阵第5行和第4列中的所有元素,在剩余元素(阴影区域)中找到最大的$x_{92}$,然后将真实边界框$B_2$分配给锚框$A_9$。\n", "之后,我们只需要遍历剩余的锚框$A_1, A_3, A_4, A_6, A_8$,然后根据阈值确定是否为它们分配真实边界框。\n", "\n", "![将真实边界框分配给锚框。](../img/anchor-label.svg)\n", ":label:`fig_anchor_label`\n", "\n", "此算法在下面的`assign_anchor_to_bbox`函数中实现。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "8ac4113a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.654100Z", "iopub.status.busy": "2023-08-18T07:00:45.653303Z", "iopub.status.idle": "2023-08-18T07:00:45.664045Z", "shell.execute_reply": "2023-08-18T07:00:45.663020Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def assign_anchor_to_bbox(ground_truth, anchors, device, iou_threshold=0.5):\n", " \"\"\"将最接近的真实边界框分配给锚框\"\"\"\n", " num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]\n", " # 位于第i行和第j列的元素x_ij是锚框i和真实边界框j的IoU\n", " jaccard = box_iou(anchors, ground_truth)\n", " # 对于每个锚框,分配的真实边界框的张量\n", " anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long,\n", " device=device)\n", " # 根据阈值,决定是否分配真实边界框\n", " max_ious, indices = torch.max(jaccard, dim=1)\n", " anc_i = torch.nonzero(max_ious >= iou_threshold).reshape(-1)\n", " box_j = indices[max_ious >= iou_threshold]\n", " anchors_bbox_map[anc_i] = box_j\n", " col_discard = torch.full((num_anchors,), -1)\n", " row_discard = torch.full((num_gt_boxes,), -1)\n", " for _ in range(num_gt_boxes):\n", " max_idx = torch.argmax(jaccard)\n", " box_idx = (max_idx % num_gt_boxes).long()\n", " anc_idx = (max_idx / num_gt_boxes).long()\n", " anchors_bbox_map[anc_idx] = box_idx\n", " jaccard[:, box_idx] = col_discard\n", " jaccard[anc_idx, :] = row_discard\n", " return anchors_bbox_map" ] }, { "cell_type": "markdown", "id": "f57512db", "metadata": { "origin_pos": 27 }, "source": [ "### 标记类别和偏移量\n", "\n", "现在我们可以为每个锚框标记类别和偏移量了。\n", "假设一个锚框$A$被分配了一个真实边界框$B$。\n", "一方面,锚框$A$的类别将被标记为与$B$相同。\n", "另一方面,锚框$A$的偏移量将根据$B$和$A$中心坐标的相对位置以及这两个框的相对大小进行标记。\n", "鉴于数据集内不同的框的位置和大小不同,我们可以对那些相对位置和大小应用变换,使其获得分布更均匀且易于拟合的偏移量。\n", "这里介绍一种常见的变换。\n", "[**给定框$A$和$B$,中心坐标分别为$(x_a, y_a)$和$(x_b, y_b)$,宽度分别为$w_a$和$w_b$,高度分别为$h_a$和$h_b$,可以将$A$的偏移量标记为:\n", "\n", "$$\\left( \\frac{ \\frac{x_b - x_a}{w_a} - \\mu_x }{\\sigma_x},\n", "\\frac{ \\frac{y_b - y_a}{h_a} - \\mu_y }{\\sigma_y},\n", "\\frac{ \\log \\frac{w_b}{w_a} - \\mu_w }{\\sigma_w},\n", "\\frac{ \\log \\frac{h_b}{h_a} - \\mu_h }{\\sigma_h}\\right),$$\n", "**]\n", "其中常量的默认值为 $\\mu_x = \\mu_y = \\mu_w = \\mu_h = 0, \\sigma_x=\\sigma_y=0.1$ , $\\sigma_w=\\sigma_h=0.2$。\n", "这种转换在下面的 `offset_boxes` 函数中实现。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "a5b6bbdc", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.668909Z", "iopub.status.busy": "2023-08-18T07:00:45.668107Z", "iopub.status.idle": "2023-08-18T07:00:45.675569Z", "shell.execute_reply": "2023-08-18T07:00:45.674509Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def offset_boxes(anchors, assigned_bb, eps=1e-6):\n", " \"\"\"对锚框偏移量的转换\"\"\"\n", " c_anc = d2l.box_corner_to_center(anchors)\n", " c_assigned_bb = d2l.box_corner_to_center(assigned_bb)\n", " offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]\n", " offset_wh = 5 * torch.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])\n", " offset = torch.cat([offset_xy, offset_wh], axis=1)\n", " return offset" ] }, { "cell_type": "markdown", "id": "1e309f7d", "metadata": { "origin_pos": 29 }, "source": [ "如果一个锚框没有被分配真实边界框,我们只需将锚框的类别标记为*背景*(background)。\n", "背景类别的锚框通常被称为*负类*锚框,其余的被称为*正类*锚框。\n", "我们使用真实边界框(`labels`参数)实现以下`multibox_target`函数,来[**标记锚框的类别和偏移量**](`anchors`参数)。\n", "此函数将背景类别的索引设置为零,然后将新类别的整数索引递增一。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "291738a2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.680339Z", "iopub.status.busy": "2023-08-18T07:00:45.679651Z", "iopub.status.idle": "2023-08-18T07:00:45.692690Z", "shell.execute_reply": "2023-08-18T07:00:45.691647Z" }, "origin_pos": 31, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def multibox_target(anchors, labels):\n", " \"\"\"使用真实边界框标记锚框\"\"\"\n", " batch_size, anchors = labels.shape[0], anchors.squeeze(0)\n", " batch_offset, batch_mask, batch_class_labels = [], [], []\n", " device, num_anchors = anchors.device, anchors.shape[0]\n", " for i in range(batch_size):\n", " label = labels[i, :, :]\n", " anchors_bbox_map = assign_anchor_to_bbox(\n", " label[:, 1:], anchors, device)\n", " bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(\n", " 1, 4)\n", " # 将类标签和分配的边界框坐标初始化为零\n", " class_labels = torch.zeros(num_anchors, dtype=torch.long,\n", " device=device)\n", " assigned_bb = torch.zeros((num_anchors, 4), dtype=torch.float32,\n", " device=device)\n", " # 使用真实边界框来标记锚框的类别。\n", " # 如果一个锚框没有被分配,标记其为背景(值为零)\n", " indices_true = torch.nonzero(anchors_bbox_map >= 0)\n", " bb_idx = anchors_bbox_map[indices_true]\n", " class_labels[indices_true] = label[bb_idx, 0].long() + 1\n", " assigned_bb[indices_true] = label[bb_idx, 1:]\n", " # 偏移量转换\n", " offset = offset_boxes(anchors, assigned_bb) * bbox_mask\n", " batch_offset.append(offset.reshape(-1))\n", " batch_mask.append(bbox_mask.reshape(-1))\n", " batch_class_labels.append(class_labels)\n", " bbox_offset = torch.stack(batch_offset)\n", " bbox_mask = torch.stack(batch_mask)\n", " class_labels = torch.stack(batch_class_labels)\n", " return (bbox_offset, bbox_mask, class_labels)" ] }, { "cell_type": "markdown", "id": "a78ead98", "metadata": { "origin_pos": 33 }, "source": [ "### 一个例子\n", "\n", "下面通过一个具体的例子来说明锚框标签。\n", "我们已经为加载图像中的狗和猫定义了真实边界框,其中第一个元素是类别(0代表狗,1代表猫),其余四个元素是左上角和右下角的$(x, y)$轴坐标(范围介于0和1之间)。\n", "我们还构建了五个锚框,用左上角和右下角的坐标进行标记:$A_0, \\ldots, A_4$(索引从0开始)。\n", "然后我们[**在图像中绘制这些真实边界框和锚框**]。\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "e22f46a6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.697286Z", "iopub.status.busy": "2023-08-18T07:00:45.696846Z", "iopub.status.idle": "2023-08-18T07:00:46.054321Z", "shell.execute_reply": "2023-08-18T07:00:46.053210Z" }, "origin_pos": 34, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:45.953248\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ground_truth = torch.tensor([[0, 0.1, 0.08, 0.52, 0.92],\n", " [1, 0.55, 0.2, 0.9, 0.88]])\n", "anchors = torch.tensor([[0, 0.1, 0.2, 0.3], [0.15, 0.2, 0.4, 0.4],\n", " [0.63, 0.05, 0.88, 0.98], [0.66, 0.45, 0.8, 0.8],\n", " [0.57, 0.3, 0.92, 0.9]])\n", "\n", "fig = d2l.plt.imshow(img)\n", "show_bboxes(fig.axes, ground_truth[:, 1:] * bbox_scale, ['dog', 'cat'], 'k')\n", "show_bboxes(fig.axes, anchors * bbox_scale, ['0', '1', '2', '3', '4']);" ] }, { "cell_type": "markdown", "id": "414c8154", "metadata": { "origin_pos": 35 }, "source": [ "使用上面定义的`multibox_target`函数,我们可以[**根据狗和猫的真实边界框,标注这些锚框的分类和偏移量**]。\n", "在这个例子中,背景、狗和猫的类索引分别为0、1和2。\n", "下面我们为锚框和真实边界框样本添加一个维度。\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "511fc2cf", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.059898Z", "iopub.status.busy": "2023-08-18T07:00:46.059146Z", "iopub.status.idle": "2023-08-18T07:00:46.067556Z", "shell.execute_reply": "2023-08-18T07:00:46.066157Z" }, "origin_pos": 37, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "labels = multibox_target(anchors.unsqueeze(dim=0),\n", " ground_truth.unsqueeze(dim=0))" ] }, { "cell_type": "markdown", "id": "5c2b454d", "metadata": { "origin_pos": 39 }, "source": [ "返回的结果中有三个元素,都是张量格式。第三个元素包含标记的输入锚框的类别。\n", "\n", "让我们根据图像中的锚框和真实边界框的位置来分析下面返回的类别标签。\n", "首先,在所有的锚框和真实边界框配对中,锚框$A_4$与猫的真实边界框的IoU是最大的。\n", "因此,$A_4$的类别被标记为猫。\n", "去除包含$A_4$或猫的真实边界框的配对,在剩下的配对中,锚框$A_1$和狗的真实边界框有最大的IoU。\n", "因此,$A_1$的类别被标记为狗。\n", "接下来,我们需要遍历剩下的三个未标记的锚框:$A_0$、$A_2$和$A_3$。\n", "对于$A_0$,与其拥有最大IoU的真实边界框的类别是狗,但IoU低于预定义的阈值(0.5),因此该类别被标记为背景;\n", "对于$A_2$,与其拥有最大IoU的真实边界框的类别是猫,IoU超过阈值,所以类别被标记为猫;\n", "对于$A_3$,与其拥有最大IoU的真实边界框的类别是猫,但值低于阈值,因此该类别被标记为背景。\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "78a2e091", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.075902Z", "iopub.status.busy": "2023-08-18T07:00:46.075000Z", "iopub.status.idle": "2023-08-18T07:00:46.084256Z", "shell.execute_reply": "2023-08-18T07:00:46.083114Z" }, "origin_pos": 40, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[0, 1, 2, 0, 2]])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels[2]" ] }, { "cell_type": "markdown", "id": "cffb366b", "metadata": { "origin_pos": 41 }, "source": [ "返回的第二个元素是掩码(mask)变量,形状为(批量大小,锚框数的四倍)。\n", "掩码变量中的元素与每个锚框的4个偏移量一一对应。\n", "由于我们不关心对背景的检测,负类的偏移量不应影响目标函数。\n", "通过元素乘法,掩码变量中的零将在计算目标函数之前过滤掉负类偏移量。\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "94a1597f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.097601Z", "iopub.status.busy": "2023-08-18T07:00:46.097228Z", "iopub.status.idle": "2023-08-18T07:00:46.104883Z", "shell.execute_reply": "2023-08-18T07:00:46.103744Z" }, "origin_pos": 42, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.,\n", " 1., 1.]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels[1]" ] }, { "cell_type": "markdown", "id": "1cfc7bb3", "metadata": { "origin_pos": 43 }, "source": [ "返回的第一个元素包含了为每个锚框标记的四个偏移值。\n", "请注意,负类锚框的偏移量被标记为零。\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "25e7f69b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.112301Z", "iopub.status.busy": "2023-08-18T07:00:46.111934Z", "iopub.status.idle": "2023-08-18T07:00:46.118802Z", "shell.execute_reply": "2023-08-18T07:00:46.117910Z" }, "origin_pos": 44, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, 1.40e+00, 1.00e+01,\n", " 2.59e+00, 7.18e+00, -1.20e+00, 2.69e-01, 1.68e+00, -1.57e+00,\n", " -0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, -5.71e-01, -1.00e+00,\n", " 4.17e-06, 6.26e-01]])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels[0]" ] }, { "cell_type": "markdown", "id": "ada70fbf", "metadata": { "origin_pos": 45 }, "source": [ "## 使用非极大值抑制预测边界框\n", ":label:`subsec_predicting-bounding-boxes-nms`\n", "\n", "在预测时,我们先为图像生成多个锚框,再为这些锚框一一预测类别和偏移量。\n", "一个*预测好的边界框*则根据其中某个带有预测偏移量的锚框而生成。\n", "下面我们实现了`offset_inverse`函数,该函数将锚框和偏移量预测作为输入,并[**应用逆偏移变换来返回预测的边界框坐标**]。\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "227ae1a8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.127370Z", "iopub.status.busy": "2023-08-18T07:00:46.127014Z", "iopub.status.idle": "2023-08-18T07:00:46.133873Z", "shell.execute_reply": "2023-08-18T07:00:46.132925Z" }, "origin_pos": 46, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def offset_inverse(anchors, offset_preds):\n", " \"\"\"根据带有预测偏移量的锚框来预测边界框\"\"\"\n", " anc = d2l.box_corner_to_center(anchors)\n", " pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]\n", " pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]\n", " pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)\n", " predicted_bbox = d2l.box_center_to_corner(pred_bbox)\n", " return predicted_bbox" ] }, { "cell_type": "markdown", "id": "c8a5c307", "metadata": { "origin_pos": 47 }, "source": [ "当有许多锚框时,可能会输出许多相似的具有明显重叠的预测边界框,都围绕着同一目标。\n", "为了简化输出,我们可以使用*非极大值抑制*(non-maximum suppression,NMS)合并属于同一目标的类似的预测边界框。\n", "\n", "以下是非极大值抑制的工作原理。\n", "对于一个预测边界框$B$,目标检测模型会计算每个类别的预测概率。\n", "假设最大的预测概率为$p$,则该概率所对应的类别$B$即为预测的类别。\n", "具体来说,我们将$p$称为预测边界框$B$的*置信度*(confidence)。\n", "在同一张图像中,所有预测的非背景边界框都按置信度降序排序,以生成列表$L$。然后我们通过以下步骤操作排序列表$L$。\n", "\n", "1. 从$L$中选取置信度最高的预测边界框$B_1$作为基准,然后将所有与$B_1$的IoU超过预定阈值$\\epsilon$的非基准预测边界框从$L$中移除。这时,$L$保留了置信度最高的预测边界框,去除了与其太过相似的其他预测边界框。简而言之,那些具有*非极大值*置信度的边界框被*抑制*了。\n", "1. 从$L$中选取置信度第二高的预测边界框$B_2$作为又一个基准,然后将所有与$B_2$的IoU大于$\\epsilon$的非基准预测边界框从$L$中移除。\n", "1. 重复上述过程,直到$L$中的所有预测边界框都曾被用作基准。此时,$L$中任意一对预测边界框的IoU都小于阈值$\\epsilon$;因此,没有一对边界框过于相似。\n", "1. 输出列表$L$中的所有预测边界框。\n", "\n", "[**以下`nms`函数按降序对置信度进行排序并返回其索引**]。\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "ac5c4e3c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.138652Z", "iopub.status.busy": "2023-08-18T07:00:46.138060Z", "iopub.status.idle": "2023-08-18T07:00:46.151360Z", "shell.execute_reply": "2023-08-18T07:00:46.150447Z" }, "origin_pos": 49, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def nms(boxes, scores, iou_threshold):\n", " \"\"\"对预测边界框的置信度进行排序\"\"\"\n", " B = torch.argsort(scores, dim=-1, descending=True)\n", " keep = [] # 保留预测边界框的指标\n", " while B.numel() > 0:\n", " i = B[0]\n", " keep.append(i)\n", " if B.numel() == 1: break\n", " iou = box_iou(boxes[i, :].reshape(-1, 4),\n", " boxes[B[1:], :].reshape(-1, 4)).reshape(-1)\n", " inds = torch.nonzero(iou <= iou_threshold).reshape(-1)\n", " B = B[inds + 1]\n", " return torch.tensor(keep, device=boxes.device)" ] }, { "cell_type": "markdown", "id": "a21e7812", "metadata": { "origin_pos": 51 }, "source": [ "我们定义以下`multibox_detection`函数来[**将非极大值抑制应用于预测边界框**]。\n", "这里的实现有点复杂,请不要担心。我们将在实现之后,马上用一个具体的例子来展示它是如何工作的。\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "baa9f34f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.157117Z", "iopub.status.busy": "2023-08-18T07:00:46.156701Z", "iopub.status.idle": "2023-08-18T07:00:46.175104Z", "shell.execute_reply": "2023-08-18T07:00:46.174174Z" }, "origin_pos": 53, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,\n", " pos_threshold=0.009999999):\n", " \"\"\"使用非极大值抑制来预测边界框\"\"\"\n", " device, batch_size = cls_probs.device, cls_probs.shape[0]\n", " anchors = anchors.squeeze(0)\n", " num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]\n", " out = []\n", " for i in range(batch_size):\n", " cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)\n", " conf, class_id = torch.max(cls_prob[1:], 0)\n", " predicted_bb = offset_inverse(anchors, offset_pred)\n", " keep = nms(predicted_bb, conf, nms_threshold)\n", "\n", " # 找到所有的non_keep索引,并将类设置为背景\n", " all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)\n", " combined = torch.cat((keep, all_idx))\n", " uniques, counts = combined.unique(return_counts=True)\n", " non_keep = uniques[counts == 1]\n", " all_id_sorted = torch.cat((keep, non_keep))\n", " class_id[non_keep] = -1\n", " class_id = class_id[all_id_sorted]\n", " conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]\n", " # pos_threshold是一个用于非背景预测的阈值\n", " below_min_idx = (conf < pos_threshold)\n", " class_id[below_min_idx] = -1\n", " conf[below_min_idx] = 1 - conf[below_min_idx]\n", " pred_info = torch.cat((class_id.unsqueeze(1),\n", " conf.unsqueeze(1),\n", " predicted_bb), dim=1)\n", " out.append(pred_info)\n", " return torch.stack(out)" ] }, { "cell_type": "markdown", "id": "3952b7dd", "metadata": { "origin_pos": 55 }, "source": [ "现在让我们[**将上述算法应用到一个带有四个锚框的具体示例中**]。\n", "为简单起见,我们假设预测的偏移量都是零,这意味着预测的边界框即是锚框。\n", "对于背景、狗和猫其中的每个类,我们还定义了它的预测概率。\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "4654e2f7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.179666Z", "iopub.status.busy": "2023-08-18T07:00:46.179298Z", "iopub.status.idle": "2023-08-18T07:00:46.188634Z", "shell.execute_reply": "2023-08-18T07:00:46.187737Z" }, "origin_pos": 56, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "anchors = torch.tensor([[0.1, 0.08, 0.52, 0.92], [0.08, 0.2, 0.56, 0.95],\n", " [0.15, 0.3, 0.62, 0.91], [0.55, 0.2, 0.9, 0.88]])\n", "offset_preds = torch.tensor([0] * anchors.numel())\n", "cls_probs = torch.tensor([[0] * 4, # 背景的预测概率\n", " [0.9, 0.8, 0.7, 0.1], # 狗的预测概率\n", " [0.1, 0.2, 0.3, 0.9]]) # 猫的预测概率" ] }, { "cell_type": "markdown", "id": "96915571", "metadata": { "origin_pos": 58 }, "source": [ "我们可以[**在图像上绘制这些预测边界框和置信度**]。\n" ] }, { "cell_type": "code", "execution_count": 20, "id": "b9619cba", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.193833Z", "iopub.status.busy": "2023-08-18T07:00:46.193198Z", "iopub.status.idle": "2023-08-18T07:00:46.435994Z", "shell.execute_reply": "2023-08-18T07:00:46.434923Z" }, "origin_pos": 59, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:46.369699\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = d2l.plt.imshow(img)\n", "show_bboxes(fig.axes, anchors * bbox_scale,\n", " ['dog=0.9', 'dog=0.8', 'dog=0.7', 'cat=0.9'])" ] }, { "cell_type": "markdown", "id": "25b63f89", "metadata": { "origin_pos": 60 }, "source": [ "现在我们可以调用`multibox_detection`函数来执行非极大值抑制,其中阈值设置为0.5。\n", "请注意,我们在示例的张量输入中添加了维度。\n", "\n", "我们可以看到[**返回结果的形状是(批量大小,锚框的数量,6)**]。\n", "最内层维度中的六个元素提供了同一预测边界框的输出信息。\n", "第一个元素是预测的类索引,从0开始(0代表狗,1代表猫),值-1表示背景或在非极大值抑制中被移除了。\n", "第二个元素是预测的边界框的置信度。\n", "其余四个元素分别是预测边界框左上角和右下角的$(x, y)$轴坐标(范围介于0和1之间)。\n" ] }, { "cell_type": "code", "execution_count": 21, "id": "ab9c180c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.440560Z", "iopub.status.busy": "2023-08-18T07:00:46.439738Z", "iopub.status.idle": "2023-08-18T07:00:46.458973Z", "shell.execute_reply": "2023-08-18T07:00:46.457996Z" }, "origin_pos": 62, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.00, 0.90, 0.10, 0.08, 0.52, 0.92],\n", " [ 1.00, 0.90, 0.55, 0.20, 0.90, 0.88],\n", " [-1.00, 0.80, 0.08, 0.20, 0.56, 0.95],\n", " [-1.00, 0.70, 0.15, 0.30, 0.62, 0.91]]])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = multibox_detection(cls_probs.unsqueeze(dim=0),\n", " offset_preds.unsqueeze(dim=0),\n", " anchors.unsqueeze(dim=0),\n", " nms_threshold=0.5)\n", "output" ] }, { "cell_type": "markdown", "id": "79cb8b53", "metadata": { "origin_pos": 64 }, "source": [ "删除-1类别(背景)的预测边界框后,我们可以[**输出由非极大值抑制保存的最终预测边界框**]。\n" ] }, { "cell_type": "code", "execution_count": 22, "id": "f1e04b3f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.496750Z", "iopub.status.busy": "2023-08-18T07:00:46.495866Z", "iopub.status.idle": "2023-08-18T07:00:46.753536Z", "shell.execute_reply": "2023-08-18T07:00:46.752302Z" }, "origin_pos": 65, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:46.686810\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = d2l.plt.imshow(img)\n", "for i in output[0].detach().numpy():\n", " if i[0] == -1:\n", " continue\n", " label = ('dog=', 'cat=')[int(i[0])] + str(i[1])\n", " show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)" ] }, { "cell_type": "markdown", "id": "222c6994", "metadata": { "origin_pos": 66 }, "source": [ "实践中,在执行非极大值抑制前,我们甚至可以将置信度较低的预测边界框移除,从而减少此算法中的计算量。\n", "我们也可以对非极大值抑制的输出结果进行后处理。例如,只保留置信度更高的结果作为最终输出。\n", "\n", "## 小结\n", "\n", "* 我们以图像的每个像素为中心生成不同形状的锚框。\n", "* 交并比(IoU)也被称为杰卡德系数,用于衡量两个边界框的相似性。它是相交面积与相并面积的比率。\n", "* 在训练集中,我们需要给每个锚框两种类型的标签。一个是与锚框中目标检测的类别,另一个是锚框真实相对于边界框的偏移量。\n", "* 预测期间可以使用非极大值抑制(NMS)来移除类似的预测边界框,从而简化输出。\n", "\n", "## 练习\n", "\n", "1. 在`multibox_prior`函数中更改`sizes`和`ratios`的值。生成的锚框有什么变化?\n", "1. 构建并可视化两个IoU为0.5的边界框。它们是怎样重叠的?\n", "1. 在 :numref:`subsec_labeling-anchor-boxes`和 :numref:`subsec_predicting-bounding-boxes-nms`中修改变量`anchors`,结果如何变化?\n", "1. 非极大值抑制是一种贪心算法,它通过*移除*来抑制预测的边界框。是否存在一种可能,被移除的一些框实际上是有用的?如何修改这个算法来柔和地抑制?可以参考Soft-NMS :cite:`Bodla.Singh.Chellappa.ea.2017`。\n", "1. 如果非手动,非最大限度的抑制可以被学习吗?\n" ] }, { "cell_type": "markdown", "id": "e48b2d07", "metadata": { "origin_pos": 68, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/2946)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }