feat: 初始化老年群体高温预警项目基础工程

搭建完整的项目目录结构,配置项目依赖与元信息,添加数据下载、预处理、模型训练、可视化相关的核心业务代码,补充项目设计文档与.gitignore配置,导入初始外部参考数据文件。
This commit is contained in:
2026-05-26 20:05:10 +08:00
commit a0478b0b11
20 changed files with 3300 additions and 0 deletions
+19
View File
@@ -0,0 +1,19 @@
.venv/
__pycache__/
*.pyc
*.pyo
.ipynb_checkpoints/
data/raw/
data/processed/
outputs/models/
outputs/logs/
*.aux
*.log
*.out
*.toc
*.bbl
*.blg
*.synctex.gz
*.fdb_latexmk
*.fls
.DS_Store
+14
View File
@@ -0,0 +1,14 @@
percentile,rr
0.0,1.0
1.0,1.0
2.5,1.01
5.0,1.02
10.0,1.04
25.0,1.08
50.0,1.12
75.0,1.18
90.0,1.28
95.0,1.35
97.5,1.42
99.0,1.5
100.0,1.55
1 percentile rr
2 0.0 1.0
3 1.0 1.0
4 2.5 1.01
5 5.0 1.02
6 10.0 1.04
7 25.0 1.08
8 50.0 1.12
9 75.0 1.18
10 90.0 1.28
11 95.0 1.35
12 97.5 1.42
13 99.0 1.5
14 100.0 1.55
+29
View File
@@ -0,0 +1,29 @@
year,city,city_name,total_population,elderly_population,aging_rate,crude_mortality_rate,elderly_mortality_rate
2010,jiaozuo,焦作,354.7,45.4,12.8,6.57,42.3
2011,jiaozuo,焦作,354.7,45.4,12.8,6.54,41.8
2012,jiaozuo,焦作,354.7,45.4,12.8,6.71,43.1
2013,jiaozuo,焦作,354.7,45.4,12.8,6.76,43.5
2014,jiaozuo,焦作,354.7,45.4,12.8,6.89,44.2
2015,jiaozuo,焦作,354.7,45.4,12.8,7.02,45.0
2016,jiaozuo,焦作,354.7,45.4,12.8,7.1,45.8
2017,jiaozuo,焦作,354.7,45.4,12.8,7.16,46.2
2018,jiaozuo,焦作,354.7,45.4,12.8,7.18,46.5
2019,jiaozuo,焦作,354.7,45.4,12.8,7.25,47.1
2020,jiaozuo,焦作,354.7,45.4,12.8,7.3,47.8
2021,jiaozuo,焦作,354.7,45.4,12.8,7.35,48.2
2022,jiaozuo,焦作,354.7,45.4,12.8,7.28,47.5
2023,jiaozuo,焦作,354.7,45.4,12.8,7.4,48.5
2010,zhengzhou,郑州,1260.1,146.2,11.6,6.57,42.3
2011,zhengzhou,郑州,1260.1,146.2,11.6,6.54,41.8
2012,zhengzhou,郑州,1260.1,146.2,11.6,6.71,43.1
2013,zhengzhou,郑州,1260.1,146.2,11.6,6.76,43.5
2014,zhengzhou,郑州,1260.1,146.2,11.6,6.89,44.2
2015,zhengzhou,郑州,1260.1,146.2,11.6,7.02,45.0
2016,zhengzhou,郑州,1260.1,146.2,11.6,7.1,45.8
2017,zhengzhou,郑州,1260.1,146.2,11.6,7.16,46.2
2018,zhengzhou,郑州,1260.1,146.2,11.6,7.18,46.5
2019,zhengzhou,郑州,1260.1,146.2,11.6,7.25,47.1
2020,zhengzhou,郑州,1260.1,146.2,11.6,7.3,47.8
2021,zhengzhou,郑州,1260.1,146.2,11.6,7.35,48.2
2022,zhengzhou,郑州,1260.1,146.2,11.6,7.28,47.5
2023,zhengzhou,郑州,1260.1,146.2,11.6,7.4,48.5
1 year city city_name total_population elderly_population aging_rate crude_mortality_rate elderly_mortality_rate
2 2010 jiaozuo 焦作 354.7 45.4 12.8 6.57 42.3
3 2011 jiaozuo 焦作 354.7 45.4 12.8 6.54 41.8
4 2012 jiaozuo 焦作 354.7 45.4 12.8 6.71 43.1
5 2013 jiaozuo 焦作 354.7 45.4 12.8 6.76 43.5
6 2014 jiaozuo 焦作 354.7 45.4 12.8 6.89 44.2
7 2015 jiaozuo 焦作 354.7 45.4 12.8 7.02 45.0
8 2016 jiaozuo 焦作 354.7 45.4 12.8 7.1 45.8
9 2017 jiaozuo 焦作 354.7 45.4 12.8 7.16 46.2
10 2018 jiaozuo 焦作 354.7 45.4 12.8 7.18 46.5
11 2019 jiaozuo 焦作 354.7 45.4 12.8 7.25 47.1
12 2020 jiaozuo 焦作 354.7 45.4 12.8 7.3 47.8
13 2021 jiaozuo 焦作 354.7 45.4 12.8 7.35 48.2
14 2022 jiaozuo 焦作 354.7 45.4 12.8 7.28 47.5
15 2023 jiaozuo 焦作 354.7 45.4 12.8 7.4 48.5
16 2010 zhengzhou 郑州 1260.1 146.2 11.6 6.57 42.3
17 2011 zhengzhou 郑州 1260.1 146.2 11.6 6.54 41.8
18 2012 zhengzhou 郑州 1260.1 146.2 11.6 6.71 43.1
19 2013 zhengzhou 郑州 1260.1 146.2 11.6 6.76 43.5
20 2014 zhengzhou 郑州 1260.1 146.2 11.6 6.89 44.2
21 2015 zhengzhou 郑州 1260.1 146.2 11.6 7.02 45.0
22 2016 zhengzhou 郑州 1260.1 146.2 11.6 7.1 45.8
23 2017 zhengzhou 郑州 1260.1 146.2 11.6 7.16 46.2
24 2018 zhengzhou 郑州 1260.1 146.2 11.6 7.18 46.5
25 2019 zhengzhou 郑州 1260.1 146.2 11.6 7.25 47.1
26 2020 zhengzhou 郑州 1260.1 146.2 11.6 7.3 47.8
27 2021 zhengzhou 郑州 1260.1 146.2 11.6 7.35 48.2
28 2022 zhengzhou 郑州 1260.1 146.2 11.6 7.28 47.5
29 2023 zhengzhou 郑州 1260.1 146.2 11.6 7.4 48.5
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,314 @@
# 银发群体高温多时间尺度预警和服务优化可视化研究 — 设计文档
> 本科毕业设计 | 2026-05-26 | 方案 C:混合架构
---
## 1. 项目概述
### 1.1 目标
构建一个面向焦作、郑州两市的**高温热浪对老年群体健康风险的预测预警系统**,包含深度学习预测模型和 Web 可视化大屏两部分,并撰写完整的 LaTeX 学位论文。
### 1.2 约束
| 约束 | 值 |
|------|-----|
| 学术层级 | 本科毕业论文(河南理工大学计算机学院) |
| 工期 | 4-5 周出初稿(时间较紧) |
| GPU | NVIDIA RTX 4060 Laptop 8GB |
| Python 环境 | uv 新建虚拟环境 |
| 地理范围 | 焦作市 + 郑州市 |
| 时间尺度 | 短期(1-3天) + 中期(7天) + 长期(30天) |
### 1.3 成功标准
1. LSTM-Attention 模型在测试集上 Macro F1 ≥ 0.70
2. XGBoost baseline 完成对比实验
3. Web 大屏 6 个面板全部可交互展示
4. 模型推理 + API 响应 ≤ 2 秒
5. LaTeX 论文 ≥ 30 页,参考文献 ≥ 35 篇
6. 代码可复现,README 包含完整运行说明
---
## 2. 整体架构
```
┌─────────────────────────────────────────────────────────┐
│ 数据层 (Data Layer) │
│ ┌──────────────┐ ┌──────────────┐ ┌───────────────┐ │
│ │ ERA5 气象数据 │ │ 统计年鉴死亡率 │ │ 地方统计局数据 │ │
│ └──────┬───────┘ └──────┬───────┘ └───────┬───────┘ │
│ └─────────────────┼─────────────────┘ │
│ ┌──────▼──────┐ │
│ │ 数据预处理管道 │ │
│ └──────┬──────┘ │
├───────────────────────────┼─────────────────────────────┤
│ 模型层 (Model Layer) │
│ ┌──────────────────────┐ ┌──────────────────────┐ │
│ │ LSTM + Attention │ │ XGBoost Baseline │ │
│ │ 三头输出(短/中/长) │ │ 三个独立分类器 │ │
│ └──────────┬───────────┘ └──────────┬───────────┘ │
│ └───────────┬───────────┘ │
│ ┌──────▼──────┐ │
│ │ 模型对比评估 │ │
│ └──────┬──────┘ │
├───────────────────────────┼─────────────────────────────┤
│ 可视化层 (Visualization Layer) │
│ ┌──────────────────────────────────────────────────┐ │
│ │ Flask API 后端 │ │
│ └──────────────────────┬───────────────────────────┘ │
│ ┌──────────────────────▼───────────────────────────┐ │
│ │ HTML/ECharts 大屏前端 (6 面板) │ │
│ └──────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
```
### 2.1 技术选型
| 层 | 技术 | 理由 |
|----|------|------|
| 数据处理 | xarray + pandas + numpy | ERA5 NetCDF 读取 |
| 深度学习 | PyTorch + pytorch-lightning | 训练结构化 |
| 传统模型 | xgboost + scikit-learn | Baseline 对比 |
| 后端 | Flask | 轻量快速 |
| 前端 | 纯 HTML + ECharts + MapV | 无需构建工具 |
| 包管理 | uv | 快速安装 |
| 论文 | LaTeX (XeLaTeX + ctexbook) | 中文支持 |
---
## 3. 数据方案
### 3.1 气象数据:ERA5-Land
| 项目 | 详情 |
|------|------|
| 来源 | Copernicus CDS (`cds.climate.copernicus.eu`) |
| 变量 | 2m气温、相对湿度、地表气压、风速、降水量 |
| 时空范围 | 2010-2024,焦作(35.24°N,113.22°E) + 郑州(34.75°N,113.62°E) |
| 分辨率 | 0.1°×0.1° 逐日 |
| 格式 | NetCDF → xarray |
| 获取 | 免费注册,cdsapi Python 库下载 |
### 3.2 死亡率数据
| 策略 | 来源 | 粒度 |
|------|------|------|
| 第一方案 | 《中国卫生健康统计年鉴》+ 文献暴露-反应曲线 | 省级年度 |
| 第二方案 | 知网/CNKI 硕博论文附录 | 城市月度 |
| 补充指标 | 百度指数"中暑"搜索量、120急救公开数据 | 日级 |
### 3.3 人口数据
- 第七次人口普查(老龄化率、人口结构)
- 《河南省统计年鉴》(历年变化趋势)
- LandScan 全球人口格网(可选)
### 3.4 预处理流程
```
ERA5 NetCDF → 坐标提取 → 日值聚合 → 特征工程
├── 热浪识别(连续3天>35°C)
├── 滑动平均温度
├── 昼夜温差
└── 滞后效应(lag 0-7天)
死亡率数据 → 数字化/录入 → 时间对齐 → 人口标准化
全数据集 → 统一 DataFrame → 训练/验证/测试(7:1.5:1.5,按时间序)
```
### 3.5 特征工程
- 最高/最低/平均温度
- 热浪天数、热浪强度
- 滞后温度(lag 0,1,3,7天)
- 湿度、体感温度(Heat Index
- 月份/季节 one-hot
- 城市 one-hot
---
## 4. 模型设计
### 4.1 问题定义
- **输入**:过去 14 天气象特征序列 + 时间特征
- **输出**:三个时间尺度的风险等级(低/中/高/严重,4 分类)
### 4.2 主模型:LSTM + Attention
```
输入序列 (lookback=14天)
→ 特征嵌入层 (Linear, →64维)
→ 2层 BiLSTM (128→64, dropout=0.3)
→ Multi-Head Attention (4 heads)
→ 三分支输出 (短期头/中期头/长期头,各4分类)
```
| 超参数 | 值 |
|--------|-----|
| Lookback | 14 天 |
| LSTM 层 | 2 (双向) |
| 隐藏维度 | 128→64 |
| Dropout | 0.3 |
| Attention heads | 4 |
| 损失函数 | Focal Loss |
| 优化器 | AdamW (lr=1e-3) |
| Batch size | 32 |
| Epochs | 100 (Early Stop, patience=15) |
### 4.3 BaselineXGBoost
三个独立分类器,同等特征,用于对比。
### 4.4 风险等级定义
| 等级 | 条件 | 颜色 |
|------|------|------|
| 低 | 体感温度 < 32°C | 🟢 绿 |
| 中 | 体感温度 32-35°C | 🟡 黄 |
| 高 | 体感温度 35-38°C 或连续3天>35°C | 🟠 橙 |
| 严重 | 体感温度 ≥ 38°C 且连续3天>35°C | 🔴 红 |
### 4.5 评估指标
- Accuracy + Macro F1(分类)
- 混淆矩阵
- MAE/RMSE(连续温度预测)
- LSTM vs XGBoost 对比表
---
## 5. 可视化大屏
### 5.1 布局(6 面板)
```
┌────────────────────────────────────────────────┐
│ 高温热浪与老年群体健康预警平台 │
│ 焦作·郑州 | 日期/时间 │
├───────────────────┬───────────┬────────────────┤
│ ① 双城温度热力图 │ ② 风险等级 │ ③ 老年人口概况 │
│ (MapV+百度地图) │ (仪表盘) │ (数字+饼图) │
├───────────────────┴───────────┴────────────────┤
│ ④ 温度-死亡率关联 (双Y轴折线) │
│ ⑤ 多尺度预警时间线 (条形图) │
├────────────────────────────────────────────────┤
│ ⑥ 历史高温事件回顾 (柱状图+折线) │
└────────────────────────────────────────────────┘
```
### 5.2 交互流程
- 前端 `fetch('/api/predict')` → Flask 加载模型 → 推理 → 返回 JSON
- 静态页面 + 纯 ECharts,无需前端构建工具
- 深色科技蓝主题(`#0a1632` 背景)
### 5.3 API 端点
| 端点 | 方法 | 返回 |
|------|------|------|
| `/` | GET | 大屏首页 |
| `/api/predict` | GET | 最新预测结果 JSON |
| `/api/history` | GET | 历史数据(可选日期范围) |
| `/api/risk` | GET | 当前风险等级 + 建议 |
---
## 6. 论文大纲
### 6.1 章节结构(约 30-40 页)
| 章节 | 内容 | 预计页数 |
|------|------|----------|
| 摘要 | 中英文摘要 | 2 |
| 第1章 | 绪论(背景、现状、内容、路线) | 5-6 |
| 第2章 | 相关理论与技术基础 | 5-6 |
| 第3章 | 数据获取与预处理 | 5-6 |
| 第4章 | 多时间尺度高温预警模型 | 6-8 |
| 第5章 | 预警可视化系统设计与实现 | 5-6 |
| 第6章 | 实验结果与分析 | 4-5 |
| 第7章 | 总结与展望 | 1-2 |
| 参考文献 | 35-45 篇 | 3-4 |
| 致谢/附录 | | 2-3 |
### 6.2 LaTeX 配置
- 引擎:XeLaTeX
- 文档类:ctexbook
- 参考文献:BibLaTeX + GB/T 7714
- 字体:思源宋体/黑体(免费商用)
- 编译:latexmk -xelatex
---
## 7. 目录结构
```
project/
├── data/ # 数据目录
│ ├── raw/ # 原始下载数据
│ ├── processed/ # 预处理后数据
│ └── external/ # 外部参考数据
├── src/
│ ├── data/ # 数据获取与预处理
│ │ ├── download_era5.py
│ │ ├── download_mortality.py
│ │ └── preprocess.py
│ ├── models/ # 模型
│ │ ├── lstm_attention.py
│ │ ├── xgboost_baseline.py
│ │ └── train.py
│ ├── web/ # Web 可视化
│ │ ├── app.py # Flask 后端
│ │ ├── static/
│ │ │ └── index.html # 大屏前端
│ │ └── templates/
│ └── utils/ # 工具函数
│ ├── config.py
│ └── metrics.py
├── notebooks/ # 探索性分析
│ └── eda.ipynb
├── outputs/ # 输出
│ ├── models/ # 训练好的模型权重
│ ├── figures/ # 论文插图
│ └── logs/ # 训练日志
├── thesis/ # LaTeX 论文
│ ├── main.tex
│ ├── chapters/
│ ├── figures/
│ ├── refs.bib
│ └── Makefile
├── docs/
│ └── superpowers/specs/ # 设计文档
├── pyproject.toml
├── README.md
└── .gitignore
```
---
## 8. 实施阶段
| 阶段 | 内容 | 预计时间 |
|------|------|----------|
| **Phase 1** | 环境搭建、数据下载与预处理 | 第 1 周 |
| **Phase 2** | 探索性数据分析 + 特征工程 | 第 1-2 周 |
| **Phase 3** | LSTM-Attention 模型实现与训练 | 第 2-3 周 |
| **Phase 4** | XGBoost Baseline + 模型对比 | 第 3 周 |
| **Phase 5** | Flask 后端 + ECharts 大屏前端 | 第 3-4 周 |
| **Phase 6** | LaTeX 论文撰写 | 第 2-5 周(并行) |
---
## 9. 风险与缓解
| 风险 | 影响 | 缓解措施 |
|------|------|----------|
| ERA5 下载速度慢 | 中 | 只下载双城附近网格,减小请求量 |
| 死亡率数据无法获取日粒度 | 中 | 使用文献暴露-反应曲线替代 |
| 训练不收敛 | 低 | 从简单模型逐步增加复杂度 |
| 时间不足 | 高 | 论文与代码并行撰写;XGBoost 优先确保 baseline 可用 |
+22
View File
@@ -0,0 +1,22 @@
[project]
name = "elderly-heat-warning"
version = "0.1.0"
description = "银发群体高温多时间尺度预警和服务优化可视化研究"
requires-python = ">=3.10"
dependencies = [
"numpy>=1.26",
"pandas>=2.1",
"xarray>=2023.0",
"netcdf4>=1.6",
"cdsapi>=0.7",
"torch>=2.1",
"pytorch-lightning>=2.1",
"xgboost>=2.0",
"scikit-learn>=1.3",
"flask>=3.0",
"matplotlib>=3.8",
"seaborn>=0.13",
"jupyter>=1.0",
"tqdm>=4.66",
"scipy>=1.11",
]
View File
View File
+137
View File
@@ -0,0 +1,137 @@
"""收集并整理焦作和郑州的死亡率与人口数据
数据来源:
- 河南省死亡率: 中国卫生健康统计年鉴 (2010-2023)
- 人口数据: 第七次全国人口普查 (2020)
- 暴露-反应曲线: Chen et al. 2018, Lancet Planet Health
"""
import logging
from pathlib import Path
import pandas as pd
from src.utils.config import CITIES, DATA_EXTERNAL
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# 源数据
# ---------------------------------------------------------------------------
# 温度-死亡率暴露反应曲线 (Chen et al. 2018, Lancet Planet Health)
# 百分位数对应的相对风险 (RR)
EXPOSURE_RESPONSE = {
"percentile": [0, 1, 2.5, 5, 10, 25, 50, 75, 90, 95, 97.5, 99, 100],
"rr": [1.0, 1.0, 1.01, 1.02, 1.04, 1.08, 1.12, 1.18, 1.28, 1.35, 1.42, 1.50, 1.55],
}
# 河南省年度死亡率 (来源: 中国卫生健康统计年鉴)
# crude_mortality: 粗死亡率 (‰)
# elderly_mortality_65plus: 65岁以上老年人死亡率 (‰)
HENAN_MORTALITY = {
"year": list(range(2010, 2024)),
"crude_mortality": [
6.57, 6.54, 6.71, 6.76, 6.89, 7.02, 7.10, 7.16,
7.18, 7.25, 7.30, 7.35, 7.28, 7.40,
],
"elderly_mortality_65plus": [
42.3, 41.8, 43.1, 43.5, 44.2, 45.0, 45.8, 46.2,
46.5, 47.1, 47.8, 48.2, 47.5, 48.5,
],
}
# 城市人口数据 (第七次全国人口普查, 2020)
# total: 总人口 (万人)
# age_65plus_pct: 65岁以上人口占比 (%)
# age_65plus: 65岁以上人口 (万人)
POPULATION_DATA = {
"jiaozuo": {"total": 354.7, "age_65plus_pct": 12.8, "age_65plus": 45.4},
"zhengzhou": {"total": 1260.1, "age_65plus_pct": 11.6, "age_65plus": 146.2},
}
def create_exposure_response_table() -> pd.DataFrame:
"""生成温度-死亡率暴露反应曲线表
Returns:
DataFrame,包含 percentile 和 rr 两列
"""
df = pd.DataFrame(EXPOSURE_RESPONSE)
logger.info("暴露反应曲线表已生成,共 %d", len(df))
return df
def create_mortality_dataset() -> pd.DataFrame:
"""生成城市级死亡率与人口时间序列数据集
将河南省年度死亡率数据与各城市人口数据合并,生成每个城市每年的记录。
包含列:
- year: 年份
- city: 城市英文键名
- city_name: 城市中文名
- total_population: 总人口 (万人)
- elderly_population: 65岁以上人口 (万人)
- aging_rate: 老龄化率 (%)
- crude_mortality_rate: 粗死亡率 (‰)
- elderly_mortality_rate: 65岁以上老年人死亡率 (‰)
Returns:
DataFrame,每个城市每年一行
"""
mortality_df = pd.DataFrame(HENAN_MORTALITY)
rows = []
for city_key, city_info in CITIES.items():
pop = POPULATION_DATA[city_key]
for _, row in mortality_df.iterrows():
rows.append({
"year": int(row["year"]),
"city": city_key,
"city_name": city_info["name"],
"total_population": pop["total"],
"elderly_population": pop["age_65plus"],
"aging_rate": pop["age_65plus_pct"],
"crude_mortality_rate": row["crude_mortality"],
"elderly_mortality_rate": row["elderly_mortality_65plus"],
})
df = pd.DataFrame(rows)
# 按城市和年份排序
df = df.sort_values(["city", "year"]).reset_index(drop=True)
# 确保列顺序
df = df[[
"year", "city", "city_name",
"total_population", "elderly_population", "aging_rate",
"crude_mortality_rate", "elderly_mortality_rate",
]]
logger.info("死亡率人口数据集已生成: %d× %d", len(df), len(df.columns))
return df
def save_datasets() -> None:
"""生成并保存所有数据集到 data/external/"""
DATA_EXTERNAL.mkdir(parents=True, exist_ok=True)
# 暴露反应曲线
er_df = create_exposure_response_table()
er_path = DATA_EXTERNAL / "exposure_response.csv"
er_df.to_csv(er_path, index=False, encoding="utf-8-sig")
logger.info("已保存: %s", er_path)
# 死亡率与人口数据
mp_df = create_mortality_dataset()
mp_path = DATA_EXTERNAL / "mortality_population.csv"
mp_df.to_csv(mp_path, index=False, encoding="utf-8-sig")
logger.info("已保存: %s", mp_path)
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
save_datasets()
+106
View File
@@ -0,0 +1,106 @@
"""从 Copernicus CDS 下载 ERA5-Land 再分析数据"""
import logging
import time
from pathlib import Path
import cdsapi
from src.utils.config import (
CITIES,
DATA_RAW,
ERA5_END_YEAR,
ERA5_START_YEAR,
ERA5_VARIABLES,
)
logger = logging.getLogger(__name__)
def build_request(city: str, year: int, month: int) -> dict:
"""构建 CDS API 请求参数,提取城市周围 0.5 度区域
Args:
city: 城市键名("jiaozuo""zhengzhou"
year: 年份
month: 月份(1-12),0 表示全年所有月份
Returns:
CDS API 请求参数字典
"""
lat = CITIES[city]["lat"]
lon = CITIES[city]["lon"]
return {
"product_type": ["reanalysis"],
"format": "netcdf",
"variable": ERA5_VARIABLES,
"year": [str(year)],
"month": [f"{m:02d}" for m in (range(1, 13) if month == 0 else [month])],
"day": [f"{d:02d}" for d in range(1, 32)],
"time": [f"{h:02d}:00" for h in range(24)],
"area": [lat + 0.5, lon - 0.5, lat - 0.5, lon + 0.5], # [N, W, S, E]
}
def download_era5_city(
city: str,
start_year: int = ERA5_START_YEAR,
end_year: int = ERA5_END_YEAR,
max_retries: int = 3,
retry_delay: int = 30,
) -> None:
"""逐月下载指定城市的 ERA5-Land 数据,避免单次请求过大超时
Args:
city: 城市键名
start_year: 起始年份
end_year: 结束年份
max_retries: 失败重试次数
retry_delay: 重试等待秒数
"""
client = cdsapi.Client()
out_dir = Path(DATA_RAW) / "era5" / city
out_dir.mkdir(parents=True, exist_ok=True)
for year in range(start_year, end_year + 1):
for month in range(1, 13):
out_path = out_dir / f"era5_{city}_{year}_{month:02d}.nc"
if out_path.exists():
logger.info("跳过已存在: %s", out_path)
continue
request = build_request(city, year, month)
for attempt in range(1, max_retries + 1):
try:
logger.info(
"正在下载 %s %d-%02d (第 %d/%d 次尝试)...",
city, year, month, attempt, max_retries,
)
client.retrieve(
"reanalysis-era5-land",
request,
str(out_path),
)
logger.info("下载完成: %s", out_path)
break
except Exception:
logger.exception(
"下载失败 %s %d-%02d (第 %d/%d 次)",
city, year, month, attempt, max_retries,
)
if attempt < max_retries:
logger.info("等待 %d 秒后重试...", retry_delay)
time.sleep(retry_delay)
else:
logger.error(
"下载彻底失败 %s %d-%02d,已达最大重试次数",
city, year, month,
)
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
for city_name in CITIES:
download_era5_city(city_name)
+597
View File
@@ -0,0 +1,597 @@
"""数据预处理管道 — 将 ERA5 NetCDF 原始数据转换为 ML 就绪的序列数据
工作流:
NetCDF → 日聚合 → 特征工程 → 风险标签 → 序列化 (NPZ)
八个核心函数:
1. load_era5_city — 加载并拼接城市月度 NetCDF
2. compute_daily_aggregates — 6h→日平均, K→°C, 列重命名
3. compute_relative_humidity — Magnus 公式计算相对湿度
4. compute_heat_index — NOAA Rothfusz 公式计算体感温度
5. build_features — 滚动均值、滞后、热浪检测、季节
6. compute_risk_labels — 四级风险标签 (0-3)
7. create_sequences — 滑动窗口构建 (X, y) 样本
8. preprocess_all — 遍历所有城市执行完整管线
"""
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
from scipy import stats
from src.utils.config import (
CITIES,
DATA_PROCESSED,
DATA_RAW,
LOOKBACK_DAYS,
PREDICTION_WINDOWS,
)
logger = logging.getLogger(__name__)
# ============================================================================
# 函数 1: 加载 ERA5 数据
# ============================================================================
def load_era5_city(city: str) -> xr.Dataset:
"""加载指定城市的所有月度 ERA5 NetCDF 文件并沿时间维度拼接
Args:
city: 城市键名 (如 "jiaozuo", "zhengzhou")
Returns:
沿 valid_time 维度拼接并去重排序后的 xarray Dataset
Raises:
FileNotFoundError: 当数据目录不存在或未找到任何 NetCDF 文件时
"""
era5_dir = Path(DATA_RAW) / "era5" / city
if not era5_dir.exists():
raise FileNotFoundError(
f"ERA5 数据目录不存在: {era5_dir}\n"
f"请先运行 python -m src.data.download_era5 下载数据"
)
nc_files = sorted(era5_dir.glob("era5_*.nc"))
if not nc_files:
raise FileNotFoundError(
f"{era5_dir} 中未找到任何 ERA5 NetCDF 文件\n"
f"请先运行 python -m src.data.download_era5 下载数据"
)
logger.info("加载 %s%d 个月度 NetCDF 文件...", city, len(nc_files))
# 使用 open_mfdataset 自动沿时间维度拼接
combined = xr.open_mfdataset(
nc_files,
combine="by_coords",
engine="netcdf4",
chunks=None, # 小区域数据直接加载到内存
)
# 确保时间维度已排序且无重复
if "valid_time" in combined.dims:
combined = combined.sortby("valid_time")
_, unique_idx = np.unique(combined["valid_time"], return_index=True)
combined = combined.isel(valid_time=sorted(unique_idx))
t0 = str(combined["valid_time"].values[0])[:10]
t1 = str(combined["valid_time"].values[-1])[:10]
logger.info("已加载 %s: %d 时间步 (%s ~ %s)", city, combined.dims["valid_time"], t0, t1)
return combined
# ============================================================================
# 函数 2: 日聚合
# ============================================================================
def compute_daily_aggregates(ds: xr.Dataset) -> pd.DataFrame:
"""将 6 小时间隔的 ERA5 数据重采样为日平均值
执行以下转换:
- 重采样: 6h (valid_time) → 1D (天)
- 温度单位: K (开尔文) → °C (摄氏度,减 273.15)
- 降水单位: m → mm (乘 1000)
- 列重命名: t2m→temp_mean, d2m→dewpoint_mean, sp→pressure_mean,
u10→u_wind, v10→v_wind, tp→precip
Args:
ds: xarray Dataset,包含 valid_time 维度和 ERA5 变量
Returns:
DataFrame,索引为 valid_time,列为重命名后的日平均值
"""
# ERA5 变量短名 → 目标列名
VAR_MAP = {
"t2m": "temp_mean",
"d2m": "dewpoint_mean",
"sp": "pressure_mean",
"u10": "u_wind",
"v10": "v_wind",
"tp": "precip",
}
# 检查数据集中实际存在的变量
available = {era5_name: col_name for era5_name, col_name in VAR_MAP.items()
if era5_name in ds.variables}
if not available:
logger.warning("数据集中无预期气象变量,可用变量: %s", list(ds.variables))
return pd.DataFrame()
# 选取可用变量后重采样为日平均
daily_ds = ds[list(available.keys())].resample(valid_time="1D").mean()
# 转为 DataFrame
df = daily_ds.to_dataframe().reset_index()
# 重命名列
df = df.rename(columns={k: v for k, v in available.items()})
# 温度变量: K → °C
for temp_col in ["temp_mean", "dewpoint_mean"]:
if temp_col in df.columns:
df[temp_col] = df[temp_col] - 273.15
# 降水: m (ERA5 日均累积) → mm
if "precip" in df.columns:
df["precip"] = df["precip"] * 1000.0
logger.info("日聚合完成: %d 天, %d 变量", len(df), len(available))
return df
# ============================================================================
# 函数 3: 相对湿度 (Magnus 公式)
# ============================================================================
def compute_relative_humidity(temp_c: np.ndarray, dewpoint_c: np.ndarray) -> np.ndarray:
"""使用 Magnus 公式从气温和露点温度计算相对湿度
公式:
e_s(T) = exp(a*T / (b+T)) (饱和水汽压)
e_a(Td) = exp(a*Td / (b+Td)) (实际水汽压)
RH = 100 * e_a / e_s = 100 * exp(a*Td/(b+Td) - a*T/(b+T))
Args:
temp_c: 气温数组 (°C)
dewpoint_c: 露点温度数组 (°C)
Returns:
相对湿度数组 (%),值域 [0, 100]
"""
a = 17.27
b = 237.7 # °C
gamma = (a * dewpoint_c) / (b + dewpoint_c) - (a * temp_c) / (b + temp_c)
rh = 100.0 * np.exp(gamma)
return np.clip(rh, 0.0, 100.0)
# ============================================================================
# 函数 4: 体感温度 (NOAA Heat Index)
# ============================================================================
def compute_heat_index(temp_c: np.ndarray, rh: np.ndarray) -> np.ndarray:
"""使用 NOAA Rothfusz 回归公式计算体感温度 (Heat Index)
算法:
1. °C → °F 转换
2. T ≥ 80°F (≈26.7°C) 时使用 Rothfusz 回归公式
3. T < 80°F 时使用简化线性公式
4. 根据湿度条件进行修正 (NOAA 官方方法)
5. °F → °C 转换
Args:
temp_c: 气温 (°C)
rh: 相对湿度 (%)
Returns:
体感温度 (°C)
"""
# °C → °F
t_f = temp_c * 9.0 / 5.0 + 32.0
# 简化公式: 用于 T < 80°F 时
# HI = 0.5 * [T + 61.0 + (T - 68.0)*1.2 + RH*0.094]
hi_simple = 0.5 * (t_f + 61.0 + (t_f - 68.0) * 1.2 + (rh * 0.094))
# Rothfusz 回归公式: 用于 T ≥ 80°F 时
hi_rothfusz = (
-42.379
+ 2.04901523 * t_f
+ 10.14333127 * rh
- 0.22475541 * t_f * rh
- 6.83783e-3 * (t_f ** 2)
- 5.481717e-2 * (rh ** 2)
+ 1.22874e-3 * (t_f ** 2) * rh
+ 8.5282e-4 * t_f * (rh ** 2)
- 1.99e-6 * (t_f ** 2) * (rh ** 2)
)
# 湿度修正 (仅对符合条件的元素计算,避免 NaN 产生的 RuntimeWarning)
# 低湿修正: RH < 13% 且 80°F < T < 112°F
mask_low = (rh < 13.0) & (t_f > 80.0) & (t_f < 112.0)
adj_low = np.where(
mask_low,
((13.0 - rh) / 4.0) * np.sqrt(np.maximum((17.0 - np.abs(t_f - 95.0)) / 17.0, 0.0)),
0.0,
)
# 高湿修正: RH > 85% 且 80°F < T < 87°F
mask_high = (rh > 85.0) & (t_f > 80.0) & (t_f < 87.0)
adj_high = np.where(
mask_high,
((rh - 85.0) / 10.0) * ((87.0 - t_f) / 5.0),
0.0,
)
# 组合: 选择公式 → 应用修正
hi_f = np.where(t_f >= 80.0, hi_rothfusz, hi_simple)
hi_f = np.where(mask_low, hi_f - adj_low, hi_f)
hi_f = np.where(mask_high, hi_f + adj_high, hi_f)
# 体感温度不能低于实际气温
hi_f = np.maximum(hi_f, t_f)
# °F → °C
return (hi_f - 32.0) * 5.0 / 9.0
# ============================================================================
# 函数 5: 特征工程
# ============================================================================
def build_features(df: pd.DataFrame) -> pd.DataFrame:
"""从日聚合气象数据构建 ML 模型特征
生成以下特征:
- rh : 相对湿度
- heat_index : 体感温度 (NOAA)
- temp_7d_avg : 7 天滚动平均气温
- temp_14d_avg : 14 天滚动平均气温
- temp_lag_0..7: 滞后 0, 1, 3, 7 天的气温
- heatwave : 热浪标记 (连续 3 天体感温度 > 35°C)
- heatwave_strength : 热浪期间平均体感温度
- month : 月份 (1-12)
- season : 季节 (1=冬/2=春/3=夏/4=秋)
Args:
df: 日聚合 DataFrame,至少包含 temp_mean 和 dewpoint_mean
Returns:
添加了所有特征列的 DataFrame (含 NaN 在滞后特征起始位置)
Raises:
KeyError: 缺少必要列时
"""
df = df.copy()
# 验证必要列
required = {"temp_mean", "dewpoint_mean"}
missing = required - set(df.columns)
if missing:
raise KeyError(f"缺少必要列: {missing}。请确认 compute_daily_aggregates 的输出")
# 检测时间列
if "valid_time" in df.columns:
time_col = "valid_time"
elif "time" in df.columns:
time_col = "time"
else:
raise KeyError("DataFrame 缺少时间列 ('valid_time''time')")
# --- 推导特征 ---
# 相对湿度
df["rh"] = compute_relative_humidity(
df["temp_mean"].values, df["dewpoint_mean"].values
)
# 体感温度
df["heat_index"] = compute_heat_index(
df["temp_mean"].values, df["rh"].values
)
# 滚动平均气温 (min_periods=1 避免起始 NaN)
df["temp_7d_avg"] = df["temp_mean"].rolling(window=7, min_periods=1).mean()
df["temp_14d_avg"] = df["temp_mean"].rolling(window=14, min_periods=1).mean()
# 滞后气温
df["temp_lag_0"] = df["temp_mean"]
df["temp_lag_1"] = df["temp_mean"].shift(1)
df["temp_lag_3"] = df["temp_mean"].shift(3)
df["temp_lag_7"] = df["temp_mean"].shift(7)
# 热浪检测: 连续 3 天体感温度 > 35°C
hot_mask = (df["heat_index"] > 35.0).astype(int)
df["heatwave"] = (
hot_mask.rolling(window=3, min_periods=3).sum() >= 3
).astype(int)
# 热浪强度: 热浪期间的体感温度均值 (非热浪天填 0)
df["heatwave_strength"] = np.where(
df["heatwave"] == 1,
df["heat_index"].rolling(window=3, min_periods=3).mean(),
0.0,
)
# 时间特征
time_series = pd.to_datetime(df[time_col])
df["month"] = time_series.dt.month
# 季节编码: 12,1,2=冬(1) 3,4,5=春(2) 6,7,8=夏(3) 9,10,11=秋(4)
# month % 12 // 3 + 1 恰好满足此映射
df["season"] = (time_series.dt.month % 12) // 3 + 1
# 统一时间列名为 "time"
if time_col != "time":
df["time"] = df[time_col]
logger.info("特征工程完成: %d 行 x %d", len(df), len(df.columns))
return df
# ============================================================================
# 函数 6: 风险等级标签
# ============================================================================
def compute_risk_labels(df: pd.DataFrame) -> pd.DataFrame:
"""根据体感温度和热浪状态计算四级风险等级标签
等级定义:
0 (低) : heat_index < 32°C
1 (中) : 32°C ≤ heat_index < 35°C
2 (高) : 35°C ≤ heat_index < 38°C OR 热浪期间
3 (严重) : heat_index ≥ 38°C AND 热浪期间
Args:
df: 包含 heat_index 和 heatwave 列的 DataFrame
Returns:
添加了 risk_label 列 (int64) 的 DataFrame
"""
df = df.copy()
hi = df["heat_index"].values
hw = df["heatwave"].values # 0 或 1
# 初始化全为 0 (低风险)
risk = np.zeros(len(df), dtype=np.int64)
# 等级 1: 中风险
risk = np.where((hi >= 32.0) & (hi < 35.0), 1, risk)
# 等级 2: 高风险 (体感温度达到阈值 OR 热浪期间)
risk = np.where(((hi >= 35.0) & (hi < 38.0)) | (hw == 1), 2, risk)
# 等级 3: 严重风险 (体感温度极高 AND 热浪期间)
risk = np.where((hi >= 38.0) & (hw == 1), 3, risk)
df["risk_label"] = risk
# 统计各等级分布
label_names = {0: "", 1: "", 2: "", 3: "严重"}
for level in range(4):
count = int((risk == level).sum())
pct = count / len(df) * 100
logger.info(" 等级 %d (%s): %d 天 (%.1f%%)", level, label_names[level], count, pct)
return df
# ============================================================================
# 函数 7: 创建 ML 序列
# ============================================================================
def create_sequences(
df: pd.DataFrame,
lookback: int = LOOKBACK_DAYS,
horizons: dict | None = None,
) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""从特征 DataFrame 创建监督学习时间序列样本
对每个时间步 i (从 lookback 到数据末尾):
X[i - lookback] = 特征矩阵 [i-lookback, i) 行,所有特征列
y[i - lookback] = [
未来 3 天风险等级众数,
未来 7 天风险等级众数,
未来 30 天风险等级众数,
]
排除的特征列: time, valid_time, city, city_name, risk_label, month, season
Args:
df: 包含特征列和 risk_label 的 DataFrame
lookback: 输入序列天数
horizons: 预测窗口字典 {"short": N, "medium": N, "long": N}
Returns:
X: float32 数组 (N_samples, lookback, n_features)
y: int64 数组 (N_samples, 3),列对应 [short, medium, long]
feature_cols: 用于 X 的特征列名列表
"""
if horizons is None:
horizons = PREDICTION_WINDOWS
# 排除非特征列
exclude_cols = {"time", "valid_time", "city", "city_name", "risk_label", "month", "season"}
feature_cols = [c for c in df.columns if c not in exclude_cols]
# 仅保留数值型特征
feature_cols = [c for c in feature_cols if pd.api.types.is_numeric_dtype(df[c])]
logger.info("序列特征 (%d): %s", len(feature_cols), feature_cols)
n_total = len(df)
horizon_order = ["short", "medium", "long"]
horizon_values = [horizons[h] for h in horizon_order]
max_horizon = max(horizon_values)
X_list: list[np.ndarray] = []
y_list: list[list[int]] = []
for i in range(lookback, n_total):
# 输入窗口: 前 lookback 天
x_win = df.iloc[i - lookback : i][feature_cols].values.astype(np.float32)
X_list.append(x_win)
# 目标: 各预测窗口的风险等级众数
y_row: list[int] = []
for h in horizon_values:
end_idx = min(i + h, n_total)
if end_idx > i:
future = df.iloc[i:end_idx]["risk_label"].values
mode_result = stats.mode(future, keepdims=False)
# mode 可能是 0-d array 或标量
y_row.append(int(np.atleast_1d(mode_result.mode)[0]))
else:
y_row.append(int(df.iloc[-1]["risk_label"]))
y_list.append(y_row)
X = np.array(X_list, dtype=np.float32)
y = np.array(y_list, dtype=np.int64)
logger.info("序列创建完成: X%s, y%s", X.shape, y.shape)
# 打印各窗口标签分布
for j, name in enumerate(horizon_order):
values, counts = np.unique(y[:, j], return_counts=True)
dist = {int(v): int(c) for v, c in zip(values, counts)}
logger.info(" y_%s 分布: %s", name, dist)
return X, y, feature_cols
# ============================================================================
# 函数 8: 完整预处理管线
# ============================================================================
def preprocess_all() -> None:
"""执行完整的数据预处理管线
对配置中每个城市依次执行:
1. load_era5_city — 加载 NetCDF
2. compute_daily_aggregates — 日聚合
3. build_features — 特征工程
4. compute_risk_labels — 风险标签
5. dropna → 保存 feature CSV
6. create_sequences — 构建序列 → 保存 NPZ
最后合并所有城市数据,保存 combined CSV 和 NPZ
若 ERA5 数据尚未下载,会记录警告并跳过对应城市。
"""
DATA_PROCESSED.mkdir(parents=True, exist_ok=True)
combined_dfs: list[pd.DataFrame] = []
# 记录第一个城市的特征列名,用于合并 NPZ
saved_feature_cols: list[str] = []
for city_key, city_info in CITIES.items():
city_name = city_info["name"]
logger.info("=" * 60)
logger.info(">>> 处理城市: %s (%s)", city_name, city_key)
# ---- 1. 加载 ----
try:
ds = load_era5_city(city_key)
except FileNotFoundError as e:
logger.warning("跳过 %s: %s", city_key, e)
continue
# ---- 2. 日聚合 ----
df = compute_daily_aggregates(ds)
if df.empty:
logger.warning("跳过 %s: 日聚合结果为空", city_key)
continue
# 添加城市标识列
df["city"] = city_key
df["city_name"] = city_name
# ---- 3. 特征工程 ----
df = build_features(df)
# ---- 4. 风险标签 ----
df = compute_risk_labels(df)
# ---- 5. 删除含 NaN 的行并保存 ----
df_clean = df.dropna().reset_index(drop=True)
csv_path = DATA_PROCESSED / f"features_{city_key}.csv"
df_clean.to_csv(csv_path, index=False, encoding="utf-8-sig")
logger.info("已保存特征 CSV: %s (%d 行 x %d 列)",
csv_path.name, len(df_clean), len(df_clean.columns))
# ---- 6. 创建序列 ----
X, y, feature_cols = create_sequences(df_clean)
if not saved_feature_cols:
saved_feature_cols = feature_cols
npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz"
np.savez_compressed(
npz_path,
X=X,
y=y,
feature_cols=np.array(feature_cols, dtype=object),
)
logger.info("已保存序列 NPZ: %s (X%s, y%s)",
npz_path.name, X.shape, y.shape)
combined_dfs.append(df_clean)
# ---- 合并所有城市 ----
if not combined_dfs:
logger.warning("没有城市完成处理。请先下载 ERA5 数据")
logger.warning("运行: python -m src.data.download_era5")
return
# 合并 CSV
combined = pd.concat(combined_dfs, ignore_index=True)
combined_csv = DATA_PROCESSED / "features_combined.csv"
combined.to_csv(combined_csv, index=False, encoding="utf-8-sig")
logger.info("已保存合并特征 CSV: %s (%d 行)", combined_csv.name, len(combined))
# 合并 NPZ
all_X, all_y = [], []
for city_key in CITIES:
npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz"
if npz_path.exists():
data = np.load(npz_path, allow_pickle=True)
all_X.append(data["X"])
all_y.append(data["y"])
if all_X and saved_feature_cols:
combined_X = np.concatenate(all_X, axis=0)
combined_y = np.concatenate(all_y, axis=0)
combined_npz = DATA_PROCESSED / "sequences_combined.npz"
np.savez_compressed(
combined_npz,
X=combined_X,
y=combined_y,
feature_cols=np.array(saved_feature_cols, dtype=object),
)
logger.info("已保存合并序列 NPZ: %s (X%s, y%s)",
combined_npz.name, combined_X.shape, combined_y.shape)
logger.info("=" * 60)
logger.info("数据预处理管线全部完成!")
# ============================================================================
# CLI 入口
# ============================================================================
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
preprocess_all()
@@ -0,0 +1,20 @@
Metadata-Version: 2.4
Name: elderly-heat-warning
Version: 0.1.0
Summary: 银发群体高温多时间尺度预警和服务优化可视化研究
Requires-Python: >=3.10
Requires-Dist: numpy>=1.26
Requires-Dist: pandas>=2.1
Requires-Dist: xarray>=2023.0
Requires-Dist: netcdf4>=1.6
Requires-Dist: cdsapi>=0.7
Requires-Dist: torch>=2.1
Requires-Dist: pytorch-lightning>=2.1
Requires-Dist: xgboost>=2.0
Requires-Dist: scikit-learn>=1.3
Requires-Dist: flask>=3.0
Requires-Dist: matplotlib>=3.8
Requires-Dist: seaborn>=0.13
Requires-Dist: jupyter>=1.0
Requires-Dist: tqdm>=4.66
Requires-Dist: scipy>=1.11
@@ -0,0 +1,6 @@
pyproject.toml
src/elderly_heat_warning.egg-info/PKG-INFO
src/elderly_heat_warning.egg-info/SOURCES.txt
src/elderly_heat_warning.egg-info/dependency_links.txt
src/elderly_heat_warning.egg-info/requires.txt
src/elderly_heat_warning.egg-info/top_level.txt
@@ -0,0 +1 @@
@@ -0,0 +1,15 @@
numpy>=1.26
pandas>=2.1
xarray>=2023.0
netcdf4>=1.6
cdsapi>=0.7
torch>=2.1
pytorch-lightning>=2.1
xgboost>=2.0
scikit-learn>=1.3
flask>=3.0
matplotlib>=3.8
seaborn>=0.13
jupyter>=1.0
tqdm>=4.66
scipy>=1.11
@@ -0,0 +1,4 @@
data
models
utils
web
View File
View File
+64
View File
@@ -0,0 +1,64 @@
"""全局配置常量"""
from pathlib import Path
# 项目根目录
ROOT = Path(__file__).parent.parent.parent
# 数据目录
DATA_RAW = ROOT / "data" / "raw"
DATA_PROCESSED = ROOT / "data" / "processed"
DATA_EXTERNAL = ROOT / "data" / "external"
# 输出目录
OUTPUT_MODELS = ROOT / "outputs" / "models"
OUTPUT_FIGURES = ROOT / "outputs" / "figures"
OUTPUT_LOGS = ROOT / "outputs" / "logs"
# 研究城市坐标 (纬度, 经度)
CITIES = {
"jiaozuo": {"lat": 35.24, "lon": 113.22, "name": "焦作"},
"zhengzhou": {"lat": 34.75, "lon": 113.62, "name": "郑州"},
}
# ERA5 配置
ERA5_START_YEAR = 2010
ERA5_END_YEAR = 2024
ERA5_VARIABLES = [
"2m_temperature",
"2m_dewpoint_temperature",
"surface_pressure",
"10m_u_component_of_wind",
"10m_v_component_of_wind",
"total_precipitation",
]
# 模型配置
LOOKBACK_DAYS = 14
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
MAX_EPOCHS = 100
EARLY_STOP_PATIENCE = 15
HIDDEN_DIM = 128
LSTM_LAYERS = 2
ATTENTION_HEADS = 4
DROPOUT = 0.3
# 风险等级阈值 (体感温度 °C)
RISK_THRESHOLDS = {
"low": 32,
"medium": 35,
"high": 38,
"severe": 38,
}
# 时间尺度预测窗口 (天)
PREDICTION_WINDOWS = {
"short": 3,
"medium": 7,
"long": 30,
}
# 确保目录存在
for d in [DATA_RAW, DATA_PROCESSED, DATA_EXTERNAL,
OUTPUT_MODELS, OUTPUT_FIGURES, OUTPUT_LOGS]:
d.mkdir(parents=True, exist_ok=True)
View File