feat: 完成模型训练/评估/Web大屏/LaTeX论文框架
- LSTM-Attention模型(983K参数) + XGBoost基线 - Flask API后端(4端点) + ECharts可视化大屏(6面板) - LaTeX学位论文完整框架(7章+参考文献) - ERA5下载脚本(CDS逐月并行下载) - README项目文档 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -17,3 +17,4 @@ outputs/logs/
|
||||
*.fdb_latexmk
|
||||
*.fls
|
||||
.DS_Store
|
||||
.claude/
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
# 银发群体高温多时间尺度预警和服务优化可视化研究
|
||||
|
||||
本科毕业设计 — 河南理工大学计算机科学与技术学院
|
||||
|
||||
## 概述
|
||||
|
||||
本项目针对焦作市和郑州市老年群体,构建了基于 LSTM-Attention 的多时间尺度高温健康风险预警模型,并开发了 ECharts 可视化大屏系统。
|
||||
|
||||
### 核心功能
|
||||
|
||||
- **多时间尺度预警**:短期(1-3天)、中期(7天)、长期(30天)三级高温健康风险预测
|
||||
- **深度学习模型**:BiLSTM + Multi-Head Attention,三头输出同时预测三个时间尺度
|
||||
- **基线对比**:XGBoost 三分类器,验证深度学习方法有效性
|
||||
- **可视化大屏**:6 面板深色科技蓝风格 Web 大屏,含温度趋势、风险预警、人口统计等
|
||||
- **完整论文**:LaTeX 学位论文,含 7 个章节 + 参考文献 + 附录
|
||||
|
||||
### 技术栈
|
||||
|
||||
| 层 | 技术 |
|
||||
|----|------|
|
||||
| 数据处理 | Python, xarray, pandas, numpy |
|
||||
| 气象数据 | ERA5-Land (CDS API) |
|
||||
| 深度学习 | PyTorch 2.12, CUDA 12.6 |
|
||||
| 传统模型 | XGBoost, scikit-learn |
|
||||
| Web 后端 | Flask |
|
||||
| 可视化 | ECharts 5.5 |
|
||||
| 包管理 | uv |
|
||||
| 论文 | LaTeX (XeLaTeX + ctexbook) |
|
||||
|
||||
## 环境配置
|
||||
|
||||
### 系统要求
|
||||
|
||||
- Python 3.13
|
||||
- NVIDIA GPU (推荐,RTX 4060 或以上)
|
||||
- Windows 11 / Linux
|
||||
|
||||
### 安装
|
||||
|
||||
```bash
|
||||
# 创建虚拟环境
|
||||
uv venv --python "D:\settings\Language\Python\Python 3.13.13\python.exe"
|
||||
|
||||
# 安装依赖
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
### CDS API 配置(数据下载必需)
|
||||
|
||||
1. 注册 Copernicus CDS 账号:https://cds.climate.copernicus.eu/
|
||||
2. 获取 API Key
|
||||
3. 创建 `~/.cdsapirc`:
|
||||
```
|
||||
url: https://cds.climate.copernicus.eu/api
|
||||
key: <UID>:<API_KEY>
|
||||
```
|
||||
|
||||
## 运行指南
|
||||
|
||||
### 1. 数据获取与预处理
|
||||
|
||||
```bash
|
||||
# 下载 ERA5 气象数据(需要 CDS API 配置)
|
||||
python -m src.data.download_era5
|
||||
|
||||
# 收集死亡率与人口数据
|
||||
python -m src.data.collect_mortality
|
||||
|
||||
# 运行预处理管道
|
||||
python -m src.data.preprocess
|
||||
```
|
||||
|
||||
### 2. 探索性数据分析
|
||||
|
||||
```bash
|
||||
jupyter notebook notebooks/eda.ipynb
|
||||
```
|
||||
|
||||
### 3. 模型训练
|
||||
|
||||
```bash
|
||||
# 训练 LSTM-Attention 模型
|
||||
python -m src.models.train
|
||||
|
||||
# 模型评估与对比
|
||||
python -m src.models.evaluate
|
||||
```
|
||||
|
||||
### 4. 启动可视化大屏
|
||||
|
||||
```bash
|
||||
python -m src.web.app
|
||||
# 浏览器打开 http://localhost:5005
|
||||
```
|
||||
|
||||
### 5. 论文编译
|
||||
|
||||
```bash
|
||||
cd thesis
|
||||
make
|
||||
# 或手动: xelatex main && biber main && xelatex main && xelatex main
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
project/
|
||||
├── data/
|
||||
│ ├── raw/era5/ # ERA5 原始 NetCDF 文件
|
||||
│ ├── processed/ # 预处理后 CSV 和 NPZ 序列
|
||||
│ └── external/ # 死亡率/人口/暴露反应数据
|
||||
├── src/
|
||||
│ ├── data/ # 数据获取与预处理
|
||||
│ │ ├── download_era5.py
|
||||
│ │ ├── collect_mortality.py
|
||||
│ │ └── preprocess.py
|
||||
│ ├── models/ # 模型定义与训练
|
||||
│ │ ├── lstm_attention.py
|
||||
│ │ ├── xgboost_baseline.py
|
||||
│ │ ├── train.py
|
||||
│ │ └── evaluate.py
|
||||
│ ├── web/ # Web 可视化
|
||||
│ │ ├── app.py
|
||||
│ │ └── static/index.html
|
||||
│ └── utils/
|
||||
│ └── config.py # 全局配置
|
||||
├── notebooks/
|
||||
│ └── eda.ipynb # 探索性数据分析
|
||||
├── outputs/
|
||||
│ ├── models/ # 训练好的模型权重
|
||||
│ ├── figures/ # 论文和评估图表
|
||||
│ └── logs/ # 训练日志
|
||||
├── thesis/ # LaTeX 学位论文
|
||||
│ ├── main.tex
|
||||
│ ├── chapters/ # 各章节 tex 文件
|
||||
│ ├── refs.bib # 参考文献
|
||||
│ └── Makefile
|
||||
└── docs/superpowers/ # 设计文档和计划
|
||||
```
|
||||
|
||||
## 模型架构
|
||||
|
||||
```
|
||||
输入 (14天气象序列)
|
||||
→ Linear 嵌入 (16 → 128)
|
||||
→ 2层 BiLSTM (128, dropout=0.3)
|
||||
→ Multi-Head Attention (4 heads)
|
||||
→ Linear 投影 (256 → 128)
|
||||
→ 三头输出
|
||||
├── 短期头 (128→64→4)
|
||||
├── 中期头 (128→64→4)
|
||||
└── 长期头 (128→64→4)
|
||||
```
|
||||
|
||||
总参数量:~983K
|
||||
|
||||
## 风险等级定义
|
||||
|
||||
| 等级 | 条件 | 颜色 |
|
||||
|------|------|------|
|
||||
| 低风险 | 体感温度 < 32°C | 绿 |
|
||||
| 中风险 | 体感温度 32-35°C | 黄 |
|
||||
| 高风险 | 体感温度 35-38°C 或连续 3 天 >35°C | 橙 |
|
||||
| 严重风险 | 体感温度 >= 38°C 且连续 3 天 >35°C | 红 |
|
||||
|
||||
## 数据来源
|
||||
|
||||
| 数据 | 来源 | 时间范围 |
|
||||
|------|------|----------|
|
||||
| 气象数据 | ERA5-Land (Copernicus CDS) | 2010-2024 |
|
||||
| 死亡率 | 中国卫生健康统计年鉴 | 2010-2023 |
|
||||
| 暴露反应曲线 | Chen et al. (2018) Lancet Planet Health | — |
|
||||
| 人口数据 | 第七次全国人口普查 (2020) | 2020 |
|
||||
| 老龄化率 | 河南省统计年鉴 | 2010-2023 |
|
||||
+51
-69
@@ -1,32 +1,19 @@
|
||||
"""从 Copernicus CDS 下载 ERA5-Land 再分析数据"""
|
||||
"""从 Copernicus CDS 下载 ERA5-Land 再分析数据(逐月,支持并行)"""
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import cdsapi
|
||||
|
||||
from src.utils.config import (
|
||||
CITIES,
|
||||
DATA_RAW,
|
||||
ERA5_END_YEAR,
|
||||
ERA5_START_YEAR,
|
||||
ERA5_VARIABLES,
|
||||
CITIES, DATA_RAW, ERA5_START_YEAR, ERA5_END_YEAR, ERA5_VARIABLES,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_request(city: str, year: int, month: int) -> dict:
|
||||
"""构建 CDS API 请求参数,提取城市周围 0.5 度区域
|
||||
|
||||
Args:
|
||||
city: 城市键名("jiaozuo" 或 "zhengzhou")
|
||||
year: 年份
|
||||
month: 月份(1-12),0 表示全年所有月份
|
||||
|
||||
Returns:
|
||||
CDS API 请求参数字典
|
||||
"""
|
||||
lat = CITIES[city]["lat"]
|
||||
lon = CITIES[city]["lon"]
|
||||
return {
|
||||
@@ -34,67 +21,62 @@ def build_request(city: str, year: int, month: int) -> dict:
|
||||
"format": "netcdf",
|
||||
"variable": ERA5_VARIABLES,
|
||||
"year": [str(year)],
|
||||
"month": [f"{m:02d}" for m in (range(1, 13) if month == 0 else [month])],
|
||||
"month": [f"{month:02d}"],
|
||||
"day": [f"{d:02d}" for d in range(1, 32)],
|
||||
"time": [f"{h:02d}:00" for h in range(24)],
|
||||
"area": [lat + 0.5, lon - 0.5, lat - 0.5, lon + 0.5], # [N, W, S, E]
|
||||
"time": [f"{h:02d}:00" for h in [0, 6, 12, 18]],
|
||||
"area": [lat + 0.5, lon - 0.5, lat - 0.5, lon + 0.5],
|
||||
}
|
||||
|
||||
|
||||
def download_era5_city(
|
||||
city: str,
|
||||
start_year: int = ERA5_START_YEAR,
|
||||
end_year: int = ERA5_END_YEAR,
|
||||
max_retries: int = 3,
|
||||
retry_delay: int = 30,
|
||||
) -> None:
|
||||
"""逐月下载指定城市的 ERA5-Land 数据,避免单次请求过大超时
|
||||
|
||||
Args:
|
||||
city: 城市键名
|
||||
start_year: 起始年份
|
||||
end_year: 结束年份
|
||||
max_retries: 失败重试次数
|
||||
retry_delay: 重试等待秒数
|
||||
"""
|
||||
def download_one_month(city: str, year: int, month: int) -> bool:
|
||||
"""下载单月数据,返回 True 表示成功"""
|
||||
client = cdsapi.Client()
|
||||
out_dir = Path(DATA_RAW) / "era5" / city
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = out_dir / f"era5_{city}_{year}_{month:02d}.nc"
|
||||
|
||||
for year in range(start_year, end_year + 1):
|
||||
for month in range(1, 13):
|
||||
out_path = out_dir / f"era5_{city}_{year}_{month:02d}.nc"
|
||||
if out_path.exists():
|
||||
logger.info("跳过已存在: %s", out_path)
|
||||
continue
|
||||
if out_path.exists():
|
||||
return True # 已存在,跳过
|
||||
|
||||
request = build_request(city, year, month)
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
logger.info(
|
||||
"正在下载 %s %d-%02d (第 %d/%d 次尝试)...",
|
||||
city, year, month, attempt, max_retries,
|
||||
)
|
||||
client.retrieve(
|
||||
"reanalysis-era5-land",
|
||||
request,
|
||||
str(out_path),
|
||||
)
|
||||
logger.info("下载完成: %s", out_path)
|
||||
break
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"下载失败 %s %d-%02d (第 %d/%d 次)",
|
||||
city, year, month, attempt, max_retries,
|
||||
)
|
||||
if attempt < max_retries:
|
||||
logger.info("等待 %d 秒后重试...", retry_delay)
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
logger.error(
|
||||
"下载彻底失败 %s %d-%02d,已达最大重试次数",
|
||||
city, year, month,
|
||||
)
|
||||
request = build_request(city, year, month)
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
client.retrieve("reanalysis-era5-land", request, str(out_path))
|
||||
return True
|
||||
except Exception:
|
||||
if attempt < 3:
|
||||
time.sleep(30)
|
||||
return False
|
||||
|
||||
|
||||
def download_city(city: str, start_year: int = ERA5_START_YEAR,
|
||||
end_year: int = ERA5_END_YEAR, max_workers: int = 3):
|
||||
"""并行下载(3线程),兼顾速度和 CDS 限流"""
|
||||
name = CITIES[city]["name"]
|
||||
tasks = [(city, y, m) for y in range(start_year, end_year + 1) for m in range(1, 13)]
|
||||
total = len(tasks)
|
||||
done = 0
|
||||
fail = 0
|
||||
|
||||
# 先统计已存在的
|
||||
existed = sum(1 for _, y, m in tasks
|
||||
if (Path(DATA_RAW) / "era5" / city / f"era5_{city}_{y}_{m:02d}.nc").exists())
|
||||
if existed > 0:
|
||||
logger.info("%s: %d/%d 已存在,跳过", name, existed, total)
|
||||
done = existed
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = {pool.submit(download_one_month, c, y, m): (y, m)
|
||||
for c, y, m in tasks if not (Path(DATA_RAW) / "era5" / city
|
||||
/ f"era5_{c}_{y}_{m:02d}.nc").exists()}
|
||||
for f in as_completed(futures):
|
||||
y, m = futures[f]
|
||||
if f.result():
|
||||
done += 1
|
||||
else:
|
||||
fail += 1
|
||||
if (done + fail) % 10 == 0 or (done + fail) == (total - existed):
|
||||
logger.info("%s: %d/%d 完成 (%d 失败)", name, done + existed, total, fail)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -103,4 +85,4 @@ if __name__ == "__main__":
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
)
|
||||
for city_name in CITIES:
|
||||
download_era5_city(city_name)
|
||||
download_city(city_name)
|
||||
|
||||
@@ -0,0 +1,453 @@
|
||||
"""模型评估与对比 — LSTM vs XGBoost 多时间尺度高温风险预测
|
||||
|
||||
功能:
|
||||
1. 加载测试数据 (与训练相同的 70/15/15 时间序划分)
|
||||
2. 加载已训练的 LSTM 模型,获取/计算预测结果
|
||||
3. 训练 XGBoost 基线并评估
|
||||
4. 生成混淆矩阵对比图和指标对比柱状图
|
||||
5. 打印格式化的模型对比表
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
|
||||
|
||||
from src.models.lstm_attention import HeatRiskPredictor
|
||||
from src.models.xgboost_baseline import train_xgboost_baseline
|
||||
from src.utils.config import DATA_PROCESSED, OUTPUT_MODELS, OUTPUT_FIGURES
|
||||
|
||||
# ============================================================================
|
||||
# 全局常量
|
||||
# ============================================================================
|
||||
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
RISK_LABELS = ["低", "中", "高", "严重"]
|
||||
HORIZON_KEYS = ["short", "medium", "long"]
|
||||
HORIZON_CN = ["短期", "中期", "长期"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 数据加载
|
||||
# ============================================================================
|
||||
|
||||
def _load_all_data() -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
|
||||
"""加载焦作和郑州的序列 NPZ 文件并沿样本轴拼接。
|
||||
|
||||
Returns:
|
||||
(X, y) 或 (None, None) 当数据文件不存在时。
|
||||
X: (N, T, D) float32 特征数组
|
||||
y: (N, 3) int64 标签数组,列顺序: short, medium, long
|
||||
"""
|
||||
cities = ["jiaozuo", "zhengzhou"]
|
||||
X_parts: list[np.ndarray] = []
|
||||
y_parts: list[np.ndarray] = []
|
||||
|
||||
for city in cities:
|
||||
path = DATA_PROCESSED / f"{city}_sequences.npz"
|
||||
if path.exists():
|
||||
data = np.load(path)
|
||||
X_parts.append(data["X"])
|
||||
y_parts.append(data["y"])
|
||||
print(f"已加载 {city}: X {data['X'].shape}, y {data['y'].shape}")
|
||||
else:
|
||||
print(f"警告: {path} 不存在,跳过该城市")
|
||||
|
||||
if not X_parts:
|
||||
print("错误: 未找到任何序列数据文件,无法进行评估。", file=sys.stderr)
|
||||
print(f"请确保 {DATA_PROCESSED / 'jiaozuo_sequences.npz'} "
|
||||
f"或 {DATA_PROCESSED / 'zhengzhou_sequences.npz'} 存在。", file=sys.stderr)
|
||||
return None, None
|
||||
|
||||
X = np.concatenate(X_parts, axis=0)
|
||||
y = np.concatenate(y_parts, axis=0)
|
||||
print(f"合并后数据: X {X.shape}, y {y.shape}")
|
||||
return X, y
|
||||
|
||||
|
||||
def load_test_data() -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
|
||||
"""加载测试集数据 (与训练相同的 70/15/15 时间序划分)。
|
||||
|
||||
Returns:
|
||||
(X_test, y_test) 或 (None, None) 当数据不可用时。
|
||||
X_test: (N_test, T, D) float32
|
||||
y_test: (N_test, 3) int64,列顺序: short, medium, long
|
||||
"""
|
||||
X, y = _load_all_data()
|
||||
if X is None:
|
||||
return None, None
|
||||
|
||||
n_total = len(X)
|
||||
n_train = int(n_total * 0.70)
|
||||
n_val = int(n_total * 0.15)
|
||||
|
||||
X_test = X[n_train + n_val:]
|
||||
y_test = y[n_train + n_val:]
|
||||
|
||||
print(f"测试集: X_test {X_test.shape}, y_test {y_test.shape}")
|
||||
return X_test, y_test
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LSTM 评估
|
||||
# ============================================================================
|
||||
|
||||
def evaluate_lstm(
|
||||
X_test: np.ndarray,
|
||||
y_test: np.ndarray,
|
||||
) -> tuple[dict[str, np.ndarray], np.ndarray] | tuple[None, None]:
|
||||
"""加载已训练的 HeatRiskPredictor 并获取测试集预测。
|
||||
|
||||
优先从 test_predictions.npz 加载已保存的预测结果;
|
||||
若文件不存在,则加载模型权重并重新推理。
|
||||
|
||||
Args:
|
||||
X_test: (N, T, D) 测试集特征
|
||||
y_test: (N, 3) 测试集标签
|
||||
|
||||
Returns:
|
||||
(predictions_dict, labels_array) 或 (None, None) 当模型不可用时。
|
||||
predictions_dict: {"short": (N,) int, "medium": (N,) int, "long": (N,) int}
|
||||
labels_array: (N, 3) int64,与 y_test 相同
|
||||
"""
|
||||
pred_path = OUTPUT_MODELS / "test_predictions.npz"
|
||||
model_path = OUTPUT_MODELS / "best_model.pt"
|
||||
|
||||
# ---- 路径 1: 从已保存的预测文件加载 ----
|
||||
if pred_path.exists():
|
||||
print(f"从 {pred_path} 加载 LSTM 预测结果...")
|
||||
data = np.load(pred_path)
|
||||
|
||||
# 兼容 train.py 的 "y_true" 和任务约定的 "labels" 两种键名
|
||||
labels_key = "labels" if "labels" in data else "y_true"
|
||||
|
||||
predictions = {
|
||||
"short": data["short"],
|
||||
"medium": data["medium"],
|
||||
"long": data["long"],
|
||||
}
|
||||
labels = data[labels_key]
|
||||
|
||||
# 验证样本数匹配
|
||||
if len(predictions["short"]) != len(y_test):
|
||||
print(f"警告: 已保存预测样本数 ({len(predictions['short'])}) "
|
||||
f"与测试集样本数 ({len(y_test)}) 不一致,将重新推理。")
|
||||
else:
|
||||
print(f"已加载 LSTM 预测: {len(predictions['short'])} 样本")
|
||||
return predictions, labels
|
||||
|
||||
# ---- 路径 2: 重新推理 ----
|
||||
if not model_path.exists():
|
||||
print(f"错误: 模型文件 {model_path} 不存在,无法评估 LSTM。", file=sys.stderr)
|
||||
return None, None
|
||||
|
||||
print(f"从 {model_path} 加载模型权重...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||
input_dim = checkpoint.get("input_dim", X_test.shape[2])
|
||||
|
||||
model = HeatRiskPredictor(input_dim=input_dim).to(device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
# 批量推理
|
||||
X_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
|
||||
# 为避免 OOM,按 batch 处理
|
||||
batch_size = 128
|
||||
all_preds: dict[str, list[int]] = {"short": [], "medium": [], "long": []}
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(X_tensor), batch_size):
|
||||
batch = X_tensor[i : i + batch_size]
|
||||
outputs = model(batch)
|
||||
for key in HORIZON_KEYS:
|
||||
preds = outputs[key].argmax(dim=1).cpu().numpy()
|
||||
all_preds[key].extend(preds.tolist())
|
||||
|
||||
predictions = {k: np.array(v, dtype=np.int64) for k, v in all_preds.items()}
|
||||
print(f"LSTM 推理完成: {len(predictions['short'])} 样本")
|
||||
return predictions, y_test
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 辅助: 计算指标
|
||||
# ============================================================================
|
||||
|
||||
def _compute_metrics(
|
||||
predictions: dict[str, np.ndarray],
|
||||
y_true: np.ndarray,
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""从预测和真实标签计算各时间尺度的 accuracy 和 macro F1。
|
||||
|
||||
Args:
|
||||
predictions: {"short": (N,), "medium": (N,), "long": (N,)} 预测标签
|
||||
y_true: (N, 3) 真实标签,列顺序: short, medium, long
|
||||
|
||||
Returns:
|
||||
{"short": {"accuracy": ..., "f1_macro": ...}, ...}
|
||||
"""
|
||||
metrics: dict[str, dict[str, float]] = {}
|
||||
for i, key in enumerate(HORIZON_KEYS):
|
||||
y_pred = predictions[key]
|
||||
y_t = y_true[:, i]
|
||||
metrics[key] = {
|
||||
"accuracy": float(accuracy_score(y_t, y_pred)),
|
||||
"f1_macro": float(f1_score(y_t, y_pred, average="macro")),
|
||||
}
|
||||
return metrics
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 绘图: 混淆矩阵对比
|
||||
# ============================================================================
|
||||
|
||||
def plot_confusion_matrices(
|
||||
lstm_preds: dict[str, np.ndarray],
|
||||
xgb_preds: dict[str, np.ndarray],
|
||||
y_true: np.ndarray,
|
||||
save_path: str | Path | None = None,
|
||||
) -> None:
|
||||
"""绘制 2x3 混淆矩阵对比图 (LSTM 行 / XGBoost 行; 短/中/长期 列)。
|
||||
|
||||
Args:
|
||||
lstm_preds: LSTM 预测 {"short", "medium", "long"}
|
||||
xgb_preds: XGBoost 预测 {"short", "medium", "long"}
|
||||
y_true: (N, 3) 真实标签
|
||||
save_path: 保存路径,默认 OUTPUT_FIGURES / "confusion_matrix_comparison.png"
|
||||
"""
|
||||
if save_path is None:
|
||||
save_path = OUTPUT_FIGURES / "confusion_matrix_comparison.png"
|
||||
OUTPUT_FIGURES.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
|
||||
model_names = ["LSTM", "XGBoost"]
|
||||
preds_list = [lstm_preds, xgb_preds]
|
||||
|
||||
for row_idx, (model_name, preds) in enumerate(zip(model_names, preds_list)):
|
||||
for col_idx, key in enumerate(HORIZON_KEYS):
|
||||
ax = axes[row_idx, col_idx]
|
||||
cm = confusion_matrix(
|
||||
y_true[:, col_idx], preds[key],
|
||||
labels=[0, 1, 2, 3],
|
||||
)
|
||||
im = ax.imshow(cm, cmap="Blues", interpolation="nearest")
|
||||
|
||||
# 在每个单元格内标注数值
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
ax.text(
|
||||
j, i, str(cm[i, j]),
|
||||
ha="center", va="center",
|
||||
fontsize=9,
|
||||
color="white" if cm[i, j] > cm.max() / 2 else "black",
|
||||
)
|
||||
|
||||
ax.set_xticks(range(4))
|
||||
ax.set_yticks(range(4))
|
||||
ax.set_xticklabels(RISK_LABELS)
|
||||
ax.set_yticklabels(RISK_LABELS)
|
||||
|
||||
title = f"{model_name} - {HORIZON_CN[col_idx]}"
|
||||
ax.set_title(title, fontsize=13, fontweight="bold")
|
||||
ax.set_xlabel("预测标签")
|
||||
ax.set_ylabel("真实标签")
|
||||
|
||||
# 共享 colorbar
|
||||
fig.colorbar(
|
||||
axes[0, 2].get_images()[0],
|
||||
ax=axes[:, -1],
|
||||
shrink=0.8,
|
||||
label="样本数",
|
||||
)
|
||||
|
||||
fig.suptitle("混淆矩阵对比: LSTM vs XGBoost", fontsize=15, fontweight="bold", y=1.01)
|
||||
plt.tight_layout()
|
||||
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
print(f"混淆矩阵对比图已保存至 {save_path}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 绘图: 指标对比柱状图
|
||||
# ============================================================================
|
||||
|
||||
def plot_metrics_comparison(
|
||||
lstm_metrics: dict[str, dict[str, float]],
|
||||
xgb_metrics: dict[str, dict[str, float]],
|
||||
save_path: str | Path | None = None,
|
||||
) -> None:
|
||||
"""绘制 1x3 指标对比柱状图 (每个时间尺度两柱: accuracy 和 F1)。
|
||||
|
||||
Args:
|
||||
lstm_metrics: LSTM 指标 {"short": {"accuracy", "f1_macro"}, ...}
|
||||
xgb_metrics: XGBoost 指标 {"short": {"accuracy", "f1_macro"}, ...}
|
||||
save_path: 保存路径,默认 OUTPUT_FIGURES / "model_comparison.png"
|
||||
"""
|
||||
if save_path is None:
|
||||
save_path = OUTPUT_FIGURES / "model_comparison.png"
|
||||
OUTPUT_FIGURES.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||
metric_keys = ["accuracy", "f1_macro"]
|
||||
metric_cn = ["Accuracy", "F1 Macro"]
|
||||
bar_width = 0.35
|
||||
colors = ["#4C72B0", "#DD8452"] # LSTM, XGBoost
|
||||
|
||||
for col_idx, horizon_key in enumerate(HORIZON_KEYS):
|
||||
ax = axes[col_idx]
|
||||
|
||||
lstm_vals = [lstm_metrics[horizon_key][m] for m in metric_keys]
|
||||
xgb_vals = [xgb_metrics[horizon_key][m] for m in metric_keys]
|
||||
|
||||
x_pos = np.arange(len(metric_keys))
|
||||
bars1 = ax.bar(
|
||||
x_pos - bar_width / 2, lstm_vals, bar_width,
|
||||
label="LSTM", color=colors[0], edgecolor="white",
|
||||
)
|
||||
bars2 = ax.bar(
|
||||
x_pos + bar_width / 2, xgb_vals, bar_width,
|
||||
label="XGBoost", color=colors[1], edgecolor="white",
|
||||
)
|
||||
|
||||
# 在柱顶标注数值
|
||||
for bar in bars1:
|
||||
height = bar.get_height()
|
||||
ax.text(
|
||||
bar.get_x() + bar.get_width() / 2, height + 0.01,
|
||||
f"{height:.3f}", ha="center", va="bottom", fontsize=9,
|
||||
)
|
||||
for bar in bars2:
|
||||
height = bar.get_height()
|
||||
ax.text(
|
||||
bar.get_x() + bar.get_width() / 2, height + 0.01,
|
||||
f"{height:.3f}", ha="center", va="bottom", fontsize=9,
|
||||
)
|
||||
|
||||
ax.set_xticks(x_pos)
|
||||
ax.set_xticklabels(metric_cn)
|
||||
ax.set_ylim(0, 1.15)
|
||||
ax.set_title(HORIZON_CN[col_idx], fontsize=13, fontweight="bold")
|
||||
ax.set_ylabel("分数")
|
||||
ax.legend(loc="lower right")
|
||||
ax.grid(axis="y", alpha=0.3)
|
||||
|
||||
fig.suptitle("模型指标对比: LSTM vs XGBoost", fontsize=15, fontweight="bold")
|
||||
plt.tight_layout()
|
||||
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
print(f"模型对比图已保存至 {save_path}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 主评估函数
|
||||
# ============================================================================
|
||||
|
||||
def evaluate() -> None:
|
||||
"""执行完整的模型评估与对比流水线。
|
||||
|
||||
流程:
|
||||
1. 加载数据并按 70/15/15 划分
|
||||
2. LSTM 评估 (加载最佳模型 + 预测)
|
||||
3. XGBoost 基线训练与评估
|
||||
4. 打印对比表
|
||||
5. 生成对比图 (混淆矩阵 + 指标柱状图)
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("模型评估与对比")
|
||||
print("=" * 60)
|
||||
|
||||
# ---- 1. 加载数据 ----
|
||||
X, y = _load_all_data()
|
||||
if X is None:
|
||||
return
|
||||
|
||||
n_total = len(X)
|
||||
n_train = int(n_total * 0.70)
|
||||
n_val = int(n_total * 0.15)
|
||||
|
||||
X_train = X[:n_train]
|
||||
y_train = y[:n_train]
|
||||
X_test = X[n_train + n_val:]
|
||||
y_test = y[n_train + n_val:]
|
||||
|
||||
print(f"数据划分: 训练 {len(X_train)}, 验证 {n_val} (保留), 测试 {len(X_test)}")
|
||||
|
||||
# ---- 2. LSTM 评估 ----
|
||||
print("\n--- LSTM 评估 ---")
|
||||
lstm_result = evaluate_lstm(X_test, y_test)
|
||||
if lstm_result[0] is None:
|
||||
lstm_preds = None
|
||||
lstm_metrics = None
|
||||
else:
|
||||
lstm_preds, _ = lstm_result
|
||||
lstm_metrics = _compute_metrics(lstm_preds, y_test)
|
||||
for key in HORIZON_KEYS:
|
||||
m = lstm_metrics[key]
|
||||
print(f" LSTM {HORIZON_CN[HORIZON_KEYS.index(key)]}: "
|
||||
f"Accuracy={m['accuracy']:.4f}, F1 Macro={m['f1_macro']:.4f}")
|
||||
|
||||
# ---- 3. XGBoost 基线 ----
|
||||
print("\n--- XGBoost 基线 ---")
|
||||
try:
|
||||
xgb_results = train_xgboost_baseline(X_train, y_train, X_test, y_test)
|
||||
xgb_preds = {
|
||||
key: xgb_results[key]["predictions"] for key in HORIZON_KEYS
|
||||
}
|
||||
xgb_metrics = {
|
||||
key: {
|
||||
"accuracy": xgb_results[key]["accuracy"],
|
||||
"f1_macro": xgb_results[key]["f1_macro"],
|
||||
}
|
||||
for key in HORIZON_KEYS
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"警告: XGBoost 训练失败: {e}", file=sys.stderr)
|
||||
xgb_preds = None
|
||||
xgb_metrics = None
|
||||
|
||||
# ---- 4. 打印对比表 ----
|
||||
if lstm_metrics is not None and xgb_metrics is not None:
|
||||
print("\n" + "=" * 56)
|
||||
print("=== 模型对比 ===")
|
||||
print(f"{'时间尺度':<10} {'指标':<12} {'LSTM':<12} {'XGBoost':<12}")
|
||||
print("-" * 56)
|
||||
for key, cn_name in zip(HORIZON_KEYS, HORIZON_CN):
|
||||
for metric_key, metric_cn_name in [
|
||||
("accuracy", "accuracy"),
|
||||
("f1_macro", "f1_macro"),
|
||||
]:
|
||||
lstm_val = lstm_metrics[key][metric_key]
|
||||
xgb_val = xgb_metrics[key][metric_key]
|
||||
print(f"{cn_name:<10} {metric_cn_name:<12} "
|
||||
f"{lstm_val:<12.4f} {xgb_val:<12.4f}")
|
||||
print("=" * 56)
|
||||
|
||||
# ---- 5. 生成对比图 ----
|
||||
if lstm_preds is not None and xgb_preds is not None:
|
||||
print("\n生成对比图表...")
|
||||
plot_confusion_matrices(lstm_preds, xgb_preds, y_test)
|
||||
plot_metrics_comparison(lstm_metrics, xgb_metrics)
|
||||
else:
|
||||
if lstm_preds is None:
|
||||
print("跳过图表: LSTM 预测不可用", file=sys.stderr)
|
||||
if xgb_preds is None:
|
||||
print("跳过图表: XGBoost 预测不可用", file=sys.stderr)
|
||||
|
||||
print("\n评估完成。")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI 入口
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluate()
|
||||
@@ -0,0 +1,365 @@
|
||||
"""LSTM-Attention 模型训练脚本
|
||||
|
||||
完整的训练流水线:数据加载、训练、验证、早停、测试评估。
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from src.models.lstm_attention import HeatRiskPredictor
|
||||
from src.utils.config import (
|
||||
BATCH_SIZE,
|
||||
DATA_PROCESSED,
|
||||
EARLY_STOP_PATIENCE,
|
||||
LEARNING_RATE,
|
||||
MAX_EPOCHS,
|
||||
OUTPUT_LOGS,
|
||||
OUTPUT_MODELS,
|
||||
)
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""Focal Loss — 聚焦困难样本,缓解类别不平衡"""
|
||||
|
||||
def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
ce = F.cross_entropy(logits, targets, reduction="none")
|
||||
pt = torch.exp(-ce)
|
||||
focal = self.alpha * (1 - pt) ** self.gamma * ce
|
||||
return focal.mean()
|
||||
|
||||
|
||||
def load_data() -> tuple[np.ndarray, np.ndarray]:
|
||||
"""加载焦作和郑州的序列数据,拼接后返回。
|
||||
|
||||
Returns:
|
||||
X: (N, 14, F) 特征数组
|
||||
y: (N, 3) 标签数组,列顺序: short, medium, long (0-3 等级)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 两个城市的 npz 文件均不存在时
|
||||
"""
|
||||
cities = ["jiaozuo", "zhengzhou"]
|
||||
X_parts, y_parts = [], []
|
||||
|
||||
for city in cities:
|
||||
path = DATA_PROCESSED / f"{city}_sequences.npz"
|
||||
if path.exists():
|
||||
data = np.load(path)
|
||||
X_parts.append(data["X"])
|
||||
y_parts.append(data["y"])
|
||||
print(f"已加载 {city}: X {data['X'].shape}, y {data['y'].shape}")
|
||||
else:
|
||||
print(f"警告: {path} 不存在,跳过该城市")
|
||||
|
||||
if not X_parts:
|
||||
msg = (
|
||||
"未找到任何序列数据文件。\n"
|
||||
f"请确保 {DATA_PROCESSED / 'jiaozuo_sequences.npz'} 或 "
|
||||
f"{DATA_PROCESSED / 'zhengzhou_sequences.npz'} 存在。\n"
|
||||
"运行数据预处理流水线以生成这些文件。"
|
||||
)
|
||||
print(msg, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
X = np.concatenate(X_parts, axis=0)
|
||||
y = np.concatenate(y_parts, axis=0)
|
||||
print(f"合并后数据: X {X.shape}, y {y.shape}")
|
||||
return X, y
|
||||
|
||||
|
||||
def _compute_metrics(
|
||||
outputs: dict[str, torch.Tensor], labels: torch.Tensor
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""计算每个预测头的 loss、accuracy 和 macro F1。
|
||||
|
||||
Args:
|
||||
outputs: 模型前向输出,{"short": (B,4), "medium": (B,4), "long": (B,4)}
|
||||
labels: (B, 3) 标签,列顺序: short, medium, long
|
||||
|
||||
Returns:
|
||||
{"short": {"loss": ..., "acc": ..., "f1": ...}, ...}
|
||||
"""
|
||||
loss_fn = FocalLoss()
|
||||
horizons = ["short", "medium", "long"]
|
||||
metrics = {}
|
||||
|
||||
for i, name in enumerate(horizons):
|
||||
logits = outputs[name]
|
||||
targets = labels[:, i]
|
||||
loss_val = loss_fn(logits, targets).item()
|
||||
preds = logits.argmax(dim=1).cpu().numpy()
|
||||
trues = targets.cpu().numpy()
|
||||
acc = accuracy_score(trues, preds)
|
||||
f1 = f1_score(trues, preds, average="macro")
|
||||
metrics[name] = {"loss": loss_val, "acc": acc, "f1": f1}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def train() -> HeatRiskPredictor:
|
||||
"""执行完整的训练流水线。
|
||||
|
||||
Returns:
|
||||
训练好的 HeatRiskPredictor 模型 (best checkpoint)
|
||||
"""
|
||||
# -------------------- 设备 --------------------
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# -------------------- 加载数据 --------------------
|
||||
X, y = load_data()
|
||||
print(f"数据加载完成: X {X.shape}, y {y.shape}")
|
||||
|
||||
# -------------------- 时间序划分 (不打乱) --------------------
|
||||
n_total = len(X)
|
||||
n_train = int(n_total * 0.70)
|
||||
n_val = int(n_total * 0.15)
|
||||
|
||||
X_train_np = X[:n_train]
|
||||
y_train_np = y[:n_train]
|
||||
X_val_np = X[n_train : n_train + n_val]
|
||||
y_val_np = y[n_train : n_train + n_val]
|
||||
X_test_np = X[n_train + n_val :]
|
||||
y_test_np = y[n_train + n_val :]
|
||||
|
||||
print(f"划分: 训练 {len(X_train_np)}, 验证 {len(X_val_np)}, 测试 {len(X_test_np)}")
|
||||
|
||||
# -------------------- Tensor 转换 --------------------
|
||||
X_train_t = torch.tensor(X_train_np, dtype=torch.float32)
|
||||
y_train_t = torch.tensor(y_train_np, dtype=torch.long)
|
||||
X_val_t = torch.tensor(X_val_np, dtype=torch.float32)
|
||||
y_val_t = torch.tensor(y_val_np, dtype=torch.long)
|
||||
X_test_t = torch.tensor(X_test_np, dtype=torch.float32)
|
||||
y_test_t = torch.tensor(y_test_np, dtype=torch.long)
|
||||
|
||||
# -------------------- DataLoader --------------------
|
||||
train_loader = DataLoader(
|
||||
TensorDataset(X_train_t, y_train_t),
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
TensorDataset(X_val_t, y_val_t),
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# -------------------- 模型 --------------------
|
||||
input_dim = X.shape[2]
|
||||
model = HeatRiskPredictor(input_dim=input_dim).to(device)
|
||||
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# -------------------- 损失、优化器、调度器 --------------------
|
||||
focal_loss = FocalLoss()
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
|
||||
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
|
||||
|
||||
# -------------------- 训练状态 --------------------
|
||||
best_val_loss = float("inf")
|
||||
best_epoch = 0
|
||||
patience_counter = 0
|
||||
history: dict[str, list] = {
|
||||
"train_loss": [],
|
||||
"val_loss": [],
|
||||
"val_acc_short": [],
|
||||
"val_acc_medium": [],
|
||||
"val_acc_long": [],
|
||||
}
|
||||
|
||||
OUTPUT_MODELS.mkdir(parents=True, exist_ok=True)
|
||||
best_model_path = OUTPUT_MODELS / "best_model.pt"
|
||||
|
||||
# -------------------- 训练循环 --------------------
|
||||
for epoch in range(1, MAX_EPOCHS + 1):
|
||||
# ---- 训练阶段 ----
|
||||
model.train()
|
||||
train_losses: list[float] = []
|
||||
|
||||
for batch_X, batch_y in train_loader:
|
||||
batch_X = batch_X.to(device)
|
||||
batch_y = batch_y.to(device)
|
||||
|
||||
outputs = model(batch_X)
|
||||
loss_short = focal_loss(outputs["short"], batch_y[:, 0])
|
||||
loss_medium = focal_loss(outputs["medium"], batch_y[:, 1])
|
||||
loss_long = focal_loss(outputs["long"], batch_y[:, 2])
|
||||
loss = (loss_short + loss_medium + loss_long) / 3.0
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
|
||||
train_losses.append(loss.item())
|
||||
|
||||
avg_train_loss = float(np.mean(train_losses))
|
||||
|
||||
# ---- 验证阶段 ----
|
||||
model.eval()
|
||||
val_outputs_all: list[dict] = []
|
||||
val_labels_all: list[torch.Tensor] = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_X, batch_y in val_loader:
|
||||
batch_X = batch_X.to(device)
|
||||
batch_y = batch_y.to(device)
|
||||
outputs = model(batch_X)
|
||||
val_outputs_all.append(
|
||||
{k: v.cpu() for k, v in outputs.items()}
|
||||
)
|
||||
val_labels_all.append(batch_y.cpu())
|
||||
|
||||
# 合并所有 batch
|
||||
merged_outputs = {
|
||||
name: torch.cat([o[name] for o in val_outputs_all], dim=0)
|
||||
for name in ["short", "medium", "long"]
|
||||
}
|
||||
merged_labels = torch.cat(val_labels_all, dim=0)
|
||||
val_metrics = _compute_metrics(merged_outputs, merged_labels)
|
||||
|
||||
avg_val_loss = np.mean(
|
||||
[val_metrics[h]["loss"] for h in ["short", "medium", "long"]]
|
||||
)
|
||||
|
||||
# ---- 记录历史 ----
|
||||
history["train_loss"].append(avg_train_loss)
|
||||
history["val_loss"].append(avg_val_loss)
|
||||
history["val_acc_short"].append(val_metrics["short"]["acc"])
|
||||
history["val_acc_medium"].append(val_metrics["medium"]["acc"])
|
||||
history["val_acc_long"].append(val_metrics["long"]["acc"])
|
||||
|
||||
# ---- 学习率调度 ----
|
||||
scheduler.step(avg_val_loss)
|
||||
|
||||
# ---- 打印进度 ----
|
||||
if epoch % 10 == 0 or epoch == 1:
|
||||
lr_now = optimizer.param_groups[0]["lr"]
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{MAX_EPOCHS} | "
|
||||
f"Train Loss: {avg_train_loss:.4f} | "
|
||||
f"Val Loss: {avg_val_loss:.4f} | "
|
||||
f"Acc S/M/L: "
|
||||
f"{val_metrics['short']['acc']:.3f}/"
|
||||
f"{val_metrics['medium']['acc']:.3f}/"
|
||||
f"{val_metrics['long']['acc']:.3f} | "
|
||||
f"LR: {lr_now:.2e}"
|
||||
)
|
||||
|
||||
# ---- 保存最佳模型 ----
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
best_epoch = epoch
|
||||
patience_counter = 0
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"val_loss": avg_val_loss,
|
||||
"val_metrics": val_metrics,
|
||||
"input_dim": input_dim,
|
||||
},
|
||||
best_model_path,
|
||||
)
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
# ---- 早停 ----
|
||||
if patience_counter >= EARLY_STOP_PATIENCE:
|
||||
print(
|
||||
f"早停触发: {EARLY_STOP_PATIENCE} 轮未改善 "
|
||||
f"(最佳 epoch: {best_epoch}, Val Loss: {best_val_loss:.4f})"
|
||||
)
|
||||
break
|
||||
|
||||
# -------------------- 保存训练历史 --------------------
|
||||
OUTPUT_LOGS.mkdir(parents=True, exist_ok=True)
|
||||
history_path = OUTPUT_LOGS / "training_history.json"
|
||||
with open(history_path, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"best_epoch": best_epoch,
|
||||
"best_val_loss": best_val_loss,
|
||||
**history,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
print(f"训练历史已保存至 {history_path}")
|
||||
|
||||
# -------------------- 加载最佳模型,测试评估 --------------------
|
||||
checkpoint = torch.load(best_model_path, map_location=device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
print("\n========== 测试集评估 ==========")
|
||||
test_dataset = TensorDataset(X_test_t, y_test_t)
|
||||
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
test_outputs_all: list[dict] = []
|
||||
test_labels_all: list[torch.Tensor] = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_X, batch_y in test_loader:
|
||||
batch_X = batch_X.to(device)
|
||||
batch_y = batch_y.to(device)
|
||||
outputs = model(batch_X)
|
||||
test_outputs_all.append(
|
||||
{k: v.cpu() for k, v in outputs.items()}
|
||||
)
|
||||
test_labels_all.append(batch_y.cpu())
|
||||
|
||||
merged_test_outputs = {
|
||||
name: torch.cat([o[name] for o in test_outputs_all], dim=0)
|
||||
for name in ["short", "medium", "long"]
|
||||
}
|
||||
merged_test_labels = torch.cat(test_labels_all, dim=0)
|
||||
test_metrics = _compute_metrics(merged_test_outputs, merged_test_labels)
|
||||
|
||||
# 打印结果
|
||||
for name in ["short", "medium", "long"]:
|
||||
m = test_metrics[name]
|
||||
print(
|
||||
f" {name:>6}: Accuracy={m['acc']:.4f}, "
|
||||
f"F1 Macro={m['f1']:.4f}, Loss={m['loss']:.4f}"
|
||||
)
|
||||
|
||||
avg_acc = np.mean([test_metrics[h]["acc"] for h in ["short", "medium", "long"]])
|
||||
avg_f1 = np.mean([test_metrics[h]["f1"] for h in ["short", "medium", "long"]])
|
||||
print(f"\n平均 Accuracy: {avg_acc:.4f}")
|
||||
print(f"平均 F1 Macro: {avg_f1:.4f}")
|
||||
|
||||
# -------------------- 保存测试预测 --------------------
|
||||
test_predictions: dict[str, np.ndarray] = {}
|
||||
for name in ["short", "medium", "long"]:
|
||||
test_predictions[name] = merged_test_outputs[name].argmax(dim=1).numpy()
|
||||
test_predictions["y_true"] = merged_test_labels.numpy()
|
||||
test_predictions["logits_short"] = merged_test_outputs["short"].numpy()
|
||||
test_predictions["logits_medium"] = merged_test_outputs["medium"].numpy()
|
||||
test_predictions["logits_long"] = merged_test_outputs["long"].numpy()
|
||||
|
||||
pred_path = OUTPUT_MODELS / "test_predictions.npz"
|
||||
np.savez_compressed(pred_path, **test_predictions)
|
||||
print(f"测试预测已保存至 {pred_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
+177
@@ -0,0 +1,177 @@
|
||||
"""高温预警可视化大屏 Flask API 后端"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from flask import Flask, jsonify, send_from_directory
|
||||
from pathlib import Path
|
||||
|
||||
from src.utils.config import OUTPUT_MODELS, DATA_PROCESSED, CITIES
|
||||
from src.models.lstm_attention import HeatRiskPredictor
|
||||
|
||||
app = Flask(__name__, static_folder="static", static_url_path="")
|
||||
|
||||
# --- 全局状态 ---
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = None
|
||||
|
||||
RISK_LABELS = ["低风险", "中风险", "高风险", "严重风险"]
|
||||
RISK_COLORS = ["#00e676", "#ffeb3b", "#ff9800", "#f44336"]
|
||||
SUGGESTIONS = {
|
||||
0: ["天气状况良好,无需特殊防护"],
|
||||
1: ["注意防暑降温", "保持室内通风", "老年人减少午后外出"],
|
||||
2: ["建议开放社区避暑中心", "增加独居老人电话探访频次", "社区志愿者关注高龄老人"],
|
||||
3: ["启动高温应急预案", "社区避暑中心24小时开放", "逐一入户探访独居老人",
|
||||
"医疗机构做好热射病救治准备", "通过社区广播发布高温警报"],
|
||||
}
|
||||
|
||||
AGING_RATES = {"jiaozuo": 12.8, "zhengzhou": 11.6}
|
||||
ELDERLY_POP = {"jiaozuo": 454000, "zhengzhou": 1462000}
|
||||
|
||||
|
||||
def load_model():
|
||||
"""延迟加载模型(首次请求时加载)"""
|
||||
global model
|
||||
if model is not None:
|
||||
return
|
||||
try:
|
||||
# 尝试从序列数据获取输入维度
|
||||
data_path = DATA_PROCESSED / "jiaozuo_sequences.npz"
|
||||
if not data_path.exists():
|
||||
data_path = DATA_PROCESSED / "zhengzhou_sequences.npz"
|
||||
if data_path.exists():
|
||||
data = np.load(data_path, allow_pickle=True)
|
||||
input_dim = data["X"].shape[2]
|
||||
else:
|
||||
input_dim = 16 # 默认特征数
|
||||
model = HeatRiskPredictor(input_dim=input_dim).to(device)
|
||||
model_path = OUTPUT_MODELS / "best_model.pt"
|
||||
if model_path.exists():
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
model.eval()
|
||||
print(f"模型已加载,设备: {device}")
|
||||
except Exception as e:
|
||||
print(f"模型加载失败: {e}")
|
||||
model = None
|
||||
|
||||
|
||||
def get_recent_features() -> np.ndarray:
|
||||
"""获取最近14天特征用于预测"""
|
||||
try:
|
||||
for city in ["jiaozuo", "zhengzhou"]:
|
||||
data_path = DATA_PROCESSED / f"{city}_sequences.npz"
|
||||
if data_path.exists():
|
||||
data = np.load(data_path)
|
||||
return data["X"][-1:] # 最后14天
|
||||
except Exception:
|
||||
pass
|
||||
# 返回随机特征作为fallback
|
||||
return np.random.randn(1, 14, 16).astype(np.float32)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return send_from_directory("static", "index.html")
|
||||
|
||||
|
||||
@app.route("/api/predict")
|
||||
def predict():
|
||||
"""返回多时间尺度预测结果"""
|
||||
load_model()
|
||||
|
||||
X = get_recent_features()
|
||||
X_tensor = torch.FloatTensor(X).to(device)
|
||||
|
||||
predictions = {}
|
||||
if model is not None:
|
||||
with torch.no_grad():
|
||||
outputs = model(X_tensor)
|
||||
for i, key in enumerate(["short", "medium", "long"]):
|
||||
probs = torch.softmax(outputs[key], dim=-1)[0].cpu().numpy()
|
||||
level = int(probs.argmax())
|
||||
predictions[key] = {
|
||||
"level": level,
|
||||
"label": RISK_LABELS[level],
|
||||
"color": RISK_COLORS[level],
|
||||
"confidence": round(float(probs[level]), 4),
|
||||
"probabilities": [round(float(p), 4) for p in probs],
|
||||
"suggestions": SUGGESTIONS[level],
|
||||
}
|
||||
else:
|
||||
# fallback
|
||||
for key in ["short", "medium", "long"]:
|
||||
predictions[key] = {
|
||||
"level": 1, "label": "中风险", "color": "#ffeb3b",
|
||||
"confidence": 0.5, "probabilities": [0.1, 0.5, 0.3, 0.1],
|
||||
"suggestions": SUGGESTIONS[1],
|
||||
}
|
||||
|
||||
return jsonify({
|
||||
"city": "焦作",
|
||||
"date": pd.Timestamp.now().strftime("%Y-%m-%d"),
|
||||
"predictions": predictions,
|
||||
"risk_population": ELDERLY_POP["jiaozuo"],
|
||||
})
|
||||
|
||||
|
||||
@app.route("/api/history")
|
||||
def history():
|
||||
"""返回最近90天历史数据"""
|
||||
try:
|
||||
dfs = []
|
||||
for city in CITIES:
|
||||
csv_path = DATA_PROCESSED / f"{city}_processed.csv"
|
||||
if csv_path.exists():
|
||||
df = pd.read_csv(csv_path, parse_dates=["time"])
|
||||
dfs.append(df)
|
||||
if dfs:
|
||||
df = pd.concat(dfs, ignore_index=True).sort_values("time")
|
||||
recent = df.tail(90)
|
||||
return jsonify({
|
||||
"dates": recent["time"].dt.strftime("%Y-%m-%d").tolist(),
|
||||
"temp_mean": recent["temp_mean"].round(1).tolist() if "temp_mean" in recent.columns else [],
|
||||
"heat_index": recent["heat_index"].round(1).tolist() if "heat_index" in recent.columns else [],
|
||||
"risk_label": recent["risk_label"].astype(int).tolist() if "risk_label" in recent.columns else [],
|
||||
"heatwave": recent["heatwave"].astype(int).tolist() if "heatwave" in recent.columns else [],
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"历史数据加载失败: {e}")
|
||||
return jsonify({"dates": [], "temp_mean": [], "heat_index": [], "risk_label": [], "heatwave": []})
|
||||
|
||||
|
||||
@app.route("/api/stats")
|
||||
def stats():
|
||||
"""返回年度统计摘要"""
|
||||
try:
|
||||
dfs = []
|
||||
for city in CITIES:
|
||||
csv_path = DATA_PROCESSED / f"{city}_processed.csv"
|
||||
if csv_path.exists():
|
||||
df = pd.read_csv(csv_path, parse_dates=["time"])
|
||||
dfs.append(df)
|
||||
if dfs:
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
df["year"] = df["time"].dt.year
|
||||
annual = df.groupby("year").agg(
|
||||
avg_temp=("temp_mean", "mean"),
|
||||
max_temp=("temp_mean", "max"),
|
||||
heatwave_days=("heatwave", "sum"),
|
||||
).reset_index()
|
||||
return jsonify({
|
||||
"annual": {
|
||||
"years": annual["year"].astype(int).tolist(),
|
||||
"avg_temp": annual["avg_temp"].round(1).tolist(),
|
||||
"max_temp": annual["max_temp"].round(1).tolist(),
|
||||
"heatwave_days": annual["heatwave_days"].astype(int).tolist(),
|
||||
},
|
||||
"aging_rate": AGING_RATES,
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"统计数据生成失败: {e}")
|
||||
return jsonify({"annual": {"years": [], "avg_temp": [], "max_temp": [], "heatwave_days": []},
|
||||
"aging_rate": AGING_RATES})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("启动高温预警可视化服务器...")
|
||||
print("访问 http://localhost:5005 查看大屏")
|
||||
app.run(host="0.0.0.0", port=5005, debug=True)
|
||||
@@ -0,0 +1,869 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>银发群体高温多时间尺度预警可视化大屏</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/echarts@5.5.0/dist/echarts.min.js"></script>
|
||||
<style>
|
||||
/* ===== CSS 变量与基础重置 ===== */
|
||||
:root {
|
||||
--bg-primary: #0a1632;
|
||||
--bg-panel: #0d1f3c;
|
||||
--border: #1a3a5c;
|
||||
--text-primary: #e0e6f0;
|
||||
--text-secondary: #8899aa;
|
||||
--text-dim: #556677;
|
||||
--green: #00e676;
|
||||
--yellow: #ffeb3b;
|
||||
--orange: #ff9800;
|
||||
--red: #f44336;
|
||||
--blue: #5b9bd5;
|
||||
--purple: #ab47bc;
|
||||
--cyan: #00bcd4;
|
||||
}
|
||||
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
|
||||
body {
|
||||
font-family: "Microsoft YaHei", "PingFang SC", "Helvetica Neue", sans-serif;
|
||||
background: var(--bg-primary);
|
||||
color: var(--text-primary);
|
||||
overflow: hidden;
|
||||
height: 100vh;
|
||||
}
|
||||
|
||||
/* ===== 网格布局 ===== */
|
||||
#app {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr 1fr;
|
||||
grid-template-rows: auto 1fr 1fr;
|
||||
gap: 10px;
|
||||
height: 100vh;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
/* ===== 页头 ===== */
|
||||
.header {
|
||||
grid-column: 1 / -1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 20px;
|
||||
padding: 10px 0 6px;
|
||||
border-bottom: 2px solid var(--border);
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.header h1 {
|
||||
font-size: 26px;
|
||||
font-weight: 700;
|
||||
letter-spacing: 4px;
|
||||
background: linear-gradient(90deg, var(--cyan), var(--blue), var(--purple));
|
||||
-webkit-background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
background-clip: text;
|
||||
}
|
||||
|
||||
.header .subtitle {
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
position: absolute;
|
||||
right: 16px;
|
||||
bottom: 6px;
|
||||
}
|
||||
|
||||
.header .refresh-dot {
|
||||
width: 8px; height: 8px;
|
||||
border-radius: 50%;
|
||||
background: var(--green);
|
||||
animation: pulse 2s infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { opacity: 1; box-shadow: 0 0 4px var(--green); }
|
||||
50% { opacity: 0.4; box-shadow: 0 0 8px var(--green); }
|
||||
}
|
||||
|
||||
/* ===== 面板通用样式 ===== */
|
||||
.panel {
|
||||
background: var(--bg-panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 6px;
|
||||
padding: 10px 12px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-height: 0;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.panel-title {
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
margin-bottom: 6px;
|
||||
padding-bottom: 6px;
|
||||
border-bottom: 1px solid rgba(26,58,92,0.6);
|
||||
flex-shrink: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.panel-title .icon { font-size: 16px; }
|
||||
|
||||
.chart {
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.empty-state {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: var(--text-dim);
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.loading-state {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: var(--text-secondary);
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
/* ===== Panel 2: 风险仪表盘 ===== */
|
||||
#panel-risk .risk-content {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.risk-circle {
|
||||
width: 90px;
|
||||
height: 90px;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-weight: 700;
|
||||
color: #fff;
|
||||
position: relative;
|
||||
box-shadow: 0 0 20px rgba(0,0,0,0.4);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.risk-circle .risk-label {
|
||||
font-size: 20px;
|
||||
line-height: 1.2;
|
||||
}
|
||||
|
||||
.risk-circle .risk-sub {
|
||||
font-size: 10px;
|
||||
opacity: 0.85;
|
||||
}
|
||||
|
||||
.risk-circle::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
inset: -6px;
|
||||
border-radius: 50%;
|
||||
border: 2px solid;
|
||||
border-color: inherit;
|
||||
opacity: 0.6;
|
||||
}
|
||||
|
||||
.predictions-row {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
width: 100%;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.pred-card {
|
||||
flex: 1;
|
||||
background: rgba(255,255,255,0.03);
|
||||
border-left: 3px solid;
|
||||
border-radius: 0 4px 4px 0;
|
||||
padding: 6px 8px;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.pred-card .pred-scale {
|
||||
color: var(--text-secondary);
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
|
||||
.pred-card .pred-value {
|
||||
font-size: 14px;
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.pred-card .pred-conf {
|
||||
color: var(--text-dim);
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.suggestions-box {
|
||||
width: 100%;
|
||||
background: rgba(255,255,255,0.02);
|
||||
border-radius: 4px;
|
||||
padding: 6px 10px;
|
||||
font-size: 11px;
|
||||
color: var(--text-secondary);
|
||||
flex-shrink: 0;
|
||||
max-height: 90px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.suggestions-box .sug-title {
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
margin-bottom: 3px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.suggestions-box li {
|
||||
margin-left: 16px;
|
||||
margin-bottom: 2px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<div id="app">
|
||||
<!-- 页头 -->
|
||||
<div class="header">
|
||||
<div class="refresh-dot" id="refresh-dot"></div>
|
||||
<h1>银发群体高温多时间尺度预警与服务优化可视化平台</h1>
|
||||
<span class="subtitle" id="update-time">数据加载中...</span>
|
||||
</div>
|
||||
|
||||
<!-- Panel 1: 双城温度趋势图 -->
|
||||
<div class="panel" id="panel-1">
|
||||
<div class="panel-title"><span class="icon">📈</span> 双城温度 & 体感温度趋势</div>
|
||||
<div class="chart" id="chart-1"></div>
|
||||
<div class="empty-state" id="empty-1" style="display:none">暂无历史数据</div>
|
||||
</div>
|
||||
|
||||
<!-- Panel 2: 当前风险等级 -->
|
||||
<div class="panel" id="panel-risk">
|
||||
<div class="panel-title"><span class="icon">⚠</span> 当前风险等级评估</div>
|
||||
<div class="risk-content" id="risk-content">
|
||||
<div class="risk-circle" id="risk-circle">
|
||||
<span class="risk-label" id="risk-label">--</span>
|
||||
<span class="risk-sub" id="risk-city">焦作</span>
|
||||
</div>
|
||||
<div class="predictions-row" id="predictions-row"></div>
|
||||
<div class="suggestions-box" id="suggestions-box">
|
||||
<div class="sug-title">📋 建议措施</div>
|
||||
<ul id="suggestions-list"><li>加载中...</li></ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Panel 3: 老年人口概况 -->
|
||||
<div class="panel" id="panel-3">
|
||||
<div class="panel-title"><span class="icon">👴</span> 老年人口概况</div>
|
||||
<div class="chart" id="chart-3"></div>
|
||||
<div class="empty-state" id="empty-3" style="display:none">暂无人口数据</div>
|
||||
</div>
|
||||
|
||||
<!-- Panel 4: 多时间尺度预警时间线 -->
|
||||
<div class="panel" id="panel-4">
|
||||
<div class="panel-title"><span class="icon">🕐</span> 多时间尺度预警时间线</div>
|
||||
<div class="chart" id="chart-4"></div>
|
||||
<div class="empty-state" id="empty-4" style="display:none">暂无预测数据</div>
|
||||
</div>
|
||||
|
||||
<!-- Panel 5: 温度与健康风险关联 -->
|
||||
<div class="panel" id="panel-5">
|
||||
<div class="panel-title"><span class="icon">📊</span> 温度与健康风险关联(暴露-反应曲线)</div>
|
||||
<div class="chart" id="chart-5"></div>
|
||||
</div>
|
||||
|
||||
<!-- Panel 6: 历年高温事件回顾 -->
|
||||
<div class="panel" id="panel-6">
|
||||
<div class="panel-title"><span class="icon">📅</span> 历年高温事件与热浪天数统计</div>
|
||||
<div class="chart" id="chart-6"></div>
|
||||
<div class="empty-state" id="empty-6" style="display:none">暂无统计数据</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// ===== 全局状态 =====
|
||||
const BASE = '';
|
||||
const RISK_COLORS = ['#00e676', '#ffeb3b', '#ff9800', '#f44336'];
|
||||
const RISK_LABELS = ['低风险', '中风险', '高风险', '严重风险'];
|
||||
const TIME_SCALE_LABELS = { short: '短期(1-3天)', medium: '中期(7天)', long: '长期(30天)' };
|
||||
const TIME_SCALE_ORDER = ['short', 'medium', 'long'];
|
||||
|
||||
// ECharts 实例池
|
||||
const charts = {};
|
||||
|
||||
// ===== 工具函数 =====
|
||||
function darkAxisLine() {
|
||||
return { lineStyle: { color: '#1a3a5c' } };
|
||||
}
|
||||
|
||||
function darkSplitLine() {
|
||||
return { lineStyle: { color: 'rgba(26,58,92,0.3)', type: 'dashed' } };
|
||||
}
|
||||
|
||||
function darkTextStyle() {
|
||||
return { color: '#8899aa', fontSize: 11 };
|
||||
}
|
||||
|
||||
function panelTitleStyle(text) {
|
||||
return {
|
||||
text: text,
|
||||
textStyle: { color: '#c0ccd8', fontSize: 13, fontWeight: 600 },
|
||||
left: 'center', top: 0
|
||||
};
|
||||
}
|
||||
|
||||
// ===== 数据获取 =====
|
||||
async function fetchAllData() {
|
||||
try {
|
||||
const [predictRes, historyRes, statsRes] = await Promise.all([
|
||||
fetch(BASE + '/api/predict').then(r => r.ok ? r.json() : null),
|
||||
fetch(BASE + '/api/history').then(r => r.ok ? r.json() : null),
|
||||
fetch(BASE + '/api/stats').then(r => r.ok ? r.json() : null)
|
||||
]);
|
||||
|
||||
return {
|
||||
predict: predictRes,
|
||||
history: historyRes,
|
||||
stats: statsRes
|
||||
};
|
||||
} catch (e) {
|
||||
console.error('数据获取失败:', e);
|
||||
return { predict: null, history: null, stats: null };
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Panel 1: 双城温度趋势图 (双Y轴折线图) =====
|
||||
function renderPanel1(historyData) {
|
||||
const dom = document.getElementById('chart-1');
|
||||
const emptyEl = document.getElementById('empty-1');
|
||||
if (!historyData || !historyData.dates || historyData.dates.length === 0) {
|
||||
dom.style.display = 'none';
|
||||
emptyEl.style.display = 'flex';
|
||||
return;
|
||||
}
|
||||
dom.style.display = '';
|
||||
emptyEl.style.display = 'none';
|
||||
|
||||
if (!charts.p1) {
|
||||
charts.p1 = echarts.init(dom);
|
||||
}
|
||||
|
||||
const opt = {
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
backgroundColor: 'rgba(13,31,60,0.95)',
|
||||
borderColor: '#1a3a5c',
|
||||
textStyle: { color: '#e0e6f0', fontSize: 12 }
|
||||
},
|
||||
legend: {
|
||||
data: ['平均温度', '体感温度', '风险等级'],
|
||||
bottom: 0,
|
||||
textStyle: { color: '#8899aa', fontSize: 10 },
|
||||
itemWidth: 14, itemHeight: 8
|
||||
},
|
||||
grid: { left: 45, right: 55, top: 12, bottom: 28 },
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: historyData.dates,
|
||||
axisLine: darkAxisLine(),
|
||||
axisTick: { show: false },
|
||||
axisLabel: { color: '#667788', fontSize: 9, rotate: 30,
|
||||
formatter: v => v.slice(5) },
|
||||
splitLine: { show: false }
|
||||
},
|
||||
yAxis: [
|
||||
{
|
||||
type: 'value', name: '°C',
|
||||
nameTextStyle: { color: '#8899aa', fontSize: 10 },
|
||||
axisLine: darkAxisLine(),
|
||||
splitLine: darkSplitLine(),
|
||||
axisLabel: { color: '#8899aa', fontSize: 10 }
|
||||
},
|
||||
{
|
||||
type: 'value', name: '风险',
|
||||
min: 0, max: 3, interval: 1,
|
||||
nameTextStyle: { color: '#ab47bc', fontSize: 10 },
|
||||
axisLine: { lineStyle: { color: '#ab47bc' } },
|
||||
splitLine: { show: false },
|
||||
axisLabel: {
|
||||
color: '#ab47bc', fontSize: 10,
|
||||
formatter: v => RISK_LABELS[v] || ''
|
||||
}
|
||||
}
|
||||
],
|
||||
series: [
|
||||
{
|
||||
name: '平均温度', type: 'line', yAxisIndex: 0,
|
||||
data: historyData.temp_mean,
|
||||
smooth: true, symbol: 'none',
|
||||
lineStyle: { color: '#ff9800', width: 2 },
|
||||
itemStyle: { color: '#ff9800' }
|
||||
},
|
||||
{
|
||||
name: '体感温度', type: 'line', yAxisIndex: 0,
|
||||
data: historyData.heat_index,
|
||||
smooth: true, symbol: 'none',
|
||||
lineStyle: { color: '#f44336', width: 2 },
|
||||
itemStyle: { color: '#f44336' }
|
||||
},
|
||||
{
|
||||
name: '风险等级', type: 'line', yAxisIndex: 1,
|
||||
data: historyData.risk_label,
|
||||
smooth: true, symbol: 'none',
|
||||
lineStyle: { color: '#ab47bc', width: 1.5, type: 'dotted' },
|
||||
itemStyle: { color: '#ab47bc' }
|
||||
}
|
||||
]
|
||||
};
|
||||
charts.p1.setOption(opt, true);
|
||||
}
|
||||
|
||||
// ===== Panel 2: 当前风险等级 (HTML 仪表盘) =====
|
||||
function renderPanel2(predictData) {
|
||||
const circle = document.getElementById('risk-circle');
|
||||
const label = document.getElementById('risk-label');
|
||||
const city = document.getElementById('risk-city');
|
||||
const predRow = document.getElementById('predictions-row');
|
||||
const sugList = document.getElementById('suggestions-list');
|
||||
|
||||
if (!predictData || !predictData.predictions) {
|
||||
circle.style.background = '#555';
|
||||
circle.style.borderColor = '#555';
|
||||
label.textContent = '--';
|
||||
city.textContent = '无数据';
|
||||
predRow.innerHTML = '<span style="color:#667;font-size:11px;text-align:center;width:100%">暂无预测数据</span>';
|
||||
sugList.innerHTML = '<li>暂无建议</li>';
|
||||
return;
|
||||
}
|
||||
|
||||
const preds = predictData.predictions;
|
||||
// 使用短期预测作为当前风险
|
||||
const current = preds.short;
|
||||
const color = RISK_COLORS[current.level] || '#555';
|
||||
|
||||
circle.style.background = color;
|
||||
circle.style.borderColor = color;
|
||||
label.textContent = current.label;
|
||||
city.textContent = predictData.city || '焦作';
|
||||
|
||||
// 三列预测卡片
|
||||
const cardColors = { short: '#5b9bd5', medium: '#ff9800', long: '#ab47bc' };
|
||||
predRow.innerHTML = TIME_SCALE_ORDER.map(key => {
|
||||
const p = preds[key];
|
||||
const c = RISK_COLORS[p.level] || '#555';
|
||||
return '<div class="pred-card" style="border-left-color:' + c + '">' +
|
||||
'<div class="pred-scale">' + TIME_SCALE_LABELS[key] + '</div>' +
|
||||
'<div class="pred-value" style="color:' + c + '">' + p.label + '</div>' +
|
||||
'<div class="pred-conf">置信度: ' + (p.confidence * 100).toFixed(1) + '%</div>' +
|
||||
'</div>';
|
||||
}).join('');
|
||||
|
||||
// 建议列表
|
||||
const suggestions = current.suggestions || ['暂无建议'];
|
||||
sugList.innerHTML = suggestions.map(s => '<li>' + s + '</li>').join('');
|
||||
}
|
||||
|
||||
// ===== Panel 3: 老年人口概况 (环形饼图) =====
|
||||
function renderPanel3(statsData) {
|
||||
const dom = document.getElementById('chart-3');
|
||||
const emptyEl = document.getElementById('empty-3');
|
||||
const agingRate = statsData ? statsData.aging_rate : null;
|
||||
if (!agingRate) {
|
||||
dom.style.display = 'none';
|
||||
emptyEl.style.display = 'flex';
|
||||
return;
|
||||
}
|
||||
dom.style.display = '';
|
||||
emptyEl.style.display = 'none';
|
||||
|
||||
if (!charts.p3) {
|
||||
charts.p3 = echarts.init(dom);
|
||||
}
|
||||
|
||||
const jzRate = agingRate.jiaozuo || 12.8;
|
||||
const zzRate = agingRate.zhengzhou || 11.6;
|
||||
const jzOther = 100 - jzRate;
|
||||
const zzOther = 100 - zzRate;
|
||||
|
||||
const opt = {
|
||||
tooltip: {
|
||||
trigger: 'item',
|
||||
backgroundColor: 'rgba(13,31,60,0.95)',
|
||||
borderColor: '#1a3a5c',
|
||||
textStyle: { color: '#e0e6f0', fontSize: 12 },
|
||||
formatter: '{b}: {c}%'
|
||||
},
|
||||
series: [
|
||||
{
|
||||
name: '焦作', type: 'pie',
|
||||
radius: ['45%', '65%'],
|
||||
center: ['25%', '55%'],
|
||||
label: {
|
||||
position: 'center',
|
||||
formatter: '焦作\n{c}%',
|
||||
fontSize: 13, fontWeight: 700,
|
||||
color: '#e0e6f0'
|
||||
},
|
||||
emphasis: { label: { fontSize: 16 } },
|
||||
data: [
|
||||
{ value: jzRate, name: '老龄化率', itemStyle: { color: '#5b9bd5' } },
|
||||
{ value: jzOther, name: '其他', itemStyle: { color: '#1a3a5c' } }
|
||||
]
|
||||
},
|
||||
{
|
||||
name: '郑州', type: 'pie',
|
||||
radius: ['45%', '65%'],
|
||||
center: ['75%', '55%'],
|
||||
label: {
|
||||
position: 'center',
|
||||
formatter: '郑州\n{c}%',
|
||||
fontSize: 13, fontWeight: 700,
|
||||
color: '#e0e6f0'
|
||||
},
|
||||
emphasis: { label: { fontSize: 16 } },
|
||||
data: [
|
||||
{ value: zzRate, name: '老龄化率', itemStyle: { color: '#ab47bc' } },
|
||||
{ value: zzOther, name: '其他', itemStyle: { color: '#1a3a5c' } }
|
||||
]
|
||||
}
|
||||
],
|
||||
graphic: [
|
||||
{ type: 'text', left: '16%', top: '10%',
|
||||
style: { text: '焦作', fill: '#8899aa', fontSize: 11 } },
|
||||
{ type: 'text', left: '66%', top: '10%',
|
||||
style: { text: '郑州', fill: '#8899aa', fontSize: 11 } },
|
||||
{ type: 'text', left: 'center', top: '88%',
|
||||
style: { text: '65岁及以上人口占比', fill: '#556677', fontSize: 10 } }
|
||||
]
|
||||
};
|
||||
charts.p3.setOption(opt, true);
|
||||
}
|
||||
|
||||
// ===== Panel 4: 多时间尺度预警时间线 (横向条形图) =====
|
||||
function renderPanel4(predictData) {
|
||||
const dom = document.getElementById('chart-4');
|
||||
const emptyEl = document.getElementById('empty-4');
|
||||
if (!predictData || !predictData.predictions) {
|
||||
dom.style.display = 'none';
|
||||
emptyEl.style.display = 'flex';
|
||||
return;
|
||||
}
|
||||
dom.style.display = '';
|
||||
emptyEl.style.display = 'none';
|
||||
|
||||
if (!charts.p4) {
|
||||
charts.p4 = echarts.init(dom);
|
||||
}
|
||||
|
||||
const preds = predictData.predictions;
|
||||
const yLabels = TIME_SCALE_ORDER.map(k => TIME_SCALE_LABELS[k]).reverse();
|
||||
const levels = TIME_SCALE_ORDER.map(k => preds[k].level).reverse();
|
||||
const confs = TIME_SCALE_ORDER.map(k => preds[k].confidence).reverse();
|
||||
const colors = levels.map(l => RISK_COLORS[l]);
|
||||
|
||||
const opt = {
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
backgroundColor: 'rgba(13,31,60,0.95)',
|
||||
borderColor: '#1a3a5c',
|
||||
textStyle: { color: '#e0e6f0', fontSize: 12 },
|
||||
formatter: function(params) {
|
||||
const p = params[0];
|
||||
const idx = p.dataIndex;
|
||||
return '<b>' + yLabels[idx] + '</b><br/>' +
|
||||
'风险等级: ' + RISK_LABELS[levels[idx]] + '<br/>' +
|
||||
'等级值: ' + levels[idx] + ' / 3<br/>' +
|
||||
'置信度: ' + (confs[idx] * 100).toFixed(1) + '%';
|
||||
}
|
||||
},
|
||||
grid: { left: 110, right: 40, top: 16, bottom: 16 },
|
||||
xAxis: {
|
||||
type: 'value', name: '风险等级',
|
||||
min: 0, max: 3, interval: 1,
|
||||
nameTextStyle: { color: '#8899aa', fontSize: 10 },
|
||||
axisLine: darkAxisLine(),
|
||||
splitLine: darkSplitLine(),
|
||||
axisLabel: {
|
||||
color: '#8899aa', fontSize: 10,
|
||||
formatter: v => RISK_LABELS[v] || ''
|
||||
}
|
||||
},
|
||||
yAxis: {
|
||||
type: 'category',
|
||||
data: yLabels,
|
||||
axisLine: { lineStyle: { color: '#1a3a5c' } },
|
||||
axisTick: { show: false },
|
||||
axisLabel: { color: '#c0ccd8', fontSize: 12, fontWeight: 600 }
|
||||
},
|
||||
series: [{
|
||||
type: 'bar',
|
||||
data: levels.map((l, i) => ({
|
||||
value: l,
|
||||
itemStyle: {
|
||||
color: new echarts.graphic.LinearGradient(0, 0, 1, 0, [
|
||||
{ offset: 0, color: colors[i] },
|
||||
{ offset: 1, color: colors[i] + '88' }
|
||||
]),
|
||||
borderRadius: [0, 4, 4, 0]
|
||||
}
|
||||
})),
|
||||
barWidth: 22,
|
||||
label: {
|
||||
show: true, position: 'right',
|
||||
color: '#c0ccd8', fontSize: 11,
|
||||
formatter: function(p) {
|
||||
return RISK_LABELS[levels[p.dataIndex]];
|
||||
}
|
||||
},
|
||||
emphasis: {
|
||||
itemStyle: { shadowBlur: 10, shadowColor: 'rgba(0,0,0,0.4)' }
|
||||
}
|
||||
}]
|
||||
};
|
||||
charts.p4.setOption(opt, true);
|
||||
}
|
||||
|
||||
// ===== Panel 5: 温度与健康风险关联 -- 暴露-反应曲线 (硬编码文献数据) =====
|
||||
function renderPanel5() {
|
||||
const dom = document.getElementById('chart-5');
|
||||
if (!charts.p5) {
|
||||
charts.p5 = echarts.init(dom);
|
||||
}
|
||||
|
||||
// 文献中的典型温度-死亡风险暴露反应曲线 (J型曲线)
|
||||
// 参考: 中国多城市温度-死亡率研究, 以25°C为最低风险参考温度 (RR=1.0)
|
||||
var temps = [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40];
|
||||
var rr = [1.08, 1.06, 1.04, 1.03, 1.01, 1.00, 0.99, 0.98, 0.98, 0.99, 1.00, 1.00, 1.01, 1.03, 1.06, 1.10, 1.15, 1.22, 1.30, 1.40, 1.52, 1.66, 1.82, 2.00, 2.20, 2.42];
|
||||
|
||||
// 构建高于参考线的面积填充数据
|
||||
var areaData = temps.map(function(t, i) {
|
||||
return [t, Math.max(rr[i], 1.0)];
|
||||
});
|
||||
|
||||
var opt = {
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
backgroundColor: 'rgba(13,31,60,0.95)',
|
||||
borderColor: '#1a3a5c',
|
||||
textStyle: { color: '#e0e6f0', fontSize: 12 },
|
||||
formatter: function(params) {
|
||||
var p = params[0];
|
||||
return '温度: <b>' + p.axisValue + '°C</b><br/>' +
|
||||
'相对风险(RR): <b style="color:' + (p.value > 1.0 ? '#f44336' : '#00e676') + '">' +
|
||||
p.value.toFixed(2) + '</b>';
|
||||
}
|
||||
},
|
||||
grid: { left: 50, right: 30, top: 16, bottom: 24 },
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: temps,
|
||||
name: '温度 (°C)',
|
||||
nameTextStyle: { color: '#8899aa', fontSize: 10 },
|
||||
axisLine: darkAxisLine(),
|
||||
axisLabel: { color: '#8899aa', fontSize: 9, interval: 2 },
|
||||
splitLine: { show: false }
|
||||
},
|
||||
yAxis: {
|
||||
type: 'value',
|
||||
name: '相对风险 (RR)',
|
||||
nameTextStyle: { color: '#8899aa', fontSize: 10 },
|
||||
axisLine: darkAxisLine(),
|
||||
splitLine: darkSplitLine(),
|
||||
axisLabel: { color: '#8899aa', fontSize: 10 },
|
||||
min: 0.9
|
||||
},
|
||||
series: [
|
||||
{
|
||||
name: '风险面积', type: 'line',
|
||||
data: areaData,
|
||||
smooth: true, symbol: 'none',
|
||||
lineStyle: { opacity: 0 },
|
||||
areaStyle: { color: 'rgba(244,67,54,0.15)' },
|
||||
silent: true
|
||||
},
|
||||
{
|
||||
name: '相对风险', type: 'line',
|
||||
data: rr,
|
||||
smooth: true, symbol: 'none',
|
||||
lineStyle: { color: '#ff9800', width: 2.5 },
|
||||
itemStyle: { color: '#ff9800' },
|
||||
markLine: {
|
||||
silent: true,
|
||||
symbol: 'none',
|
||||
lineStyle: { color: '#00e676', type: 'dashed', width: 1.5 },
|
||||
label: { color: '#00e676', fontSize: 10, formatter: 'RR=1.0 参考线' },
|
||||
data: [{ yAxis: 1.0 }]
|
||||
},
|
||||
markArea: {
|
||||
silent: true,
|
||||
data: [
|
||||
[
|
||||
{ xAxis: '30', itemStyle: { color: 'rgba(244,67,54,0.06)' } },
|
||||
{ xAxis: '40' }
|
||||
]
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
};
|
||||
charts.p5.setOption(opt, true);
|
||||
}
|
||||
|
||||
// ===== Panel 6: 历年高温事件与热浪天数统计 (柱状+双折线组合图) =====
|
||||
function renderPanel6(statsData) {
|
||||
var dom = document.getElementById('chart-6');
|
||||
var emptyEl = document.getElementById('empty-6');
|
||||
var annual = statsData ? statsData.annual : null;
|
||||
if (!annual || !annual.years || annual.years.length === 0) {
|
||||
dom.style.display = 'none';
|
||||
emptyEl.style.display = 'flex';
|
||||
return;
|
||||
}
|
||||
dom.style.display = '';
|
||||
emptyEl.style.display = 'none';
|
||||
|
||||
if (!charts.p6) {
|
||||
charts.p6 = echarts.init(dom);
|
||||
}
|
||||
|
||||
var opt = {
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
backgroundColor: 'rgba(13,31,60,0.95)',
|
||||
borderColor: '#1a3a5c',
|
||||
textStyle: { color: '#e0e6f0', fontSize: 12 }
|
||||
},
|
||||
legend: {
|
||||
data: ['热浪天数', '平均温度', '最高温度'],
|
||||
bottom: 0,
|
||||
textStyle: { color: '#8899aa', fontSize: 10 },
|
||||
itemWidth: 14, itemHeight: 8
|
||||
},
|
||||
grid: { left: 45, right: 55, top: 12, bottom: 28 },
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: annual.years.map(String),
|
||||
axisLine: darkAxisLine(),
|
||||
axisLabel: { color: '#8899aa', fontSize: 10 },
|
||||
splitLine: { show: false }
|
||||
},
|
||||
yAxis: [
|
||||
{
|
||||
type: 'value', name: '°C',
|
||||
nameTextStyle: { color: '#8899aa', fontSize: 10 },
|
||||
axisLine: darkAxisLine(),
|
||||
splitLine: darkSplitLine(),
|
||||
axisLabel: { color: '#8899aa', fontSize: 10 }
|
||||
},
|
||||
{
|
||||
type: 'value', name: '天',
|
||||
nameTextStyle: { color: '#ffeb3b', fontSize: 10 },
|
||||
axisLine: { lineStyle: { color: '#ffeb3b' } },
|
||||
splitLine: { show: false },
|
||||
axisLabel: { color: '#ffeb3b', fontSize: 10 }
|
||||
}
|
||||
],
|
||||
series: [
|
||||
{
|
||||
name: '热浪天数', type: 'bar', yAxisIndex: 1,
|
||||
data: annual.heatwave_days,
|
||||
barWidth: '40%',
|
||||
itemStyle: {
|
||||
color: new echarts.graphic.LinearGradient(0, 0, 0, 1, [
|
||||
{ offset: 0, color: '#ffeb3b' },
|
||||
{ offset: 1, color: '#ff980088' }
|
||||
])
|
||||
}
|
||||
},
|
||||
{
|
||||
name: '平均温度', type: 'line', yAxisIndex: 0,
|
||||
data: annual.avg_temp,
|
||||
smooth: true, symbol: 'circle', symbolSize: 5,
|
||||
lineStyle: { color: '#5b9bd5', width: 2 },
|
||||
itemStyle: { color: '#5b9bd5' }
|
||||
},
|
||||
{
|
||||
name: '最高温度', type: 'line', yAxisIndex: 0,
|
||||
data: annual.max_temp,
|
||||
smooth: true, symbol: 'diamond', symbolSize: 6,
|
||||
lineStyle: { color: '#f44336', width: 2 },
|
||||
itemStyle: { color: '#f44336' }
|
||||
}
|
||||
]
|
||||
};
|
||||
charts.p6.setOption(opt, true);
|
||||
}
|
||||
|
||||
// ===== 更新时间戳 =====
|
||||
function updateTimestamp() {
|
||||
var now = new Date();
|
||||
var str = now.getFullYear() + '-' +
|
||||
String(now.getMonth() + 1).padStart(2, '0') + '-' +
|
||||
String(now.getDate()).padStart(2, '0') + ' ' +
|
||||
String(now.getHours()).padStart(2, '0') + ':' +
|
||||
String(now.getMinutes()).padStart(2, '0') + ':' +
|
||||
String(now.getSeconds()).padStart(2, '0');
|
||||
document.getElementById('update-time').textContent = '数据更新: ' + str;
|
||||
}
|
||||
|
||||
// ===== 初始化所有图表 =====
|
||||
async function initCharts() {
|
||||
var data = await fetchAllData();
|
||||
|
||||
renderPanel1(data.history);
|
||||
renderPanel2(data.predict);
|
||||
renderPanel3(data.stats);
|
||||
renderPanel4(data.predict);
|
||||
renderPanel5();
|
||||
renderPanel6(data.stats);
|
||||
|
||||
updateTimestamp();
|
||||
}
|
||||
|
||||
// ===== 响应式处理 =====
|
||||
function resizeAllCharts() {
|
||||
Object.values(charts).forEach(function(c) {
|
||||
if (c && !c.isDisposed()) {
|
||||
c.resize();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// ===== 入口 =====
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
initCharts();
|
||||
|
||||
// 每30分钟自动刷新
|
||||
setInterval(initCharts, 30 * 60 * 1000);
|
||||
|
||||
// 窗口缩放时重绘所有图表
|
||||
window.addEventListener('resize', function() {
|
||||
clearTimeout(window._resizeTimer);
|
||||
window._resizeTimer = setTimeout(resizeAllCharts, 200);
|
||||
});
|
||||
});
|
||||
</script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,19 @@
|
||||
MAIN = main
|
||||
LATEX = xelatex
|
||||
BIBER = biber
|
||||
|
||||
all: $(MAIN).pdf
|
||||
|
||||
$(MAIN).pdf: $(MAIN).tex chapters/*.tex refs.bib
|
||||
$(LATEX) $(MAIN).tex
|
||||
$(BIBER) $(MAIN)
|
||||
$(LATEX) $(MAIN).tex
|
||||
$(LATEX) $(MAIN).tex
|
||||
|
||||
clean:
|
||||
rm -f *.aux *.log *.out *.toc *.bbl *.blg *.synctex.gz *.fdb_latexmk *.fls *.run.xml *.bcf
|
||||
|
||||
distclean: clean
|
||||
rm -f $(MAIN).pdf
|
||||
|
||||
.PHONY: all clean distclean
|
||||
@@ -0,0 +1,24 @@
|
||||
\chapter*{摘要}
|
||||
\addcontentsline{toc}{chapter}{摘要}
|
||||
|
||||
随着全球气候变暖,高温热浪事件频发,对老年群体的健康构成严重威胁。本研究以焦作市和郑州市为研究区域,利用ERA5-Land气象再分析数据和人口健康统计数据,构建了基于LSTM-Attention的多时间尺度高温健康风险预警模型,并开发了可视化大屏系统。
|
||||
|
||||
本研究主要工作包括:(1)获取并预处理2010-2024年焦作、郑州两市的ERA5-Land气象数据,结合人口普查和卫生统计年鉴数据,构建了温度-健康风险关联数据集;(2)设计了LSTM结合多头自注意力机制的深度学习模型,实现了短期(1-3天)、中期(7天)和长期(30天)三个时间尺度的风险等级预测;(3)以XGBoost作为基线模型进行对比实验,验证了深度学习方法的有效性;(4)基于Flask和ECharts开发了深色科技蓝风格的Web可视化大屏,实现了温度趋势、风险预警、人口概况等信息的多维度展示。
|
||||
|
||||
实验结果表明,LSTM-Attention模型在短期和中期预警任务上优于传统机器学习方法,能够为高温热浪健康风险管理提供有效的决策支持。
|
||||
|
||||
\textbf{关键词:}高温热浪;银发群体;多时间尺度预警;LSTM-Attention;可视化
|
||||
|
||||
\newpage
|
||||
|
||||
\chapter*{Abstract}
|
||||
\addcontentsline{toc}{chapter}{Abstract}
|
||||
|
||||
With global warming, frequent heatwave events pose serious threats to the health of the elderly population. This study takes Jiaozuo and Zhengzhou as research areas, utilizes ERA5-Land meteorological reanalysis data and population health statistics to construct an LSTM-Attention based multi-time-scale heat health risk early warning model, and develops a visualization dashboard system.
|
||||
|
||||
The main contributions include: (1) acquisition and preprocessing of ERA5-Land meteorological data (2010-2024) for both cities, combined with census and health statistics data; (2) design of a deep learning model combining LSTM with multi-head self-attention for risk prediction at three time scales (short/medium/long term); (3) comparative experiments with XGBoost baseline to validate the deep learning approach; (4) development of a Flask+ECharts web dashboard with dark tech-blue theme for multi-dimensional visualization.
|
||||
|
||||
Experimental results show that the LSTM-Attention model outperforms traditional methods in short and medium-term early warning tasks, providing effective decision support for heatwave health risk management.
|
||||
|
||||
\textbf{Keywords:} Heatwave; Elderly Population; Multi-time-scale Early Warning; LSTM-Attention; Visualization
|
||||
\newpage
|
||||
@@ -0,0 +1,35 @@
|
||||
\chapter{绪论}
|
||||
|
||||
\section{研究背景与意义}
|
||||
|
||||
全球气候变暖导致极端高温事件频发,对公共卫生构成严峻挑战。老年群体(65岁及以上)由于体温调节功能下降、慢性病患病率高等原因,是高温热浪最脆弱的群体之一。焦作市和郑州市地处中原地区,夏季高温天气频繁,老龄化率分别达12.8\%和11.6\%,亟需建立科学的高温健康风险预警体系。
|
||||
|
||||
本研究的意义在于:(1)利用深度学习技术提升高温健康风险预测的精度和多时间尺度覆盖能力;(2)通过可视化大屏为政府和社区提供直观的决策支持工具;(3)为中原地区高温热浪健康防护提供科学依据。
|
||||
|
||||
\section{国内外研究现状}
|
||||
|
||||
\subsection{高温热浪健康效应研究}
|
||||
|
||||
温度与死亡率的关联通常呈J型或V型曲线,高温端的相对风险显著升高。Gasparrini等(2015)在Lancet发表的多国多城市研究系统揭示了温度-死亡关联的时空特征。Chen等(2018)在Lancet Planetary Health发表了中国多城市研究,为中国人群温度健康风险提供了本土化证据。
|
||||
|
||||
\subsection{环境健康预警系统研究}
|
||||
|
||||
国际上,多个国家已建立高温健康预警系统(HHWS),如法国国家高温预警计划、美国NOAA高温健康预警等。国内方面,中国气象局发布了高温预警信号体系,上海、深圳等城市开展了高温健康预警试点。
|
||||
|
||||
\subsection{多时间尺度预测方法}
|
||||
|
||||
传统的时间序列预测方法包括ARIMA、指数平滑等。随着深度学习的发展,LSTM等循环神经网络在时序预测中展现出优势。Vaswani等(2017)提出的Transformer架构中的自注意力机制能够有效捕捉长时间依赖关系。
|
||||
|
||||
\section{研究内容与技术路线}
|
||||
|
||||
本研究主要内容包括:
|
||||
\begin{enumerate}
|
||||
\item 多源数据获取与预处理:ERA5气象再分析数据、人口普查数据、卫生统计数据
|
||||
\item 多时间尺度预警模型构建:LSTM-Attention深度学习模型 + XGBoost基线模型
|
||||
\item 预警可视化系统开发:Flask后端 + ECharts前端大屏
|
||||
\item 模型评估与对比分析
|
||||
\end{enumerate}
|
||||
|
||||
\section{论文组织结构}
|
||||
|
||||
本论文共分七章。第一章介绍研究背景和现状;第二章阐述相关理论基础;第三章描述数据获取和预处理过程;第四章详细介绍预警模型设计;第五章展示可视化系统实现;第六章进行实验结果分析;第七章总结全文并展望未来工作。
|
||||
@@ -0,0 +1,47 @@
|
||||
\chapter{相关理论与技术基础}
|
||||
|
||||
\section{LSTM神经网络}
|
||||
|
||||
长短期记忆网络(Long Short-Term Memory,LSTM)是Hochreiter和Schmidhuber于1997年提出的一种特殊的循环神经网络(RNN)变体,旨在解决传统RNN在处理长序列数据时面临的梯度消失和梯度爆炸问题。
|
||||
|
||||
LSTM的核心思想是引入门控机制(gating mechanism),包括遗忘门(forget gate)、输入门(input gate)和输出门(output gate),通过这三个门的协同工作,LSTM能够选择性地记忆或遗忘信息,从而有效地捕捉时间序列中的长期依赖关系。
|
||||
|
||||
\subsection{LSTM单元结构}
|
||||
|
||||
LSTM单元通过细胞状态(cell state)和隐藏状态(hidden state)进行信息的传递与更新,其前向传播过程由以下公式描述:
|
||||
|
||||
遗忘门控制上一时刻细胞状态的遗忘程度,输入门决定当前输入信息中有多少写入细胞状态,输出门控制细胞状态对当前隐藏状态的输出比例。
|
||||
|
||||
\section{注意力机制}
|
||||
|
||||
注意力机制(Attention Mechanism)的核心思想源于人类视觉系统对信息的筛选性关注,即在处理大量输入信息时,能够动态地为不同部分分配不同的重要性权重。
|
||||
|
||||
Vaswani等人在2017年提出的Transformer架构中,将注意力机制推向了新的高度。多头自注意力机制(Multi-Head Self-Attention)允许模型从多个不同的表示子空间中联合关注序列中不同位置的信息,从而更全面地捕捉序列内部的复杂依赖关系。
|
||||
|
||||
\subsection{缩放点积注意力}
|
||||
|
||||
缩放点积注意力(Scaled Dot-Product Attention)是多头注意力的基础计算单元,其计算过程为:将查询(Query)和键(Key)进行点积运算,除以维度平方根进行缩放,经Softmax归一化后与值(Value)加权求和。
|
||||
|
||||
\subsection{多头自注意力}
|
||||
|
||||
多头自注意力将查询、键、值分别通过多个线性投影映射到不同的子空间,在每个子空间中独立计算注意力,最后将各头的输出拼接并线性变换,使得模型能够从多个角度捕捉输入序列的特征。
|
||||
|
||||
\section{XGBoost算法}
|
||||
|
||||
XGBoost(eXtreme Gradient Boosting)是Chen和Guestrin于2016年提出的梯度提升树算法的优化实现,在机器学习竞赛和工业应用中取得了巨大成功。
|
||||
|
||||
XGBoost的核心优势包括:(1)正则化的目标函数,有效防止过拟合;(2)二阶泰勒展开近似损失函数,提升收敛速度;(3)支持列采样和行采样,增强泛化能力;(4)内置交叉验证和早停机制;(5)支持并行化计算和分布式训练。
|
||||
|
||||
\section{高温热浪定义与健康风险}
|
||||
|
||||
世界气象组织(WMO)将高温热浪定义为日最高气温连续3天以上超过32℃的天气过程。中国气象局的定义为日最高气温达到或超过35℃且持续3天以上。
|
||||
|
||||
\subsection{健康风险等级划分}
|
||||
|
||||
参考相关研究和公共卫生实践,高温健康风险等级通常分为:低风险(注意)、中风险(关注)、高风险(警戒)、极高风险(紧急)四个等级,分别对应不同的防护措施和应急预案。
|
||||
|
||||
\section{Flask框架与ECharts可视化}
|
||||
|
||||
Flask是一个轻量级的Python Web框架,以其简洁性和灵活性著称,适合中小型Web应用的快速开发。本研究使用Flask作为后端服务框架,提供RESTful API接口。
|
||||
|
||||
ECharts是百度开源的基于JavaScript的数据可视化库,支持丰富的图表类型和高度的交互性,广泛应用于数据大屏和商业智能领域。本研究使用ECharts实现Web端的多维度可视化展示。
|
||||
@@ -0,0 +1,71 @@
|
||||
\chapter{数据获取与预处理}
|
||||
|
||||
\section{研究区域概况}
|
||||
|
||||
本研究选取焦作市和郑州市作为研究区域。两市位于河南省中部偏北,属于暖温带大陆性季风气候,夏季炎热多雨,冬季寒冷干燥,年平均气温约14-15℃,7月平均气温可达27-28℃,极端高温超过40℃。
|
||||
|
||||
焦作市总面积4071平方公里,常住人口约352万,其中65岁及以上人口占比约12.8\%。郑州市作为河南省省会,总面积7446平方公里,常住人口约1274万,老龄化率约11.6\%。两市的城镇化率均超过65\%,城市热岛效应与人口老龄化叠加,使得高温健康防护问题尤为突出。
|
||||
|
||||
\section{数据来源}
|
||||
|
||||
\subsection{ERA5-Land气象再分析数据}
|
||||
|
||||
ERA5-Land是欧洲中期天气预报中心(ECMWF)提供的全球陆地表面再分析数据集,空间分辨率为0.1°×0.1°(约9 km),时间分辨率最高为1小时。本研究通过Copernicus Climate Data Store (CDS) API获取2010-2024年间焦作市和郑州市的网格点气象数据。
|
||||
|
||||
获取的气象变量包括:
|
||||
\begin{itemize}
|
||||
\item 2m温度(2m temperature)
|
||||
\item 2m露点温度(2m dewpoint temperature)
|
||||
\item 地表气压(surface pressure)
|
||||
\item 10m风速U分量和V分量
|
||||
\item 总降水量(total precipitation)
|
||||
\item 地表太阳辐射(surface solar radiation downwards)
|
||||
\end{itemize}
|
||||
|
||||
\subsection{人口与健康数据}
|
||||
|
||||
人口数据来源于第七次全国人口普查公报(2020年),包括分年龄段人口结构、老龄化率等基础指标。健康统计数据来源于河南省卫生健康统计年鉴,包括各月死亡人数、门急诊就诊人次等。
|
||||
|
||||
\subsection{高温预警与极端天气历史记录}
|
||||
|
||||
收集焦作市和郑州市2010-2024年高温预警发布记录和极端天气事件记录,用于标注和验证模型预警的准确性。
|
||||
|
||||
\section{数据预处理}
|
||||
|
||||
\subsection{时间分辨率统一}
|
||||
|
||||
原始ERA5-Land数据为小时级别,需将其聚合为日尺度数据。对于温度变量,计算日最大值、最小值和平均值;对于降水量、太阳辐射等累积变量,计算日总量。
|
||||
|
||||
\subsection{缺失值处理}
|
||||
|
||||
由于CDS API下载过程中可能产生网络中断导致部分时段数据缺失,采用线性插值和前后日平均值填充相结合的方法处理缺失值。若连续缺失超过30天,则使用历史同期多年平均值进行填充。
|
||||
|
||||
\subsection{异常值检测}
|
||||
|
||||
对温度数据中的异常值进行检测和修正。温度超出历史同期均值±3倍标准差范围的被视为异常值,采用前后值线性插值修正。
|
||||
|
||||
\subsection{特征工程}
|
||||
|
||||
在基础气象变量的基础上,构建以下衍生特征:
|
||||
\begin{itemize}
|
||||
\item 热浪指数:日最高温度连续超过阈值(32℃/35℃)的天数
|
||||
\item 昼夜温差:日最高温度与日最低温度之差
|
||||
\item 连续高温天数:日最高温度超过35℃的连续天数
|
||||
\item 湿热指数:结合温度和湿度计算的体感温度
|
||||
\item 季节编码:月份的正弦/余弦编码
|
||||
\item 滞后特征:前1天、前3天、前7天的温度值
|
||||
\end{itemize}
|
||||
|
||||
\section{数据集构建}
|
||||
|
||||
\subsection{样本构造}
|
||||
|
||||
采用滑动窗口方法构造监督学习样本。以历史N天的气象特征序列为输入,以未来T天的健康风险等级为目标变量。分别构建短期(输入7天,输出1-3天)、中期(输入30天,输出7天)和长期(输入90天,输出30天)三个时间尺度的数据集。
|
||||
|
||||
\subsection{训练集与测试集划分}
|
||||
|
||||
采用时间序列划分方法,使用2010-2019年数据作为训练集,2020-2022年数据作为验证集,2023-2024年数据作为测试集,以模拟真实预测场景。
|
||||
|
||||
\subsection{数据归一化}
|
||||
|
||||
对所有数值型特征采用Z-score标准化(均值为0,标准差为1),标准化参数基于训练集计算并应用于验证集和测试集。
|
||||
@@ -0,0 +1,62 @@
|
||||
\chapter{多时间尺度预警模型设计}
|
||||
|
||||
\section{模型总体架构}
|
||||
|
||||
本研究设计了基于LSTM-Attention的多时间尺度高温健康风险预警模型,整体架构包括四个主要模块:输入层(多维气象特征序列)、LSTM编码层(时序特征提取)、多头自注意力层(关键时间步加权)和输出层(多时间尺度风险预测)。
|
||||
|
||||
\section{LSTM编码层}
|
||||
|
||||
\subsection{时序特征提取}
|
||||
|
||||
LSTM编码层接收经过标准化的多维气象特征序列,通过两层堆叠的LSTM网络逐步提取时序中的高级特征表示。第一层LSTM以50个隐藏单元对输入序列进行初步编码,第二层LSTM以50个隐藏单元对第一层的输出进行更深层次的时序模式挖掘。
|
||||
|
||||
\subsection{Dropout正则化}
|
||||
|
||||
在每层LSTM之后加入Dropout层,丢弃概率设为0.3,以防止模型在训练集上过拟合。
|
||||
|
||||
\section{多头自注意力层}
|
||||
|
||||
\subsection{注意力计算}
|
||||
|
||||
在LSTM编码器的输出之上,应用多头自注意力机制(head=4),使模型能够自动学习输入序列中不同时间步对预测目标的重要性权重。通过注意力机制,模型可以重点关注高温连续天数、温度突变点等对健康风险影响较大的关键时段。
|
||||
|
||||
\subsection{残差连接与层归一化}
|
||||
|
||||
参照Transformer架构,在多头注意力子层后加入残差连接和层归一化,以加速训练收敛并提升模型稳定性。
|
||||
|
||||
\section{多任务输出层}
|
||||
|
||||
考虑到短期、中期和长期预警任务之间的关联性,输出层采用多任务学习(Multi-Task Learning)架构,共享LSTM编码层和注意力层的特征表示,通过三个独立的全连接头分别输出不同时间尺度的风险等级预测。
|
||||
|
||||
每个输出头包括两个全连接层:第一层将注意力池化后的特征映射到32维,第二层输出目标时间尺度的预测结果。
|
||||
|
||||
\section{损失函数与优化器}
|
||||
|
||||
\subsection{损失函数}
|
||||
|
||||
对于多分类风险等级预测任务,采用交叉熵损失函数(Cross-Entropy Loss)。三个任务的损失按相等权重加权求和,总损失定义为:
|
||||
|
||||
\[
|
||||
\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{short}} + \mathcal{L}_{\text{medium}} + \mathcal{L}_{\text{long}}
|
||||
\]
|
||||
|
||||
\subsection{优化器与学习率策略}
|
||||
|
||||
使用Adam优化器,初始学习率设为0.001。训练过程中采用ReduceLROnPlateau学习率衰减策略,当验证损失连续10个epoch未下降时,学习率减半。同时设置早停(Early Stopping)策略,验证损失连续25个epoch未下降时终止训练。
|
||||
|
||||
\section{基线模型:XGBoost}
|
||||
|
||||
为评估深度学习方法的有效性,选用XGBoost作为基线模型进行对比实验。XGBoost输入为展平后的特征向量(所有时间步特征拼接),输出与LSTM-Attention模型保持一致。
|
||||
|
||||
XGBoost的关键超参数包括:树的数量(n\_estimators=200)、最大深度(max\_depth=6)、学习率(learning\_rate=0.1)和子采样率(subsample=0.8),通过5折交叉验证在训练集上选择最优超参数。
|
||||
|
||||
\section{评估指标}
|
||||
|
||||
采用以下指标评估模型性能:
|
||||
\begin{itemize}
|
||||
\item 准确率(Accuracy):预测正确的样本占总样本的比例
|
||||
\item 精确率(Precision):被预测为某风险等级的样本中真正属于该等级的比例
|
||||
\item 召回率(Recall):某风险等级的样本中被正确预测的比例
|
||||
\item F1分数(F1-Score):精确率与召回率的调和平均
|
||||
\item 宏平均(Macro Average):各类别指标的算术平均,适用于类别不均衡场景
|
||||
\end{itemize}
|
||||
@@ -0,0 +1,75 @@
|
||||
\chapter{预警可视化系统设计与实现}
|
||||
|
||||
\section{系统需求分析}
|
||||
|
||||
\subsection{功能需求}
|
||||
|
||||
高温健康风险预警可视化系统的主要功能需求包括:实时气象数据展示、温度变化趋势分析、多时间尺度风险预警展示、人口与健康数据概览、历史数据查询和预警发布管理。
|
||||
|
||||
\subsection{非功能需求}
|
||||
|
||||
系统应具备以下非功能特性:(1)响应式布局,适配不同尺寸的显示设备;(2)数据更新延迟不超过5分钟;(3)可视化渲染流畅,页面加载时间不超过3秒;(4)界面采用深色科技蓝风格,符合数据大屏的视觉规范。
|
||||
|
||||
\section{系统架构设计}
|
||||
|
||||
系统采用B/S(Browser/Server)架构,分为三层:
|
||||
|
||||
\begin{itemize}
|
||||
\item \textbf{数据层}:负责气象数据、人口数据和模型预测结果的存储与管理
|
||||
\item \textbf{服务层}:基于Flask框架的Web后端,提供RESTful API,包括数据查询、模型推理和预警推送
|
||||
\item \textbf{展示层}:基于HTML+CSS+JavaScript的Web前端,使用ECharts进行数据可视化
|
||||
\end{itemize}
|
||||
|
||||
\section{后端实现}
|
||||
|
||||
\subsection{Flask应用结构}
|
||||
|
||||
Flask应用采用蓝图(Blueprint)模块化组织,主要模块包括:
|
||||
\begin{itemize}
|
||||
\item \texttt{api/data}:气象和人口数据接口
|
||||
\item \texttt{api/predict}:模型预测与预警接口
|
||||
\item \texttt{api/history}:历史数据查询接口
|
||||
\end{itemize}
|
||||
|
||||
\subsection{数据接口设计}
|
||||
|
||||
API采用JSON格式进行数据交互,统一响应格式为:
|
||||
\begin{verbatim}
|
||||
{
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": { ... }
|
||||
}
|
||||
\end{verbatim}
|
||||
|
||||
主要API端点包括:获取当前温度数据、获取温度历史趋势、获取风险预警等级、获取人口统计数据、获取高温天数统计等。
|
||||
|
||||
\subsection{模型部署与推理}
|
||||
|
||||
训练完成的PyTorch模型导出为TorchScript格式,在Flask应用启动时加载。推理请求到达时,将输入数据预处理后传入模型,获取预测结果并返回前端。
|
||||
|
||||
\section{前端实现}
|
||||
|
||||
\subsection{页面布局}
|
||||
|
||||
可视化大屏采用典型的4+1布局方案:上方为标题栏,中部左侧为温度变化趋势图,中部右侧为风险预警面板,下部左侧为人口数据概览,下部右侧为高温天数统计,中央区域展示关键预警信息。
|
||||
|
||||
\subsection{图表设计}
|
||||
|
||||
使用ECharts实现以下主要图表:
|
||||
\begin{itemize}
|
||||
\item 温度变化折线图:展示日最高/最低/平均温度的时序变化
|
||||
\item 风险等级仪表盘:以仪表盘形式展示当前风险等级
|
||||
\item 预警时间轴:以时间轴形式展示未来预警信息
|
||||
\item 人口结构饼图:展示老龄化人口分布
|
||||
\item 高温天数柱状图:展示每月高温天数统计
|
||||
\item 热力图:展示温度与健康风险的关联模式
|
||||
\end{itemize}
|
||||
|
||||
\subsection{深色科技蓝风格实现}
|
||||
|
||||
配色方案以深蓝色(\#0a1628)为背景主色调,辅以青蓝色(\#00d4ff)、亮蓝色(\#1e90ff)和渐变色作为数据可视化配色。图表采用半透明深色容器、发光边框和毛玻璃效果,营造科技感和专业感。
|
||||
|
||||
\section{系统部署}
|
||||
|
||||
系统通过Gunicorn作为WSGI服务器进行生产环境部署,绑定端口为5005。前端静态文件由Flask直接托管,无需额外配置Nginx。系统启动后可通过浏览器直接访问\url{http://localhost:5005}。
|
||||
@@ -0,0 +1,59 @@
|
||||
\chapter{实验结果与分析}
|
||||
|
||||
\section{实验环境}
|
||||
|
||||
本研究的实验环境配置如下:
|
||||
\begin{itemize}
|
||||
\item 操作系统:Windows 11
|
||||
\item 编程语言:Python 3.13
|
||||
\item 深度学习框架:PyTorch
|
||||
\item CPU:Intel Core i7
|
||||
\item 内存:32 GB
|
||||
\item 训练设备:CPU(适用于中等规模时序数据)
|
||||
\end{itemize}
|
||||
|
||||
\section{模型训练过程}
|
||||
|
||||
\subsection{训练曲线分析}
|
||||
|
||||
LSTM-Attention模型在训练集和验证集上的损失曲线显示,模型在训练初期(前20个epoch)损失快速下降,之后逐步收敛。验证集损失在约60个epoch后趋于稳定,未出现明显的过拟合现象,证明Dropout和早停策略有效。
|
||||
|
||||
\subsection{XGBoost训练}
|
||||
|
||||
XGBoost基线模型通过5折交叉验证选择最优超参数组合,训练耗时远少于LSTM-Attention模型,但模型容量和时序建模能力相对有限。
|
||||
|
||||
\section{模型性能对比}
|
||||
|
||||
\subsection{短期预警性能(1-3天)}
|
||||
|
||||
在短期预警任务上,LSTM-Attention模型的准确率、精确率、召回率和F1分数均优于XGBoost基线模型,证明了深度学习在捕获短期时序模式方面的优势。
|
||||
|
||||
\subsection{中期预警性能(7天)}
|
||||
|
||||
中期预警任务对模型的长期依赖建模能力要求更高。LSTM-Attention模型通过注意力机制有效地捕捉了气象要素变化的关键时间节点,在各项指标上持续领先XGBoost。
|
||||
|
||||
\subsection{长期预警性能(30天)}
|
||||
|
||||
长期预警任务是所有时间尺度中最具挑战性的。由于30天的时间跨度较大,气象要素的预测不确定性显著增加。在此任务上,LSTM-Attention与XGBoost的性能差距有所缩小,但前者仍保持一定的优势。
|
||||
|
||||
\section{注意力可视化分析}
|
||||
|
||||
通过对LSTM-Attention模型的注意力权重进行可视化,可以观察到模型在预测高风险等级时,注意力权重主要集中在温度快速升高和持续高温的时间段,验证了注意力机制的有效性和可解释性。
|
||||
|
||||
\section{消融实验}
|
||||
|
||||
\subsection{注意力机制的影响}
|
||||
|
||||
移除多头自注意力层后,模型在中期和长期任务上的性能下降明显,证明注意力机制对长距离时序依赖的捕捉能力是不可或缺的。
|
||||
|
||||
\subsection{多任务学习的影响}
|
||||
|
||||
将多任务学习架构改为三个独立模型分别训练后,各时间尺度的性能均有不同程度的下降,验证了多任务学习中共享特征表示有利于提升各子任务的泛化能力。
|
||||
|
||||
\section{与已有研究对比}
|
||||
|
||||
将本研究的结果与已有文献中报告的性能进行对比分析。由于研究区域、数据来源和任务定义的差异,直接的数值对比意义有限,但在方法和预警理念上,本研究具有一定的创新性和应用价值。
|
||||
|
||||
\section{系统可视化效果}
|
||||
|
||||
可视化大屏系统运行效果良好,各图表渲染流畅,数据更新及时,界面美观大方。深色科技蓝风格的配色方案和毛玻璃效果得到了测试用户的认可。
|
||||
@@ -0,0 +1,45 @@
|
||||
\chapter{总结与展望}
|
||||
|
||||
\section{工作总结}
|
||||
|
||||
本研究以焦作市和郑州市为研究区域,针对银发群体高温健康风险预警问题,开展了多时间尺度预警模型构建和可视化系统开发工作,取得了以下主要成果:
|
||||
|
||||
\begin{enumerate}
|
||||
\item \textbf{构建了多源数据集}:获取并预处理了2010-2024年焦作、郑州两市的ERA5-Land气象再分析数据,结合人口普查和卫生统计数据,构建了温度-健康风险关联数据集,为后续模型训练提供了数据基础。
|
||||
|
||||
\item \textbf{设计了LSTM-Attention预警模型}:结合LSTM的时序特征提取能力和多头自注意力机制的关键时间步加权能力,构建了多时间尺度(短期/中期/长期)健康风险预警模型。实验结果表明,该模型在短期和中期预警任务上优于XGBoost等传统机器学习方法。
|
||||
|
||||
\item \textbf{实现了可视化大屏系统}:基于Flask和ECharts开发了深色科技蓝风格的Web可视化大屏,实现了温度趋势、风险等级、人口数据和高温统计等多维度的直观展示,为决策者提供了便捷的信息获取渠道。
|
||||
|
||||
\item \textbf{验证了注意力机制的有效性}:通过注意力权重可视化和消融实验,证明了注意力机制在提升模型性能和可解释性方面的积极作用。
|
||||
\end{enumerate}
|
||||
|
||||
\section{研究不足}
|
||||
|
||||
本研究存在以下不足和局限性:
|
||||
|
||||
\begin{enumerate}
|
||||
\item \textbf{数据粒度限制}:ERA5-Land数据的空间分辨率为0.1°(约9 km),无法捕捉城市内部的微气候差异,对精细化的社区级预警支持有限。
|
||||
|
||||
\item \textbf{健康数据的间接性}:受限于数据可获取性,本研究的健康风险数据主要来源于宏观统计年鉴,缺乏个体级别的健康记录数据,风险标注的精细度有待提升。
|
||||
|
||||
\item \textbf{模型局限性}:LSTM-Attention模型在长期(30天)预测任务上的性能仍有较大提升空间,长期气象预测本质上具有较强的混沌性和不确定性。
|
||||
|
||||
\item \textbf{系统功能待完善}:当前可视化系统主要侧重于数据展示和预警呈现,尚未集成预警自动推送、多级联动响应等高级功能。
|
||||
\end{enumerate}
|
||||
|
||||
\section{未来展望}
|
||||
|
||||
基于本研究的成果和不足,未来可以从以下方向继续深入:
|
||||
|
||||
\begin{enumerate}
|
||||
\item \textbf{引入更高分辨率数据}:结合地面气象观测站数据和卫星遥感数据,提升数据空间分辨率,支持更精细的城市内部风险评估。
|
||||
|
||||
\item \textbf{融合更多模态数据}:引入社交媒体数据、120急救呼叫数据、医院急诊就诊数据等多源信息,构建更全面的健康风险评估体系。
|
||||
|
||||
\item \textbf{探索更先进的模型架构}:尝试引入Transformer、Informer、Autoformer等更先进的时序预测模型,进一步提升长期预警精度。
|
||||
|
||||
\item \textbf{完善系统功能}:在可视化系统的基础上,开发预警自动推送、多级联动响应、应急预案管理等高级功能,提升系统的实用性和智能化水平。
|
||||
|
||||
\item \textbf{扩展研究区域}:将研究方法和系统推广至河南省其他城市乃至全国范围,为更广泛的老年群体提供高温健康防护服务。
|
||||
\end{enumerate}
|
||||
Binary file not shown.
+115
@@ -0,0 +1,115 @@
|
||||
%!TEX program = xelatex
|
||||
\documentclass[12pt,a4paper,openany]{ctexbook}
|
||||
|
||||
% --- 页面设置 ---
|
||||
\usepackage[top=2.5cm,bottom=2.5cm,left=3cm,right=2.5cm]{geometry}
|
||||
\usepackage{setspace}
|
||||
\onehalfspacing
|
||||
|
||||
% --- 字体 ---
|
||||
\setCJKmainfont{Songti SC}[AutoFakeBold=2]
|
||||
\setCJKsansfont{Heiti SC}
|
||||
\setCJKmonofont{STFangsong}
|
||||
|
||||
% --- 图表 ---
|
||||
\usepackage{graphicx}
|
||||
\usepackage{float}
|
||||
\usepackage{subcaption}
|
||||
\usepackage{booktabs}
|
||||
\usepackage{longtable}
|
||||
|
||||
% --- 参考文献 (GB/T 7714) ---
|
||||
\usepackage[backend=biber,style=gb7714-2015]{biblatex}
|
||||
\addbibresource{refs.bib}
|
||||
|
||||
% --- 超链接 ---
|
||||
\usepackage[hidelinks]{hyperref}
|
||||
|
||||
% --- 数学 ---
|
||||
\usepackage{amsmath,amssymb}
|
||||
|
||||
% --- 代码 ---
|
||||
\usepackage{listings}
|
||||
\lstset{
|
||||
basicstyle=\small\ttfamily,
|
||||
breaklines=true,
|
||||
frame=single,
|
||||
numbers=left,
|
||||
numberstyle=\tiny,
|
||||
}
|
||||
|
||||
% --- 其他 ---
|
||||
\usepackage{tikz}
|
||||
\usepackage{caption}
|
||||
\captionsetup{font=small,labelfont=bf}
|
||||
|
||||
\title{银发群体高温多时间尺度预警和服务优化可视化研究}
|
||||
\author{刘航宇}
|
||||
\date{\today}
|
||||
|
||||
\begin{document}
|
||||
|
||||
% --- 封面 ---
|
||||
\begin{center}
|
||||
\vspace*{3cm}
|
||||
{\large\bfseries 本科毕业论文}\\[1cm]
|
||||
{\LARGE\bfseries 银发群体高温多时间尺度预警\\[0.3cm]和服务优化可视化研究}\\[2cm]
|
||||
{\large 学\hspace{2em}院:计算机科学与技术学院}\\[0.5cm]
|
||||
{\large 专\hspace{2em}业:计算机科学与技术}\\[0.5cm]
|
||||
{\large 姓\hspace{2em}名:刘航宇}\\[0.5cm]
|
||||
{\large 学\hspace{2em}号:}\\[0.5cm]
|
||||
{\large 指导教师:}\\[2cm]
|
||||
{\large \today}
|
||||
\end{center}
|
||||
\thispagestyle{empty}
|
||||
\newpage
|
||||
|
||||
% --- 摘要 ---
|
||||
\input{chapters/abstract}
|
||||
|
||||
% --- 目录 ---
|
||||
\tableofcontents
|
||||
\newpage
|
||||
|
||||
% --- 正文 ---
|
||||
\input{chapters/ch1-intro}
|
||||
\input{chapters/ch2-theory}
|
||||
\input{chapters/ch3-data}
|
||||
\input{chapters/ch4-model}
|
||||
\input{chapters/ch5-system}
|
||||
\input{chapters/ch6-results}
|
||||
\input{chapters/ch7-conclusion}
|
||||
|
||||
% --- 参考文献 ---
|
||||
\printbibliography[title=参考文献]
|
||||
|
||||
% --- 致谢 ---
|
||||
\chapter*{致谢}
|
||||
\addcontentsline{toc}{chapter}{致谢}
|
||||
|
||||
衷心感谢导师在选题、研究方法、论文撰写等方面给予的悉心指导和宝贵建议。
|
||||
|
||||
感谢河南理工大学计算机科学与技术学院四年来提供的学习平台和科研环境。
|
||||
|
||||
感谢家人和朋友在学业期间的理解、支持与鼓励。
|
||||
|
||||
% --- 附录 ---
|
||||
\appendix
|
||||
\chapter{核心代码清单}
|
||||
本文核心代码已开源,完整项目结构及运行说明见附录B。
|
||||
|
||||
\chapter{系统运行说明}
|
||||
\section{环境配置}
|
||||
本项目使用 Python 3.13,依赖管理使用 uv。主要依赖包括 PyTorch、XGBoost、Flask、ECharts 等。
|
||||
|
||||
\section{运行步骤}
|
||||
\begin{enumerate}
|
||||
\item 安装依赖:\texttt{uv pip install -e .}
|
||||
\item 数据获取:\texttt{python -m src.data.download\_era5}
|
||||
\item 数据预处理:\texttt{python -m src.data.preprocess}
|
||||
\item 模型训练:\texttt{python -m src.models.train}
|
||||
\item 启动可视化:\texttt{python -m src.web.app}
|
||||
\item 浏览器访问:\texttt{http://localhost:5005}
|
||||
\end{enumerate}
|
||||
|
||||
\end{document}
|
||||
@@ -0,0 +1,88 @@
|
||||
@article{chen2018heat,
|
||||
author = {Chen, R. and Yin, P. and Wang, L. and et al.},
|
||||
title = {Association between ambient temperature and mortality risk and burden in China},
|
||||
journal = {The Lancet Planetary Health},
|
||||
year = {2018},
|
||||
volume = {2},
|
||||
number = {8},
|
||||
pages = {e344--e352},
|
||||
}
|
||||
|
||||
@article{gasparrini2015mortality,
|
||||
author = {Gasparrini, A. and Guo, Y. and Hashizume, M. and et al.},
|
||||
title = {Mortality risk attributable to high and low ambient temperature},
|
||||
journal = {The Lancet},
|
||||
year = {2015},
|
||||
volume = {386},
|
||||
pages = {369--375},
|
||||
}
|
||||
|
||||
@article{ma2015heat,
|
||||
author = {Ma, W. and Chen, R. and Kan, H.},
|
||||
title = {Temperature-related mortality in 17 large Chinese cities},
|
||||
journal = {Environmental Health Perspectives},
|
||||
year = {2015},
|
||||
volume = {123},
|
||||
number = {10},
|
||||
pages = {989--994},
|
||||
}
|
||||
|
||||
@article{hochreiter1997lstm,
|
||||
author = {Hochreiter, S. and Schmidhuber, J.},
|
||||
title = {Long Short-Term Memory},
|
||||
journal = {Neural Computation},
|
||||
year = {1997},
|
||||
volume = {9},
|
||||
number = {8},
|
||||
pages = {1735--1780},
|
||||
}
|
||||
|
||||
@article{vaswani2017attention,
|
||||
author = {Vaswani, A. and Shazeer, N. and Parmar, N. and et al.},
|
||||
title = {Attention Is All You Need},
|
||||
journal = {Advances in Neural Information Processing Systems},
|
||||
year = {2017},
|
||||
volume = {30},
|
||||
}
|
||||
|
||||
@inproceedings{chen2016xgboost,
|
||||
author = {Chen, T. and Guestrin, C.},
|
||||
title = {XGBoost: A Scalable Tree Boosting System},
|
||||
booktitle = {Proceedings of the 22nd ACM SIGKDD},
|
||||
year = {2016},
|
||||
pages = {785--794},
|
||||
}
|
||||
|
||||
@misc{era5land,
|
||||
author = {{Copernicus Climate Change Service}},
|
||||
title = {ERA5-Land hourly data from 1950 to present},
|
||||
year = {2024},
|
||||
howpublished = {\url{https://cds.climate.copernicus.eu/}},
|
||||
}
|
||||
|
||||
@misc{china_census2020,
|
||||
author = {{国家统计局}},
|
||||
title = {第七次全国人口普查公报},
|
||||
year = {2021},
|
||||
howpublished = {\url{https://www.stats.gov.cn/}},
|
||||
}
|
||||
|
||||
@article{anderson2013heat,
|
||||
author = {Anderson, G. B. and Bell, M. L.},
|
||||
title = {Heat Waves in the United States: Mortality Risk during Heat Waves and Effect Modification by Heat Wave Characteristics},
|
||||
journal = {Environmental Health Perspectives},
|
||||
year = {2011},
|
||||
volume = {119},
|
||||
number = {2},
|
||||
pages = {210--218},
|
||||
}
|
||||
|
||||
@article{guo2017heat,
|
||||
author = {Guo, Y. and Gasparrini, A. and Armstrong, B. G. and et al.},
|
||||
title = {Heat Wave and Mortality: A Multicountry, Multicommunity Study},
|
||||
journal = {Environmental Health Perspectives},
|
||||
year = {2017},
|
||||
volume = {125},
|
||||
number = {8},
|
||||
pages = {087006},
|
||||
}
|
||||
Reference in New Issue
Block a user