{
"cells": [
{
"cell_type": "markdown",
"id": "fec74aec",
"metadata": {
"origin_pos": 0
},
"source": [
"# 实战 Kaggle 比赛:图像分类 (CIFAR-10)\n",
":label:`sec_kaggle_cifar10`\n",
"\n",
"之前几节中,我们一直在使用深度学习框架的高级API直接获取张量格式的图像数据集。\n",
"但是在实践中,图像数据集通常以图像文件的形式出现。\n",
"本节将从原始图像文件开始,然后逐步组织、读取并将它们转换为张量格式。\n",
"\n",
"我们在 :numref:`sec_image_augmentation`中对CIFAR-10数据集做了一个实验。CIFAR-10是计算机视觉领域中的一个重要的数据集。\n",
"本节将运用我们在前几节中学到的知识来参加CIFAR-10图像分类问题的Kaggle竞赛,(**比赛的网址是https://www.kaggle.com/c/cifar-10**)。\n",
"\n",
" :numref:`fig_kaggle_cifar10`显示了竞赛网站页面上的信息。\n",
"为了能提交结果,首先需要注册一个Kaggle账户。\n",
"\n",
"\n",
":width:`600px`\n",
":label:`fig_kaggle_cifar10`\n",
"\n",
"首先,导入竞赛所需的包和模块。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f7b4fa3c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:12.123905Z",
"iopub.status.busy": "2023-08-18T07:02:12.123323Z",
"iopub.status.idle": "2023-08-18T07:02:14.203247Z",
"shell.execute_reply": "2023-08-18T07:02:14.202358Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import collections\n",
"import math\n",
"import os\n",
"import shutil\n",
"import pandas as pd\n",
"import torch\n",
"import torchvision\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "11f6b9d7",
"metadata": {
"origin_pos": 4
},
"source": [
"## 获取并组织数据集\n",
"\n",
"比赛数据集分为训练集和测试集,其中训练集包含50000张、测试集包含300000张图像。\n",
"在测试集中,10000张图像将被用于评估,而剩下的290000张图像将不会被进行评估,包含它们只是为了防止手动标记测试集并提交标记结果。\n",
"两个数据集中的图像都是png格式,高度和宽度均为32像素并有三个颜色通道(RGB)。\n",
"这些图片共涵盖10个类别:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。\n",
" :numref:`fig_kaggle_cifar10`的左上角显示了数据集中飞机、汽车和鸟类的一些图像。\n",
"\n",
"### 下载数据集\n",
"\n",
"登录Kaggle后,我们可以点击 :numref:`fig_kaggle_cifar10`中显示的CIFAR-10图像分类竞赛网页上的“Data”选项卡,然后单击“Download All”按钮下载数据集。\n",
"在`../data`中解压下载的文件并在其中解压缩`train.7z`和`test.7z`后,在以下路径中可以找到整个数据集:\n",
"\n",
"* `../data/cifar-10/train/[1-50000].png`\n",
"* `../data/cifar-10/test/[1-300000].png`\n",
"* `../data/cifar-10/trainLabels.csv`\n",
"* `../data/cifar-10/sampleSubmission.csv`\n",
"\n",
"`train`和`test`文件夹分别包含训练和测试图像,`trainLabels.csv`含有训练图像的标签,\n",
"`sample_submission.csv`是提交文件的范例。\n",
"\n",
"为了便于入门,[**我们提供包含前1000个训练图像和5个随机测试图像的数据集的小规模样本**]。\n",
"要使用Kaggle竞赛的完整数据集,需要将以下`demo`变量设置为`False`。\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7ae59ae9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.207394Z",
"iopub.status.busy": "2023-08-18T07:02:14.207019Z",
"iopub.status.idle": "2023-08-18T07:02:14.623037Z",
"shell.execute_reply": "2023-08-18T07:02:14.622201Z"
},
"origin_pos": 5,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading ../data/kaggle_cifar10_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_cifar10_tiny.zip...\n"
]
}
],
"source": [
"#@save\n",
"d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',\n",
" '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')\n",
"\n",
"# 如果使用完整的Kaggle竞赛的数据集,设置demo为False\n",
"demo = True\n",
"\n",
"if demo:\n",
" data_dir = d2l.download_extract('cifar10_tiny')\n",
"else:\n",
" data_dir = '../data/cifar-10/'"
]
},
{
"cell_type": "markdown",
"id": "56ef995f",
"metadata": {
"origin_pos": 6
},
"source": [
"### [**整理数据集**]\n",
"\n",
"我们需要整理数据集来训练和测试模型。\n",
"首先,我们用以下函数读取CSV文件中的标签,它返回一个字典,该字典将文件名中不带扩展名的部分映射到其标签。\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "24b4fdfb",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.627215Z",
"iopub.status.busy": "2023-08-18T07:02:14.626653Z",
"iopub.status.idle": "2023-08-18T07:02:14.634237Z",
"shell.execute_reply": "2023-08-18T07:02:14.633299Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# 训练样本 : 1000\n",
"# 类别 : 10\n"
]
}
],
"source": [
"#@save\n",
"def read_csv_labels(fname):\n",
" \"\"\"读取fname来给标签字典返回一个文件名\"\"\"\n",
" with open(fname, 'r') as f:\n",
" # 跳过文件头行(列名)\n",
" lines = f.readlines()[1:]\n",
" tokens = [l.rstrip().split(',') for l in lines]\n",
" return dict(((name, label) for name, label in tokens))\n",
"\n",
"labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n",
"print('# 训练样本 :', len(labels))\n",
"print('# 类别 :', len(set(labels.values())))"
]
},
{
"cell_type": "markdown",
"id": "3359e04e",
"metadata": {
"origin_pos": 8
},
"source": [
"接下来,我们定义`reorg_train_valid`函数来[**将验证集从原始的训练集中拆分出来**]。\n",
"此函数中的参数`valid_ratio`是验证集中的样本数与原始训练集中的样本数之比。\n",
"更具体地说,令$n$等于样本最少的类别中的图像数量,而$r$是比率。\n",
"验证集将为每个类别拆分出$\\max(\\lfloor nr\\rfloor,1)$张图像。\n",
"让我们以`valid_ratio=0.1`为例,由于原始的训练集有50000张图像,因此`train_valid_test/train`路径中将有45000张图像用于训练,而剩下5000张图像将作为路径`train_valid_test/valid`中的验证集。\n",
"组织数据集后,同类别的图像将被放置在同一文件夹下。\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fbfbac4f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.637882Z",
"iopub.status.busy": "2023-08-18T07:02:14.637608Z",
"iopub.status.idle": "2023-08-18T07:02:14.644987Z",
"shell.execute_reply": "2023-08-18T07:02:14.644218Z"
},
"origin_pos": 9,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"#@save\n",
"def copyfile(filename, target_dir):\n",
" \"\"\"将文件复制到目标目录\"\"\"\n",
" os.makedirs(target_dir, exist_ok=True)\n",
" shutil.copy(filename, target_dir)\n",
"\n",
"#@save\n",
"def reorg_train_valid(data_dir, labels, valid_ratio):\n",
" \"\"\"将验证集从原始的训练集中拆分出来\"\"\"\n",
" # 训练数据集中样本最少的类别中的样本数\n",
" n = collections.Counter(labels.values()).most_common()[-1][1]\n",
" # 验证集中每个类别的样本数\n",
" n_valid_per_label = max(1, math.floor(n * valid_ratio))\n",
" label_count = {}\n",
" for train_file in os.listdir(os.path.join(data_dir, 'train')):\n",
" label = labels[train_file.split('.')[0]]\n",
" fname = os.path.join(data_dir, 'train', train_file)\n",
" copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n",
" 'train_valid', label))\n",
" if label not in label_count or label_count[label] < n_valid_per_label:\n",
" copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n",
" 'valid', label))\n",
" label_count[label] = label_count.get(label, 0) + 1\n",
" else:\n",
" copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n",
" 'train', label))\n",
" return n_valid_per_label"
]
},
{
"cell_type": "markdown",
"id": "4d89ac6f",
"metadata": {
"origin_pos": 10
},
"source": [
"下面的`reorg_test`函数用来[**在预测期间整理测试集,以方便读取**]。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9ad6d005",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.648388Z",
"iopub.status.busy": "2023-08-18T07:02:14.648058Z",
"iopub.status.idle": "2023-08-18T07:02:14.653344Z",
"shell.execute_reply": "2023-08-18T07:02:14.652542Z"
},
"origin_pos": 11,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"#@save\n",
"def reorg_test(data_dir):\n",
" \"\"\"在预测期间整理测试集,以方便读取\"\"\"\n",
" for test_file in os.listdir(os.path.join(data_dir, 'test')):\n",
" copyfile(os.path.join(data_dir, 'test', test_file),\n",
" os.path.join(data_dir, 'train_valid_test', 'test',\n",
" 'unknown'))"
]
},
{
"cell_type": "markdown",
"id": "da9e6d5a",
"metadata": {
"origin_pos": 12
},
"source": [
"最后,我们使用一个函数来[**调用前面定义的函数**]`read_csv_labels`、`reorg_train_valid`和`reorg_test`。\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "37a42208",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.657456Z",
"iopub.status.busy": "2023-08-18T07:02:14.656944Z",
"iopub.status.idle": "2023-08-18T07:02:14.661187Z",
"shell.execute_reply": "2023-08-18T07:02:14.660400Z"
},
"origin_pos": 13,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def reorg_cifar10_data(data_dir, valid_ratio):\n",
" labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n",
" reorg_train_valid(data_dir, labels, valid_ratio)\n",
" reorg_test(data_dir)"
]
},
{
"cell_type": "markdown",
"id": "b62b33fa",
"metadata": {
"origin_pos": 14
},
"source": [
"在这里,我们只将样本数据集的批量大小设置为32。\n",
"在实际训练和测试中,应该使用Kaggle竞赛的完整数据集,并将`batch_size`设置为更大的整数,例如128。\n",
"我们将10%的训练样本作为调整超参数的验证集。\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "868eb62b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.664878Z",
"iopub.status.busy": "2023-08-18T07:02:14.664240Z",
"iopub.status.idle": "2023-08-18T07:02:14.931508Z",
"shell.execute_reply": "2023-08-18T07:02:14.930669Z"
},
"origin_pos": 15,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"batch_size = 32 if demo else 128\n",
"valid_ratio = 0.1\n",
"reorg_cifar10_data(data_dir, valid_ratio)"
]
},
{
"cell_type": "markdown",
"id": "31458131",
"metadata": {
"origin_pos": 16
},
"source": [
"## [**图像增广**]\n",
"\n",
"我们使用图像增广来解决过拟合的问题。例如在训练中,我们可以随机水平翻转图像。\n",
"我们还可以对彩色图像的三个RGB通道执行标准化。\n",
"下面,我们列出了其中一些可以调整的操作。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "300ef249",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.935568Z",
"iopub.status.busy": "2023-08-18T07:02:14.934993Z",
"iopub.status.idle": "2023-08-18T07:02:14.940662Z",
"shell.execute_reply": "2023-08-18T07:02:14.939875Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"transform_train = torchvision.transforms.Compose([\n",
" # 在高度和宽度上将图像放大到40像素的正方形\n",
" torchvision.transforms.Resize(40),\n",
" # 随机裁剪出一个高度和宽度均为40像素的正方形图像,\n",
" # 生成一个面积为原始图像面积0.64~1倍的小正方形,\n",
" # 然后将其缩放为高度和宽度均为32像素的正方形\n",
" torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),\n",
" ratio=(1.0, 1.0)),\n",
" torchvision.transforms.RandomHorizontalFlip(),\n",
" torchvision.transforms.ToTensor(),\n",
" # 标准化图像的每个通道\n",
" torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n",
" [0.2023, 0.1994, 0.2010])])"
]
},
{
"cell_type": "markdown",
"id": "694c31d1",
"metadata": {
"origin_pos": 20
},
"source": [
"在测试期间,我们只对图像执行标准化,以消除评估结果中的随机性。\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "6bd19592",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.945514Z",
"iopub.status.busy": "2023-08-18T07:02:14.944858Z",
"iopub.status.idle": "2023-08-18T07:02:14.949240Z",
"shell.execute_reply": "2023-08-18T07:02:14.948438Z"
},
"origin_pos": 22,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"transform_test = torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n",
" [0.2023, 0.1994, 0.2010])])"
]
},
{
"cell_type": "markdown",
"id": "c06d3e36",
"metadata": {
"origin_pos": 24
},
"source": [
"## 读取数据集\n",
"\n",
"接下来,我们[**读取由原始图像组成的数据集**],每个样本都包括一张图片和一个标签。\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c1815173",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.952392Z",
"iopub.status.busy": "2023-08-18T07:02:14.952109Z",
"iopub.status.idle": "2023-08-18T07:02:14.966044Z",
"shell.execute_reply": "2023-08-18T07:02:14.965283Z"
},
"origin_pos": 26,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(\n",
" os.path.join(data_dir, 'train_valid_test', folder),\n",
" transform=transform_train) for folder in ['train', 'train_valid']]\n",
"\n",
"valid_ds, test_ds = [torchvision.datasets.ImageFolder(\n",
" os.path.join(data_dir, 'train_valid_test', folder),\n",
" transform=transform_test) for folder in ['valid', 'test']]"
]
},
{
"cell_type": "markdown",
"id": "f22934d1",
"metadata": {
"origin_pos": 28
},
"source": [
"在训练期间,我们需要[**指定上面定义的所有图像增广操作**]。\n",
"当验证集在超参数调整过程中用于模型评估时,不应引入图像增广的随机性。\n",
"在最终预测之前,我们根据训练集和验证集组合而成的训练模型进行训练,以充分利用所有标记的数据。\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d9528fff",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.969294Z",
"iopub.status.busy": "2023-08-18T07:02:14.968975Z",
"iopub.status.idle": "2023-08-18T07:02:14.974247Z",
"shell.execute_reply": "2023-08-18T07:02:14.973489Z"
},
"origin_pos": 30,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"train_iter, train_valid_iter = [torch.utils.data.DataLoader(\n",
" dataset, batch_size, shuffle=True, drop_last=True)\n",
" for dataset in (train_ds, train_valid_ds)]\n",
"\n",
"valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,\n",
" drop_last=True)\n",
"\n",
"test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,\n",
" drop_last=False)"
]
},
{
"cell_type": "markdown",
"id": "a40cd7e3",
"metadata": {
"origin_pos": 32
},
"source": [
"## 定义[**模型**]\n"
]
},
{
"cell_type": "markdown",
"id": "8a4e0897",
"metadata": {
"origin_pos": 38,
"tab": [
"pytorch"
]
},
"source": [
"我们定义了 :numref:`sec_resnet`中描述的Resnet-18模型。\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c35c8dca",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.977846Z",
"iopub.status.busy": "2023-08-18T07:02:14.977285Z",
"iopub.status.idle": "2023-08-18T07:02:14.981584Z",
"shell.execute_reply": "2023-08-18T07:02:14.980732Z"
},
"origin_pos": 41,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def get_net():\n",
" num_classes = 10\n",
" net = d2l.resnet18(num_classes, 3)\n",
" return net\n",
"\n",
"loss = nn.CrossEntropyLoss(reduction=\"none\")"
]
},
{
"cell_type": "markdown",
"id": "e38ab14a",
"metadata": {
"origin_pos": 43
},
"source": [
"## 定义[**训练函数**]\n",
"\n",
"我们将根据模型在验证集上的表现来选择模型并调整超参数。\n",
"下面我们定义了模型训练函数`train`。\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e082315c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.984795Z",
"iopub.status.busy": "2023-08-18T07:02:14.984512Z",
"iopub.status.idle": "2023-08-18T07:02:14.994288Z",
"shell.execute_reply": "2023-08-18T07:02:14.993512Z"
},
"origin_pos": 45,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n",
" lr_decay):\n",
" trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,\n",
" weight_decay=wd)\n",
" scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)\n",
" num_batches, timer = len(train_iter), d2l.Timer()\n",
" legend = ['train loss', 'train acc']\n",
" if valid_iter is not None:\n",
" legend.append('valid acc')\n",
" animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n",
" legend=legend)\n",
" net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n",
" for epoch in range(num_epochs):\n",
" net.train()\n",
" metric = d2l.Accumulator(3)\n",
" for i, (features, labels) in enumerate(train_iter):\n",
" timer.start()\n",
" l, acc = d2l.train_batch_ch13(net, features, labels,\n",
" loss, trainer, devices)\n",
" metric.add(l, acc, labels.shape[0])\n",
" timer.stop()\n",
" if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n",
" animator.add(epoch + (i + 1) / num_batches,\n",
" (metric[0] / metric[2], metric[1] / metric[2],\n",
" None))\n",
" if valid_iter is not None:\n",
" valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)\n",
" animator.add(epoch + 1, (None, None, valid_acc))\n",
" scheduler.step()\n",
" measures = (f'train loss {metric[0] / metric[2]:.3f}, '\n",
" f'train acc {metric[1] / metric[2]:.3f}')\n",
" if valid_iter is not None:\n",
" measures += f', valid acc {valid_acc:.3f}'\n",
" print(measures + f'\\n{metric[2] * num_epochs / timer.sum():.1f}'\n",
" f' examples/sec on {str(devices)}')"
]
},
{
"cell_type": "markdown",
"id": "41c3007d",
"metadata": {
"origin_pos": 47
},
"source": [
"## [**训练和验证模型**]\n",
"\n",
"现在,我们可以训练和验证模型了,而以下所有超参数都可以调整。\n",
"例如,我们可以增加周期的数量。当`lr_period`和`lr_decay`分别设置为4和0.9时,优化算法的学习速率将在每4个周期乘以0.9。\n",
"为便于演示,我们在这里只训练20个周期。\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "267a469f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:02:14.997959Z",
"iopub.status.busy": "2023-08-18T07:02:14.997526Z",
"iopub.status.idle": "2023-08-18T07:03:21.092598Z",
"shell.execute_reply": "2023-08-18T07:03:21.091331Z"
},
"origin_pos": 49,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss 0.668, train acc 0.761, valid acc 0.406\n",
"758.4 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4\n",
"lr_period, lr_decay, net = 4, 0.9, get_net()\n",
"train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n",
" lr_decay)"
]
},
{
"cell_type": "markdown",
"id": "7ebad777",
"metadata": {
"origin_pos": 51
},
"source": [
"## 在 Kaggle 上[**对测试集进行分类并提交结果**]\n",
"\n",
"在获得具有超参数的满意的模型后,我们使用所有标记的数据(包括验证集)来重新训练模型并对测试集进行分类。\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "92b85006",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:03:21.098390Z",
"iopub.status.busy": "2023-08-18T07:03:21.097395Z",
"iopub.status.idle": "2023-08-18T07:04:21.878943Z",
"shell.execute_reply": "2023-08-18T07:04:21.878089Z"
},
"origin_pos": 53,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss 0.745, train acc 0.734\n",
"883.3 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"net, preds = get_net(), []\n",
"train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,\n",
" lr_decay)\n",
"\n",
"for X, _ in test_iter:\n",
" y_hat = net(X.to(devices[0]))\n",
" preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())\n",
"sorted_ids = list(range(1, len(test_ds) + 1))\n",
"sorted_ids.sort(key=lambda x: str(x))\n",
"df = pd.DataFrame({'id': sorted_ids, 'label': preds})\n",
"df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])\n",
"df.to_csv('submission.csv', index=False)"
]
},
{
"cell_type": "markdown",
"id": "b6c0bf49",
"metadata": {
"origin_pos": 55
},
"source": [
"向Kaggle提交结果的方法与 :numref:`sec_kaggle_house`中的方法类似,上面的代码将生成一个\n",
"`submission.csv`文件,其格式符合Kaggle竞赛的要求。\n",
"\n",
"## 小结\n",
"\n",
"* 将包含原始图像文件的数据集组织为所需格式后,我们可以读取它们。\n"
]
},
{
"cell_type": "markdown",
"id": "c9aff1dc",
"metadata": {
"origin_pos": 57,
"tab": [
"pytorch"
]
},
"source": [
"* 我们可以在图像分类竞赛中使用卷积神经网络和图像增广。\n"
]
},
{
"cell_type": "markdown",
"id": "725588ee",
"metadata": {
"origin_pos": 59
},
"source": [
"## 练习\n",
"\n",
"1. 在这场Kaggle竞赛中使用完整的CIFAR-10数据集。将超参数设为`batch_size = 128`,`num_epochs = 100`,`lr = 0.1`,`lr_period = 50`,`lr_decay = 0.1`。看看在这场比赛中能达到什么准确度和排名。能进一步改进吗?\n",
"1. 不使用图像增广时,能获得怎样的准确度?\n"
]
},
{
"cell_type": "markdown",
"id": "ae7f6d9e",
"metadata": {
"origin_pos": 61,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/2831)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}