{
"cells": [
{
"cell_type": "markdown",
"id": "8ae526de",
"metadata": {
"origin_pos": 0
},
"source": [
"# 图像分类数据集\n",
":label:`sec_fashion_mnist`\n",
"\n",
"(**MNIST数据集**) :cite:`LeCun.Bottou.Bengio.ea.1998`\n",
"(**是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。\n",
"我们将使用类似但更复杂的Fashion-MNIST数据集**) :cite:`Xiao.Rasul.Vollgraf.2017`。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "716c9e45",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:32.130767Z",
"iopub.status.busy": "2023-08-18T07:00:32.129861Z",
"iopub.status.idle": "2023-08-18T07:00:34.258162Z",
"shell.execute_reply": "2023-08-18T07:00:34.257055Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import torch\n",
"import torchvision\n",
"from torch.utils import data\n",
"from torchvision import transforms\n",
"from d2l import torch as d2l\n",
"\n",
"d2l.use_svg_display()"
]
},
{
"cell_type": "markdown",
"id": "601c08d4",
"metadata": {
"origin_pos": 5
},
"source": [
"## 读取数据集\n",
"\n",
"我们可以[**通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中**]。\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d8593555",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:34.264466Z",
"iopub.status.busy": "2023-08-18T07:00:34.263710Z",
"iopub.status.idle": "2023-08-18T07:00:34.378988Z",
"shell.execute_reply": "2023-08-18T07:00:34.377831Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,\n",
"# 并除以255使得所有像素的数值均在0~1之间\n",
"trans = transforms.ToTensor()\n",
"mnist_train = torchvision.datasets.FashionMNIST(\n",
" root=\"../data\", train=True, transform=trans, download=True)\n",
"mnist_test = torchvision.datasets.FashionMNIST(\n",
" root=\"../data\", train=False, transform=trans, download=True)"
]
},
{
"cell_type": "markdown",
"id": "3d25caa7",
"metadata": {
"origin_pos": 10
},
"source": [
"Fashion-MNIST由10个类别的图像组成,\n",
"每个类别由*训练数据集*(train dataset)中的6000张图像\n",
"和*测试数据集*(test dataset)中的1000张图像组成。\n",
"因此,训练集和测试集分别包含60000和10000张图像。\n",
"测试数据集不会用于训练,只用于评估模型性能。\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6db7fb8c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:34.384171Z",
"iopub.status.busy": "2023-08-18T07:00:34.383782Z",
"iopub.status.idle": "2023-08-18T07:00:34.391174Z",
"shell.execute_reply": "2023-08-18T07:00:34.390176Z"
},
"origin_pos": 11,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"(60000, 10000)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(mnist_train), len(mnist_test)"
]
},
{
"cell_type": "markdown",
"id": "534d543c",
"metadata": {
"origin_pos": 13
},
"source": [
"每个输入图像的高度和宽度均为28像素。\n",
"数据集由灰度图像组成,其通道数为1。\n",
"为了简洁起见,本书将高度$h$像素、宽度$w$像素图像的形状记为$h \\times w$或($h$,$w$)。\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3c69c2c8",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:34.396338Z",
"iopub.status.busy": "2023-08-18T07:00:34.395813Z",
"iopub.status.idle": "2023-08-18T07:00:34.403276Z",
"shell.execute_reply": "2023-08-18T07:00:34.402307Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 28, 28])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mnist_train[0][0].shape"
]
},
{
"cell_type": "markdown",
"id": "4eb34556",
"metadata": {
"origin_pos": 15
},
"source": [
"[~~两个可视化数据集的函数~~]\n",
"\n",
"Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。\n",
"以下函数用于在数字标签索引及其文本名称之间进行转换。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fe9f8cfe",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:34.407798Z",
"iopub.status.busy": "2023-08-18T07:00:34.407292Z",
"iopub.status.idle": "2023-08-18T07:00:34.413948Z",
"shell.execute_reply": "2023-08-18T07:00:34.412905Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def get_fashion_mnist_labels(labels): #@save\n",
" \"\"\"返回Fashion-MNIST数据集的文本标签\"\"\"\n",
" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n",
" 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n",
" return [text_labels[int(i)] for i in labels]"
]
},
{
"cell_type": "markdown",
"id": "1af6b85c",
"metadata": {
"origin_pos": 17
},
"source": [
"我们现在可以创建一个函数来可视化这些样本。\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "12d8707e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:34.421351Z",
"iopub.status.busy": "2023-08-18T07:00:34.420405Z",
"iopub.status.idle": "2023-08-18T07:00:34.429911Z",
"shell.execute_reply": "2023-08-18T07:00:34.428770Z"
},
"origin_pos": 19,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save\n",
" \"\"\"绘制图像列表\"\"\"\n",
" figsize = (num_cols * scale, num_rows * scale)\n",
" _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)\n",
" axes = axes.flatten()\n",
" for i, (ax, img) in enumerate(zip(axes, imgs)):\n",
" if torch.is_tensor(img):\n",
" # 图片张量\n",
" ax.imshow(img.numpy())\n",
" else:\n",
" # PIL图片\n",
" ax.imshow(img)\n",
" ax.axes.get_xaxis().set_visible(False)\n",
" ax.axes.get_yaxis().set_visible(False)\n",
" if titles:\n",
" ax.set_title(titles[i])\n",
" return axes"
]
},
{
"cell_type": "markdown",
"id": "aea8d92e",
"metadata": {
"origin_pos": 21
},
"source": [
"以下是训练数据集中前[**几个样本的图像及其相应的标签**]。\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e7d37edd",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:34.435295Z",
"iopub.status.busy": "2023-08-18T07:00:34.434562Z",
"iopub.status.idle": "2023-08-18T07:00:35.484726Z",
"shell.execute_reply": "2023-08-18T07:00:35.483779Z"
},
"origin_pos": 23,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))\n",
"show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));"
]
},
{
"cell_type": "markdown",
"id": "8ffe4da3",
"metadata": {
"origin_pos": 26
},
"source": [
"## 读取小批量\n",
"\n",
"为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。\n",
"回顾一下,在每次迭代中,数据加载器每次都会[**读取一小批量数据,大小为`batch_size`**]。\n",
"通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dcf11f71",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:35.493448Z",
"iopub.status.busy": "2023-08-18T07:00:35.492606Z",
"iopub.status.idle": "2023-08-18T07:00:35.498328Z",
"shell.execute_reply": "2023-08-18T07:00:35.497372Z"
},
"origin_pos": 28,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"batch_size = 256\n",
"\n",
"def get_dataloader_workers(): #@save\n",
" \"\"\"使用4个进程来读取数据\"\"\"\n",
" return 4\n",
"\n",
"train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,\n",
" num_workers=get_dataloader_workers())"
]
},
{
"cell_type": "markdown",
"id": "f878f635",
"metadata": {
"origin_pos": 31
},
"source": [
"我们看一下读取训练数据所需的时间。\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8dc12e48",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:35.502439Z",
"iopub.status.busy": "2023-08-18T07:00:35.501591Z",
"iopub.status.idle": "2023-08-18T07:00:38.879964Z",
"shell.execute_reply": "2023-08-18T07:00:38.878822Z"
},
"origin_pos": 32,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"'3.37 sec'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"timer = d2l.Timer()\n",
"for X, y in train_iter:\n",
" continue\n",
"f'{timer.stop():.2f} sec'"
]
},
{
"cell_type": "markdown",
"id": "0bd9a185",
"metadata": {
"origin_pos": 33
},
"source": [
"## 整合所有组件\n",
"\n",
"现在我们[**定义`load_data_fashion_mnist`函数**],用于获取和读取Fashion-MNIST数据集。\n",
"这个函数返回训练集和验证集的数据迭代器。\n",
"此外,这个函数还接受一个可选参数`resize`,用来将图像大小调整为另一种形状。\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "423baf20",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:38.885979Z",
"iopub.status.busy": "2023-08-18T07:00:38.885569Z",
"iopub.status.idle": "2023-08-18T07:00:38.895158Z",
"shell.execute_reply": "2023-08-18T07:00:38.894185Z"
},
"origin_pos": 35,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def load_data_fashion_mnist(batch_size, resize=None): #@save\n",
" \"\"\"下载Fashion-MNIST数据集,然后将其加载到内存中\"\"\"\n",
" trans = [transforms.ToTensor()]\n",
" if resize:\n",
" trans.insert(0, transforms.Resize(resize))\n",
" trans = transforms.Compose(trans)\n",
" mnist_train = torchvision.datasets.FashionMNIST(\n",
" root=\"../data\", train=True, transform=trans, download=True)\n",
" mnist_test = torchvision.datasets.FashionMNIST(\n",
" root=\"../data\", train=False, transform=trans, download=True)\n",
" return (data.DataLoader(mnist_train, batch_size, shuffle=True,\n",
" num_workers=get_dataloader_workers()),\n",
" data.DataLoader(mnist_test, batch_size, shuffle=False,\n",
" num_workers=get_dataloader_workers()))"
]
},
{
"cell_type": "markdown",
"id": "79c2b84b",
"metadata": {
"origin_pos": 38
},
"source": [
"下面,我们通过指定`resize`参数来测试`load_data_fashion_mnist`函数的图像大小调整功能。\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0807e2a3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:38.902559Z",
"iopub.status.busy": "2023-08-18T07:00:38.900441Z",
"iopub.status.idle": "2023-08-18T07:00:39.372670Z",
"shell.execute_reply": "2023-08-18T07:00:39.371373Z"
},
"origin_pos": 39,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64\n"
]
}
],
"source": [
"train_iter, test_iter = load_data_fashion_mnist(32, resize=64)\n",
"for X, y in train_iter:\n",
" print(X.shape, X.dtype, y.shape, y.dtype)\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "f435b06f",
"metadata": {
"origin_pos": 40
},
"source": [
"我们现在已经准备好使用Fashion-MNIST数据集,便于下面的章节调用来评估各种分类算法。\n",
"\n",
"## 小结\n",
"\n",
"* Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成。我们将在后续章节中使用此数据集来评估各种分类算法。\n",
"* 我们将高度$h$像素,宽度$w$像素图像的形状记为$h \\times w$或($h$,$w$)。\n",
"* 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。\n",
"\n",
"## 练习\n",
"\n",
"1. 减少`batch_size`(如减少到1)是否会影响读取性能?\n",
"1. 数据迭代器的性能非常重要。当前的实现足够快吗?探索各种选择来改进它。\n",
"1. 查阅框架的在线API文档。还有哪些其他数据集可用?\n"
]
},
{
"cell_type": "markdown",
"id": "a83d0dc0",
"metadata": {
"origin_pos": 42,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/1787)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}