{ "cells": [ { "cell_type": "markdown", "id": "8ae526de", "metadata": { "origin_pos": 0 }, "source": [ "# 图像分类数据集\n", ":label:`sec_fashion_mnist`\n", "\n", "(**MNIST数据集**) :cite:`LeCun.Bottou.Bengio.ea.1998`\n", "(**是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。\n", "我们将使用类似但更复杂的Fashion-MNIST数据集**) :cite:`Xiao.Rasul.Vollgraf.2017`。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "716c9e45", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:32.130767Z", "iopub.status.busy": "2023-08-18T07:00:32.129861Z", "iopub.status.idle": "2023-08-18T07:00:34.258162Z", "shell.execute_reply": "2023-08-18T07:00:34.257055Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "import torchvision\n", "from torch.utils import data\n", "from torchvision import transforms\n", "from d2l import torch as d2l\n", "\n", "d2l.use_svg_display()" ] }, { "cell_type": "markdown", "id": "601c08d4", "metadata": { "origin_pos": 5 }, "source": [ "## 读取数据集\n", "\n", "我们可以[**通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中**]。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "d8593555", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:34.264466Z", "iopub.status.busy": "2023-08-18T07:00:34.263710Z", "iopub.status.idle": "2023-08-18T07:00:34.378988Z", "shell.execute_reply": "2023-08-18T07:00:34.377831Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,\n", "# 并除以255使得所有像素的数值均在0~1之间\n", "trans = transforms.ToTensor()\n", "mnist_train = torchvision.datasets.FashionMNIST(\n", " root=\"../data\", train=True, transform=trans, download=True)\n", "mnist_test = torchvision.datasets.FashionMNIST(\n", " root=\"../data\", train=False, transform=trans, download=True)" ] }, { "cell_type": "markdown", "id": "3d25caa7", "metadata": { "origin_pos": 10 }, "source": [ "Fashion-MNIST由10个类别的图像组成,\n", "每个类别由*训练数据集*(train dataset)中的6000张图像\n", "和*测试数据集*(test dataset)中的1000张图像组成。\n", "因此,训练集和测试集分别包含60000和10000张图像。\n", "测试数据集不会用于训练,只用于评估模型性能。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "6db7fb8c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:34.384171Z", "iopub.status.busy": "2023-08-18T07:00:34.383782Z", "iopub.status.idle": "2023-08-18T07:00:34.391174Z", "shell.execute_reply": "2023-08-18T07:00:34.390176Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(60000, 10000)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(mnist_train), len(mnist_test)" ] }, { "cell_type": "markdown", "id": "534d543c", "metadata": { "origin_pos": 13 }, "source": [ "每个输入图像的高度和宽度均为28像素。\n", "数据集由灰度图像组成,其通道数为1。\n", "为了简洁起见,本书将高度$h$像素、宽度$w$像素图像的形状记为$h \\times w$或($h$,$w$)。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "3c69c2c8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:34.396338Z", "iopub.status.busy": "2023-08-18T07:00:34.395813Z", "iopub.status.idle": "2023-08-18T07:00:34.403276Z", "shell.execute_reply": "2023-08-18T07:00:34.402307Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 28, 28])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mnist_train[0][0].shape" ] }, { "cell_type": "markdown", "id": "4eb34556", "metadata": { "origin_pos": 15 }, "source": [ "[~~两个可视化数据集的函数~~]\n", "\n", "Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。\n", "以下函数用于在数字标签索引及其文本名称之间进行转换。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "fe9f8cfe", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:34.407798Z", "iopub.status.busy": "2023-08-18T07:00:34.407292Z", "iopub.status.idle": "2023-08-18T07:00:34.413948Z", "shell.execute_reply": "2023-08-18T07:00:34.412905Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def get_fashion_mnist_labels(labels): #@save\n", " \"\"\"返回Fashion-MNIST数据集的文本标签\"\"\"\n", " text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n", " 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n", " return [text_labels[int(i)] for i in labels]" ] }, { "cell_type": "markdown", "id": "1af6b85c", "metadata": { "origin_pos": 17 }, "source": [ "我们现在可以创建一个函数来可视化这些样本。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "12d8707e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:34.421351Z", "iopub.status.busy": "2023-08-18T07:00:34.420405Z", "iopub.status.idle": "2023-08-18T07:00:34.429911Z", "shell.execute_reply": "2023-08-18T07:00:34.428770Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save\n", " \"\"\"绘制图像列表\"\"\"\n", " figsize = (num_cols * scale, num_rows * scale)\n", " _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)\n", " axes = axes.flatten()\n", " for i, (ax, img) in enumerate(zip(axes, imgs)):\n", " if torch.is_tensor(img):\n", " # 图片张量\n", " ax.imshow(img.numpy())\n", " else:\n", " # PIL图片\n", " ax.imshow(img)\n", " ax.axes.get_xaxis().set_visible(False)\n", " ax.axes.get_yaxis().set_visible(False)\n", " if titles:\n", " ax.set_title(titles[i])\n", " return axes" ] }, { "cell_type": "markdown", "id": "aea8d92e", "metadata": { "origin_pos": 21 }, "source": [ "以下是训练数据集中前[**几个样本的图像及其相应的标签**]。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "e7d37edd", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:34.435295Z", "iopub.status.busy": "2023-08-18T07:00:34.434562Z", "iopub.status.idle": "2023-08-18T07:00:35.484726Z", "shell.execute_reply": "2023-08-18T07:00:35.483779Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:35.357570\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))\n", "show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));" ] }, { "cell_type": "markdown", "id": "8ffe4da3", "metadata": { "origin_pos": 26 }, "source": [ "## 读取小批量\n", "\n", "为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。\n", "回顾一下,在每次迭代中,数据加载器每次都会[**读取一小批量数据,大小为`batch_size`**]。\n", "通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "dcf11f71", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:35.493448Z", "iopub.status.busy": "2023-08-18T07:00:35.492606Z", "iopub.status.idle": "2023-08-18T07:00:35.498328Z", "shell.execute_reply": "2023-08-18T07:00:35.497372Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size = 256\n", "\n", "def get_dataloader_workers(): #@save\n", " \"\"\"使用4个进程来读取数据\"\"\"\n", " return 4\n", "\n", "train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,\n", " num_workers=get_dataloader_workers())" ] }, { "cell_type": "markdown", "id": "f878f635", "metadata": { "origin_pos": 31 }, "source": [ "我们看一下读取训练数据所需的时间。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "8dc12e48", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:35.502439Z", "iopub.status.busy": "2023-08-18T07:00:35.501591Z", "iopub.status.idle": "2023-08-18T07:00:38.879964Z", "shell.execute_reply": "2023-08-18T07:00:38.878822Z" }, "origin_pos": 32, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'3.37 sec'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "timer = d2l.Timer()\n", "for X, y in train_iter:\n", " continue\n", "f'{timer.stop():.2f} sec'" ] }, { "cell_type": "markdown", "id": "0bd9a185", "metadata": { "origin_pos": 33 }, "source": [ "## 整合所有组件\n", "\n", "现在我们[**定义`load_data_fashion_mnist`函数**],用于获取和读取Fashion-MNIST数据集。\n", "这个函数返回训练集和验证集的数据迭代器。\n", "此外,这个函数还接受一个可选参数`resize`,用来将图像大小调整为另一种形状。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "423baf20", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:38.885979Z", "iopub.status.busy": "2023-08-18T07:00:38.885569Z", "iopub.status.idle": "2023-08-18T07:00:38.895158Z", "shell.execute_reply": "2023-08-18T07:00:38.894185Z" }, "origin_pos": 35, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def load_data_fashion_mnist(batch_size, resize=None): #@save\n", " \"\"\"下载Fashion-MNIST数据集,然后将其加载到内存中\"\"\"\n", " trans = [transforms.ToTensor()]\n", " if resize:\n", " trans.insert(0, transforms.Resize(resize))\n", " trans = transforms.Compose(trans)\n", " mnist_train = torchvision.datasets.FashionMNIST(\n", " root=\"../data\", train=True, transform=trans, download=True)\n", " mnist_test = torchvision.datasets.FashionMNIST(\n", " root=\"../data\", train=False, transform=trans, download=True)\n", " return (data.DataLoader(mnist_train, batch_size, shuffle=True,\n", " num_workers=get_dataloader_workers()),\n", " data.DataLoader(mnist_test, batch_size, shuffle=False,\n", " num_workers=get_dataloader_workers()))" ] }, { "cell_type": "markdown", "id": "79c2b84b", "metadata": { "origin_pos": 38 }, "source": [ "下面,我们通过指定`resize`参数来测试`load_data_fashion_mnist`函数的图像大小调整功能。\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "0807e2a3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:38.902559Z", "iopub.status.busy": "2023-08-18T07:00:38.900441Z", "iopub.status.idle": "2023-08-18T07:00:39.372670Z", "shell.execute_reply": "2023-08-18T07:00:39.371373Z" }, "origin_pos": 39, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64\n" ] } ], "source": [ "train_iter, test_iter = load_data_fashion_mnist(32, resize=64)\n", "for X, y in train_iter:\n", " print(X.shape, X.dtype, y.shape, y.dtype)\n", " break" ] }, { "cell_type": "markdown", "id": "f435b06f", "metadata": { "origin_pos": 40 }, "source": [ "我们现在已经准备好使用Fashion-MNIST数据集,便于下面的章节调用来评估各种分类算法。\n", "\n", "## 小结\n", "\n", "* Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成。我们将在后续章节中使用此数据集来评估各种分类算法。\n", "* 我们将高度$h$像素,宽度$w$像素图像的形状记为$h \\times w$或($h$,$w$)。\n", "* 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。\n", "\n", "## 练习\n", "\n", "1. 减少`batch_size`(如减少到1)是否会影响读取性能?\n", "1. 数据迭代器的性能非常重要。当前的实现足够快吗?探索各种选择来改进它。\n", "1. 查阅框架的在线API文档。还有哪些其他数据集可用?\n" ] }, { "cell_type": "markdown", "id": "a83d0dc0", "metadata": { "origin_pos": 42, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1787)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }