{ "cells": [ { "cell_type": "markdown", "id": "937f9e37", "metadata": { "origin_pos": 0 }, "source": [ "# 实战Kaggle比赛:预测房价\n", ":label:`sec_kaggle_house`\n", "\n", "之前几节我们学习了一些训练深度网络的基本工具和网络正则化的技术(如权重衰减、暂退法等)。\n", "本节我们将通过Kaggle比赛,将所学知识付诸实践。\n", "Kaggle的房价预测比赛是一个很好的起点。\n", "此数据集由Bart de Cock于2011年收集 :cite:`De-Cock.2011`,\n", "涵盖了2006-2010年期间亚利桑那州埃姆斯市的房价。\n", "这个数据集是相当通用的,不会需要使用复杂模型架构。\n", "它比哈里森和鲁宾菲尔德的[波士顿房价](https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.names)\n", "数据集要大得多,也有更多的特征。\n", "\n", "本节我们将详细介绍数据预处理、模型设计和超参数选择。\n", "通过亲身实践,你将获得一手经验,这些经验将有益数据科学家的职业成长。\n", "\n", "## 下载和缓存数据集\n", "\n", "在整本书中,我们将下载不同的数据集,并训练和测试模型。\n", "这里我们(**实现几个函数来方便下载数据**)。\n", "首先,我们建立字典`DATA_HUB`,\n", "它可以将数据集名称的字符串映射到数据集相关的二元组上,\n", "这个二元组包含数据集的url和验证文件完整性的sha-1密钥。\n", "所有类似的数据集都托管在地址为`DATA_URL`的站点上。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "734593b0", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:30.249795Z", "iopub.status.busy": "2023-08-18T06:58:30.249006Z", "iopub.status.idle": "2023-08-18T06:58:30.344738Z", "shell.execute_reply": "2023-08-18T06:58:30.343588Z" }, "origin_pos": 1, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import hashlib\n", "import os\n", "import tarfile\n", "import zipfile\n", "import requests\n", "\n", "#@save\n", "DATA_HUB = dict()\n", "DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'" ] }, { "cell_type": "markdown", "id": "1ea53b4f", "metadata": { "origin_pos": 2 }, "source": [ "下面的`download`函数用来下载数据集,\n", "将数据集缓存在本地目录(默认情况下为`../data`)中,\n", "并返回下载文件的名称。\n", "如果缓存目录中已经存在此数据集文件,并且其sha-1与存储在`DATA_HUB`中的相匹配,\n", "我们将使用缓存的文件,以避免重复的下载。\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "276702a6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:30.350368Z", "iopub.status.busy": "2023-08-18T06:58:30.349724Z", "iopub.status.idle": "2023-08-18T06:58:30.361205Z", "shell.execute_reply": "2023-08-18T06:58:30.360058Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def download(name, cache_dir=os.path.join('..', 'data')): #@save\n", " \"\"\"下载一个DATA_HUB中的文件,返回本地文件名\"\"\"\n", " assert name in DATA_HUB, f\"{name} 不存在于 {DATA_HUB}\"\n", " url, sha1_hash = DATA_HUB[name]\n", " os.makedirs(cache_dir, exist_ok=True)\n", " fname = os.path.join(cache_dir, url.split('/')[-1])\n", " if os.path.exists(fname):\n", " sha1 = hashlib.sha1()\n", " with open(fname, 'rb') as f:\n", " while True:\n", " data = f.read(1048576)\n", " if not data:\n", " break\n", " sha1.update(data)\n", " if sha1.hexdigest() == sha1_hash:\n", " return fname # 命中缓存\n", " print(f'正在从{url}下载{fname}...')\n", " r = requests.get(url, stream=True, verify=True)\n", " with open(fname, 'wb') as f:\n", " f.write(r.content)\n", " return fname" ] }, { "cell_type": "markdown", "id": "8ee59036", "metadata": { "origin_pos": 4 }, "source": [ "我们还需实现两个实用函数:\n", "一个将下载并解压缩一个zip或tar文件,\n", "另一个是将本书中使用的所有数据集从`DATA_HUB`下载到缓存目录中。\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "42ad8efa", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:30.366317Z", "iopub.status.busy": "2023-08-18T06:58:30.365422Z", "iopub.status.idle": "2023-08-18T06:58:30.374280Z", "shell.execute_reply": "2023-08-18T06:58:30.373220Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def download_extract(name, folder=None): #@save\n", " \"\"\"下载并解压zip/tar文件\"\"\"\n", " fname = download(name)\n", " base_dir = os.path.dirname(fname)\n", " data_dir, ext = os.path.splitext(fname)\n", " if ext == '.zip':\n", " fp = zipfile.ZipFile(fname, 'r')\n", " elif ext in ('.tar', '.gz'):\n", " fp = tarfile.open(fname, 'r')\n", " else:\n", " assert False, '只有zip/tar文件可以被解压缩'\n", " fp.extractall(base_dir)\n", " return os.path.join(base_dir, folder) if folder else data_dir\n", "\n", "def download_all(): #@save\n", " \"\"\"下载DATA_HUB中的所有文件\"\"\"\n", " for name in DATA_HUB:\n", " download(name)" ] }, { "cell_type": "markdown", "id": "9ff53967", "metadata": { "origin_pos": 6 }, "source": [ "## Kaggle\n", "\n", "[Kaggle](https://www.kaggle.com)是一个当今流行举办机器学习比赛的平台,\n", "每场比赛都以至少一个数据集为中心。\n", "许多比赛有赞助方,他们为获胜的解决方案提供奖金。\n", "该平台帮助用户通过论坛和共享代码进行互动,促进协作和竞争。\n", "虽然排行榜的追逐往往令人失去理智:\n", "有些研究人员短视地专注于预处理步骤,而不是考虑基础性问题。\n", "但一个客观的平台有巨大的价值:该平台促进了竞争方法之间的直接定量比较,以及代码共享。\n", "这便于每个人都可以学习哪些方法起作用,哪些没有起作用。\n", "如果我们想参加Kaggle比赛,首先需要注册一个账户(见 :numref:`fig_kaggle`)。\n", "\n", "![Kaggle网站](../img/kaggle.png)\n", ":width:`400px`\n", ":label:`fig_kaggle`\n", "\n", "在房价预测比赛页面(如 :numref:`fig_house_pricing` 所示)的\"Data\"选项卡下可以找到数据集。我们可以通过下面的网址提交预测,并查看排名:\n", "\n", ">https://www.kaggle.com/c/house-prices-advanced-regression-techniques\n", "\n", "![房价预测比赛页面](../img/house-pricing.png)\n", ":width:`400px`\n", ":label:`fig_house_pricing`\n", "\n", "## 访问和读取数据集\n", "\n", "注意,竞赛数据分为训练集和测试集。\n", "每条记录都包括房屋的属性值和属性,如街道类型、施工年份、屋顶类型、地下室状况等。\n", "这些特征由各种数据类型组成。\n", "例如,建筑年份由整数表示,屋顶类型由离散类别表示,其他特征由浮点数表示。\n", "这就是现实让事情变得复杂的地方:例如,一些数据完全丢失了,缺失值被简单地标记为“NA”。\n", "每套房子的价格只出现在训练集中(毕竟这是一场比赛)。\n", "我们将希望划分训练集以创建验证集,但是在将预测结果上传到Kaggle之后,\n", "我们只能在官方测试集中评估我们的模型。\n", "在 :numref:`fig_house_pricing` 中,\"Data\"选项卡有下载数据的链接。\n", "\n", "开始之前,我们将[**使用`pandas`读入并处理数据**],\n", "这是我们在 :numref:`sec_pandas`中引入的。\n", "因此,在继续操作之前,我们需要确保已安装`pandas`。\n", "幸运的是,如果我们正在用Jupyter阅读该书,可以在不离开笔记本的情况下安装`pandas`。\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "66e7e040", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:30.379485Z", "iopub.status.busy": "2023-08-18T06:58:30.378590Z", "iopub.status.idle": "2023-08-18T06:58:33.390405Z", "shell.execute_reply": "2023-08-18T06:58:33.389064Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "# 如果没有安装pandas,请取消下一行的注释\n", "# !pip install pandas\n", "\n", "%matplotlib inline\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "144df29b", "metadata": { "origin_pos": 11 }, "source": [ "为方便起见,我们可以使用上面定义的脚本下载并缓存Kaggle房屋数据集。\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "ea733544", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:33.396895Z", "iopub.status.busy": "2023-08-18T06:58:33.395535Z", "iopub.status.idle": "2023-08-18T06:58:33.402172Z", "shell.execute_reply": "2023-08-18T06:58:33.400982Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "DATA_HUB['kaggle_house_train'] = ( #@save\n", " DATA_URL + 'kaggle_house_pred_train.csv',\n", " '585e9cc93e70b39160e7921475f9bcd7d31219ce')\n", "\n", "DATA_HUB['kaggle_house_test'] = ( #@save\n", " DATA_URL + 'kaggle_house_pred_test.csv',\n", " 'fa19780a7b011d9b009e8bff8e99922a8ee2eb90')" ] }, { "cell_type": "markdown", "id": "d7ed71e0", "metadata": { "origin_pos": 13 }, "source": [ "我们使用`pandas`分别加载包含训练数据和测试数据的两个CSV文件。\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "4928df7b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:33.407201Z", "iopub.status.busy": "2023-08-18T06:58:33.406476Z", "iopub.status.idle": "2023-08-18T06:58:33.710870Z", "shell.execute_reply": "2023-08-18T06:58:33.709609Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "正在从http://d2l-data.s3-accelerate.amazonaws.com/kaggle_house_pred_train.csv下载../data/kaggle_house_pred_train.csv...\n", "正在从http://d2l-data.s3-accelerate.amazonaws.com/kaggle_house_pred_test.csv下载../data/kaggle_house_pred_test.csv...\n" ] } ], "source": [ "train_data = pd.read_csv(download('kaggle_house_train'))\n", "test_data = pd.read_csv(download('kaggle_house_test'))" ] }, { "cell_type": "markdown", "id": "ac22b4c8", "metadata": { "origin_pos": 15 }, "source": [ "训练数据集包括1460个样本,每个样本80个特征和1个标签,\n", "而测试数据集包含1459个样本,每个样本80个特征。\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "55aee9f2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:33.716674Z", "iopub.status.busy": "2023-08-18T06:58:33.715846Z", "iopub.status.idle": "2023-08-18T06:58:33.722539Z", "shell.execute_reply": "2023-08-18T06:58:33.721369Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1460, 81)\n", "(1459, 80)\n" ] } ], "source": [ "print(train_data.shape)\n", "print(test_data.shape)" ] }, { "cell_type": "markdown", "id": "6a2b927b", "metadata": { "origin_pos": 17 }, "source": [ "让我们看看[**前四个和最后两个特征,以及相应标签**](房价)。\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "cb459c3d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:33.727643Z", "iopub.status.busy": "2023-08-18T06:58:33.726910Z", "iopub.status.idle": "2023-08-18T06:58:33.741457Z", "shell.execute_reply": "2023-08-18T06:58:33.740293Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Id MSSubClass MSZoning LotFrontage SaleType SaleCondition SalePrice\n", "0 1 60 RL 65.0 WD Normal 208500\n", "1 2 20 RL 80.0 WD Normal 181500\n", "2 3 60 RL 68.0 WD Normal 223500\n", "3 4 70 RL 60.0 WD Abnorml 140000\n" ] } ], "source": [ "print(train_data.iloc[0:4, [0, 1, 2, 3, -3, -2, -1]])" ] }, { "cell_type": "markdown", "id": "9e8244c9", "metadata": { "origin_pos": 19 }, "source": [ "我们可以看到,(**在每个样本中,第一个特征是ID,**)\n", "这有助于模型识别每个训练样本。\n", "虽然这很方便,但它不携带任何用于预测的信息。\n", "因此,在将数据提供给模型之前,(**我们将其从数据集中删除**)。\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "fe5338aa", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:33.746380Z", "iopub.status.busy": "2023-08-18T06:58:33.745604Z", "iopub.status.idle": "2023-08-18T06:58:33.773972Z", "shell.execute_reply": "2023-08-18T06:58:33.772656Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:]))" ] }, { "cell_type": "markdown", "id": "51dbfdc0", "metadata": { "origin_pos": 21 }, "source": [ "## 数据预处理\n", "\n", "如上所述,我们有各种各样的数据类型。\n", "在开始建模之前,我们需要对数据进行预处理。\n", "首先,我们[**将所有缺失的值替换为相应特征的平均值。**]然后,为了将所有特征放在一个共同的尺度上,\n", "我们(**通过将特征重新缩放到零均值和单位方差来标准化数据**):\n", "\n", "$$x \\leftarrow \\frac{x - \\mu}{\\sigma},$$\n", "\n", "其中$\\mu$和$\\sigma$分别表示均值和标准差。\n", "现在,这些特征具有零均值和单位方差,即 $E[\\frac{x-\\mu}{\\sigma}] = \\frac{\\mu - \\mu}{\\sigma} = 0$和$E[(x-\\mu)^2] = (\\sigma^2 + \\mu^2) - 2\\mu^2+\\mu^2 = \\sigma^2$。\n", "直观地说,我们标准化数据有两个原因:\n", "首先,它方便优化。\n", "其次,因为我们不知道哪些特征是相关的,\n", "所以我们不想让惩罚分配给一个特征的系数比分配给其他任何特征的系数更大。\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "ae337076", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:33.779169Z", "iopub.status.busy": "2023-08-18T06:58:33.778411Z", "iopub.status.idle": "2023-08-18T06:58:33.856298Z", "shell.execute_reply": "2023-08-18T06:58:33.855062Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "# 若无法获得测试数据,则可根据训练数据计算均值和标准差\n", "numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index\n", "all_features[numeric_features] = all_features[numeric_features].apply(\n", " lambda x: (x - x.mean()) / (x.std()))\n", "# 在标准化数据之后,所有均值消失,因此我们可以将缺失值设置为0\n", "all_features[numeric_features] = all_features[numeric_features].fillna(0)" ] }, { "cell_type": "markdown", "id": "149f5aa7", "metadata": { "origin_pos": 23 }, "source": [ "接下来,我们[**处理离散值。**]\n", "这包括诸如“MSZoning”之类的特征。\n", "(**我们用独热编码替换它们**),\n", "方法与前面将多类别标签转换为向量的方式相同\n", "(请参见 :numref:`subsec_classification-problem`)。\n", "例如,“MSZoning”包含值“RL”和“Rm”。\n", "我们将创建两个新的指示器特征“MSZoning_RL”和“MSZoning_RM”,其值为0或1。\n", "根据独热编码,如果“MSZoning”的原始值为“RL”,\n", "则:“MSZoning_RL”为1,“MSZoning_RM”为0。\n", "`pandas`软件包会自动为我们实现这一点。\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "73804c29", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:33.861948Z", "iopub.status.busy": "2023-08-18T06:58:33.861162Z", "iopub.status.idle": "2023-08-18T06:58:33.936809Z", "shell.execute_reply": "2023-08-18T06:58:33.935956Z" }, "origin_pos": 24, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(2919, 331)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# “Dummy_na=True”将“na”(缺失值)视为有效的特征值,并为其创建指示符特征\n", "all_features = pd.get_dummies(all_features, dummy_na=True)\n", "all_features.shape" ] }, { "cell_type": "markdown", "id": "c2df3949", "metadata": { "origin_pos": 25 }, "source": [ "可以看到此转换会将特征的总数量从79个增加到331个。\n", "最后,通过`values`属性,我们可以\n", "[**从`pandas`格式中提取NumPy格式,并将其转换为张量表示**]用于训练。\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "2e73c9b7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:33.941143Z", "iopub.status.busy": "2023-08-18T06:58:33.940251Z", "iopub.status.idle": "2023-08-18T06:58:33.968351Z", "shell.execute_reply": "2023-08-18T06:58:33.967159Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "n_train = train_data.shape[0]\n", "train_features = torch.tensor(all_features[:n_train].values, dtype=torch.float32)\n", "test_features = torch.tensor(all_features[n_train:].values, dtype=torch.float32)\n", "train_labels = torch.tensor(\n", " train_data.SalePrice.values.reshape(-1, 1), dtype=torch.float32)" ] }, { "cell_type": "markdown", "id": "2b949329", "metadata": { "origin_pos": 27 }, "source": [ "## [**训练**]\n", "\n", "首先,我们训练一个带有损失平方的线性模型。\n", "显然线性模型很难让我们在竞赛中获胜,但线性模型提供了一种健全性检查,\n", "以查看数据中是否存在有意义的信息。\n", "如果我们在这里不能做得比随机猜测更好,那么我们很可能存在数据处理错误。\n", "如果一切顺利,线性模型将作为*基线*(baseline)模型,\n", "让我们直观地知道最好的模型有超出简单的模型多少。\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "4e16c1dc", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:33.976812Z", "iopub.status.busy": "2023-08-18T06:58:33.974995Z", "iopub.status.idle": "2023-08-18T06:58:33.984092Z", "shell.execute_reply": "2023-08-18T06:58:33.983132Z" }, "origin_pos": 29, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "loss = nn.MSELoss()\n", "in_features = train_features.shape[1]\n", "\n", "def get_net():\n", " net = nn.Sequential(nn.Linear(in_features,1))\n", " return net" ] }, { "cell_type": "markdown", "id": "c3e5f9ca", "metadata": { "origin_pos": 31 }, "source": [ "房价就像股票价格一样,我们关心的是相对数量,而不是绝对数量。\n", "因此,[**我们更关心相对误差$\\frac{y - \\hat{y}}{y}$,**]\n", "而不是绝对误差$y - \\hat{y}$。\n", "例如,如果我们在俄亥俄州农村地区估计一栋房子的价格时,\n", "假设我们的预测偏差了10万美元,\n", "然而那里一栋典型的房子的价值是12.5万美元,\n", "那么模型可能做得很糟糕。\n", "另一方面,如果我们在加州豪宅区的预测出现同样的10万美元的偏差,\n", "(在那里,房价中位数超过400万美元)\n", "这可能是一个不错的预测。\n", "\n", "(**解决这个问题的一种方法是用价格预测的对数来衡量差异**)。\n", "事实上,这也是比赛中官方用来评价提交质量的误差指标。\n", "即将$\\delta$ for $|\\log y - \\log \\hat{y}| \\leq \\delta$\n", "转换为$e^{-\\delta} \\leq \\frac{\\hat{y}}{y} \\leq e^\\delta$。\n", "这使得预测价格的对数与真实标签价格的对数之间出现以下均方根误差:\n", "\n", "$$\\sqrt{\\frac{1}{n}\\sum_{i=1}^n\\left(\\log y_i -\\log \\hat{y}_i\\right)^2}.$$\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "ffbf5478", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:33.991464Z", "iopub.status.busy": "2023-08-18T06:58:33.989673Z", "iopub.status.idle": "2023-08-18T06:58:33.999109Z", "shell.execute_reply": "2023-08-18T06:58:33.998094Z" }, "origin_pos": 33, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def log_rmse(net, features, labels):\n", " # 为了在取对数时进一步稳定该值,将小于1的值设置为1\n", " clipped_preds = torch.clamp(net(features), 1, float('inf'))\n", " rmse = torch.sqrt(loss(torch.log(clipped_preds),\n", " torch.log(labels)))\n", " return rmse.item()" ] }, { "cell_type": "markdown", "id": "ff00e6e9", "metadata": { "origin_pos": 36 }, "source": [ "与前面的部分不同,[**我们的训练函数将借助Adam优化器**]\n", "(我们将在后面章节更详细地描述它)。\n", "Adam优化器的主要吸引力在于它对初始学习率不那么敏感。\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "e2761591", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:34.007489Z", "iopub.status.busy": "2023-08-18T06:58:34.004586Z", "iopub.status.idle": "2023-08-18T06:58:34.017214Z", "shell.execute_reply": "2023-08-18T06:58:34.016158Z" }, "origin_pos": 38, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def train(net, train_features, train_labels, test_features, test_labels,\n", " num_epochs, learning_rate, weight_decay, batch_size):\n", " train_ls, test_ls = [], []\n", " train_iter = d2l.load_array((train_features, train_labels), batch_size)\n", " # 这里使用的是Adam优化算法\n", " optimizer = torch.optim.Adam(net.parameters(),\n", " lr = learning_rate,\n", " weight_decay = weight_decay)\n", " for epoch in range(num_epochs):\n", " for X, y in train_iter:\n", " optimizer.zero_grad()\n", " l = loss(net(X), y)\n", " l.backward()\n", " optimizer.step()\n", " train_ls.append(log_rmse(net, train_features, train_labels))\n", " if test_labels is not None:\n", " test_ls.append(log_rmse(net, test_features, test_labels))\n", " return train_ls, test_ls" ] }, { "cell_type": "markdown", "id": "b81580ed", "metadata": { "origin_pos": 41 }, "source": [ "## $K$折交叉验证\n", "\n", "本书在讨论模型选择的部分( :numref:`sec_model_selection`)\n", "中介绍了[**K折交叉验证**],\n", "它有助于模型选择和超参数调整。\n", "我们首先需要定义一个函数,在$K$折交叉验证过程中返回第$i$折的数据。\n", "具体地说,它选择第$i$个切片作为验证数据,其余部分作为训练数据。\n", "注意,这并不是处理数据的最有效方法,如果我们的数据集大得多,会有其他解决办法。\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "93fbda31", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:34.025514Z", "iopub.status.busy": "2023-08-18T06:58:34.023645Z", "iopub.status.idle": "2023-08-18T06:58:34.035607Z", "shell.execute_reply": "2023-08-18T06:58:34.034655Z" }, "origin_pos": 42, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def get_k_fold_data(k, i, X, y):\n", " assert k > 1\n", " fold_size = X.shape[0] // k\n", " X_train, y_train = None, None\n", " for j in range(k):\n", " idx = slice(j * fold_size, (j + 1) * fold_size)\n", " X_part, y_part = X[idx, :], y[idx]\n", " if j == i:\n", " X_valid, y_valid = X_part, y_part\n", " elif X_train is None:\n", " X_train, y_train = X_part, y_part\n", " else:\n", " X_train = torch.cat([X_train, X_part], 0)\n", " y_train = torch.cat([y_train, y_part], 0)\n", " return X_train, y_train, X_valid, y_valid" ] }, { "cell_type": "markdown", "id": "fbc65cf8", "metadata": { "origin_pos": 43 }, "source": [ "当我们在$K$折交叉验证中训练$K$次后,[**返回训练和验证误差的平均值**]。\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "8da46520", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:34.043961Z", "iopub.status.busy": "2023-08-18T06:58:34.041184Z", "iopub.status.idle": "2023-08-18T06:58:34.055008Z", "shell.execute_reply": "2023-08-18T06:58:34.054026Z" }, "origin_pos": 44, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def k_fold(k, X_train, y_train, num_epochs, learning_rate, weight_decay,\n", " batch_size):\n", " train_l_sum, valid_l_sum = 0, 0\n", " for i in range(k):\n", " data = get_k_fold_data(k, i, X_train, y_train)\n", " net = get_net()\n", " train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,\n", " weight_decay, batch_size)\n", " train_l_sum += train_ls[-1]\n", " valid_l_sum += valid_ls[-1]\n", " if i == 0:\n", " d2l.plot(list(range(1, num_epochs + 1)), [train_ls, valid_ls],\n", " xlabel='epoch', ylabel='rmse', xlim=[1, num_epochs],\n", " legend=['train', 'valid'], yscale='log')\n", " print(f'折{i + 1},训练log rmse{float(train_ls[-1]):f}, '\n", " f'验证log rmse{float(valid_ls[-1]):f}')\n", " return train_l_sum / k, valid_l_sum / k" ] }, { "cell_type": "markdown", "id": "6bd41791", "metadata": { "origin_pos": 45 }, "source": [ "## [**模型选择**]\n", "\n", "在本例中,我们选择了一组未调优的超参数,并将其留给读者来改进模型。\n", "找到一组调优的超参数可能需要时间,这取决于一个人优化了多少变量。\n", "有了足够大的数据集和合理设置的超参数,$K$折交叉验证往往对多次测试具有相当的稳定性。\n", "然而,如果我们尝试了不合理的超参数,我们可能会发现验证效果不再代表真正的误差。\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "0ceb952f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:34.062374Z", "iopub.status.busy": "2023-08-18T06:58:34.060507Z", "iopub.status.idle": "2023-08-18T06:58:48.772917Z", "shell.execute_reply": "2023-08-18T06:58:48.771939Z" }, "origin_pos": 46, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "折1,训练log rmse0.170212, 验证log rmse0.156864\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "折2,训练log rmse0.162003, 验证log rmse0.188812\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "折3,训练log rmse0.163810, 验证log rmse0.168171\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "折4,训练log rmse0.167946, 验证log rmse0.154694\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "折5,训练log rmse0.163320, 验证log rmse0.182928\n", "5-折验证: 平均训练log rmse: 0.165458, 平均验证log rmse: 0.170293\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T06:58:48.714952\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "k, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64\n", "train_l, valid_l = k_fold(k, train_features, train_labels, num_epochs, lr,\n", " weight_decay, batch_size)\n", "print(f'{k}-折验证: 平均训练log rmse: {float(train_l):f}, '\n", " f'平均验证log rmse: {float(valid_l):f}')" ] }, { "cell_type": "markdown", "id": "4e418fd3", "metadata": { "origin_pos": 47 }, "source": [ "请注意,有时一组超参数的训练误差可能非常低,但$K$折交叉验证的误差要高得多,\n", "这表明模型过拟合了。\n", "在整个训练过程中,我们希望监控训练误差和验证误差这两个数字。\n", "较少的过拟合可能表明现有数据可以支撑一个更强大的模型,\n", "较大的过拟合可能意味着我们可以通过正则化技术来获益。\n", "\n", "## [**提交Kaggle预测**]\n", "\n", "既然我们知道应该选择什么样的超参数,\n", "我们不妨使用所有数据对其进行训练\n", "(而不是仅使用交叉验证中使用的$1-1/K$的数据)。\n", "然后,我们通过这种方式获得的模型可以应用于测试集。\n", "将预测保存在CSV文件中可以简化将结果上传到Kaggle的过程。\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "568e9ca5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:48.777939Z", "iopub.status.busy": "2023-08-18T06:58:48.777525Z", "iopub.status.idle": "2023-08-18T06:58:48.787742Z", "shell.execute_reply": "2023-08-18T06:58:48.786661Z" }, "origin_pos": 48, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def train_and_pred(train_features, test_features, train_labels, test_data,\n", " num_epochs, lr, weight_decay, batch_size):\n", " net = get_net()\n", " train_ls, _ = train(net, train_features, train_labels, None, None,\n", " num_epochs, lr, weight_decay, batch_size)\n", " d2l.plot(np.arange(1, num_epochs + 1), [train_ls], xlabel='epoch',\n", " ylabel='log rmse', xlim=[1, num_epochs], yscale='log')\n", " print(f'训练log rmse:{float(train_ls[-1]):f}')\n", " # 将网络应用于测试集。\n", " preds = net(test_features).detach().numpy()\n", " # 将其重新格式化以导出到Kaggle\n", " test_data['SalePrice'] = pd.Series(preds.reshape(1, -1)[0])\n", " submission = pd.concat([test_data['Id'], test_data['SalePrice']], axis=1)\n", " submission.to_csv('submission.csv', index=False)" ] }, { "cell_type": "markdown", "id": "f311cb6c", "metadata": { "origin_pos": 49 }, "source": [ "如果测试集上的预测与$K$倍交叉验证过程中的预测相似,\n", "那就是时候把它们上传到Kaggle了。\n", "下面的代码将生成一个名为`submission.csv`的文件。\n" ] }, { "cell_type": "code", "execution_count": 20, "id": "7fd14d5d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T06:58:48.792405Z", "iopub.status.busy": "2023-08-18T06:58:48.792013Z", "iopub.status.idle": "2023-08-18T06:58:52.795733Z", "shell.execute_reply": "2023-08-18T06:58:52.794625Z" }, "origin_pos": 50, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "训练log rmse:0.162354\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T06:58:52.719276\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train_and_pred(train_features, test_features, train_labels, test_data,\n", " num_epochs, lr, weight_decay, batch_size)" ] }, { "cell_type": "markdown", "id": "5efcf208", "metadata": { "origin_pos": 51 }, "source": [ "接下来,如 :numref:`fig_kaggle_submit2`中所示,\n", "我们可以提交预测到Kaggle上,并查看在测试集上的预测与实际房价(标签)的比较情况。\n", "步骤非常简单。\n", "\n", "* 登录Kaggle网站,访问房价预测竞赛页面。\n", "* 点击“Submit Predictions”或“Late Submission”按钮(在撰写本文时,该按钮位于右侧)。\n", "* 点击页面底部虚线框中的“Upload Submission File”按钮,选择要上传的预测文件。\n", "* 点击页面底部的“Make Submission”按钮,即可查看结果。\n", "\n", "![向Kaggle提交数据](../img/kaggle-submit2.png)\n", ":width:`400px`\n", ":label:`fig_kaggle_submit2`\n", "\n", "## 小结\n", "\n", "* 真实数据通常混合了不同的数据类型,需要进行预处理。\n", "* 常用的预处理方法:将实值数据重新缩放为零均值和单位方法;用均值替换缺失值。\n", "* 将类别特征转化为指标特征,可以使我们把这个特征当作一个独热向量来对待。\n", "* 我们可以使用$K$折交叉验证来选择模型并调整超参数。\n", "* 对数对于相对误差很有用。\n", "\n", "## 练习\n", "\n", "1. 把预测提交给Kaggle,它有多好?\n", "1. 能通过直接最小化价格的对数来改进模型吗?如果试图预测价格的对数而不是价格,会发生什么?\n", "1. 用平均值替换缺失值总是好主意吗?提示:能构造一个不随机丢失值的情况吗?\n", "1. 通过$K$折交叉验证调整超参数,从而提高Kaggle的得分。\n", "1. 通过改进模型(例如,层、权重衰减和dropout)来提高分数。\n", "1. 如果我们没有像本节所做的那样标准化连续的数值特征,会发生什么?\n" ] }, { "cell_type": "markdown", "id": "5198d6f0", "metadata": { "origin_pos": 53, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1824)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }