Files
2025-12-16 09:23:53 +08:00

2925 lines
103 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "937f9e37",
"metadata": {
"origin_pos": 0
},
"source": [
"# 实战Kaggle比赛:预测房价\n",
":label:`sec_kaggle_house`\n",
"\n",
"之前几节我们学习了一些训练深度网络的基本工具和网络正则化的技术(如权重衰减、暂退法等)。\n",
"本节我们将通过Kaggle比赛,将所学知识付诸实践。\n",
"Kaggle的房价预测比赛是一个很好的起点。\n",
"此数据集由Bart de Cock于2011年收集 :cite:`De-Cock.2011`\n",
"涵盖了2006-2010年期间亚利桑那州埃姆斯市的房价。\n",
"这个数据集是相当通用的,不会需要使用复杂模型架构。\n",
"它比哈里森和鲁宾菲尔德的[波士顿房价](https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.names)\n",
"数据集要大得多,也有更多的特征。\n",
"\n",
"本节我们将详细介绍数据预处理、模型设计和超参数选择。\n",
"通过亲身实践,你将获得一手经验,这些经验将有益数据科学家的职业成长。\n",
"\n",
"## 下载和缓存数据集\n",
"\n",
"在整本书中,我们将下载不同的数据集,并训练和测试模型。\n",
"这里我们(**实现几个函数来方便下载数据**)。\n",
"首先,我们建立字典`DATA_HUB`\n",
"它可以将数据集名称的字符串映射到数据集相关的二元组上,\n",
"这个二元组包含数据集的url和验证文件完整性的sha-1密钥。\n",
"所有类似的数据集都托管在地址为`DATA_URL`的站点上。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "734593b0",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:30.249795Z",
"iopub.status.busy": "2023-08-18T06:58:30.249006Z",
"iopub.status.idle": "2023-08-18T06:58:30.344738Z",
"shell.execute_reply": "2023-08-18T06:58:30.343588Z"
},
"origin_pos": 1,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import hashlib\n",
"import os\n",
"import tarfile\n",
"import zipfile\n",
"import requests\n",
"\n",
"#@save\n",
"DATA_HUB = dict()\n",
"DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'"
]
},
{
"cell_type": "markdown",
"id": "1ea53b4f",
"metadata": {
"origin_pos": 2
},
"source": [
"下面的`download`函数用来下载数据集,\n",
"将数据集缓存在本地目录(默认情况下为`../data`)中,\n",
"并返回下载文件的名称。\n",
"如果缓存目录中已经存在此数据集文件,并且其sha-1与存储在`DATA_HUB`中的相匹配,\n",
"我们将使用缓存的文件,以避免重复的下载。\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "276702a6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:30.350368Z",
"iopub.status.busy": "2023-08-18T06:58:30.349724Z",
"iopub.status.idle": "2023-08-18T06:58:30.361205Z",
"shell.execute_reply": "2023-08-18T06:58:30.360058Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def download(name, cache_dir=os.path.join('..', 'data')): #@save\n",
" \"\"\"下载一个DATA_HUB中的文件,返回本地文件名\"\"\"\n",
" assert name in DATA_HUB, f\"{name} 不存在于 {DATA_HUB}\"\n",
" url, sha1_hash = DATA_HUB[name]\n",
" os.makedirs(cache_dir, exist_ok=True)\n",
" fname = os.path.join(cache_dir, url.split('/')[-1])\n",
" if os.path.exists(fname):\n",
" sha1 = hashlib.sha1()\n",
" with open(fname, 'rb') as f:\n",
" while True:\n",
" data = f.read(1048576)\n",
" if not data:\n",
" break\n",
" sha1.update(data)\n",
" if sha1.hexdigest() == sha1_hash:\n",
" return fname # 命中缓存\n",
" print(f'正在从{url}下载{fname}...')\n",
" r = requests.get(url, stream=True, verify=True)\n",
" with open(fname, 'wb') as f:\n",
" f.write(r.content)\n",
" return fname"
]
},
{
"cell_type": "markdown",
"id": "8ee59036",
"metadata": {
"origin_pos": 4
},
"source": [
"我们还需实现两个实用函数:\n",
"一个将下载并解压缩一个zip或tar文件,\n",
"另一个是将本书中使用的所有数据集从`DATA_HUB`下载到缓存目录中。\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "42ad8efa",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:30.366317Z",
"iopub.status.busy": "2023-08-18T06:58:30.365422Z",
"iopub.status.idle": "2023-08-18T06:58:30.374280Z",
"shell.execute_reply": "2023-08-18T06:58:30.373220Z"
},
"origin_pos": 5,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def download_extract(name, folder=None): #@save\n",
" \"\"\"下载并解压zip/tar文件\"\"\"\n",
" fname = download(name)\n",
" base_dir = os.path.dirname(fname)\n",
" data_dir, ext = os.path.splitext(fname)\n",
" if ext == '.zip':\n",
" fp = zipfile.ZipFile(fname, 'r')\n",
" elif ext in ('.tar', '.gz'):\n",
" fp = tarfile.open(fname, 'r')\n",
" else:\n",
" assert False, '只有zip/tar文件可以被解压缩'\n",
" fp.extractall(base_dir)\n",
" return os.path.join(base_dir, folder) if folder else data_dir\n",
"\n",
"def download_all(): #@save\n",
" \"\"\"下载DATA_HUB中的所有文件\"\"\"\n",
" for name in DATA_HUB:\n",
" download(name)"
]
},
{
"cell_type": "markdown",
"id": "9ff53967",
"metadata": {
"origin_pos": 6
},
"source": [
"## Kaggle\n",
"\n",
"[Kaggle](https://www.kaggle.com)是一个当今流行举办机器学习比赛的平台,\n",
"每场比赛都以至少一个数据集为中心。\n",
"许多比赛有赞助方,他们为获胜的解决方案提供奖金。\n",
"该平台帮助用户通过论坛和共享代码进行互动,促进协作和竞争。\n",
"虽然排行榜的追逐往往令人失去理智:\n",
"有些研究人员短视地专注于预处理步骤,而不是考虑基础性问题。\n",
"但一个客观的平台有巨大的价值:该平台促进了竞争方法之间的直接定量比较,以及代码共享。\n",
"这便于每个人都可以学习哪些方法起作用,哪些没有起作用。\n",
"如果我们想参加Kaggle比赛,首先需要注册一个账户(见 :numref:`fig_kaggle`)。\n",
"\n",
"![Kaggle网站](../img/kaggle.png)\n",
":width:`400px`\n",
":label:`fig_kaggle`\n",
"\n",
"在房价预测比赛页面(如 :numref:`fig_house_pricing` 所示)的\"Data\"选项卡下可以找到数据集。我们可以通过下面的网址提交预测,并查看排名:\n",
"\n",
">https://www.kaggle.com/c/house-prices-advanced-regression-techniques\n",
"\n",
"![房价预测比赛页面](../img/house-pricing.png)\n",
":width:`400px`\n",
":label:`fig_house_pricing`\n",
"\n",
"## 访问和读取数据集\n",
"\n",
"注意,竞赛数据分为训练集和测试集。\n",
"每条记录都包括房屋的属性值和属性,如街道类型、施工年份、屋顶类型、地下室状况等。\n",
"这些特征由各种数据类型组成。\n",
"例如,建筑年份由整数表示,屋顶类型由离散类别表示,其他特征由浮点数表示。\n",
"这就是现实让事情变得复杂的地方:例如,一些数据完全丢失了,缺失值被简单地标记为“NA”。\n",
"每套房子的价格只出现在训练集中(毕竟这是一场比赛)。\n",
"我们将希望划分训练集以创建验证集,但是在将预测结果上传到Kaggle之后,\n",
"我们只能在官方测试集中评估我们的模型。\n",
"在 :numref:`fig_house_pricing` 中,\"Data\"选项卡有下载数据的链接。\n",
"\n",
"开始之前,我们将[**使用`pandas`读入并处理数据**]\n",
"这是我们在 :numref:`sec_pandas`中引入的。\n",
"因此,在继续操作之前,我们需要确保已安装`pandas`。\n",
"幸运的是,如果我们正在用Jupyter阅读该书,可以在不离开笔记本的情况下安装`pandas`。\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "66e7e040",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:30.379485Z",
"iopub.status.busy": "2023-08-18T06:58:30.378590Z",
"iopub.status.idle": "2023-08-18T06:58:33.390405Z",
"shell.execute_reply": "2023-08-18T06:58:33.389064Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"# 如果没有安装pandas,请取消下一行的注释\n",
"# !pip install pandas\n",
"\n",
"%matplotlib inline\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "144df29b",
"metadata": {
"origin_pos": 11
},
"source": [
"为方便起见,我们可以使用上面定义的脚本下载并缓存Kaggle房屋数据集。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ea733544",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:33.396895Z",
"iopub.status.busy": "2023-08-18T06:58:33.395535Z",
"iopub.status.idle": "2023-08-18T06:58:33.402172Z",
"shell.execute_reply": "2023-08-18T06:58:33.400982Z"
},
"origin_pos": 12,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"DATA_HUB['kaggle_house_train'] = ( #@save\n",
" DATA_URL + 'kaggle_house_pred_train.csv',\n",
" '585e9cc93e70b39160e7921475f9bcd7d31219ce')\n",
"\n",
"DATA_HUB['kaggle_house_test'] = ( #@save\n",
" DATA_URL + 'kaggle_house_pred_test.csv',\n",
" 'fa19780a7b011d9b009e8bff8e99922a8ee2eb90')"
]
},
{
"cell_type": "markdown",
"id": "d7ed71e0",
"metadata": {
"origin_pos": 13
},
"source": [
"我们使用`pandas`分别加载包含训练数据和测试数据的两个CSV文件。\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4928df7b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:33.407201Z",
"iopub.status.busy": "2023-08-18T06:58:33.406476Z",
"iopub.status.idle": "2023-08-18T06:58:33.710870Z",
"shell.execute_reply": "2023-08-18T06:58:33.709609Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"正在从http://d2l-data.s3-accelerate.amazonaws.com/kaggle_house_pred_train.csv下载../data/kaggle_house_pred_train.csv...\n",
"正在从http://d2l-data.s3-accelerate.amazonaws.com/kaggle_house_pred_test.csv下载../data/kaggle_house_pred_test.csv...\n"
]
}
],
"source": [
"train_data = pd.read_csv(download('kaggle_house_train'))\n",
"test_data = pd.read_csv(download('kaggle_house_test'))"
]
},
{
"cell_type": "markdown",
"id": "ac22b4c8",
"metadata": {
"origin_pos": 15
},
"source": [
"训练数据集包括1460个样本,每个样本80个特征和1个标签,\n",
"而测试数据集包含1459个样本,每个样本80个特征。\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "55aee9f2",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:33.716674Z",
"iopub.status.busy": "2023-08-18T06:58:33.715846Z",
"iopub.status.idle": "2023-08-18T06:58:33.722539Z",
"shell.execute_reply": "2023-08-18T06:58:33.721369Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1460, 81)\n",
"(1459, 80)\n"
]
}
],
"source": [
"print(train_data.shape)\n",
"print(test_data.shape)"
]
},
{
"cell_type": "markdown",
"id": "6a2b927b",
"metadata": {
"origin_pos": 17
},
"source": [
"让我们看看[**前四个和最后两个特征,以及相应标签**](房价)。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cb459c3d",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:33.727643Z",
"iopub.status.busy": "2023-08-18T06:58:33.726910Z",
"iopub.status.idle": "2023-08-18T06:58:33.741457Z",
"shell.execute_reply": "2023-08-18T06:58:33.740293Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Id MSSubClass MSZoning LotFrontage SaleType SaleCondition SalePrice\n",
"0 1 60 RL 65.0 WD Normal 208500\n",
"1 2 20 RL 80.0 WD Normal 181500\n",
"2 3 60 RL 68.0 WD Normal 223500\n",
"3 4 70 RL 60.0 WD Abnorml 140000\n"
]
}
],
"source": [
"print(train_data.iloc[0:4, [0, 1, 2, 3, -3, -2, -1]])"
]
},
{
"cell_type": "markdown",
"id": "9e8244c9",
"metadata": {
"origin_pos": 19
},
"source": [
"我们可以看到,(**在每个样本中,第一个特征是ID**)\n",
"这有助于模型识别每个训练样本。\n",
"虽然这很方便,但它不携带任何用于预测的信息。\n",
"因此,在将数据提供给模型之前,(**我们将其从数据集中删除**)。\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fe5338aa",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:33.746380Z",
"iopub.status.busy": "2023-08-18T06:58:33.745604Z",
"iopub.status.idle": "2023-08-18T06:58:33.773972Z",
"shell.execute_reply": "2023-08-18T06:58:33.772656Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:]))"
]
},
{
"cell_type": "markdown",
"id": "51dbfdc0",
"metadata": {
"origin_pos": 21
},
"source": [
"## 数据预处理\n",
"\n",
"如上所述,我们有各种各样的数据类型。\n",
"在开始建模之前,我们需要对数据进行预处理。\n",
"首先,我们[**将所有缺失的值替换为相应特征的平均值。**]然后,为了将所有特征放在一个共同的尺度上,\n",
"我们(**通过将特征重新缩放到零均值和单位方差来标准化数据**):\n",
"\n",
"$$x \\leftarrow \\frac{x - \\mu}{\\sigma},$$\n",
"\n",
"其中$\\mu$和$\\sigma$分别表示均值和标准差。\n",
"现在,这些特征具有零均值和单位方差,即 $E[\\frac{x-\\mu}{\\sigma}] = \\frac{\\mu - \\mu}{\\sigma} = 0$和$E[(x-\\mu)^2] = (\\sigma^2 + \\mu^2) - 2\\mu^2+\\mu^2 = \\sigma^2$。\n",
"直观地说,我们标准化数据有两个原因:\n",
"首先,它方便优化。\n",
"其次,因为我们不知道哪些特征是相关的,\n",
"所以我们不想让惩罚分配给一个特征的系数比分配给其他任何特征的系数更大。\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ae337076",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:33.779169Z",
"iopub.status.busy": "2023-08-18T06:58:33.778411Z",
"iopub.status.idle": "2023-08-18T06:58:33.856298Z",
"shell.execute_reply": "2023-08-18T06:58:33.855062Z"
},
"origin_pos": 22,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"# 若无法获得测试数据,则可根据训练数据计算均值和标准差\n",
"numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index\n",
"all_features[numeric_features] = all_features[numeric_features].apply(\n",
" lambda x: (x - x.mean()) / (x.std()))\n",
"# 在标准化数据之后,所有均值消失,因此我们可以将缺失值设置为0\n",
"all_features[numeric_features] = all_features[numeric_features].fillna(0)"
]
},
{
"cell_type": "markdown",
"id": "149f5aa7",
"metadata": {
"origin_pos": 23
},
"source": [
"接下来,我们[**处理离散值。**]\n",
"这包括诸如“MSZoning”之类的特征。\n",
"(**我们用独热编码替换它们**)\n",
"方法与前面将多类别标签转换为向量的方式相同\n",
"(请参见 :numref:`subsec_classification-problem`)。\n",
"例如,“MSZoning”包含值“RL”和“Rm”。\n",
"我们将创建两个新的指示器特征“MSZoning_RL”和“MSZoning_RM”,其值为0或1。\n",
"根据独热编码,如果“MSZoning”的原始值为“RL”,\n",
"则:“MSZoning_RL”为1,“MSZoning_RM”为0。\n",
"`pandas`软件包会自动为我们实现这一点。\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "73804c29",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:33.861948Z",
"iopub.status.busy": "2023-08-18T06:58:33.861162Z",
"iopub.status.idle": "2023-08-18T06:58:33.936809Z",
"shell.execute_reply": "2023-08-18T06:58:33.935956Z"
},
"origin_pos": 24,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"(2919, 331)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# “Dummy_na=True”将“na”(缺失值)视为有效的特征值,并为其创建指示符特征\n",
"all_features = pd.get_dummies(all_features, dummy_na=True)\n",
"all_features.shape"
]
},
{
"cell_type": "markdown",
"id": "c2df3949",
"metadata": {
"origin_pos": 25
},
"source": [
"可以看到此转换会将特征的总数量从79个增加到331个。\n",
"最后,通过`values`属性,我们可以\n",
"[**从`pandas`格式中提取NumPy格式,并将其转换为张量表示**]用于训练。\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2e73c9b7",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:33.941143Z",
"iopub.status.busy": "2023-08-18T06:58:33.940251Z",
"iopub.status.idle": "2023-08-18T06:58:33.968351Z",
"shell.execute_reply": "2023-08-18T06:58:33.967159Z"
},
"origin_pos": 26,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"n_train = train_data.shape[0]\n",
"train_features = torch.tensor(all_features[:n_train].values, dtype=torch.float32)\n",
"test_features = torch.tensor(all_features[n_train:].values, dtype=torch.float32)\n",
"train_labels = torch.tensor(\n",
" train_data.SalePrice.values.reshape(-1, 1), dtype=torch.float32)"
]
},
{
"cell_type": "markdown",
"id": "2b949329",
"metadata": {
"origin_pos": 27
},
"source": [
"## [**训练**]\n",
"\n",
"首先,我们训练一个带有损失平方的线性模型。\n",
"显然线性模型很难让我们在竞赛中获胜,但线性模型提供了一种健全性检查,\n",
"以查看数据中是否存在有意义的信息。\n",
"如果我们在这里不能做得比随机猜测更好,那么我们很可能存在数据处理错误。\n",
"如果一切顺利,线性模型将作为*基线*(baseline)模型,\n",
"让我们直观地知道最好的模型有超出简单的模型多少。\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "4e16c1dc",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:33.976812Z",
"iopub.status.busy": "2023-08-18T06:58:33.974995Z",
"iopub.status.idle": "2023-08-18T06:58:33.984092Z",
"shell.execute_reply": "2023-08-18T06:58:33.983132Z"
},
"origin_pos": 29,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"loss = nn.MSELoss()\n",
"in_features = train_features.shape[1]\n",
"\n",
"def get_net():\n",
" net = nn.Sequential(nn.Linear(in_features,1))\n",
" return net"
]
},
{
"cell_type": "markdown",
"id": "c3e5f9ca",
"metadata": {
"origin_pos": 31
},
"source": [
"房价就像股票价格一样,我们关心的是相对数量,而不是绝对数量。\n",
"因此,[**我们更关心相对误差$\\frac{y - \\hat{y}}{y}$**]\n",
"而不是绝对误差$y - \\hat{y}$。\n",
"例如,如果我们在俄亥俄州农村地区估计一栋房子的价格时,\n",
"假设我们的预测偏差了10万美元,\n",
"然而那里一栋典型的房子的价值是12.5万美元,\n",
"那么模型可能做得很糟糕。\n",
"另一方面,如果我们在加州豪宅区的预测出现同样的10万美元的偏差,\n",
"(在那里,房价中位数超过400万美元)\n",
"这可能是一个不错的预测。\n",
"\n",
"(**解决这个问题的一种方法是用价格预测的对数来衡量差异**)。\n",
"事实上,这也是比赛中官方用来评价提交质量的误差指标。\n",
"即将$\\delta$ for $|\\log y - \\log \\hat{y}| \\leq \\delta$\n",
"转换为$e^{-\\delta} \\leq \\frac{\\hat{y}}{y} \\leq e^\\delta$。\n",
"这使得预测价格的对数与真实标签价格的对数之间出现以下均方根误差:\n",
"\n",
"$$\\sqrt{\\frac{1}{n}\\sum_{i=1}^n\\left(\\log y_i -\\log \\hat{y}_i\\right)^2}.$$\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "ffbf5478",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:33.991464Z",
"iopub.status.busy": "2023-08-18T06:58:33.989673Z",
"iopub.status.idle": "2023-08-18T06:58:33.999109Z",
"shell.execute_reply": "2023-08-18T06:58:33.998094Z"
},
"origin_pos": 33,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def log_rmse(net, features, labels):\n",
" # 为了在取对数时进一步稳定该值,将小于1的值设置为1\n",
" clipped_preds = torch.clamp(net(features), 1, float('inf'))\n",
" rmse = torch.sqrt(loss(torch.log(clipped_preds),\n",
" torch.log(labels)))\n",
" return rmse.item()"
]
},
{
"cell_type": "markdown",
"id": "ff00e6e9",
"metadata": {
"origin_pos": 36
},
"source": [
"与前面的部分不同,[**我们的训练函数将借助Adam优化器**]\n",
"(我们将在后面章节更详细地描述它)。\n",
"Adam优化器的主要吸引力在于它对初始学习率不那么敏感。\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "e2761591",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:34.007489Z",
"iopub.status.busy": "2023-08-18T06:58:34.004586Z",
"iopub.status.idle": "2023-08-18T06:58:34.017214Z",
"shell.execute_reply": "2023-08-18T06:58:34.016158Z"
},
"origin_pos": 38,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def train(net, train_features, train_labels, test_features, test_labels,\n",
" num_epochs, learning_rate, weight_decay, batch_size):\n",
" train_ls, test_ls = [], []\n",
" train_iter = d2l.load_array((train_features, train_labels), batch_size)\n",
" # 这里使用的是Adam优化算法\n",
" optimizer = torch.optim.Adam(net.parameters(),\n",
" lr = learning_rate,\n",
" weight_decay = weight_decay)\n",
" for epoch in range(num_epochs):\n",
" for X, y in train_iter:\n",
" optimizer.zero_grad()\n",
" l = loss(net(X), y)\n",
" l.backward()\n",
" optimizer.step()\n",
" train_ls.append(log_rmse(net, train_features, train_labels))\n",
" if test_labels is not None:\n",
" test_ls.append(log_rmse(net, test_features, test_labels))\n",
" return train_ls, test_ls"
]
},
{
"cell_type": "markdown",
"id": "b81580ed",
"metadata": {
"origin_pos": 41
},
"source": [
"## $K$折交叉验证\n",
"\n",
"本书在讨论模型选择的部分( :numref:`sec_model_selection`\n",
"中介绍了[**K折交叉验证**]\n",
"它有助于模型选择和超参数调整。\n",
"我们首先需要定义一个函数,在$K$折交叉验证过程中返回第$i$折的数据。\n",
"具体地说,它选择第$i$个切片作为验证数据,其余部分作为训练数据。\n",
"注意,这并不是处理数据的最有效方法,如果我们的数据集大得多,会有其他解决办法。\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "93fbda31",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:34.025514Z",
"iopub.status.busy": "2023-08-18T06:58:34.023645Z",
"iopub.status.idle": "2023-08-18T06:58:34.035607Z",
"shell.execute_reply": "2023-08-18T06:58:34.034655Z"
},
"origin_pos": 42,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def get_k_fold_data(k, i, X, y):\n",
" assert k > 1\n",
" fold_size = X.shape[0] // k\n",
" X_train, y_train = None, None\n",
" for j in range(k):\n",
" idx = slice(j * fold_size, (j + 1) * fold_size)\n",
" X_part, y_part = X[idx, :], y[idx]\n",
" if j == i:\n",
" X_valid, y_valid = X_part, y_part\n",
" elif X_train is None:\n",
" X_train, y_train = X_part, y_part\n",
" else:\n",
" X_train = torch.cat([X_train, X_part], 0)\n",
" y_train = torch.cat([y_train, y_part], 0)\n",
" return X_train, y_train, X_valid, y_valid"
]
},
{
"cell_type": "markdown",
"id": "fbc65cf8",
"metadata": {
"origin_pos": 43
},
"source": [
"当我们在$K$折交叉验证中训练$K$次后,[**返回训练和验证误差的平均值**]。\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "8da46520",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:34.043961Z",
"iopub.status.busy": "2023-08-18T06:58:34.041184Z",
"iopub.status.idle": "2023-08-18T06:58:34.055008Z",
"shell.execute_reply": "2023-08-18T06:58:34.054026Z"
},
"origin_pos": 44,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def k_fold(k, X_train, y_train, num_epochs, learning_rate, weight_decay,\n",
" batch_size):\n",
" train_l_sum, valid_l_sum = 0, 0\n",
" for i in range(k):\n",
" data = get_k_fold_data(k, i, X_train, y_train)\n",
" net = get_net()\n",
" train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,\n",
" weight_decay, batch_size)\n",
" train_l_sum += train_ls[-1]\n",
" valid_l_sum += valid_ls[-1]\n",
" if i == 0:\n",
" d2l.plot(list(range(1, num_epochs + 1)), [train_ls, valid_ls],\n",
" xlabel='epoch', ylabel='rmse', xlim=[1, num_epochs],\n",
" legend=['train', 'valid'], yscale='log')\n",
" print(f'折{i + 1},训练log rmse{float(train_ls[-1]):f}, '\n",
" f'验证log rmse{float(valid_ls[-1]):f}')\n",
" return train_l_sum / k, valid_l_sum / k"
]
},
{
"cell_type": "markdown",
"id": "6bd41791",
"metadata": {
"origin_pos": 45
},
"source": [
"## [**模型选择**]\n",
"\n",
"在本例中,我们选择了一组未调优的超参数,并将其留给读者来改进模型。\n",
"找到一组调优的超参数可能需要时间,这取决于一个人优化了多少变量。\n",
"有了足够大的数据集和合理设置的超参数,$K$折交叉验证往往对多次测试具有相当的稳定性。\n",
"然而,如果我们尝试了不合理的超参数,我们可能会发现验证效果不再代表真正的误差。\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "0ceb952f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:34.062374Z",
"iopub.status.busy": "2023-08-18T06:58:34.060507Z",
"iopub.status.idle": "2023-08-18T06:58:48.772917Z",
"shell.execute_reply": "2023-08-18T06:58:48.771939Z"
},
"origin_pos": 46,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"折1,训练log rmse0.170212, 验证log rmse0.156864\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"折2,训练log rmse0.162003, 验证log rmse0.188812\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"折3,训练log rmse0.163810, 验证log rmse0.168171\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"折4,训练log rmse0.167946, 验证log rmse0.154694\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"折5,训练log rmse0.163320, 验证log rmse0.182928\n",
"5-折验证: 平均训练log rmse: 0.165458, 平均验证log rmse: 0.170293\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"257.521875pt\" height=\"180.65625pt\" viewBox=\"0 0 257.521875 180.65625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
" <metadata>\n",
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2023-08-18T06:58:48.714952</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
" <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
" </cc:Agent>\n",
" </dc:creator>\n",
" </cc:Work>\n",
" </rdf:RDF>\n",
" </metadata>\n",
" <defs>\n",
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 180.65625 \n",
"L 257.521875 180.65625 \n",
"L 257.521875 0 \n",
"L 0 0 \n",
"L 0 180.65625 \n",
"z\n",
"\" style=\"fill: none\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 45.478125 143.1 \n",
"L 240.778125 143.1 \n",
"L 240.778125 7.2 \n",
"L 45.478125 7.2 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <path d=\"M 82.959943 143.1 \n",
"L 82.959943 7.2 \n",
"\" clip-path=\"url(#pb1f4014886)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_2\">\n",
" <defs>\n",
" <path id=\"m847b688c52\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m847b688c52\" x=\"82.959943\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- 20 -->\n",
" <g transform=\"translate(76.597443 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
"L 3431 531 \n",
"L 3431 0 \n",
"L 469 0 \n",
"L 469 531 \n",
"Q 828 903 1448 1529 \n",
"Q 2069 2156 2228 2338 \n",
"Q 2531 2678 2651 2914 \n",
"Q 2772 3150 2772 3378 \n",
"Q 2772 3750 2511 3984 \n",
"Q 2250 4219 1831 4219 \n",
"Q 1534 4219 1204 4116 \n",
"Q 875 4013 500 3803 \n",
"L 500 4441 \n",
"Q 881 4594 1212 4672 \n",
"Q 1544 4750 1819 4750 \n",
"Q 2544 4750 2975 4387 \n",
"Q 3406 4025 3406 3419 \n",
"Q 3406 3131 3298 2873 \n",
"Q 3191 2616 2906 2266 \n",
"Q 2828 2175 2409 1742 \n",
"Q 1991 1309 1228 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
"Q 1547 4250 1301 3770 \n",
"Q 1056 3291 1056 2328 \n",
"Q 1056 1369 1301 889 \n",
"Q 1547 409 2034 409 \n",
"Q 2525 409 2770 889 \n",
"Q 3016 1369 3016 2328 \n",
"Q 3016 3291 2770 3770 \n",
"Q 2525 4250 2034 4250 \n",
"z\n",
"M 2034 4750 \n",
"Q 2819 4750 3233 4129 \n",
"Q 3647 3509 3647 2328 \n",
"Q 3647 1150 3233 529 \n",
"Q 2819 -91 2034 -91 \n",
"Q 1250 -91 836 529 \n",
"Q 422 1150 422 2328 \n",
"Q 422 3509 836 4129 \n",
"Q 1250 4750 2034 4750 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_3\">\n",
" <path d=\"M 122.414489 143.1 \n",
"L 122.414489 7.2 \n",
"\" clip-path=\"url(#pb1f4014886)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use xlink:href=\"#m847b688c52\" x=\"122.414489\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- 40 -->\n",
" <g transform=\"translate(116.051989 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
"L 825 1625 \n",
"L 2419 1625 \n",
"L 2419 4116 \n",
"z\n",
"M 2253 4666 \n",
"L 3047 4666 \n",
"L 3047 1625 \n",
"L 3713 1625 \n",
"L 3713 1100 \n",
"L 3047 1100 \n",
"L 3047 0 \n",
"L 2419 0 \n",
"L 2419 1100 \n",
"L 313 1100 \n",
"L 313 1709 \n",
"L 2253 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_5\">\n",
" <path d=\"M 161.869034 143.1 \n",
"L 161.869034 7.2 \n",
"\" clip-path=\"url(#pb1f4014886)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_6\">\n",
" <g>\n",
" <use xlink:href=\"#m847b688c52\" x=\"161.869034\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 60 -->\n",
" <g transform=\"translate(155.506534 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \n",
"Q 1688 2584 1439 2293 \n",
"Q 1191 2003 1191 1497 \n",
"Q 1191 994 1439 701 \n",
"Q 1688 409 2113 409 \n",
"Q 2538 409 2786 701 \n",
"Q 3034 994 3034 1497 \n",
"Q 3034 2003 2786 2293 \n",
"Q 2538 2584 2113 2584 \n",
"z\n",
"M 3366 4563 \n",
"L 3366 3988 \n",
"Q 3128 4100 2886 4159 \n",
"Q 2644 4219 2406 4219 \n",
"Q 1781 4219 1451 3797 \n",
"Q 1122 3375 1075 2522 \n",
"Q 1259 2794 1537 2939 \n",
"Q 1816 3084 2150 3084 \n",
"Q 2853 3084 3261 2657 \n",
"Q 3669 2231 3669 1497 \n",
"Q 3669 778 3244 343 \n",
"Q 2819 -91 2113 -91 \n",
"Q 1303 -91 875 529 \n",
"Q 447 1150 447 2328 \n",
"Q 447 3434 972 4092 \n",
"Q 1497 4750 2381 4750 \n",
"Q 2619 4750 2861 4703 \n",
"Q 3103 4656 3366 4563 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-36\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_7\">\n",
" <path d=\"M 201.32358 143.1 \n",
"L 201.32358 7.2 \n",
"\" clip-path=\"url(#pb1f4014886)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use xlink:href=\"#m847b688c52\" x=\"201.32358\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 80 -->\n",
" <g transform=\"translate(194.96108 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \n",
"Q 1584 2216 1326 1975 \n",
"Q 1069 1734 1069 1313 \n",
"Q 1069 891 1326 650 \n",
"Q 1584 409 2034 409 \n",
"Q 2484 409 2743 651 \n",
"Q 3003 894 3003 1313 \n",
"Q 3003 1734 2745 1975 \n",
"Q 2488 2216 2034 2216 \n",
"z\n",
"M 1403 2484 \n",
"Q 997 2584 770 2862 \n",
"Q 544 3141 544 3541 \n",
"Q 544 4100 942 4425 \n",
"Q 1341 4750 2034 4750 \n",
"Q 2731 4750 3128 4425 \n",
"Q 3525 4100 3525 3541 \n",
"Q 3525 3141 3298 2862 \n",
"Q 3072 2584 2669 2484 \n",
"Q 3125 2378 3379 2068 \n",
"Q 3634 1759 3634 1313 \n",
"Q 3634 634 3220 271 \n",
"Q 2806 -91 2034 -91 \n",
"Q 1263 -91 848 271 \n",
"Q 434 634 434 1313 \n",
"Q 434 1759 690 2068 \n",
"Q 947 2378 1403 2484 \n",
"z\n",
"M 1172 3481 \n",
"Q 1172 3119 1398 2916 \n",
"Q 1625 2713 2034 2713 \n",
"Q 2441 2713 2670 2916 \n",
"Q 2900 3119 2900 3481 \n",
"Q 2900 3844 2670 4047 \n",
"Q 2441 4250 2034 4250 \n",
"Q 1625 4250 1398 4047 \n",
"Q 1172 3844 1172 3481 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-38\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_9\">\n",
" <path d=\"M 240.778125 143.1 \n",
"L 240.778125 7.2 \n",
"\" clip-path=\"url(#pb1f4014886)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use xlink:href=\"#m847b688c52\" x=\"240.778125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 100 -->\n",
" <g transform=\"translate(231.234375 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
"L 1825 531 \n",
"L 1825 4091 \n",
"L 703 3866 \n",
"L 703 4441 \n",
"L 1819 4666 \n",
"L 2450 4666 \n",
"L 2450 531 \n",
"L 3481 531 \n",
"L 3481 0 \n",
"L 794 0 \n",
"L 794 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- epoch -->\n",
" <g transform=\"translate(127.9 171.376563)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
"L 3597 1613 \n",
"L 953 1613 \n",
"Q 991 1019 1311 708 \n",
"Q 1631 397 2203 397 \n",
"Q 2534 397 2845 478 \n",
"Q 3156 559 3463 722 \n",
"L 3463 178 \n",
"Q 3153 47 2828 -22 \n",
"Q 2503 -91 2169 -91 \n",
"Q 1331 -91 842 396 \n",
"Q 353 884 353 1716 \n",
"Q 353 2575 817 3079 \n",
"Q 1281 3584 2069 3584 \n",
"Q 2775 3584 3186 3129 \n",
"Q 3597 2675 3597 1894 \n",
"z\n",
"M 3022 2063 \n",
"Q 3016 2534 2758 2815 \n",
"Q 2500 3097 2075 3097 \n",
"Q 1594 3097 1305 2825 \n",
"Q 1016 2553 972 2059 \n",
"L 3022 2063 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
"L 1159 -1331 \n",
"L 581 -1331 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2969 \n",
"Q 1341 3281 1617 3432 \n",
"Q 1894 3584 2278 3584 \n",
"Q 2916 3584 3314 3078 \n",
"Q 3713 2572 3713 1747 \n",
"Q 3713 922 3314 415 \n",
"Q 2916 -91 2278 -91 \n",
"Q 1894 -91 1617 61 \n",
"Q 1341 213 1159 525 \n",
"z\n",
"M 3116 1747 \n",
"Q 3116 2381 2855 2742 \n",
"Q 2594 3103 2138 3103 \n",
"Q 1681 3103 1420 2742 \n",
"Q 1159 2381 1159 1747 \n",
"Q 1159 1113 1420 752 \n",
"Q 1681 391 2138 391 \n",
"Q 2594 391 2855 752 \n",
"Q 3116 1113 3116 1747 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
"Q 1497 3097 1228 2736 \n",
"Q 959 2375 959 1747 \n",
"Q 959 1119 1226 758 \n",
"Q 1494 397 1959 397 \n",
"Q 2419 397 2687 759 \n",
"Q 2956 1122 2956 1747 \n",
"Q 2956 2369 2687 2733 \n",
"Q 2419 3097 1959 3097 \n",
"z\n",
"M 1959 3584 \n",
"Q 2709 3584 3137 3096 \n",
"Q 3566 2609 3566 1747 \n",
"Q 3566 888 3137 398 \n",
"Q 2709 -91 1959 -91 \n",
"Q 1206 -91 779 398 \n",
"Q 353 888 353 1747 \n",
"Q 353 2609 779 3096 \n",
"Q 1206 3584 1959 3584 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-63\" d=\"M 3122 3366 \n",
"L 3122 2828 \n",
"Q 2878 2963 2633 3030 \n",
"Q 2388 3097 2138 3097 \n",
"Q 1578 3097 1268 2742 \n",
"Q 959 2388 959 1747 \n",
"Q 959 1106 1268 751 \n",
"Q 1578 397 2138 397 \n",
"Q 2388 397 2633 464 \n",
"Q 2878 531 3122 666 \n",
"L 3122 134 \n",
"Q 2881 22 2623 -34 \n",
"Q 2366 -91 2075 -91 \n",
"Q 1284 -91 818 406 \n",
"Q 353 903 353 1747 \n",
"Q 353 2603 823 3093 \n",
"Q 1294 3584 2113 3584 \n",
"Q 2378 3584 2631 3529 \n",
"Q 2884 3475 3122 3366 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-68\" d=\"M 3513 2113 \n",
"L 3513 0 \n",
"L 2938 0 \n",
"L 2938 2094 \n",
"Q 2938 2591 2744 2837 \n",
"Q 2550 3084 2163 3084 \n",
"Q 1697 3084 1428 2787 \n",
"Q 1159 2491 1159 1978 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 4863 \n",
"L 1159 4863 \n",
"L 1159 2956 \n",
"Q 1366 3272 1645 3428 \n",
"Q 1925 3584 2291 3584 \n",
"Q 2894 3584 3203 3211 \n",
"Q 3513 2838 3513 2113 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-65\"/>\n",
" <use xlink:href=\"#DejaVuSans-70\" x=\"61.523438\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"125\"/>\n",
" <use xlink:href=\"#DejaVuSans-63\" x=\"186.181641\"/>\n",
" <use xlink:href=\"#DejaVuSans-68\" x=\"241.162109\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_11\">\n",
" <path d=\"M 45.478125 64.841953 \n",
"L 240.778125 64.841953 \n",
"\" clip-path=\"url(#pb1f4014886)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_12\">\n",
" <defs>\n",
" <path id=\"m45cb16a538\" d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m45cb16a538\" x=\"45.478125\" y=\"64.841953\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- $\\mathdefault{10^{0}}$ -->\n",
" <g transform=\"translate(20.878125 68.641172)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.765625)\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.765625)\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(128.203125 39.046875)scale(0.7)\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_13\">\n",
" <defs>\n",
" <path id=\"mbdc64d3efd\" d=\"M 0 0 \n",
"L -2 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"127.457042\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_14\">\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"111.682445\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_15\">\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"100.490191\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_16\">\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"91.808804\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_6\">\n",
" <g id=\"line2d_17\">\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"84.715594\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_7\">\n",
" <g id=\"line2d_18\">\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"78.718371\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_8\">\n",
" <g id=\"line2d_19\">\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"73.52334\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_9\">\n",
" <g id=\"line2d_20\">\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"68.940998\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_10\">\n",
" <g id=\"line2d_21\">\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"37.875103\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_11\">\n",
" <g id=\"line2d_22\">\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"22.100506\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_12\">\n",
" <g id=\"line2d_23\">\n",
" <g>\n",
" <use xlink:href=\"#mbdc64d3efd\" x=\"45.478125\" y=\"10.908252\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- rmse -->\n",
" <g transform=\"translate(14.798437 87.669531)rotate(-90)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
"Q 2534 3019 2420 3045 \n",
"Q 2306 3072 2169 3072 \n",
"Q 1681 3072 1420 2755 \n",
"Q 1159 2438 1159 1844 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2956 \n",
"Q 1341 3275 1631 3429 \n",
"Q 1922 3584 2338 3584 \n",
"Q 2397 3584 2469 3576 \n",
"Q 2541 3569 2628 3553 \n",
"L 2631 2963 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6d\" d=\"M 3328 2828 \n",
"Q 3544 3216 3844 3400 \n",
"Q 4144 3584 4550 3584 \n",
"Q 5097 3584 5394 3201 \n",
"Q 5691 2819 5691 2113 \n",
"L 5691 0 \n",
"L 5113 0 \n",
"L 5113 2094 \n",
"Q 5113 2597 4934 2840 \n",
"Q 4756 3084 4391 3084 \n",
"Q 3944 3084 3684 2787 \n",
"Q 3425 2491 3425 1978 \n",
"L 3425 0 \n",
"L 2847 0 \n",
"L 2847 2094 \n",
"Q 2847 2600 2669 2842 \n",
"Q 2491 3084 2119 3084 \n",
"Q 1678 3084 1418 2786 \n",
"Q 1159 2488 1159 1978 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2956 \n",
"Q 1356 3278 1631 3431 \n",
"Q 1906 3584 2284 3584 \n",
"Q 2666 3584 2933 3390 \n",
"Q 3200 3197 3328 2828 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
"L 2834 2853 \n",
"Q 2591 2978 2328 3040 \n",
"Q 2066 3103 1784 3103 \n",
"Q 1356 3103 1142 2972 \n",
"Q 928 2841 928 2578 \n",
"Q 928 2378 1081 2264 \n",
"Q 1234 2150 1697 2047 \n",
"L 1894 2003 \n",
"Q 2506 1872 2764 1633 \n",
"Q 3022 1394 3022 966 \n",
"Q 3022 478 2636 193 \n",
"Q 2250 -91 1575 -91 \n",
"Q 1294 -91 989 -36 \n",
"Q 684 19 347 128 \n",
"L 347 722 \n",
"Q 666 556 975 473 \n",
"Q 1284 391 1588 391 \n",
"Q 1994 391 2212 530 \n",
"Q 2431 669 2431 922 \n",
"Q 2431 1156 2273 1281 \n",
"Q 2116 1406 1581 1522 \n",
"L 1381 1569 \n",
"Q 847 1681 609 1914 \n",
"Q 372 2147 372 2553 \n",
"Q 372 3047 722 3315 \n",
"Q 1072 3584 1716 3584 \n",
"Q 2034 3584 2315 3537 \n",
"Q 2597 3491 2834 3397 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-72\"/>\n",
" <use xlink:href=\"#DejaVuSans-6d\" x=\"39.363281\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"136.775391\"/>\n",
" <use xlink:href=\"#DejaVuSans-65\" x=\"188.875\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_24\">\n",
" <path d=\"M 45.478125 13.532241 \n",
"L 47.450852 21.43701 \n",
"L 49.42358 26.90266 \n",
"L 51.396307 31.271056 \n",
"L 53.369034 35.002306 \n",
"L 55.341761 38.313618 \n",
"L 57.314489 41.30559 \n",
"L 59.287216 44.088306 \n",
"L 61.259943 46.694206 \n",
"L 63.23267 49.162645 \n",
"L 65.205398 51.510203 \n",
"L 67.178125 53.772155 \n",
"L 69.150852 55.948901 \n",
"L 71.12358 58.053385 \n",
"L 73.096307 60.124212 \n",
"L 75.069034 62.123519 \n",
"L 77.041761 64.090804 \n",
"L 79.014489 66.010622 \n",
"L 80.987216 67.887486 \n",
"L 82.959943 69.748016 \n",
"L 84.93267 71.597371 \n",
"L 86.905398 73.386553 \n",
"L 88.878125 75.179985 \n",
"L 90.850852 76.936149 \n",
"L 92.82358 78.69032 \n",
"L 94.796307 80.424523 \n",
"L 96.769034 82.106046 \n",
"L 98.741761 83.781672 \n",
"L 100.714489 85.473676 \n",
"L 102.687216 87.136061 \n",
"L 104.659943 88.776845 \n",
"L 106.63267 90.402677 \n",
"L 108.605398 92.011823 \n",
"L 110.578125 93.627281 \n",
"L 112.550852 95.217228 \n",
"L 114.52358 96.777037 \n",
"L 116.496307 98.340012 \n",
"L 118.469034 99.860291 \n",
"L 120.441761 101.350827 \n",
"L 122.414489 102.851174 \n",
"L 124.387216 104.313602 \n",
"L 126.359943 105.747871 \n",
"L 128.33267 107.155366 \n",
"L 130.305398 108.560916 \n",
"L 132.278125 109.926078 \n",
"L 134.250852 111.270731 \n",
"L 136.22358 112.56023 \n",
"L 138.196307 113.818544 \n",
"L 140.169034 115.030561 \n",
"L 142.141761 116.21707 \n",
"L 144.114489 117.369411 \n",
"L 146.087216 118.452745 \n",
"L 148.059943 119.50921 \n",
"L 150.03267 120.579803 \n",
"L 152.005398 121.562723 \n",
"L 153.978125 122.482033 \n",
"L 155.950852 123.387798 \n",
"L 157.92358 124.211812 \n",
"L 159.896307 125.00577 \n",
"L 161.869034 125.734448 \n",
"L 163.841761 126.43801 \n",
"L 165.814489 127.085453 \n",
"L 167.787216 127.705297 \n",
"L 169.759943 128.270264 \n",
"L 171.73267 128.808448 \n",
"L 173.705398 129.294496 \n",
"L 175.678125 129.755459 \n",
"L 177.650852 130.157327 \n",
"L 179.62358 130.54595 \n",
"L 181.596307 130.87324 \n",
"L 183.569034 131.195329 \n",
"L 185.541761 131.508978 \n",
"L 187.514489 131.78612 \n",
"L 189.487216 132.010136 \n",
"L 191.459943 132.214879 \n",
"L 193.43267 132.398635 \n",
"L 195.405398 132.576011 \n",
"L 197.378125 132.727123 \n",
"L 199.350852 132.873056 \n",
"L 201.32358 132.994403 \n",
"L 203.296307 133.108921 \n",
"L 205.269034 133.19652 \n",
"L 207.241761 133.283782 \n",
"L 209.214489 133.35436 \n",
"L 211.187216 133.418961 \n",
"L 213.159943 133.459415 \n",
"L 215.13267 133.509866 \n",
"L 217.105398 133.553984 \n",
"L 219.078125 133.579105 \n",
"L 221.050852 133.613121 \n",
"L 223.02358 133.632131 \n",
"L 224.996307 133.651275 \n",
"L 226.969034 133.66435 \n",
"L 228.941761 133.687324 \n",
"L 230.914489 133.707316 \n",
"L 232.887216 133.711387 \n",
"L 234.859943 133.72296 \n",
"L 236.83267 133.74001 \n",
"L 238.805398 133.738276 \n",
"L 240.778125 133.731304 \n",
"\" clip-path=\"url(#pb1f4014886)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_25\">\n",
" <path d=\"M 45.478125 13.377273 \n",
"L 47.450852 21.251753 \n",
"L 49.42358 26.698583 \n",
"L 51.396307 31.044913 \n",
"L 53.369034 34.75598 \n",
"L 55.341761 38.05278 \n",
"L 57.314489 41.023066 \n",
"L 59.287216 43.790829 \n",
"L 61.259943 46.381902 \n",
"L 63.23267 48.831338 \n",
"L 65.205398 51.1637 \n",
"L 67.178125 53.408115 \n",
"L 69.150852 55.570585 \n",
"L 71.12358 57.660569 \n",
"L 73.096307 59.723133 \n",
"L 75.069034 61.711181 \n",
"L 77.041761 63.667637 \n",
"L 79.014489 65.575993 \n",
"L 80.987216 67.443238 \n",
"L 82.959943 69.291152 \n",
"L 84.93267 71.133531 \n",
"L 86.905398 72.916472 \n",
"L 88.878125 74.699266 \n",
"L 90.850852 76.446786 \n",
"L 92.82358 78.19455 \n",
"L 94.796307 79.922274 \n",
"L 96.769034 81.597433 \n",
"L 98.741761 83.270477 \n",
"L 100.714489 84.958479 \n",
"L 102.687216 86.620857 \n",
"L 104.659943 88.263874 \n",
"L 106.63267 89.895608 \n",
"L 108.605398 91.509575 \n",
"L 110.578125 93.135235 \n",
"L 112.550852 94.739259 \n",
"L 114.52358 96.313268 \n",
"L 116.496307 97.894025 \n",
"L 118.469034 99.438356 \n",
"L 120.441761 100.959013 \n",
"L 122.414489 102.495324 \n",
"L 124.387216 103.993173 \n",
"L 126.359943 105.46836 \n",
"L 128.33267 106.924742 \n",
"L 130.305398 108.381042 \n",
"L 132.278125 109.809264 \n",
"L 134.250852 111.215186 \n",
"L 136.22358 112.573345 \n",
"L 138.196307 113.906923 \n",
"L 140.169034 115.198379 \n",
"L 142.141761 116.477175 \n",
"L 144.114489 117.720336 \n",
"L 146.087216 118.901939 \n",
"L 148.059943 120.058975 \n",
"L 150.03267 121.239848 \n",
"L 152.005398 122.332063 \n",
"L 153.978125 123.366982 \n",
"L 155.950852 124.386301 \n",
"L 157.92358 125.325814 \n",
"L 159.896307 126.242798 \n",
"L 161.869034 127.092516 \n",
"L 163.841761 127.918204 \n",
"L 165.814489 128.685412 \n",
"L 167.787216 129.428898 \n",
"L 169.759943 130.107908 \n",
"L 171.73267 130.753493 \n",
"L 173.705398 131.353316 \n",
"L 175.678125 131.917724 \n",
"L 177.650852 132.420356 \n",
"L 179.62358 132.902439 \n",
"L 181.596307 133.310255 \n",
"L 183.569034 133.710097 \n",
"L 185.541761 134.109393 \n",
"L 187.514489 134.449865 \n",
"L 189.487216 134.73695 \n",
"L 191.459943 135.00287 \n",
"L 193.43267 135.233736 \n",
"L 195.405398 135.473172 \n",
"L 197.378125 135.651445 \n",
"L 199.350852 135.851344 \n",
"L 201.32358 136.004171 \n",
"L 203.296307 136.161622 \n",
"L 205.269034 136.275763 \n",
"L 207.241761 136.381605 \n",
"L 209.214489 136.471532 \n",
"L 211.187216 136.549373 \n",
"L 213.159943 136.604075 \n",
"L 215.13267 136.672038 \n",
"L 217.105398 136.721991 \n",
"L 219.078125 136.760291 \n",
"L 221.050852 136.791514 \n",
"L 223.02358 136.823347 \n",
"L 224.996307 136.833999 \n",
"L 226.969034 136.8478 \n",
"L 228.941761 136.866727 \n",
"L 230.914489 136.882767 \n",
"L 232.887216 136.887927 \n",
"L 234.859943 136.8883 \n",
"L 236.83267 136.908064 \n",
"L 238.805398 136.922727 \n",
"L 240.778125 136.908618 \n",
"\" clip-path=\"url(#pb1f4014886)\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 45.478125 143.1 \n",
"L 45.478125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 240.778125 143.1 \n",
"L 240.778125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 45.478125 143.1 \n",
"L 240.778125 143.1 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 45.478125 7.2 \n",
"L 240.778125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"legend_1\">\n",
" <g id=\"patch_7\">\n",
" <path d=\"M 177.826562 44.55625 \n",
"L 233.778125 44.55625 \n",
"Q 235.778125 44.55625 235.778125 42.55625 \n",
"L 235.778125 14.2 \n",
"Q 235.778125 12.2 233.778125 12.2 \n",
"L 177.826562 12.2 \n",
"Q 175.826562 12.2 175.826562 14.2 \n",
"L 175.826562 42.55625 \n",
"Q 175.826562 44.55625 177.826562 44.55625 \n",
"z\n",
"\" style=\"fill: #ffffff; opacity: 0.8; stroke: #cccccc; stroke-linejoin: miter\"/>\n",
" </g>\n",
" <g id=\"line2d_26\">\n",
" <path d=\"M 179.826562 20.298437 \n",
"L 189.826562 20.298437 \n",
"L 199.826562 20.298437 \n",
"\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- train -->\n",
" <g transform=\"translate(207.826562 23.798437)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-74\" d=\"M 1172 4494 \n",
"L 1172 3500 \n",
"L 2356 3500 \n",
"L 2356 3053 \n",
"L 1172 3053 \n",
"L 1172 1153 \n",
"Q 1172 725 1289 603 \n",
"Q 1406 481 1766 481 \n",
"L 2356 481 \n",
"L 2356 0 \n",
"L 1766 0 \n",
"Q 1100 0 847 248 \n",
"Q 594 497 594 1153 \n",
"L 594 3053 \n",
"L 172 3053 \n",
"L 172 3500 \n",
"L 594 3500 \n",
"L 594 4494 \n",
"L 1172 4494 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-61\" d=\"M 2194 1759 \n",
"Q 1497 1759 1228 1600 \n",
"Q 959 1441 959 1056 \n",
"Q 959 750 1161 570 \n",
"Q 1363 391 1709 391 \n",
"Q 2188 391 2477 730 \n",
"Q 2766 1069 2766 1631 \n",
"L 2766 1759 \n",
"L 2194 1759 \n",
"z\n",
"M 3341 1997 \n",
"L 3341 0 \n",
"L 2766 0 \n",
"L 2766 531 \n",
"Q 2569 213 2275 61 \n",
"Q 1981 -91 1556 -91 \n",
"Q 1019 -91 701 211 \n",
"Q 384 513 384 1019 \n",
"Q 384 1609 779 1909 \n",
"Q 1175 2209 1959 2209 \n",
"L 2766 2209 \n",
"L 2766 2266 \n",
"Q 2766 2663 2505 2880 \n",
"Q 2244 3097 1772 3097 \n",
"Q 1472 3097 1187 3025 \n",
"Q 903 2953 641 2809 \n",
"L 641 3341 \n",
"Q 956 3463 1253 3523 \n",
"Q 1550 3584 1831 3584 \n",
"Q 2591 3584 2966 3190 \n",
"Q 3341 2797 3341 1997 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
"L 1178 3500 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 3500 \n",
"z\n",
"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 4134 \n",
"L 603 4134 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6e\" d=\"M 3513 2113 \n",
"L 3513 0 \n",
"L 2938 0 \n",
"L 2938 2094 \n",
"Q 2938 2591 2744 2837 \n",
"Q 2550 3084 2163 3084 \n",
"Q 1697 3084 1428 2787 \n",
"Q 1159 2491 1159 1978 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2956 \n",
"Q 1366 3272 1645 3428 \n",
"Q 1925 3584 2291 3584 \n",
"Q 2894 3584 3203 3211 \n",
"Q 3513 2838 3513 2113 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-74\"/>\n",
" <use xlink:href=\"#DejaVuSans-72\" x=\"39.208984\"/>\n",
" <use xlink:href=\"#DejaVuSans-61\" x=\"80.322266\"/>\n",
" <use xlink:href=\"#DejaVuSans-69\" x=\"141.601562\"/>\n",
" <use xlink:href=\"#DejaVuSans-6e\" x=\"169.384766\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_27\">\n",
" <path d=\"M 179.826562 34.976562 \n",
"L 189.826562 34.976562 \n",
"L 199.826562 34.976562 \n",
"\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- valid -->\n",
" <g transform=\"translate(207.826562 38.476562)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-76\" d=\"M 191 3500 \n",
"L 800 3500 \n",
"L 1894 563 \n",
"L 2988 3500 \n",
"L 3597 3500 \n",
"L 2284 0 \n",
"L 1503 0 \n",
"L 191 3500 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-64\" d=\"M 2906 2969 \n",
"L 2906 4863 \n",
"L 3481 4863 \n",
"L 3481 0 \n",
"L 2906 0 \n",
"L 2906 525 \n",
"Q 2725 213 2448 61 \n",
"Q 2172 -91 1784 -91 \n",
"Q 1150 -91 751 415 \n",
"Q 353 922 353 1747 \n",
"Q 353 2572 751 3078 \n",
"Q 1150 3584 1784 3584 \n",
"Q 2172 3584 2448 3432 \n",
"Q 2725 3281 2906 2969 \n",
"z\n",
"M 947 1747 \n",
"Q 947 1113 1208 752 \n",
"Q 1469 391 1925 391 \n",
"Q 2381 391 2643 752 \n",
"Q 2906 1113 2906 1747 \n",
"Q 2906 2381 2643 2742 \n",
"Q 2381 3103 1925 3103 \n",
"Q 1469 3103 1208 2742 \n",
"Q 947 2381 947 1747 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-76\"/>\n",
" <use xlink:href=\"#DejaVuSans-61\" x=\"59.179688\"/>\n",
" <use xlink:href=\"#DejaVuSans-6c\" x=\"120.458984\"/>\n",
" <use xlink:href=\"#DejaVuSans-69\" x=\"148.242188\"/>\n",
" <use xlink:href=\"#DejaVuSans-64\" x=\"176.025391\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"pb1f4014886\">\n",
" <rect x=\"45.478125\" y=\"7.2\" width=\"195.3\" height=\"135.9\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<Figure size 252x180 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"k, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64\n",
"train_l, valid_l = k_fold(k, train_features, train_labels, num_epochs, lr,\n",
" weight_decay, batch_size)\n",
"print(f'{k}-折验证: 平均训练log rmse: {float(train_l):f}, '\n",
" f'平均验证log rmse: {float(valid_l):f}')"
]
},
{
"cell_type": "markdown",
"id": "4e418fd3",
"metadata": {
"origin_pos": 47
},
"source": [
"请注意,有时一组超参数的训练误差可能非常低,但$K$折交叉验证的误差要高得多,\n",
"这表明模型过拟合了。\n",
"在整个训练过程中,我们希望监控训练误差和验证误差这两个数字。\n",
"较少的过拟合可能表明现有数据可以支撑一个更强大的模型,\n",
"较大的过拟合可能意味着我们可以通过正则化技术来获益。\n",
"\n",
"## [**提交Kaggle预测**]\n",
"\n",
"既然我们知道应该选择什么样的超参数,\n",
"我们不妨使用所有数据对其进行训练\n",
"(而不是仅使用交叉验证中使用的$1-1/K$的数据)。\n",
"然后,我们通过这种方式获得的模型可以应用于测试集。\n",
"将预测保存在CSV文件中可以简化将结果上传到Kaggle的过程。\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "568e9ca5",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:48.777939Z",
"iopub.status.busy": "2023-08-18T06:58:48.777525Z",
"iopub.status.idle": "2023-08-18T06:58:48.787742Z",
"shell.execute_reply": "2023-08-18T06:58:48.786661Z"
},
"origin_pos": 48,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def train_and_pred(train_features, test_features, train_labels, test_data,\n",
" num_epochs, lr, weight_decay, batch_size):\n",
" net = get_net()\n",
" train_ls, _ = train(net, train_features, train_labels, None, None,\n",
" num_epochs, lr, weight_decay, batch_size)\n",
" d2l.plot(np.arange(1, num_epochs + 1), [train_ls], xlabel='epoch',\n",
" ylabel='log rmse', xlim=[1, num_epochs], yscale='log')\n",
" print(f'训练log rmse{float(train_ls[-1]):f}')\n",
" # 将网络应用于测试集。\n",
" preds = net(test_features).detach().numpy()\n",
" # 将其重新格式化以导出到Kaggle\n",
" test_data['SalePrice'] = pd.Series(preds.reshape(1, -1)[0])\n",
" submission = pd.concat([test_data['Id'], test_data['SalePrice']], axis=1)\n",
" submission.to_csv('submission.csv', index=False)"
]
},
{
"cell_type": "markdown",
"id": "f311cb6c",
"metadata": {
"origin_pos": 49
},
"source": [
"如果测试集上的预测与$K$倍交叉验证过程中的预测相似,\n",
"那就是时候把它们上传到Kaggle了。\n",
"下面的代码将生成一个名为`submission.csv`的文件。\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "7fd14d5d",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:48.792405Z",
"iopub.status.busy": "2023-08-18T06:58:48.792013Z",
"iopub.status.idle": "2023-08-18T06:58:52.795733Z",
"shell.execute_reply": "2023-08-18T06:58:52.794625Z"
},
"origin_pos": 50,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"训练log rmse0.162354\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"257.521875pt\" height=\"180.65625pt\" viewBox=\"0 0 257.521875 180.65625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
" <metadata>\n",
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2023-08-18T06:58:52.719276</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
" <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
" </cc:Agent>\n",
" </dc:creator>\n",
" </cc:Work>\n",
" </rdf:RDF>\n",
" </metadata>\n",
" <defs>\n",
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 180.65625 \n",
"L 257.521875 180.65625 \n",
"L 257.521875 0 \n",
"L 0 0 \n",
"L 0 180.65625 \n",
"z\n",
"\" style=\"fill: none\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 45.478125 143.1 \n",
"L 240.778125 143.1 \n",
"L 240.778125 7.2 \n",
"L 45.478125 7.2 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <path d=\"M 82.959943 143.1 \n",
"L 82.959943 7.2 \n",
"\" clip-path=\"url(#p6b7c6bf10b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_2\">\n",
" <defs>\n",
" <path id=\"mfd1703beb3\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#mfd1703beb3\" x=\"82.959943\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- 20 -->\n",
" <g transform=\"translate(76.597443 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
"L 3431 531 \n",
"L 3431 0 \n",
"L 469 0 \n",
"L 469 531 \n",
"Q 828 903 1448 1529 \n",
"Q 2069 2156 2228 2338 \n",
"Q 2531 2678 2651 2914 \n",
"Q 2772 3150 2772 3378 \n",
"Q 2772 3750 2511 3984 \n",
"Q 2250 4219 1831 4219 \n",
"Q 1534 4219 1204 4116 \n",
"Q 875 4013 500 3803 \n",
"L 500 4441 \n",
"Q 881 4594 1212 4672 \n",
"Q 1544 4750 1819 4750 \n",
"Q 2544 4750 2975 4387 \n",
"Q 3406 4025 3406 3419 \n",
"Q 3406 3131 3298 2873 \n",
"Q 3191 2616 2906 2266 \n",
"Q 2828 2175 2409 1742 \n",
"Q 1991 1309 1228 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
"Q 1547 4250 1301 3770 \n",
"Q 1056 3291 1056 2328 \n",
"Q 1056 1369 1301 889 \n",
"Q 1547 409 2034 409 \n",
"Q 2525 409 2770 889 \n",
"Q 3016 1369 3016 2328 \n",
"Q 3016 3291 2770 3770 \n",
"Q 2525 4250 2034 4250 \n",
"z\n",
"M 2034 4750 \n",
"Q 2819 4750 3233 4129 \n",
"Q 3647 3509 3647 2328 \n",
"Q 3647 1150 3233 529 \n",
"Q 2819 -91 2034 -91 \n",
"Q 1250 -91 836 529 \n",
"Q 422 1150 422 2328 \n",
"Q 422 3509 836 4129 \n",
"Q 1250 4750 2034 4750 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_3\">\n",
" <path d=\"M 122.414489 143.1 \n",
"L 122.414489 7.2 \n",
"\" clip-path=\"url(#p6b7c6bf10b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use xlink:href=\"#mfd1703beb3\" x=\"122.414489\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- 40 -->\n",
" <g transform=\"translate(116.051989 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
"L 825 1625 \n",
"L 2419 1625 \n",
"L 2419 4116 \n",
"z\n",
"M 2253 4666 \n",
"L 3047 4666 \n",
"L 3047 1625 \n",
"L 3713 1625 \n",
"L 3713 1100 \n",
"L 3047 1100 \n",
"L 3047 0 \n",
"L 2419 0 \n",
"L 2419 1100 \n",
"L 313 1100 \n",
"L 313 1709 \n",
"L 2253 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_5\">\n",
" <path d=\"M 161.869034 143.1 \n",
"L 161.869034 7.2 \n",
"\" clip-path=\"url(#p6b7c6bf10b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_6\">\n",
" <g>\n",
" <use xlink:href=\"#mfd1703beb3\" x=\"161.869034\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 60 -->\n",
" <g transform=\"translate(155.506534 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \n",
"Q 1688 2584 1439 2293 \n",
"Q 1191 2003 1191 1497 \n",
"Q 1191 994 1439 701 \n",
"Q 1688 409 2113 409 \n",
"Q 2538 409 2786 701 \n",
"Q 3034 994 3034 1497 \n",
"Q 3034 2003 2786 2293 \n",
"Q 2538 2584 2113 2584 \n",
"z\n",
"M 3366 4563 \n",
"L 3366 3988 \n",
"Q 3128 4100 2886 4159 \n",
"Q 2644 4219 2406 4219 \n",
"Q 1781 4219 1451 3797 \n",
"Q 1122 3375 1075 2522 \n",
"Q 1259 2794 1537 2939 \n",
"Q 1816 3084 2150 3084 \n",
"Q 2853 3084 3261 2657 \n",
"Q 3669 2231 3669 1497 \n",
"Q 3669 778 3244 343 \n",
"Q 2819 -91 2113 -91 \n",
"Q 1303 -91 875 529 \n",
"Q 447 1150 447 2328 \n",
"Q 447 3434 972 4092 \n",
"Q 1497 4750 2381 4750 \n",
"Q 2619 4750 2861 4703 \n",
"Q 3103 4656 3366 4563 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-36\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_7\">\n",
" <path d=\"M 201.32358 143.1 \n",
"L 201.32358 7.2 \n",
"\" clip-path=\"url(#p6b7c6bf10b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use xlink:href=\"#mfd1703beb3\" x=\"201.32358\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 80 -->\n",
" <g transform=\"translate(194.96108 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \n",
"Q 1584 2216 1326 1975 \n",
"Q 1069 1734 1069 1313 \n",
"Q 1069 891 1326 650 \n",
"Q 1584 409 2034 409 \n",
"Q 2484 409 2743 651 \n",
"Q 3003 894 3003 1313 \n",
"Q 3003 1734 2745 1975 \n",
"Q 2488 2216 2034 2216 \n",
"z\n",
"M 1403 2484 \n",
"Q 997 2584 770 2862 \n",
"Q 544 3141 544 3541 \n",
"Q 544 4100 942 4425 \n",
"Q 1341 4750 2034 4750 \n",
"Q 2731 4750 3128 4425 \n",
"Q 3525 4100 3525 3541 \n",
"Q 3525 3141 3298 2862 \n",
"Q 3072 2584 2669 2484 \n",
"Q 3125 2378 3379 2068 \n",
"Q 3634 1759 3634 1313 \n",
"Q 3634 634 3220 271 \n",
"Q 2806 -91 2034 -91 \n",
"Q 1263 -91 848 271 \n",
"Q 434 634 434 1313 \n",
"Q 434 1759 690 2068 \n",
"Q 947 2378 1403 2484 \n",
"z\n",
"M 1172 3481 \n",
"Q 1172 3119 1398 2916 \n",
"Q 1625 2713 2034 2713 \n",
"Q 2441 2713 2670 2916 \n",
"Q 2900 3119 2900 3481 \n",
"Q 2900 3844 2670 4047 \n",
"Q 2441 4250 2034 4250 \n",
"Q 1625 4250 1398 4047 \n",
"Q 1172 3844 1172 3481 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-38\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_9\">\n",
" <path d=\"M 240.778125 143.1 \n",
"L 240.778125 7.2 \n",
"\" clip-path=\"url(#p6b7c6bf10b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use xlink:href=\"#mfd1703beb3\" x=\"240.778125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 100 -->\n",
" <g transform=\"translate(231.234375 157.698438)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
"L 1825 531 \n",
"L 1825 4091 \n",
"L 703 3866 \n",
"L 703 4441 \n",
"L 1819 4666 \n",
"L 2450 4666 \n",
"L 2450 531 \n",
"L 3481 531 \n",
"L 3481 0 \n",
"L 794 0 \n",
"L 794 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- epoch -->\n",
" <g transform=\"translate(127.9 171.376563)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
"L 3597 1613 \n",
"L 953 1613 \n",
"Q 991 1019 1311 708 \n",
"Q 1631 397 2203 397 \n",
"Q 2534 397 2845 478 \n",
"Q 3156 559 3463 722 \n",
"L 3463 178 \n",
"Q 3153 47 2828 -22 \n",
"Q 2503 -91 2169 -91 \n",
"Q 1331 -91 842 396 \n",
"Q 353 884 353 1716 \n",
"Q 353 2575 817 3079 \n",
"Q 1281 3584 2069 3584 \n",
"Q 2775 3584 3186 3129 \n",
"Q 3597 2675 3597 1894 \n",
"z\n",
"M 3022 2063 \n",
"Q 3016 2534 2758 2815 \n",
"Q 2500 3097 2075 3097 \n",
"Q 1594 3097 1305 2825 \n",
"Q 1016 2553 972 2059 \n",
"L 3022 2063 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
"L 1159 -1331 \n",
"L 581 -1331 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2969 \n",
"Q 1341 3281 1617 3432 \n",
"Q 1894 3584 2278 3584 \n",
"Q 2916 3584 3314 3078 \n",
"Q 3713 2572 3713 1747 \n",
"Q 3713 922 3314 415 \n",
"Q 2916 -91 2278 -91 \n",
"Q 1894 -91 1617 61 \n",
"Q 1341 213 1159 525 \n",
"z\n",
"M 3116 1747 \n",
"Q 3116 2381 2855 2742 \n",
"Q 2594 3103 2138 3103 \n",
"Q 1681 3103 1420 2742 \n",
"Q 1159 2381 1159 1747 \n",
"Q 1159 1113 1420 752 \n",
"Q 1681 391 2138 391 \n",
"Q 2594 391 2855 752 \n",
"Q 3116 1113 3116 1747 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
"Q 1497 3097 1228 2736 \n",
"Q 959 2375 959 1747 \n",
"Q 959 1119 1226 758 \n",
"Q 1494 397 1959 397 \n",
"Q 2419 397 2687 759 \n",
"Q 2956 1122 2956 1747 \n",
"Q 2956 2369 2687 2733 \n",
"Q 2419 3097 1959 3097 \n",
"z\n",
"M 1959 3584 \n",
"Q 2709 3584 3137 3096 \n",
"Q 3566 2609 3566 1747 \n",
"Q 3566 888 3137 398 \n",
"Q 2709 -91 1959 -91 \n",
"Q 1206 -91 779 398 \n",
"Q 353 888 353 1747 \n",
"Q 353 2609 779 3096 \n",
"Q 1206 3584 1959 3584 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-63\" d=\"M 3122 3366 \n",
"L 3122 2828 \n",
"Q 2878 2963 2633 3030 \n",
"Q 2388 3097 2138 3097 \n",
"Q 1578 3097 1268 2742 \n",
"Q 959 2388 959 1747 \n",
"Q 959 1106 1268 751 \n",
"Q 1578 397 2138 397 \n",
"Q 2388 397 2633 464 \n",
"Q 2878 531 3122 666 \n",
"L 3122 134 \n",
"Q 2881 22 2623 -34 \n",
"Q 2366 -91 2075 -91 \n",
"Q 1284 -91 818 406 \n",
"Q 353 903 353 1747 \n",
"Q 353 2603 823 3093 \n",
"Q 1294 3584 2113 3584 \n",
"Q 2378 3584 2631 3529 \n",
"Q 2884 3475 3122 3366 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-68\" d=\"M 3513 2113 \n",
"L 3513 0 \n",
"L 2938 0 \n",
"L 2938 2094 \n",
"Q 2938 2591 2744 2837 \n",
"Q 2550 3084 2163 3084 \n",
"Q 1697 3084 1428 2787 \n",
"Q 1159 2491 1159 1978 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 4863 \n",
"L 1159 4863 \n",
"L 1159 2956 \n",
"Q 1366 3272 1645 3428 \n",
"Q 1925 3584 2291 3584 \n",
"Q 2894 3584 3203 3211 \n",
"Q 3513 2838 3513 2113 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-65\"/>\n",
" <use xlink:href=\"#DejaVuSans-70\" x=\"61.523438\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"125\"/>\n",
" <use xlink:href=\"#DejaVuSans-63\" x=\"186.181641\"/>\n",
" <use xlink:href=\"#DejaVuSans-68\" x=\"241.162109\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_11\">\n",
" <path d=\"M 45.478125 64.080243 \n",
"L 240.778125 64.080243 \n",
"\" clip-path=\"url(#p6b7c6bf10b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"line2d_12\">\n",
" <defs>\n",
" <path id=\"m9893522615\" d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m9893522615\" x=\"45.478125\" y=\"64.080243\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- $\\mathdefault{10^{0}}$ -->\n",
" <g transform=\"translate(20.878125 67.879462)scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.765625)\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.765625)\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(128.203125 39.046875)scale(0.7)\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_13\">\n",
" <defs>\n",
" <path id=\"m17faa2ae29\" d=\"M 0 0 \n",
"L -2 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"128.567109\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_14\">\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"112.320956\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_15\">\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"100.794127\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_16\">\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"91.853224\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_6\">\n",
" <g id=\"line2d_17\">\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"84.547975\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_7\">\n",
" <g id=\"line2d_18\">\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"78.371474\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_8\">\n",
" <g id=\"line2d_19\">\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"73.021146\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_9\">\n",
" <g id=\"line2d_20\">\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"68.301822\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_10\">\n",
" <g id=\"line2d_21\">\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"36.307261\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_11\">\n",
" <g id=\"line2d_22\">\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"20.061109\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_12\">\n",
" <g id=\"line2d_23\">\n",
" <g>\n",
" <use xlink:href=\"#m17faa2ae29\" x=\"45.478125\" y=\"8.53428\" style=\"stroke: #000000; stroke-width: 0.6\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- log rmse -->\n",
" <g transform=\"translate(14.798438 96.88125)rotate(-90)scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-67\" d=\"M 2906 1791 \n",
"Q 2906 2416 2648 2759 \n",
"Q 2391 3103 1925 3103 \n",
"Q 1463 3103 1205 2759 \n",
"Q 947 2416 947 1791 \n",
"Q 947 1169 1205 825 \n",
"Q 1463 481 1925 481 \n",
"Q 2391 481 2648 825 \n",
"Q 2906 1169 2906 1791 \n",
"z\n",
"M 3481 434 \n",
"Q 3481 -459 3084 -895 \n",
"Q 2688 -1331 1869 -1331 \n",
"Q 1566 -1331 1297 -1286 \n",
"Q 1028 -1241 775 -1147 \n",
"L 775 -588 \n",
"Q 1028 -725 1275 -790 \n",
"Q 1522 -856 1778 -856 \n",
"Q 2344 -856 2625 -561 \n",
"Q 2906 -266 2906 331 \n",
"L 2906 616 \n",
"Q 2728 306 2450 153 \n",
"Q 2172 0 1784 0 \n",
"Q 1141 0 747 490 \n",
"Q 353 981 353 1791 \n",
"Q 353 2603 747 3093 \n",
"Q 1141 3584 1784 3584 \n",
"Q 2172 3584 2450 3431 \n",
"Q 2728 3278 2906 2969 \n",
"L 2906 3500 \n",
"L 3481 3500 \n",
"L 3481 434 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-20\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
"Q 2534 3019 2420 3045 \n",
"Q 2306 3072 2169 3072 \n",
"Q 1681 3072 1420 2755 \n",
"Q 1159 2438 1159 1844 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2956 \n",
"Q 1341 3275 1631 3429 \n",
"Q 1922 3584 2338 3584 \n",
"Q 2397 3584 2469 3576 \n",
"Q 2541 3569 2628 3553 \n",
"L 2631 2963 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6d\" d=\"M 3328 2828 \n",
"Q 3544 3216 3844 3400 \n",
"Q 4144 3584 4550 3584 \n",
"Q 5097 3584 5394 3201 \n",
"Q 5691 2819 5691 2113 \n",
"L 5691 0 \n",
"L 5113 0 \n",
"L 5113 2094 \n",
"Q 5113 2597 4934 2840 \n",
"Q 4756 3084 4391 3084 \n",
"Q 3944 3084 3684 2787 \n",
"Q 3425 2491 3425 1978 \n",
"L 3425 0 \n",
"L 2847 0 \n",
"L 2847 2094 \n",
"Q 2847 2600 2669 2842 \n",
"Q 2491 3084 2119 3084 \n",
"Q 1678 3084 1418 2786 \n",
"Q 1159 2488 1159 1978 \n",
"L 1159 0 \n",
"L 581 0 \n",
"L 581 3500 \n",
"L 1159 3500 \n",
"L 1159 2956 \n",
"Q 1356 3278 1631 3431 \n",
"Q 1906 3584 2284 3584 \n",
"Q 2666 3584 2933 3390 \n",
"Q 3200 3197 3328 2828 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
"L 2834 2853 \n",
"Q 2591 2978 2328 3040 \n",
"Q 2066 3103 1784 3103 \n",
"Q 1356 3103 1142 2972 \n",
"Q 928 2841 928 2578 \n",
"Q 928 2378 1081 2264 \n",
"Q 1234 2150 1697 2047 \n",
"L 1894 2003 \n",
"Q 2506 1872 2764 1633 \n",
"Q 3022 1394 3022 966 \n",
"Q 3022 478 2636 193 \n",
"Q 2250 -91 1575 -91 \n",
"Q 1294 -91 989 -36 \n",
"Q 684 19 347 128 \n",
"L 347 722 \n",
"Q 666 556 975 473 \n",
"Q 1284 391 1588 391 \n",
"Q 1994 391 2212 530 \n",
"Q 2431 669 2431 922 \n",
"Q 2431 1156 2273 1281 \n",
"Q 2116 1406 1581 1522 \n",
"L 1381 1569 \n",
"Q 847 1681 609 1914 \n",
"Q 372 2147 372 2553 \n",
"Q 372 3047 722 3315 \n",
"Q 1072 3584 1716 3584 \n",
"Q 2034 3584 2315 3537 \n",
"Q 2597 3491 2834 3397 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
" <use xlink:href=\"#DejaVuSans-67\" x=\"88.964844\"/>\n",
" <use xlink:href=\"#DejaVuSans-20\" x=\"152.441406\"/>\n",
" <use xlink:href=\"#DejaVuSans-72\" x=\"184.228516\"/>\n",
" <use xlink:href=\"#DejaVuSans-6d\" x=\"223.591797\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"321.003906\"/>\n",
" <use xlink:href=\"#DejaVuSans-65\" x=\"373.103516\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_24\">\n",
" <path d=\"M 45.478125 13.377273 \n",
"L 47.450852 21.96878 \n",
"L 49.42358 27.969951 \n",
"L 51.396307 32.78897 \n",
"L 53.369034 36.93004 \n",
"L 55.341761 40.632663 \n",
"L 57.314489 44.011315 \n",
"L 59.287216 47.158803 \n",
"L 61.259943 50.12599 \n",
"L 63.23267 52.933235 \n",
"L 65.205398 55.653071 \n",
"L 67.178125 58.255767 \n",
"L 69.150852 60.783924 \n",
"L 71.12358 63.234436 \n",
"L 73.096307 65.632446 \n",
"L 75.069034 67.981188 \n",
"L 77.041761 70.293433 \n",
"L 79.014489 72.54506 \n",
"L 80.987216 74.783504 \n",
"L 82.959943 76.9851 \n",
"L 84.93267 79.157084 \n",
"L 86.905398 81.298943 \n",
"L 88.878125 83.417442 \n",
"L 90.850852 85.52311 \n",
"L 92.82358 87.59844 \n",
"L 94.796307 89.659276 \n",
"L 96.769034 91.6786 \n",
"L 98.741761 93.693225 \n",
"L 100.714489 95.686095 \n",
"L 102.687216 97.637429 \n",
"L 104.659943 99.576667 \n",
"L 106.63267 101.47981 \n",
"L 108.605398 103.361015 \n",
"L 110.578125 105.219814 \n",
"L 112.550852 107.01625 \n",
"L 114.52358 108.820812 \n",
"L 116.496307 110.533904 \n",
"L 118.469034 112.208428 \n",
"L 120.441761 113.859549 \n",
"L 122.414489 115.44452 \n",
"L 124.387216 116.999057 \n",
"L 126.359943 118.452532 \n",
"L 128.33267 119.882089 \n",
"L 130.305398 121.227645 \n",
"L 132.278125 122.49977 \n",
"L 134.250852 123.742423 \n",
"L 136.22358 124.856799 \n",
"L 138.196307 125.932782 \n",
"L 140.169034 126.963676 \n",
"L 142.141761 127.884423 \n",
"L 144.114489 128.751806 \n",
"L 146.087216 129.533372 \n",
"L 148.059943 130.245346 \n",
"L 150.03267 130.918259 \n",
"L 152.005398 131.524266 \n",
"L 153.978125 132.07588 \n",
"L 155.950852 132.589024 \n",
"L 157.92358 133.021698 \n",
"L 159.896307 133.429464 \n",
"L 161.869034 133.766261 \n",
"L 163.841761 134.085871 \n",
"L 165.814489 134.358165 \n",
"L 167.787216 134.607122 \n",
"L 169.759943 134.831235 \n",
"L 171.73267 135.016223 \n",
"L 173.705398 135.182673 \n",
"L 175.678125 135.327551 \n",
"L 177.650852 135.451414 \n",
"L 179.62358 135.566072 \n",
"L 181.596307 135.658179 \n",
"L 183.569034 135.738529 \n",
"L 185.541761 135.806532 \n",
"L 187.514489 135.868957 \n",
"L 189.487216 135.927649 \n",
"L 191.459943 135.955127 \n",
"L 193.43267 136.006141 \n",
"L 195.405398 136.036597 \n",
"L 197.378125 136.075129 \n",
"L 199.350852 136.087808 \n",
"L 201.32358 136.110148 \n",
"L 203.296307 136.144095 \n",
"L 205.269034 136.156641 \n",
"L 207.241761 136.177392 \n",
"L 209.214489 136.210492 \n",
"L 211.187216 136.208791 \n",
"L 213.159943 136.272053 \n",
"L 215.13267 136.260874 \n",
"L 217.105398 136.310248 \n",
"L 219.078125 136.343643 \n",
"L 221.050852 136.358736 \n",
"L 223.02358 136.392941 \n",
"L 224.996307 136.468161 \n",
"L 226.969034 136.495415 \n",
"L 228.941761 136.541541 \n",
"L 230.914489 136.590595 \n",
"L 232.887216 136.682678 \n",
"L 234.859943 136.757496 \n",
"L 236.83267 136.769804 \n",
"L 238.805398 136.866148 \n",
"L 240.778125 136.922727 \n",
"\" clip-path=\"url(#p6b7c6bf10b)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 45.478125 143.1 \n",
"L 45.478125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 240.778125 143.1 \n",
"L 240.778125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 45.478125 143.1 \n",
"L 240.778125 143.1 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 45.478125 7.2 \n",
"L 240.778125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p6b7c6bf10b\">\n",
" <rect x=\"45.478125\" y=\"7.2\" width=\"195.3\" height=\"135.9\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<Figure size 252x180 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"train_and_pred(train_features, test_features, train_labels, test_data,\n",
" num_epochs, lr, weight_decay, batch_size)"
]
},
{
"cell_type": "markdown",
"id": "5efcf208",
"metadata": {
"origin_pos": 51
},
"source": [
"接下来,如 :numref:`fig_kaggle_submit2`中所示,\n",
"我们可以提交预测到Kaggle上,并查看在测试集上的预测与实际房价(标签)的比较情况。\n",
"步骤非常简单。\n",
"\n",
"* 登录Kaggle网站,访问房价预测竞赛页面。\n",
"* 点击“Submit Predictions”或“Late Submission”按钮(在撰写本文时,该按钮位于右侧)。\n",
"* 点击页面底部虚线框中的“Upload Submission File”按钮,选择要上传的预测文件。\n",
"* 点击页面底部的“Make Submission”按钮,即可查看结果。\n",
"\n",
"![向Kaggle提交数据](../img/kaggle-submit2.png)\n",
":width:`400px`\n",
":label:`fig_kaggle_submit2`\n",
"\n",
"## 小结\n",
"\n",
"* 真实数据通常混合了不同的数据类型,需要进行预处理。\n",
"* 常用的预处理方法:将实值数据重新缩放为零均值和单位方法;用均值替换缺失值。\n",
"* 将类别特征转化为指标特征,可以使我们把这个特征当作一个独热向量来对待。\n",
"* 我们可以使用$K$折交叉验证来选择模型并调整超参数。\n",
"* 对数对于相对误差很有用。\n",
"\n",
"## 练习\n",
"\n",
"1. 把预测提交给Kaggle,它有多好?\n",
"1. 能通过直接最小化价格的对数来改进模型吗?如果试图预测价格的对数而不是价格,会发生什么?\n",
"1. 用平均值替换缺失值总是好主意吗?提示:能构造一个不随机丢失值的情况吗?\n",
"1. 通过$K$折交叉验证调整超参数,从而提高Kaggle的得分。\n",
"1. 通过改进模型(例如,层、权重衰减和dropout)来提高分数。\n",
"1. 如果我们没有像本节所做的那样标准化连续的数值特征,会发生什么?\n"
]
},
{
"cell_type": "markdown",
"id": "5198d6f0",
"metadata": {
"origin_pos": 53,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/1824)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}