{ "cells": [ { "cell_type": "markdown", "id": "fec74aec", "metadata": { "origin_pos": 0 }, "source": [ "# 实战 Kaggle 比赛:图像分类 (CIFAR-10)\n", ":label:`sec_kaggle_cifar10`\n", "\n", "之前几节中,我们一直在使用深度学习框架的高级API直接获取张量格式的图像数据集。\n", "但是在实践中,图像数据集通常以图像文件的形式出现。\n", "本节将从原始图像文件开始,然后逐步组织、读取并将它们转换为张量格式。\n", "\n", "我们在 :numref:`sec_image_augmentation`中对CIFAR-10数据集做了一个实验。CIFAR-10是计算机视觉领域中的一个重要的数据集。\n", "本节将运用我们在前几节中学到的知识来参加CIFAR-10图像分类问题的Kaggle竞赛,(**比赛的网址是https://www.kaggle.com/c/cifar-10**)。\n", "\n", " :numref:`fig_kaggle_cifar10`显示了竞赛网站页面上的信息。\n", "为了能提交结果,首先需要注册一个Kaggle账户。\n", "\n", "![CIFAR-10 图像分类竞赛页面上的信息。竞赛用的数据集可通过点击“Data”选项卡获取。](../img/kaggle-cifar10.png)\n", ":width:`600px`\n", ":label:`fig_kaggle_cifar10`\n", "\n", "首先,导入竞赛所需的包和模块。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "f7b4fa3c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:12.123905Z", "iopub.status.busy": "2023-08-18T07:02:12.123323Z", "iopub.status.idle": "2023-08-18T07:02:14.203247Z", "shell.execute_reply": "2023-08-18T07:02:14.202358Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import collections\n", "import math\n", "import os\n", "import shutil\n", "import pandas as pd\n", "import torch\n", "import torchvision\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "11f6b9d7", "metadata": { "origin_pos": 4 }, "source": [ "## 获取并组织数据集\n", "\n", "比赛数据集分为训练集和测试集,其中训练集包含50000张、测试集包含300000张图像。\n", "在测试集中,10000张图像将被用于评估,而剩下的290000张图像将不会被进行评估,包含它们只是为了防止手动标记测试集并提交标记结果。\n", "两个数据集中的图像都是png格式,高度和宽度均为32像素并有三个颜色通道(RGB)。\n", "这些图片共涵盖10个类别:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。\n", " :numref:`fig_kaggle_cifar10`的左上角显示了数据集中飞机、汽车和鸟类的一些图像。\n", "\n", "### 下载数据集\n", "\n", "登录Kaggle后,我们可以点击 :numref:`fig_kaggle_cifar10`中显示的CIFAR-10图像分类竞赛网页上的“Data”选项卡,然后单击“Download All”按钮下载数据集。\n", "在`../data`中解压下载的文件并在其中解压缩`train.7z`和`test.7z`后,在以下路径中可以找到整个数据集:\n", "\n", "* `../data/cifar-10/train/[1-50000].png`\n", "* `../data/cifar-10/test/[1-300000].png`\n", "* `../data/cifar-10/trainLabels.csv`\n", "* `../data/cifar-10/sampleSubmission.csv`\n", "\n", "`train`和`test`文件夹分别包含训练和测试图像,`trainLabels.csv`含有训练图像的标签,\n", "`sample_submission.csv`是提交文件的范例。\n", "\n", "为了便于入门,[**我们提供包含前1000个训练图像和5个随机测试图像的数据集的小规模样本**]。\n", "要使用Kaggle竞赛的完整数据集,需要将以下`demo`变量设置为`False`。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "7ae59ae9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.207394Z", "iopub.status.busy": "2023-08-18T07:02:14.207019Z", "iopub.status.idle": "2023-08-18T07:02:14.623037Z", "shell.execute_reply": "2023-08-18T07:02:14.622201Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading ../data/kaggle_cifar10_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_cifar10_tiny.zip...\n" ] } ], "source": [ "#@save\n", "d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',\n", " '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')\n", "\n", "# 如果使用完整的Kaggle竞赛的数据集,设置demo为False\n", "demo = True\n", "\n", "if demo:\n", " data_dir = d2l.download_extract('cifar10_tiny')\n", "else:\n", " data_dir = '../data/cifar-10/'" ] }, { "cell_type": "markdown", "id": "56ef995f", "metadata": { "origin_pos": 6 }, "source": [ "### [**整理数据集**]\n", "\n", "我们需要整理数据集来训练和测试模型。\n", "首先,我们用以下函数读取CSV文件中的标签,它返回一个字典,该字典将文件名中不带扩展名的部分映射到其标签。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "24b4fdfb", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.627215Z", "iopub.status.busy": "2023-08-18T07:02:14.626653Z", "iopub.status.idle": "2023-08-18T07:02:14.634237Z", "shell.execute_reply": "2023-08-18T07:02:14.633299Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# 训练样本 : 1000\n", "# 类别 : 10\n" ] } ], "source": [ "#@save\n", "def read_csv_labels(fname):\n", " \"\"\"读取fname来给标签字典返回一个文件名\"\"\"\n", " with open(fname, 'r') as f:\n", " # 跳过文件头行(列名)\n", " lines = f.readlines()[1:]\n", " tokens = [l.rstrip().split(',') for l in lines]\n", " return dict(((name, label) for name, label in tokens))\n", "\n", "labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n", "print('# 训练样本 :', len(labels))\n", "print('# 类别 :', len(set(labels.values())))" ] }, { "cell_type": "markdown", "id": "3359e04e", "metadata": { "origin_pos": 8 }, "source": [ "接下来,我们定义`reorg_train_valid`函数来[**将验证集从原始的训练集中拆分出来**]。\n", "此函数中的参数`valid_ratio`是验证集中的样本数与原始训练集中的样本数之比。\n", "更具体地说,令$n$等于样本最少的类别中的图像数量,而$r$是比率。\n", "验证集将为每个类别拆分出$\\max(\\lfloor nr\\rfloor,1)$张图像。\n", "让我们以`valid_ratio=0.1`为例,由于原始的训练集有50000张图像,因此`train_valid_test/train`路径中将有45000张图像用于训练,而剩下5000张图像将作为路径`train_valid_test/valid`中的验证集。\n", "组织数据集后,同类别的图像将被放置在同一文件夹下。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "fbfbac4f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.637882Z", "iopub.status.busy": "2023-08-18T07:02:14.637608Z", "iopub.status.idle": "2023-08-18T07:02:14.644987Z", "shell.execute_reply": "2023-08-18T07:02:14.644218Z" }, "origin_pos": 9, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def copyfile(filename, target_dir):\n", " \"\"\"将文件复制到目标目录\"\"\"\n", " os.makedirs(target_dir, exist_ok=True)\n", " shutil.copy(filename, target_dir)\n", "\n", "#@save\n", "def reorg_train_valid(data_dir, labels, valid_ratio):\n", " \"\"\"将验证集从原始的训练集中拆分出来\"\"\"\n", " # 训练数据集中样本最少的类别中的样本数\n", " n = collections.Counter(labels.values()).most_common()[-1][1]\n", " # 验证集中每个类别的样本数\n", " n_valid_per_label = max(1, math.floor(n * valid_ratio))\n", " label_count = {}\n", " for train_file in os.listdir(os.path.join(data_dir, 'train')):\n", " label = labels[train_file.split('.')[0]]\n", " fname = os.path.join(data_dir, 'train', train_file)\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'train_valid', label))\n", " if label not in label_count or label_count[label] < n_valid_per_label:\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'valid', label))\n", " label_count[label] = label_count.get(label, 0) + 1\n", " else:\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'train', label))\n", " return n_valid_per_label" ] }, { "cell_type": "markdown", "id": "4d89ac6f", "metadata": { "origin_pos": 10 }, "source": [ "下面的`reorg_test`函数用来[**在预测期间整理测试集,以方便读取**]。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "9ad6d005", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.648388Z", "iopub.status.busy": "2023-08-18T07:02:14.648058Z", "iopub.status.idle": "2023-08-18T07:02:14.653344Z", "shell.execute_reply": "2023-08-18T07:02:14.652542Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def reorg_test(data_dir):\n", " \"\"\"在预测期间整理测试集,以方便读取\"\"\"\n", " for test_file in os.listdir(os.path.join(data_dir, 'test')):\n", " copyfile(os.path.join(data_dir, 'test', test_file),\n", " os.path.join(data_dir, 'train_valid_test', 'test',\n", " 'unknown'))" ] }, { "cell_type": "markdown", "id": "da9e6d5a", "metadata": { "origin_pos": 12 }, "source": [ "最后,我们使用一个函数来[**调用前面定义的函数**]`read_csv_labels`、`reorg_train_valid`和`reorg_test`。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "37a42208", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.657456Z", "iopub.status.busy": "2023-08-18T07:02:14.656944Z", "iopub.status.idle": "2023-08-18T07:02:14.661187Z", "shell.execute_reply": "2023-08-18T07:02:14.660400Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def reorg_cifar10_data(data_dir, valid_ratio):\n", " labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n", " reorg_train_valid(data_dir, labels, valid_ratio)\n", " reorg_test(data_dir)" ] }, { "cell_type": "markdown", "id": "b62b33fa", "metadata": { "origin_pos": 14 }, "source": [ "在这里,我们只将样本数据集的批量大小设置为32。\n", "在实际训练和测试中,应该使用Kaggle竞赛的完整数据集,并将`batch_size`设置为更大的整数,例如128。\n", "我们将10%的训练样本作为调整超参数的验证集。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "868eb62b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.664878Z", "iopub.status.busy": "2023-08-18T07:02:14.664240Z", "iopub.status.idle": "2023-08-18T07:02:14.931508Z", "shell.execute_reply": "2023-08-18T07:02:14.930669Z" }, "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size = 32 if demo else 128\n", "valid_ratio = 0.1\n", "reorg_cifar10_data(data_dir, valid_ratio)" ] }, { "cell_type": "markdown", "id": "31458131", "metadata": { "origin_pos": 16 }, "source": [ "## [**图像增广**]\n", "\n", "我们使用图像增广来解决过拟合的问题。例如在训练中,我们可以随机水平翻转图像。\n", "我们还可以对彩色图像的三个RGB通道执行标准化。\n", "下面,我们列出了其中一些可以调整的操作。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "300ef249", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.935568Z", "iopub.status.busy": "2023-08-18T07:02:14.934993Z", "iopub.status.idle": "2023-08-18T07:02:14.940662Z", "shell.execute_reply": "2023-08-18T07:02:14.939875Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "transform_train = torchvision.transforms.Compose([\n", " # 在高度和宽度上将图像放大到40像素的正方形\n", " torchvision.transforms.Resize(40),\n", " # 随机裁剪出一个高度和宽度均为40像素的正方形图像,\n", " # 生成一个面积为原始图像面积0.64~1倍的小正方形,\n", " # 然后将其缩放为高度和宽度均为32像素的正方形\n", " torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),\n", " ratio=(1.0, 1.0)),\n", " torchvision.transforms.RandomHorizontalFlip(),\n", " torchvision.transforms.ToTensor(),\n", " # 标准化图像的每个通道\n", " torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n", " [0.2023, 0.1994, 0.2010])])" ] }, { "cell_type": "markdown", "id": "694c31d1", "metadata": { "origin_pos": 20 }, "source": [ "在测试期间,我们只对图像执行标准化,以消除评估结果中的随机性。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "6bd19592", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.945514Z", "iopub.status.busy": "2023-08-18T07:02:14.944858Z", "iopub.status.idle": "2023-08-18T07:02:14.949240Z", "shell.execute_reply": "2023-08-18T07:02:14.948438Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "transform_test = torchvision.transforms.Compose([\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n", " [0.2023, 0.1994, 0.2010])])" ] }, { "cell_type": "markdown", "id": "c06d3e36", "metadata": { "origin_pos": 24 }, "source": [ "## 读取数据集\n", "\n", "接下来,我们[**读取由原始图像组成的数据集**],每个样本都包括一张图片和一个标签。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "c1815173", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.952392Z", "iopub.status.busy": "2023-08-18T07:02:14.952109Z", "iopub.status.idle": "2023-08-18T07:02:14.966044Z", "shell.execute_reply": "2023-08-18T07:02:14.965283Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(\n", " os.path.join(data_dir, 'train_valid_test', folder),\n", " transform=transform_train) for folder in ['train', 'train_valid']]\n", "\n", "valid_ds, test_ds = [torchvision.datasets.ImageFolder(\n", " os.path.join(data_dir, 'train_valid_test', folder),\n", " transform=transform_test) for folder in ['valid', 'test']]" ] }, { "cell_type": "markdown", "id": "f22934d1", "metadata": { "origin_pos": 28 }, "source": [ "在训练期间,我们需要[**指定上面定义的所有图像增广操作**]。\n", "当验证集在超参数调整过程中用于模型评估时,不应引入图像增广的随机性。\n", "在最终预测之前,我们根据训练集和验证集组合而成的训练模型进行训练,以充分利用所有标记的数据。\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "d9528fff", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.969294Z", "iopub.status.busy": "2023-08-18T07:02:14.968975Z", "iopub.status.idle": "2023-08-18T07:02:14.974247Z", "shell.execute_reply": "2023-08-18T07:02:14.973489Z" }, "origin_pos": 30, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "train_iter, train_valid_iter = [torch.utils.data.DataLoader(\n", " dataset, batch_size, shuffle=True, drop_last=True)\n", " for dataset in (train_ds, train_valid_ds)]\n", "\n", "valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,\n", " drop_last=True)\n", "\n", "test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,\n", " drop_last=False)" ] }, { "cell_type": "markdown", "id": "a40cd7e3", "metadata": { "origin_pos": 32 }, "source": [ "## 定义[**模型**]\n" ] }, { "cell_type": "markdown", "id": "8a4e0897", "metadata": { "origin_pos": 38, "tab": [ "pytorch" ] }, "source": [ "我们定义了 :numref:`sec_resnet`中描述的Resnet-18模型。\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "c35c8dca", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.977846Z", "iopub.status.busy": "2023-08-18T07:02:14.977285Z", "iopub.status.idle": "2023-08-18T07:02:14.981584Z", "shell.execute_reply": "2023-08-18T07:02:14.980732Z" }, "origin_pos": 41, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def get_net():\n", " num_classes = 10\n", " net = d2l.resnet18(num_classes, 3)\n", " return net\n", "\n", "loss = nn.CrossEntropyLoss(reduction=\"none\")" ] }, { "cell_type": "markdown", "id": "e38ab14a", "metadata": { "origin_pos": 43 }, "source": [ "## 定义[**训练函数**]\n", "\n", "我们将根据模型在验证集上的表现来选择模型并调整超参数。\n", "下面我们定义了模型训练函数`train`。\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "e082315c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.984795Z", "iopub.status.busy": "2023-08-18T07:02:14.984512Z", "iopub.status.idle": "2023-08-18T07:02:14.994288Z", "shell.execute_reply": "2023-08-18T07:02:14.993512Z" }, "origin_pos": 45, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n", " lr_decay):\n", " trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,\n", " weight_decay=wd)\n", " scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)\n", " num_batches, timer = len(train_iter), d2l.Timer()\n", " legend = ['train loss', 'train acc']\n", " if valid_iter is not None:\n", " legend.append('valid acc')\n", " animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", " legend=legend)\n", " net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n", " for epoch in range(num_epochs):\n", " net.train()\n", " metric = d2l.Accumulator(3)\n", " for i, (features, labels) in enumerate(train_iter):\n", " timer.start()\n", " l, acc = d2l.train_batch_ch13(net, features, labels,\n", " loss, trainer, devices)\n", " metric.add(l, acc, labels.shape[0])\n", " timer.stop()\n", " if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n", " animator.add(epoch + (i + 1) / num_batches,\n", " (metric[0] / metric[2], metric[1] / metric[2],\n", " None))\n", " if valid_iter is not None:\n", " valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)\n", " animator.add(epoch + 1, (None, None, valid_acc))\n", " scheduler.step()\n", " measures = (f'train loss {metric[0] / metric[2]:.3f}, '\n", " f'train acc {metric[1] / metric[2]:.3f}')\n", " if valid_iter is not None:\n", " measures += f', valid acc {valid_acc:.3f}'\n", " print(measures + f'\\n{metric[2] * num_epochs / timer.sum():.1f}'\n", " f' examples/sec on {str(devices)}')" ] }, { "cell_type": "markdown", "id": "41c3007d", "metadata": { "origin_pos": 47 }, "source": [ "## [**训练和验证模型**]\n", "\n", "现在,我们可以训练和验证模型了,而以下所有超参数都可以调整。\n", "例如,我们可以增加周期的数量。当`lr_period`和`lr_decay`分别设置为4和0.9时,优化算法的学习速率将在每4个周期乘以0.9。\n", "为便于演示,我们在这里只训练20个周期。\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "267a469f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:02:14.997959Z", "iopub.status.busy": "2023-08-18T07:02:14.997526Z", "iopub.status.idle": "2023-08-18T07:03:21.092598Z", "shell.execute_reply": "2023-08-18T07:03:21.091331Z" }, "origin_pos": 49, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.668, train acc 0.761, valid acc 0.406\n", "758.4 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:03:21.035440\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4\n", "lr_period, lr_decay, net = 4, 0.9, get_net()\n", "train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n", " lr_decay)" ] }, { "cell_type": "markdown", "id": "7ebad777", "metadata": { "origin_pos": 51 }, "source": [ "## 在 Kaggle 上[**对测试集进行分类并提交结果**]\n", "\n", "在获得具有超参数的满意的模型后,我们使用所有标记的数据(包括验证集)来重新训练模型并对测试集进行分类。\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "92b85006", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:03:21.098390Z", "iopub.status.busy": "2023-08-18T07:03:21.097395Z", "iopub.status.idle": "2023-08-18T07:04:21.878943Z", "shell.execute_reply": "2023-08-18T07:04:21.878089Z" }, "origin_pos": 53, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.745, train acc 0.734\n", "883.3 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:04:21.843396\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "net, preds = get_net(), []\n", "train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,\n", " lr_decay)\n", "\n", "for X, _ in test_iter:\n", " y_hat = net(X.to(devices[0]))\n", " preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())\n", "sorted_ids = list(range(1, len(test_ds) + 1))\n", "sorted_ids.sort(key=lambda x: str(x))\n", "df = pd.DataFrame({'id': sorted_ids, 'label': preds})\n", "df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])\n", "df.to_csv('submission.csv', index=False)" ] }, { "cell_type": "markdown", "id": "b6c0bf49", "metadata": { "origin_pos": 55 }, "source": [ "向Kaggle提交结果的方法与 :numref:`sec_kaggle_house`中的方法类似,上面的代码将生成一个\n", "`submission.csv`文件,其格式符合Kaggle竞赛的要求。\n", "\n", "## 小结\n", "\n", "* 将包含原始图像文件的数据集组织为所需格式后,我们可以读取它们。\n" ] }, { "cell_type": "markdown", "id": "c9aff1dc", "metadata": { "origin_pos": 57, "tab": [ "pytorch" ] }, "source": [ "* 我们可以在图像分类竞赛中使用卷积神经网络和图像增广。\n" ] }, { "cell_type": "markdown", "id": "725588ee", "metadata": { "origin_pos": 59 }, "source": [ "## 练习\n", "\n", "1. 在这场Kaggle竞赛中使用完整的CIFAR-10数据集。将超参数设为`batch_size = 128`,`num_epochs = 100`,`lr = 0.1`,`lr_period = 50`,`lr_decay = 0.1`。看看在这场比赛中能达到什么准确度和排名。能进一步改进吗?\n", "1. 不使用图像增广时,能获得怎样的准确度?\n" ] }, { "cell_type": "markdown", "id": "ae7f6d9e", "metadata": { "origin_pos": 61, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/2831)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }