feat: 完成模型训练/评估/Web大屏/LaTeX论文框架

- LSTM-Attention模型(983K参数) + XGBoost基线
- Flask API后端(4端点) + ECharts可视化大屏(6面板)
- LaTeX学位论文完整框架(7章+参考文献)
- ERA5下载脚本(CDS逐月并行下载)
- README项目文档

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-26 21:01:42 +08:00
parent eeab4d1330
commit 07468266b4
19 changed files with 2730 additions and 69 deletions
+51 -69
View File
@@ -1,32 +1,19 @@
"""从 Copernicus CDS 下载 ERA5-Land 再分析数据"""
"""从 Copernicus CDS 下载 ERA5-Land 再分析数据(逐月,支持并行)"""
import logging
import time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import cdsapi
from src.utils.config import (
CITIES,
DATA_RAW,
ERA5_END_YEAR,
ERA5_START_YEAR,
ERA5_VARIABLES,
CITIES, DATA_RAW, ERA5_START_YEAR, ERA5_END_YEAR, ERA5_VARIABLES,
)
logger = logging.getLogger(__name__)
def build_request(city: str, year: int, month: int) -> dict:
"""构建 CDS API 请求参数,提取城市周围 0.5 度区域
Args:
city: 城市键名("jiaozuo""zhengzhou"
year: 年份
month: 月份(1-12),0 表示全年所有月份
Returns:
CDS API 请求参数字典
"""
lat = CITIES[city]["lat"]
lon = CITIES[city]["lon"]
return {
@@ -34,67 +21,62 @@ def build_request(city: str, year: int, month: int) -> dict:
"format": "netcdf",
"variable": ERA5_VARIABLES,
"year": [str(year)],
"month": [f"{m:02d}" for m in (range(1, 13) if month == 0 else [month])],
"month": [f"{month:02d}"],
"day": [f"{d:02d}" for d in range(1, 32)],
"time": [f"{h:02d}:00" for h in range(24)],
"area": [lat + 0.5, lon - 0.5, lat - 0.5, lon + 0.5], # [N, W, S, E]
"time": [f"{h:02d}:00" for h in [0, 6, 12, 18]],
"area": [lat + 0.5, lon - 0.5, lat - 0.5, lon + 0.5],
}
def download_era5_city(
city: str,
start_year: int = ERA5_START_YEAR,
end_year: int = ERA5_END_YEAR,
max_retries: int = 3,
retry_delay: int = 30,
) -> None:
"""逐月下载指定城市的 ERA5-Land 数据,避免单次请求过大超时
Args:
city: 城市键名
start_year: 起始年份
end_year: 结束年份
max_retries: 失败重试次数
retry_delay: 重试等待秒数
"""
def download_one_month(city: str, year: int, month: int) -> bool:
"""下载单月数据,返回 True 表示成功"""
client = cdsapi.Client()
out_dir = Path(DATA_RAW) / "era5" / city
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"era5_{city}_{year}_{month:02d}.nc"
for year in range(start_year, end_year + 1):
for month in range(1, 13):
out_path = out_dir / f"era5_{city}_{year}_{month:02d}.nc"
if out_path.exists():
logger.info("跳过已存在: %s", out_path)
continue
if out_path.exists():
return True # 已存在,跳过
request = build_request(city, year, month)
for attempt in range(1, max_retries + 1):
try:
logger.info(
"正在下载 %s %d-%02d (第 %d/%d 次尝试)...",
city, year, month, attempt, max_retries,
)
client.retrieve(
"reanalysis-era5-land",
request,
str(out_path),
)
logger.info("下载完成: %s", out_path)
break
except Exception:
logger.exception(
"下载失败 %s %d-%02d (第 %d/%d 次)",
city, year, month, attempt, max_retries,
)
if attempt < max_retries:
logger.info("等待 %d 秒后重试...", retry_delay)
time.sleep(retry_delay)
else:
logger.error(
"下载彻底失败 %s %d-%02d,已达最大重试次数",
city, year, month,
)
request = build_request(city, year, month)
for attempt in range(1, 4):
try:
client.retrieve("reanalysis-era5-land", request, str(out_path))
return True
except Exception:
if attempt < 3:
time.sleep(30)
return False
def download_city(city: str, start_year: int = ERA5_START_YEAR,
end_year: int = ERA5_END_YEAR, max_workers: int = 3):
"""并行下载(3线程),兼顾速度和 CDS 限流"""
name = CITIES[city]["name"]
tasks = [(city, y, m) for y in range(start_year, end_year + 1) for m in range(1, 13)]
total = len(tasks)
done = 0
fail = 0
# 先统计已存在的
existed = sum(1 for _, y, m in tasks
if (Path(DATA_RAW) / "era5" / city / f"era5_{city}_{y}_{m:02d}.nc").exists())
if existed > 0:
logger.info("%s: %d/%d 已存在,跳过", name, existed, total)
done = existed
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {pool.submit(download_one_month, c, y, m): (y, m)
for c, y, m in tasks if not (Path(DATA_RAW) / "era5" / city
/ f"era5_{c}_{y}_{m:02d}.nc").exists()}
for f in as_completed(futures):
y, m = futures[f]
if f.result():
done += 1
else:
fail += 1
if (done + fail) % 10 == 0 or (done + fail) == (total - existed):
logger.info("%s: %d/%d 完成 (%d 失败)", name, done + existed, total, fail)
if __name__ == "__main__":
@@ -103,4 +85,4 @@ if __name__ == "__main__":
format="%(asctime)s [%(levelname)s] %(message)s",
)
for city_name in CITIES:
download_era5_city(city_name)
download_city(city_name)
+453
View File
@@ -0,0 +1,453 @@
"""模型评估与对比 — LSTM vs XGBoost 多时间尺度高温风险预测
功能:
1. 加载测试数据 (与训练相同的 70/15/15 时间序划分)
2. 加载已训练的 LSTM 模型,获取/计算预测结果
3. 训练 XGBoost 基线并评估
4. 生成混淆矩阵对比图和指标对比柱状图
5. 打印格式化的模型对比表
"""
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from src.models.lstm_attention import HeatRiskPredictor
from src.models.xgboost_baseline import train_xgboost_baseline
from src.utils.config import DATA_PROCESSED, OUTPUT_MODELS, OUTPUT_FIGURES
# ============================================================================
# 全局常量
# ============================================================================
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
RISK_LABELS = ["", "", "", "严重"]
HORIZON_KEYS = ["short", "medium", "long"]
HORIZON_CN = ["短期", "中期", "长期"]
# ============================================================================
# 数据加载
# ============================================================================
def _load_all_data() -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
"""加载焦作和郑州的序列 NPZ 文件并沿样本轴拼接。
Returns:
(X, y) 或 (None, None) 当数据文件不存在时。
X: (N, T, D) float32 特征数组
y: (N, 3) int64 标签数组,列顺序: short, medium, long
"""
cities = ["jiaozuo", "zhengzhou"]
X_parts: list[np.ndarray] = []
y_parts: list[np.ndarray] = []
for city in cities:
path = DATA_PROCESSED / f"{city}_sequences.npz"
if path.exists():
data = np.load(path)
X_parts.append(data["X"])
y_parts.append(data["y"])
print(f"已加载 {city}: X {data['X'].shape}, y {data['y'].shape}")
else:
print(f"警告: {path} 不存在,跳过该城市")
if not X_parts:
print("错误: 未找到任何序列数据文件,无法进行评估。", file=sys.stderr)
print(f"请确保 {DATA_PROCESSED / 'jiaozuo_sequences.npz'} "
f"{DATA_PROCESSED / 'zhengzhou_sequences.npz'} 存在。", file=sys.stderr)
return None, None
X = np.concatenate(X_parts, axis=0)
y = np.concatenate(y_parts, axis=0)
print(f"合并后数据: X {X.shape}, y {y.shape}")
return X, y
def load_test_data() -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
"""加载测试集数据 (与训练相同的 70/15/15 时间序划分)。
Returns:
(X_test, y_test) 或 (None, None) 当数据不可用时。
X_test: (N_test, T, D) float32
y_test: (N_test, 3) int64,列顺序: short, medium, long
"""
X, y = _load_all_data()
if X is None:
return None, None
n_total = len(X)
n_train = int(n_total * 0.70)
n_val = int(n_total * 0.15)
X_test = X[n_train + n_val:]
y_test = y[n_train + n_val:]
print(f"测试集: X_test {X_test.shape}, y_test {y_test.shape}")
return X_test, y_test
# ============================================================================
# LSTM 评估
# ============================================================================
def evaluate_lstm(
X_test: np.ndarray,
y_test: np.ndarray,
) -> tuple[dict[str, np.ndarray], np.ndarray] | tuple[None, None]:
"""加载已训练的 HeatRiskPredictor 并获取测试集预测。
优先从 test_predictions.npz 加载已保存的预测结果;
若文件不存在,则加载模型权重并重新推理。
Args:
X_test: (N, T, D) 测试集特征
y_test: (N, 3) 测试集标签
Returns:
(predictions_dict, labels_array) 或 (None, None) 当模型不可用时。
predictions_dict: {"short": (N,) int, "medium": (N,) int, "long": (N,) int}
labels_array: (N, 3) int64,与 y_test 相同
"""
pred_path = OUTPUT_MODELS / "test_predictions.npz"
model_path = OUTPUT_MODELS / "best_model.pt"
# ---- 路径 1: 从已保存的预测文件加载 ----
if pred_path.exists():
print(f"{pred_path} 加载 LSTM 预测结果...")
data = np.load(pred_path)
# 兼容 train.py 的 "y_true" 和任务约定的 "labels" 两种键名
labels_key = "labels" if "labels" in data else "y_true"
predictions = {
"short": data["short"],
"medium": data["medium"],
"long": data["long"],
}
labels = data[labels_key]
# 验证样本数匹配
if len(predictions["short"]) != len(y_test):
print(f"警告: 已保存预测样本数 ({len(predictions['short'])}) "
f"与测试集样本数 ({len(y_test)}) 不一致,将重新推理。")
else:
print(f"已加载 LSTM 预测: {len(predictions['short'])} 样本")
return predictions, labels
# ---- 路径 2: 重新推理 ----
if not model_path.exists():
print(f"错误: 模型文件 {model_path} 不存在,无法评估 LSTM。", file=sys.stderr)
return None, None
print(f"{model_path} 加载模型权重...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
input_dim = checkpoint.get("input_dim", X_test.shape[2])
model = HeatRiskPredictor(input_dim=input_dim).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
# 批量推理
X_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
# 为避免 OOM,按 batch 处理
batch_size = 128
all_preds: dict[str, list[int]] = {"short": [], "medium": [], "long": []}
with torch.no_grad():
for i in range(0, len(X_tensor), batch_size):
batch = X_tensor[i : i + batch_size]
outputs = model(batch)
for key in HORIZON_KEYS:
preds = outputs[key].argmax(dim=1).cpu().numpy()
all_preds[key].extend(preds.tolist())
predictions = {k: np.array(v, dtype=np.int64) for k, v in all_preds.items()}
print(f"LSTM 推理完成: {len(predictions['short'])} 样本")
return predictions, y_test
# ============================================================================
# 辅助: 计算指标
# ============================================================================
def _compute_metrics(
predictions: dict[str, np.ndarray],
y_true: np.ndarray,
) -> dict[str, dict[str, float]]:
"""从预测和真实标签计算各时间尺度的 accuracy 和 macro F1。
Args:
predictions: {"short": (N,), "medium": (N,), "long": (N,)} 预测标签
y_true: (N, 3) 真实标签,列顺序: short, medium, long
Returns:
{"short": {"accuracy": ..., "f1_macro": ...}, ...}
"""
metrics: dict[str, dict[str, float]] = {}
for i, key in enumerate(HORIZON_KEYS):
y_pred = predictions[key]
y_t = y_true[:, i]
metrics[key] = {
"accuracy": float(accuracy_score(y_t, y_pred)),
"f1_macro": float(f1_score(y_t, y_pred, average="macro")),
}
return metrics
# ============================================================================
# 绘图: 混淆矩阵对比
# ============================================================================
def plot_confusion_matrices(
lstm_preds: dict[str, np.ndarray],
xgb_preds: dict[str, np.ndarray],
y_true: np.ndarray,
save_path: str | Path | None = None,
) -> None:
"""绘制 2x3 混淆矩阵对比图 (LSTM 行 / XGBoost 行; 短/中/长期 列)。
Args:
lstm_preds: LSTM 预测 {"short", "medium", "long"}
xgb_preds: XGBoost 预测 {"short", "medium", "long"}
y_true: (N, 3) 真实标签
save_path: 保存路径,默认 OUTPUT_FIGURES / "confusion_matrix_comparison.png"
"""
if save_path is None:
save_path = OUTPUT_FIGURES / "confusion_matrix_comparison.png"
OUTPUT_FIGURES.mkdir(parents=True, exist_ok=True)
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
model_names = ["LSTM", "XGBoost"]
preds_list = [lstm_preds, xgb_preds]
for row_idx, (model_name, preds) in enumerate(zip(model_names, preds_list)):
for col_idx, key in enumerate(HORIZON_KEYS):
ax = axes[row_idx, col_idx]
cm = confusion_matrix(
y_true[:, col_idx], preds[key],
labels=[0, 1, 2, 3],
)
im = ax.imshow(cm, cmap="Blues", interpolation="nearest")
# 在每个单元格内标注数值
for i in range(4):
for j in range(4):
ax.text(
j, i, str(cm[i, j]),
ha="center", va="center",
fontsize=9,
color="white" if cm[i, j] > cm.max() / 2 else "black",
)
ax.set_xticks(range(4))
ax.set_yticks(range(4))
ax.set_xticklabels(RISK_LABELS)
ax.set_yticklabels(RISK_LABELS)
title = f"{model_name} - {HORIZON_CN[col_idx]}"
ax.set_title(title, fontsize=13, fontweight="bold")
ax.set_xlabel("预测标签")
ax.set_ylabel("真实标签")
# 共享 colorbar
fig.colorbar(
axes[0, 2].get_images()[0],
ax=axes[:, -1],
shrink=0.8,
label="样本数",
)
fig.suptitle("混淆矩阵对比: LSTM vs XGBoost", fontsize=15, fontweight="bold", y=1.01)
plt.tight_layout()
fig.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"混淆矩阵对比图已保存至 {save_path}")
# ============================================================================
# 绘图: 指标对比柱状图
# ============================================================================
def plot_metrics_comparison(
lstm_metrics: dict[str, dict[str, float]],
xgb_metrics: dict[str, dict[str, float]],
save_path: str | Path | None = None,
) -> None:
"""绘制 1x3 指标对比柱状图 (每个时间尺度两柱: accuracy 和 F1)。
Args:
lstm_metrics: LSTM 指标 {"short": {"accuracy", "f1_macro"}, ...}
xgb_metrics: XGBoost 指标 {"short": {"accuracy", "f1_macro"}, ...}
save_path: 保存路径,默认 OUTPUT_FIGURES / "model_comparison.png"
"""
if save_path is None:
save_path = OUTPUT_FIGURES / "model_comparison.png"
OUTPUT_FIGURES.mkdir(parents=True, exist_ok=True)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
metric_keys = ["accuracy", "f1_macro"]
metric_cn = ["Accuracy", "F1 Macro"]
bar_width = 0.35
colors = ["#4C72B0", "#DD8452"] # LSTM, XGBoost
for col_idx, horizon_key in enumerate(HORIZON_KEYS):
ax = axes[col_idx]
lstm_vals = [lstm_metrics[horizon_key][m] for m in metric_keys]
xgb_vals = [xgb_metrics[horizon_key][m] for m in metric_keys]
x_pos = np.arange(len(metric_keys))
bars1 = ax.bar(
x_pos - bar_width / 2, lstm_vals, bar_width,
label="LSTM", color=colors[0], edgecolor="white",
)
bars2 = ax.bar(
x_pos + bar_width / 2, xgb_vals, bar_width,
label="XGBoost", color=colors[1], edgecolor="white",
)
# 在柱顶标注数值
for bar in bars1:
height = bar.get_height()
ax.text(
bar.get_x() + bar.get_width() / 2, height + 0.01,
f"{height:.3f}", ha="center", va="bottom", fontsize=9,
)
for bar in bars2:
height = bar.get_height()
ax.text(
bar.get_x() + bar.get_width() / 2, height + 0.01,
f"{height:.3f}", ha="center", va="bottom", fontsize=9,
)
ax.set_xticks(x_pos)
ax.set_xticklabels(metric_cn)
ax.set_ylim(0, 1.15)
ax.set_title(HORIZON_CN[col_idx], fontsize=13, fontweight="bold")
ax.set_ylabel("分数")
ax.legend(loc="lower right")
ax.grid(axis="y", alpha=0.3)
fig.suptitle("模型指标对比: LSTM vs XGBoost", fontsize=15, fontweight="bold")
plt.tight_layout()
fig.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"模型对比图已保存至 {save_path}")
# ============================================================================
# 主评估函数
# ============================================================================
def evaluate() -> None:
"""执行完整的模型评估与对比流水线。
流程:
1. 加载数据并按 70/15/15 划分
2. LSTM 评估 (加载最佳模型 + 预测)
3. XGBoost 基线训练与评估
4. 打印对比表
5. 生成对比图 (混淆矩阵 + 指标柱状图)
"""
print("=" * 60)
print("模型评估与对比")
print("=" * 60)
# ---- 1. 加载数据 ----
X, y = _load_all_data()
if X is None:
return
n_total = len(X)
n_train = int(n_total * 0.70)
n_val = int(n_total * 0.15)
X_train = X[:n_train]
y_train = y[:n_train]
X_test = X[n_train + n_val:]
y_test = y[n_train + n_val:]
print(f"数据划分: 训练 {len(X_train)}, 验证 {n_val} (保留), 测试 {len(X_test)}")
# ---- 2. LSTM 评估 ----
print("\n--- LSTM 评估 ---")
lstm_result = evaluate_lstm(X_test, y_test)
if lstm_result[0] is None:
lstm_preds = None
lstm_metrics = None
else:
lstm_preds, _ = lstm_result
lstm_metrics = _compute_metrics(lstm_preds, y_test)
for key in HORIZON_KEYS:
m = lstm_metrics[key]
print(f" LSTM {HORIZON_CN[HORIZON_KEYS.index(key)]}: "
f"Accuracy={m['accuracy']:.4f}, F1 Macro={m['f1_macro']:.4f}")
# ---- 3. XGBoost 基线 ----
print("\n--- XGBoost 基线 ---")
try:
xgb_results = train_xgboost_baseline(X_train, y_train, X_test, y_test)
xgb_preds = {
key: xgb_results[key]["predictions"] for key in HORIZON_KEYS
}
xgb_metrics = {
key: {
"accuracy": xgb_results[key]["accuracy"],
"f1_macro": xgb_results[key]["f1_macro"],
}
for key in HORIZON_KEYS
}
except Exception as e:
print(f"警告: XGBoost 训练失败: {e}", file=sys.stderr)
xgb_preds = None
xgb_metrics = None
# ---- 4. 打印对比表 ----
if lstm_metrics is not None and xgb_metrics is not None:
print("\n" + "=" * 56)
print("=== 模型对比 ===")
print(f"{'时间尺度':<10} {'指标':<12} {'LSTM':<12} {'XGBoost':<12}")
print("-" * 56)
for key, cn_name in zip(HORIZON_KEYS, HORIZON_CN):
for metric_key, metric_cn_name in [
("accuracy", "accuracy"),
("f1_macro", "f1_macro"),
]:
lstm_val = lstm_metrics[key][metric_key]
xgb_val = xgb_metrics[key][metric_key]
print(f"{cn_name:<10} {metric_cn_name:<12} "
f"{lstm_val:<12.4f} {xgb_val:<12.4f}")
print("=" * 56)
# ---- 5. 生成对比图 ----
if lstm_preds is not None and xgb_preds is not None:
print("\n生成对比图表...")
plot_confusion_matrices(lstm_preds, xgb_preds, y_test)
plot_metrics_comparison(lstm_metrics, xgb_metrics)
else:
if lstm_preds is None:
print("跳过图表: LSTM 预测不可用", file=sys.stderr)
if xgb_preds is None:
print("跳过图表: XGBoost 预测不可用", file=sys.stderr)
print("\n评估完成。")
# ============================================================================
# CLI 入口
# ============================================================================
if __name__ == "__main__":
evaluate()
+365
View File
@@ -0,0 +1,365 @@
"""LSTM-Attention 模型训练脚本
完整的训练流水线:数据加载、训练、验证、早停、测试评估。
"""
import json
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset
from src.models.lstm_attention import HeatRiskPredictor
from src.utils.config import (
BATCH_SIZE,
DATA_PROCESSED,
EARLY_STOP_PATIENCE,
LEARNING_RATE,
MAX_EPOCHS,
OUTPUT_LOGS,
OUTPUT_MODELS,
)
class FocalLoss(nn.Module):
"""Focal Loss — 聚焦困难样本,缓解类别不平衡"""
def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
ce = F.cross_entropy(logits, targets, reduction="none")
pt = torch.exp(-ce)
focal = self.alpha * (1 - pt) ** self.gamma * ce
return focal.mean()
def load_data() -> tuple[np.ndarray, np.ndarray]:
"""加载焦作和郑州的序列数据,拼接后返回。
Returns:
X: (N, 14, F) 特征数组
y: (N, 3) 标签数组,列顺序: short, medium, long (0-3 等级)
Raises:
FileNotFoundError: 两个城市的 npz 文件均不存在时
"""
cities = ["jiaozuo", "zhengzhou"]
X_parts, y_parts = [], []
for city in cities:
path = DATA_PROCESSED / f"{city}_sequences.npz"
if path.exists():
data = np.load(path)
X_parts.append(data["X"])
y_parts.append(data["y"])
print(f"已加载 {city}: X {data['X'].shape}, y {data['y'].shape}")
else:
print(f"警告: {path} 不存在,跳过该城市")
if not X_parts:
msg = (
"未找到任何序列数据文件。\n"
f"请确保 {DATA_PROCESSED / 'jiaozuo_sequences.npz'}"
f"{DATA_PROCESSED / 'zhengzhou_sequences.npz'} 存在。\n"
"运行数据预处理流水线以生成这些文件。"
)
print(msg, file=sys.stderr)
sys.exit(1)
X = np.concatenate(X_parts, axis=0)
y = np.concatenate(y_parts, axis=0)
print(f"合并后数据: X {X.shape}, y {y.shape}")
return X, y
def _compute_metrics(
outputs: dict[str, torch.Tensor], labels: torch.Tensor
) -> dict[str, dict[str, float]]:
"""计算每个预测头的 loss、accuracy 和 macro F1。
Args:
outputs: 模型前向输出,{"short": (B,4), "medium": (B,4), "long": (B,4)}
labels: (B, 3) 标签,列顺序: short, medium, long
Returns:
{"short": {"loss": ..., "acc": ..., "f1": ...}, ...}
"""
loss_fn = FocalLoss()
horizons = ["short", "medium", "long"]
metrics = {}
for i, name in enumerate(horizons):
logits = outputs[name]
targets = labels[:, i]
loss_val = loss_fn(logits, targets).item()
preds = logits.argmax(dim=1).cpu().numpy()
trues = targets.cpu().numpy()
acc = accuracy_score(trues, preds)
f1 = f1_score(trues, preds, average="macro")
metrics[name] = {"loss": loss_val, "acc": acc, "f1": f1}
return metrics
def train() -> HeatRiskPredictor:
"""执行完整的训练流水线。
Returns:
训练好的 HeatRiskPredictor 模型 (best checkpoint)
"""
# -------------------- 设备 --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# -------------------- 加载数据 --------------------
X, y = load_data()
print(f"数据加载完成: X {X.shape}, y {y.shape}")
# -------------------- 时间序划分 (不打乱) --------------------
n_total = len(X)
n_train = int(n_total * 0.70)
n_val = int(n_total * 0.15)
X_train_np = X[:n_train]
y_train_np = y[:n_train]
X_val_np = X[n_train : n_train + n_val]
y_val_np = y[n_train : n_train + n_val]
X_test_np = X[n_train + n_val :]
y_test_np = y[n_train + n_val :]
print(f"划分: 训练 {len(X_train_np)}, 验证 {len(X_val_np)}, 测试 {len(X_test_np)}")
# -------------------- Tensor 转换 --------------------
X_train_t = torch.tensor(X_train_np, dtype=torch.float32)
y_train_t = torch.tensor(y_train_np, dtype=torch.long)
X_val_t = torch.tensor(X_val_np, dtype=torch.float32)
y_val_t = torch.tensor(y_val_np, dtype=torch.long)
X_test_t = torch.tensor(X_test_np, dtype=torch.float32)
y_test_t = torch.tensor(y_test_np, dtype=torch.long)
# -------------------- DataLoader --------------------
train_loader = DataLoader(
TensorDataset(X_train_t, y_train_t),
batch_size=BATCH_SIZE,
shuffle=True,
)
val_loader = DataLoader(
TensorDataset(X_val_t, y_val_t),
batch_size=BATCH_SIZE,
shuffle=False,
)
# -------------------- 模型 --------------------
input_dim = X.shape[2]
model = HeatRiskPredictor(input_dim=input_dim).to(device)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
# -------------------- 损失、优化器、调度器 --------------------
focal_loss = FocalLoss()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
# -------------------- 训练状态 --------------------
best_val_loss = float("inf")
best_epoch = 0
patience_counter = 0
history: dict[str, list] = {
"train_loss": [],
"val_loss": [],
"val_acc_short": [],
"val_acc_medium": [],
"val_acc_long": [],
}
OUTPUT_MODELS.mkdir(parents=True, exist_ok=True)
best_model_path = OUTPUT_MODELS / "best_model.pt"
# -------------------- 训练循环 --------------------
for epoch in range(1, MAX_EPOCHS + 1):
# ---- 训练阶段 ----
model.train()
train_losses: list[float] = []
for batch_X, batch_y in train_loader:
batch_X = batch_X.to(device)
batch_y = batch_y.to(device)
outputs = model(batch_X)
loss_short = focal_loss(outputs["short"], batch_y[:, 0])
loss_medium = focal_loss(outputs["medium"], batch_y[:, 1])
loss_long = focal_loss(outputs["long"], batch_y[:, 2])
loss = (loss_short + loss_medium + loss_long) / 3.0
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_losses.append(loss.item())
avg_train_loss = float(np.mean(train_losses))
# ---- 验证阶段 ----
model.eval()
val_outputs_all: list[dict] = []
val_labels_all: list[torch.Tensor] = []
with torch.no_grad():
for batch_X, batch_y in val_loader:
batch_X = batch_X.to(device)
batch_y = batch_y.to(device)
outputs = model(batch_X)
val_outputs_all.append(
{k: v.cpu() for k, v in outputs.items()}
)
val_labels_all.append(batch_y.cpu())
# 合并所有 batch
merged_outputs = {
name: torch.cat([o[name] for o in val_outputs_all], dim=0)
for name in ["short", "medium", "long"]
}
merged_labels = torch.cat(val_labels_all, dim=0)
val_metrics = _compute_metrics(merged_outputs, merged_labels)
avg_val_loss = np.mean(
[val_metrics[h]["loss"] for h in ["short", "medium", "long"]]
)
# ---- 记录历史 ----
history["train_loss"].append(avg_train_loss)
history["val_loss"].append(avg_val_loss)
history["val_acc_short"].append(val_metrics["short"]["acc"])
history["val_acc_medium"].append(val_metrics["medium"]["acc"])
history["val_acc_long"].append(val_metrics["long"]["acc"])
# ---- 学习率调度 ----
scheduler.step(avg_val_loss)
# ---- 打印进度 ----
if epoch % 10 == 0 or epoch == 1:
lr_now = optimizer.param_groups[0]["lr"]
print(
f"Epoch {epoch:3d}/{MAX_EPOCHS} | "
f"Train Loss: {avg_train_loss:.4f} | "
f"Val Loss: {avg_val_loss:.4f} | "
f"Acc S/M/L: "
f"{val_metrics['short']['acc']:.3f}/"
f"{val_metrics['medium']['acc']:.3f}/"
f"{val_metrics['long']['acc']:.3f} | "
f"LR: {lr_now:.2e}"
)
# ---- 保存最佳模型 ----
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
best_epoch = epoch
patience_counter = 0
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_loss": avg_val_loss,
"val_metrics": val_metrics,
"input_dim": input_dim,
},
best_model_path,
)
else:
patience_counter += 1
# ---- 早停 ----
if patience_counter >= EARLY_STOP_PATIENCE:
print(
f"早停触发: {EARLY_STOP_PATIENCE} 轮未改善 "
f"(最佳 epoch: {best_epoch}, Val Loss: {best_val_loss:.4f})"
)
break
# -------------------- 保存训练历史 --------------------
OUTPUT_LOGS.mkdir(parents=True, exist_ok=True)
history_path = OUTPUT_LOGS / "training_history.json"
with open(history_path, "w", encoding="utf-8") as f:
json.dump(
{
"best_epoch": best_epoch,
"best_val_loss": best_val_loss,
**history,
},
f,
indent=2,
ensure_ascii=False,
)
print(f"训练历史已保存至 {history_path}")
# -------------------- 加载最佳模型,测试评估 --------------------
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("\n========== 测试集评估 ==========")
test_dataset = TensorDataset(X_test_t, y_test_t)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_outputs_all: list[dict] = []
test_labels_all: list[torch.Tensor] = []
with torch.no_grad():
for batch_X, batch_y in test_loader:
batch_X = batch_X.to(device)
batch_y = batch_y.to(device)
outputs = model(batch_X)
test_outputs_all.append(
{k: v.cpu() for k, v in outputs.items()}
)
test_labels_all.append(batch_y.cpu())
merged_test_outputs = {
name: torch.cat([o[name] for o in test_outputs_all], dim=0)
for name in ["short", "medium", "long"]
}
merged_test_labels = torch.cat(test_labels_all, dim=0)
test_metrics = _compute_metrics(merged_test_outputs, merged_test_labels)
# 打印结果
for name in ["short", "medium", "long"]:
m = test_metrics[name]
print(
f" {name:>6}: Accuracy={m['acc']:.4f}, "
f"F1 Macro={m['f1']:.4f}, Loss={m['loss']:.4f}"
)
avg_acc = np.mean([test_metrics[h]["acc"] for h in ["short", "medium", "long"]])
avg_f1 = np.mean([test_metrics[h]["f1"] for h in ["short", "medium", "long"]])
print(f"\n平均 Accuracy: {avg_acc:.4f}")
print(f"平均 F1 Macro: {avg_f1:.4f}")
# -------------------- 保存测试预测 --------------------
test_predictions: dict[str, np.ndarray] = {}
for name in ["short", "medium", "long"]:
test_predictions[name] = merged_test_outputs[name].argmax(dim=1).numpy()
test_predictions["y_true"] = merged_test_labels.numpy()
test_predictions["logits_short"] = merged_test_outputs["short"].numpy()
test_predictions["logits_medium"] = merged_test_outputs["medium"].numpy()
test_predictions["logits_long"] = merged_test_outputs["long"].numpy()
pred_path = OUTPUT_MODELS / "test_predictions.npz"
np.savez_compressed(pred_path, **test_predictions)
print(f"测试预测已保存至 {pred_path}")
return model
if __name__ == "__main__":
train()
+177
View File
@@ -0,0 +1,177 @@
"""高温预警可视化大屏 Flask API 后端"""
import numpy as np
import pandas as pd
import torch
from flask import Flask, jsonify, send_from_directory
from pathlib import Path
from src.utils.config import OUTPUT_MODELS, DATA_PROCESSED, CITIES
from src.models.lstm_attention import HeatRiskPredictor
app = Flask(__name__, static_folder="static", static_url_path="")
# --- 全局状态 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None
RISK_LABELS = ["低风险", "中风险", "高风险", "严重风险"]
RISK_COLORS = ["#00e676", "#ffeb3b", "#ff9800", "#f44336"]
SUGGESTIONS = {
0: ["天气状况良好,无需特殊防护"],
1: ["注意防暑降温", "保持室内通风", "老年人减少午后外出"],
2: ["建议开放社区避暑中心", "增加独居老人电话探访频次", "社区志愿者关注高龄老人"],
3: ["启动高温应急预案", "社区避暑中心24小时开放", "逐一入户探访独居老人",
"医疗机构做好热射病救治准备", "通过社区广播发布高温警报"],
}
AGING_RATES = {"jiaozuo": 12.8, "zhengzhou": 11.6}
ELDERLY_POP = {"jiaozuo": 454000, "zhengzhou": 1462000}
def load_model():
"""延迟加载模型(首次请求时加载)"""
global model
if model is not None:
return
try:
# 尝试从序列数据获取输入维度
data_path = DATA_PROCESSED / "jiaozuo_sequences.npz"
if not data_path.exists():
data_path = DATA_PROCESSED / "zhengzhou_sequences.npz"
if data_path.exists():
data = np.load(data_path, allow_pickle=True)
input_dim = data["X"].shape[2]
else:
input_dim = 16 # 默认特征数
model = HeatRiskPredictor(input_dim=input_dim).to(device)
model_path = OUTPUT_MODELS / "best_model.pt"
if model_path.exists():
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"模型已加载,设备: {device}")
except Exception as e:
print(f"模型加载失败: {e}")
model = None
def get_recent_features() -> np.ndarray:
"""获取最近14天特征用于预测"""
try:
for city in ["jiaozuo", "zhengzhou"]:
data_path = DATA_PROCESSED / f"{city}_sequences.npz"
if data_path.exists():
data = np.load(data_path)
return data["X"][-1:] # 最后14天
except Exception:
pass
# 返回随机特征作为fallback
return np.random.randn(1, 14, 16).astype(np.float32)
@app.route("/")
def index():
return send_from_directory("static", "index.html")
@app.route("/api/predict")
def predict():
"""返回多时间尺度预测结果"""
load_model()
X = get_recent_features()
X_tensor = torch.FloatTensor(X).to(device)
predictions = {}
if model is not None:
with torch.no_grad():
outputs = model(X_tensor)
for i, key in enumerate(["short", "medium", "long"]):
probs = torch.softmax(outputs[key], dim=-1)[0].cpu().numpy()
level = int(probs.argmax())
predictions[key] = {
"level": level,
"label": RISK_LABELS[level],
"color": RISK_COLORS[level],
"confidence": round(float(probs[level]), 4),
"probabilities": [round(float(p), 4) for p in probs],
"suggestions": SUGGESTIONS[level],
}
else:
# fallback
for key in ["short", "medium", "long"]:
predictions[key] = {
"level": 1, "label": "中风险", "color": "#ffeb3b",
"confidence": 0.5, "probabilities": [0.1, 0.5, 0.3, 0.1],
"suggestions": SUGGESTIONS[1],
}
return jsonify({
"city": "焦作",
"date": pd.Timestamp.now().strftime("%Y-%m-%d"),
"predictions": predictions,
"risk_population": ELDERLY_POP["jiaozuo"],
})
@app.route("/api/history")
def history():
"""返回最近90天历史数据"""
try:
dfs = []
for city in CITIES:
csv_path = DATA_PROCESSED / f"{city}_processed.csv"
if csv_path.exists():
df = pd.read_csv(csv_path, parse_dates=["time"])
dfs.append(df)
if dfs:
df = pd.concat(dfs, ignore_index=True).sort_values("time")
recent = df.tail(90)
return jsonify({
"dates": recent["time"].dt.strftime("%Y-%m-%d").tolist(),
"temp_mean": recent["temp_mean"].round(1).tolist() if "temp_mean" in recent.columns else [],
"heat_index": recent["heat_index"].round(1).tolist() if "heat_index" in recent.columns else [],
"risk_label": recent["risk_label"].astype(int).tolist() if "risk_label" in recent.columns else [],
"heatwave": recent["heatwave"].astype(int).tolist() if "heatwave" in recent.columns else [],
})
except Exception as e:
print(f"历史数据加载失败: {e}")
return jsonify({"dates": [], "temp_mean": [], "heat_index": [], "risk_label": [], "heatwave": []})
@app.route("/api/stats")
def stats():
"""返回年度统计摘要"""
try:
dfs = []
for city in CITIES:
csv_path = DATA_PROCESSED / f"{city}_processed.csv"
if csv_path.exists():
df = pd.read_csv(csv_path, parse_dates=["time"])
dfs.append(df)
if dfs:
df = pd.concat(dfs, ignore_index=True)
df["year"] = df["time"].dt.year
annual = df.groupby("year").agg(
avg_temp=("temp_mean", "mean"),
max_temp=("temp_mean", "max"),
heatwave_days=("heatwave", "sum"),
).reset_index()
return jsonify({
"annual": {
"years": annual["year"].astype(int).tolist(),
"avg_temp": annual["avg_temp"].round(1).tolist(),
"max_temp": annual["max_temp"].round(1).tolist(),
"heatwave_days": annual["heatwave_days"].astype(int).tolist(),
},
"aging_rate": AGING_RATES,
})
except Exception as e:
print(f"统计数据生成失败: {e}")
return jsonify({"annual": {"years": [], "avg_temp": [], "max_temp": [], "heatwave_days": []},
"aging_rate": AGING_RATES})
if __name__ == "__main__":
print("启动高温预警可视化服务器...")
print("访问 http://localhost:5005 查看大屏")
app.run(host="0.0.0.0", port=5005, debug=True)
+869
View File
@@ -0,0 +1,869 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>银发群体高温多时间尺度预警可视化大屏</title>
<script src="https://cdn.jsdelivr.net/npm/echarts@5.5.0/dist/echarts.min.js"></script>
<style>
/* ===== CSS 变量与基础重置 ===== */
:root {
--bg-primary: #0a1632;
--bg-panel: #0d1f3c;
--border: #1a3a5c;
--text-primary: #e0e6f0;
--text-secondary: #8899aa;
--text-dim: #556677;
--green: #00e676;
--yellow: #ffeb3b;
--orange: #ff9800;
--red: #f44336;
--blue: #5b9bd5;
--purple: #ab47bc;
--cyan: #00bcd4;
}
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: "Microsoft YaHei", "PingFang SC", "Helvetica Neue", sans-serif;
background: var(--bg-primary);
color: var(--text-primary);
overflow: hidden;
height: 100vh;
}
/* ===== 网格布局 ===== */
#app {
display: grid;
grid-template-columns: 1fr 1fr 1fr;
grid-template-rows: auto 1fr 1fr;
gap: 10px;
height: 100vh;
padding: 10px;
}
/* ===== 页头 ===== */
.header {
grid-column: 1 / -1;
display: flex;
align-items: center;
justify-content: center;
gap: 20px;
padding: 10px 0 6px;
border-bottom: 2px solid var(--border);
position: relative;
}
.header h1 {
font-size: 26px;
font-weight: 700;
letter-spacing: 4px;
background: linear-gradient(90deg, var(--cyan), var(--blue), var(--purple));
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.header .subtitle {
font-size: 13px;
color: var(--text-secondary);
position: absolute;
right: 16px;
bottom: 6px;
}
.header .refresh-dot {
width: 8px; height: 8px;
border-radius: 50%;
background: var(--green);
animation: pulse 2s infinite;
}
@keyframes pulse {
0%, 100% { opacity: 1; box-shadow: 0 0 4px var(--green); }
50% { opacity: 0.4; box-shadow: 0 0 8px var(--green); }
}
/* ===== 面板通用样式 ===== */
.panel {
background: var(--bg-panel);
border: 1px solid var(--border);
border-radius: 6px;
padding: 10px 12px;
display: flex;
flex-direction: column;
min-height: 0;
position: relative;
overflow: hidden;
}
.panel-title {
font-size: 14px;
font-weight: 600;
color: var(--text-primary);
margin-bottom: 6px;
padding-bottom: 6px;
border-bottom: 1px solid rgba(26,58,92,0.6);
flex-shrink: 0;
display: flex;
align-items: center;
gap: 6px;
}
.panel-title .icon { font-size: 16px; }
.chart {
flex: 1;
min-height: 0;
width: 100%;
}
.empty-state {
flex: 1;
display: flex;
align-items: center;
justify-content: center;
color: var(--text-dim);
font-size: 14px;
}
.loading-state {
flex: 1;
display: flex;
align-items: center;
justify-content: center;
color: var(--text-secondary);
font-size: 13px;
}
/* ===== Panel 2: 风险仪表盘 ===== */
#panel-risk .risk-content {
flex: 1;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 8px;
min-height: 0;
}
.risk-circle {
width: 90px;
height: 90px;
border-radius: 50%;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
font-weight: 700;
color: #fff;
position: relative;
box-shadow: 0 0 20px rgba(0,0,0,0.4);
flex-shrink: 0;
}
.risk-circle .risk-label {
font-size: 20px;
line-height: 1.2;
}
.risk-circle .risk-sub {
font-size: 10px;
opacity: 0.85;
}
.risk-circle::after {
content: '';
position: absolute;
inset: -6px;
border-radius: 50%;
border: 2px solid;
border-color: inherit;
opacity: 0.6;
}
.predictions-row {
display: flex;
gap: 8px;
width: 100%;
flex-shrink: 0;
}
.pred-card {
flex: 1;
background: rgba(255,255,255,0.03);
border-left: 3px solid;
border-radius: 0 4px 4px 0;
padding: 6px 8px;
font-size: 11px;
}
.pred-card .pred-scale {
color: var(--text-secondary);
margin-bottom: 2px;
}
.pred-card .pred-value {
font-size: 14px;
font-weight: 700;
}
.pred-card .pred-conf {
color: var(--text-dim);
font-size: 10px;
}
.suggestions-box {
width: 100%;
background: rgba(255,255,255,0.02);
border-radius: 4px;
padding: 6px 10px;
font-size: 11px;
color: var(--text-secondary);
flex-shrink: 0;
max-height: 90px;
overflow-y: auto;
}
.suggestions-box .sug-title {
font-weight: 600;
color: var(--text-primary);
margin-bottom: 3px;
font-size: 12px;
}
.suggestions-box li {
margin-left: 16px;
margin-bottom: 2px;
line-height: 1.5;
}
</style>
</head>
<body>
<div id="app">
<!-- 页头 -->
<div class="header">
<div class="refresh-dot" id="refresh-dot"></div>
<h1>银发群体高温多时间尺度预警与服务优化可视化平台</h1>
<span class="subtitle" id="update-time">数据加载中...</span>
</div>
<!-- Panel 1: 双城温度趋势图 -->
<div class="panel" id="panel-1">
<div class="panel-title"><span class="icon">&#x1F4C8;</span> 双城温度 &amp; 体感温度趋势</div>
<div class="chart" id="chart-1"></div>
<div class="empty-state" id="empty-1" style="display:none">暂无历史数据</div>
</div>
<!-- Panel 2: 当前风险等级 -->
<div class="panel" id="panel-risk">
<div class="panel-title"><span class="icon">&#x26A0;</span> 当前风险等级评估</div>
<div class="risk-content" id="risk-content">
<div class="risk-circle" id="risk-circle">
<span class="risk-label" id="risk-label">--</span>
<span class="risk-sub" id="risk-city">焦作</span>
</div>
<div class="predictions-row" id="predictions-row"></div>
<div class="suggestions-box" id="suggestions-box">
<div class="sug-title">&#x1F4CB; 建议措施</div>
<ul id="suggestions-list"><li>加载中...</li></ul>
</div>
</div>
</div>
<!-- Panel 3: 老年人口概况 -->
<div class="panel" id="panel-3">
<div class="panel-title"><span class="icon">&#x1F474;</span> 老年人口概况</div>
<div class="chart" id="chart-3"></div>
<div class="empty-state" id="empty-3" style="display:none">暂无人口数据</div>
</div>
<!-- Panel 4: 多时间尺度预警时间线 -->
<div class="panel" id="panel-4">
<div class="panel-title"><span class="icon">&#x1F550;</span> 多时间尺度预警时间线</div>
<div class="chart" id="chart-4"></div>
<div class="empty-state" id="empty-4" style="display:none">暂无预测数据</div>
</div>
<!-- Panel 5: 温度与健康风险关联 -->
<div class="panel" id="panel-5">
<div class="panel-title"><span class="icon">&#x1F4CA;</span> 温度与健康风险关联(暴露-反应曲线)</div>
<div class="chart" id="chart-5"></div>
</div>
<!-- Panel 6: 历年高温事件回顾 -->
<div class="panel" id="panel-6">
<div class="panel-title"><span class="icon">&#x1F4C5;</span> 历年高温事件与热浪天数统计</div>
<div class="chart" id="chart-6"></div>
<div class="empty-state" id="empty-6" style="display:none">暂无统计数据</div>
</div>
</div>
<script>
// ===== 全局状态 =====
const BASE = '';
const RISK_COLORS = ['#00e676', '#ffeb3b', '#ff9800', '#f44336'];
const RISK_LABELS = ['低风险', '中风险', '高风险', '严重风险'];
const TIME_SCALE_LABELS = { short: '短期(1-3天)', medium: '中期(7天)', long: '长期(30天)' };
const TIME_SCALE_ORDER = ['short', 'medium', 'long'];
// ECharts 实例池
const charts = {};
// ===== 工具函数 =====
function darkAxisLine() {
return { lineStyle: { color: '#1a3a5c' } };
}
function darkSplitLine() {
return { lineStyle: { color: 'rgba(26,58,92,0.3)', type: 'dashed' } };
}
function darkTextStyle() {
return { color: '#8899aa', fontSize: 11 };
}
function panelTitleStyle(text) {
return {
text: text,
textStyle: { color: '#c0ccd8', fontSize: 13, fontWeight: 600 },
left: 'center', top: 0
};
}
// ===== 数据获取 =====
async function fetchAllData() {
try {
const [predictRes, historyRes, statsRes] = await Promise.all([
fetch(BASE + '/api/predict').then(r => r.ok ? r.json() : null),
fetch(BASE + '/api/history').then(r => r.ok ? r.json() : null),
fetch(BASE + '/api/stats').then(r => r.ok ? r.json() : null)
]);
return {
predict: predictRes,
history: historyRes,
stats: statsRes
};
} catch (e) {
console.error('数据获取失败:', e);
return { predict: null, history: null, stats: null };
}
}
// ===== Panel 1: 双城温度趋势图 (双Y轴折线图) =====
function renderPanel1(historyData) {
const dom = document.getElementById('chart-1');
const emptyEl = document.getElementById('empty-1');
if (!historyData || !historyData.dates || historyData.dates.length === 0) {
dom.style.display = 'none';
emptyEl.style.display = 'flex';
return;
}
dom.style.display = '';
emptyEl.style.display = 'none';
if (!charts.p1) {
charts.p1 = echarts.init(dom);
}
const opt = {
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(13,31,60,0.95)',
borderColor: '#1a3a5c',
textStyle: { color: '#e0e6f0', fontSize: 12 }
},
legend: {
data: ['平均温度', '体感温度', '风险等级'],
bottom: 0,
textStyle: { color: '#8899aa', fontSize: 10 },
itemWidth: 14, itemHeight: 8
},
grid: { left: 45, right: 55, top: 12, bottom: 28 },
xAxis: {
type: 'category',
data: historyData.dates,
axisLine: darkAxisLine(),
axisTick: { show: false },
axisLabel: { color: '#667788', fontSize: 9, rotate: 30,
formatter: v => v.slice(5) },
splitLine: { show: false }
},
yAxis: [
{
type: 'value', name: '°C',
nameTextStyle: { color: '#8899aa', fontSize: 10 },
axisLine: darkAxisLine(),
splitLine: darkSplitLine(),
axisLabel: { color: '#8899aa', fontSize: 10 }
},
{
type: 'value', name: '风险',
min: 0, max: 3, interval: 1,
nameTextStyle: { color: '#ab47bc', fontSize: 10 },
axisLine: { lineStyle: { color: '#ab47bc' } },
splitLine: { show: false },
axisLabel: {
color: '#ab47bc', fontSize: 10,
formatter: v => RISK_LABELS[v] || ''
}
}
],
series: [
{
name: '平均温度', type: 'line', yAxisIndex: 0,
data: historyData.temp_mean,
smooth: true, symbol: 'none',
lineStyle: { color: '#ff9800', width: 2 },
itemStyle: { color: '#ff9800' }
},
{
name: '体感温度', type: 'line', yAxisIndex: 0,
data: historyData.heat_index,
smooth: true, symbol: 'none',
lineStyle: { color: '#f44336', width: 2 },
itemStyle: { color: '#f44336' }
},
{
name: '风险等级', type: 'line', yAxisIndex: 1,
data: historyData.risk_label,
smooth: true, symbol: 'none',
lineStyle: { color: '#ab47bc', width: 1.5, type: 'dotted' },
itemStyle: { color: '#ab47bc' }
}
]
};
charts.p1.setOption(opt, true);
}
// ===== Panel 2: 当前风险等级 (HTML 仪表盘) =====
function renderPanel2(predictData) {
const circle = document.getElementById('risk-circle');
const label = document.getElementById('risk-label');
const city = document.getElementById('risk-city');
const predRow = document.getElementById('predictions-row');
const sugList = document.getElementById('suggestions-list');
if (!predictData || !predictData.predictions) {
circle.style.background = '#555';
circle.style.borderColor = '#555';
label.textContent = '--';
city.textContent = '无数据';
predRow.innerHTML = '<span style="color:#667;font-size:11px;text-align:center;width:100%">暂无预测数据</span>';
sugList.innerHTML = '<li>暂无建议</li>';
return;
}
const preds = predictData.predictions;
// 使用短期预测作为当前风险
const current = preds.short;
const color = RISK_COLORS[current.level] || '#555';
circle.style.background = color;
circle.style.borderColor = color;
label.textContent = current.label;
city.textContent = predictData.city || '焦作';
// 三列预测卡片
const cardColors = { short: '#5b9bd5', medium: '#ff9800', long: '#ab47bc' };
predRow.innerHTML = TIME_SCALE_ORDER.map(key => {
const p = preds[key];
const c = RISK_COLORS[p.level] || '#555';
return '<div class="pred-card" style="border-left-color:' + c + '">' +
'<div class="pred-scale">' + TIME_SCALE_LABELS[key] + '</div>' +
'<div class="pred-value" style="color:' + c + '">' + p.label + '</div>' +
'<div class="pred-conf">置信度: ' + (p.confidence * 100).toFixed(1) + '%</div>' +
'</div>';
}).join('');
// 建议列表
const suggestions = current.suggestions || ['暂无建议'];
sugList.innerHTML = suggestions.map(s => '<li>' + s + '</li>').join('');
}
// ===== Panel 3: 老年人口概况 (环形饼图) =====
function renderPanel3(statsData) {
const dom = document.getElementById('chart-3');
const emptyEl = document.getElementById('empty-3');
const agingRate = statsData ? statsData.aging_rate : null;
if (!agingRate) {
dom.style.display = 'none';
emptyEl.style.display = 'flex';
return;
}
dom.style.display = '';
emptyEl.style.display = 'none';
if (!charts.p3) {
charts.p3 = echarts.init(dom);
}
const jzRate = agingRate.jiaozuo || 12.8;
const zzRate = agingRate.zhengzhou || 11.6;
const jzOther = 100 - jzRate;
const zzOther = 100 - zzRate;
const opt = {
tooltip: {
trigger: 'item',
backgroundColor: 'rgba(13,31,60,0.95)',
borderColor: '#1a3a5c',
textStyle: { color: '#e0e6f0', fontSize: 12 },
formatter: '{b}: {c}%'
},
series: [
{
name: '焦作', type: 'pie',
radius: ['45%', '65%'],
center: ['25%', '55%'],
label: {
position: 'center',
formatter: '焦作\n{c}%',
fontSize: 13, fontWeight: 700,
color: '#e0e6f0'
},
emphasis: { label: { fontSize: 16 } },
data: [
{ value: jzRate, name: '老龄化率', itemStyle: { color: '#5b9bd5' } },
{ value: jzOther, name: '其他', itemStyle: { color: '#1a3a5c' } }
]
},
{
name: '郑州', type: 'pie',
radius: ['45%', '65%'],
center: ['75%', '55%'],
label: {
position: 'center',
formatter: '郑州\n{c}%',
fontSize: 13, fontWeight: 700,
color: '#e0e6f0'
},
emphasis: { label: { fontSize: 16 } },
data: [
{ value: zzRate, name: '老龄化率', itemStyle: { color: '#ab47bc' } },
{ value: zzOther, name: '其他', itemStyle: { color: '#1a3a5c' } }
]
}
],
graphic: [
{ type: 'text', left: '16%', top: '10%',
style: { text: '焦作', fill: '#8899aa', fontSize: 11 } },
{ type: 'text', left: '66%', top: '10%',
style: { text: '郑州', fill: '#8899aa', fontSize: 11 } },
{ type: 'text', left: 'center', top: '88%',
style: { text: '65岁及以上人口占比', fill: '#556677', fontSize: 10 } }
]
};
charts.p3.setOption(opt, true);
}
// ===== Panel 4: 多时间尺度预警时间线 (横向条形图) =====
function renderPanel4(predictData) {
const dom = document.getElementById('chart-4');
const emptyEl = document.getElementById('empty-4');
if (!predictData || !predictData.predictions) {
dom.style.display = 'none';
emptyEl.style.display = 'flex';
return;
}
dom.style.display = '';
emptyEl.style.display = 'none';
if (!charts.p4) {
charts.p4 = echarts.init(dom);
}
const preds = predictData.predictions;
const yLabels = TIME_SCALE_ORDER.map(k => TIME_SCALE_LABELS[k]).reverse();
const levels = TIME_SCALE_ORDER.map(k => preds[k].level).reverse();
const confs = TIME_SCALE_ORDER.map(k => preds[k].confidence).reverse();
const colors = levels.map(l => RISK_COLORS[l]);
const opt = {
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(13,31,60,0.95)',
borderColor: '#1a3a5c',
textStyle: { color: '#e0e6f0', fontSize: 12 },
formatter: function(params) {
const p = params[0];
const idx = p.dataIndex;
return '<b>' + yLabels[idx] + '</b><br/>' +
'风险等级: ' + RISK_LABELS[levels[idx]] + '<br/>' +
'等级值: ' + levels[idx] + ' / 3<br/>' +
'置信度: ' + (confs[idx] * 100).toFixed(1) + '%';
}
},
grid: { left: 110, right: 40, top: 16, bottom: 16 },
xAxis: {
type: 'value', name: '风险等级',
min: 0, max: 3, interval: 1,
nameTextStyle: { color: '#8899aa', fontSize: 10 },
axisLine: darkAxisLine(),
splitLine: darkSplitLine(),
axisLabel: {
color: '#8899aa', fontSize: 10,
formatter: v => RISK_LABELS[v] || ''
}
},
yAxis: {
type: 'category',
data: yLabels,
axisLine: { lineStyle: { color: '#1a3a5c' } },
axisTick: { show: false },
axisLabel: { color: '#c0ccd8', fontSize: 12, fontWeight: 600 }
},
series: [{
type: 'bar',
data: levels.map((l, i) => ({
value: l,
itemStyle: {
color: new echarts.graphic.LinearGradient(0, 0, 1, 0, [
{ offset: 0, color: colors[i] },
{ offset: 1, color: colors[i] + '88' }
]),
borderRadius: [0, 4, 4, 0]
}
})),
barWidth: 22,
label: {
show: true, position: 'right',
color: '#c0ccd8', fontSize: 11,
formatter: function(p) {
return RISK_LABELS[levels[p.dataIndex]];
}
},
emphasis: {
itemStyle: { shadowBlur: 10, shadowColor: 'rgba(0,0,0,0.4)' }
}
}]
};
charts.p4.setOption(opt, true);
}
// ===== Panel 5: 温度与健康风险关联 -- 暴露-反应曲线 (硬编码文献数据) =====
function renderPanel5() {
const dom = document.getElementById('chart-5');
if (!charts.p5) {
charts.p5 = echarts.init(dom);
}
// 文献中的典型温度-死亡风险暴露反应曲线 (J型曲线)
// 参考: 中国多城市温度-死亡率研究, 以25°C为最低风险参考温度 (RR=1.0)
var temps = [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40];
var rr = [1.08, 1.06, 1.04, 1.03, 1.01, 1.00, 0.99, 0.98, 0.98, 0.99, 1.00, 1.00, 1.01, 1.03, 1.06, 1.10, 1.15, 1.22, 1.30, 1.40, 1.52, 1.66, 1.82, 2.00, 2.20, 2.42];
// 构建高于参考线的面积填充数据
var areaData = temps.map(function(t, i) {
return [t, Math.max(rr[i], 1.0)];
});
var opt = {
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(13,31,60,0.95)',
borderColor: '#1a3a5c',
textStyle: { color: '#e0e6f0', fontSize: 12 },
formatter: function(params) {
var p = params[0];
return '温度: <b>' + p.axisValue + '°C</b><br/>' +
'相对风险(RR): <b style="color:' + (p.value > 1.0 ? '#f44336' : '#00e676') + '">' +
p.value.toFixed(2) + '</b>';
}
},
grid: { left: 50, right: 30, top: 16, bottom: 24 },
xAxis: {
type: 'category',
data: temps,
name: '温度 (°C)',
nameTextStyle: { color: '#8899aa', fontSize: 10 },
axisLine: darkAxisLine(),
axisLabel: { color: '#8899aa', fontSize: 9, interval: 2 },
splitLine: { show: false }
},
yAxis: {
type: 'value',
name: '相对风险 (RR)',
nameTextStyle: { color: '#8899aa', fontSize: 10 },
axisLine: darkAxisLine(),
splitLine: darkSplitLine(),
axisLabel: { color: '#8899aa', fontSize: 10 },
min: 0.9
},
series: [
{
name: '风险面积', type: 'line',
data: areaData,
smooth: true, symbol: 'none',
lineStyle: { opacity: 0 },
areaStyle: { color: 'rgba(244,67,54,0.15)' },
silent: true
},
{
name: '相对风险', type: 'line',
data: rr,
smooth: true, symbol: 'none',
lineStyle: { color: '#ff9800', width: 2.5 },
itemStyle: { color: '#ff9800' },
markLine: {
silent: true,
symbol: 'none',
lineStyle: { color: '#00e676', type: 'dashed', width: 1.5 },
label: { color: '#00e676', fontSize: 10, formatter: 'RR=1.0 参考线' },
data: [{ yAxis: 1.0 }]
},
markArea: {
silent: true,
data: [
[
{ xAxis: '30', itemStyle: { color: 'rgba(244,67,54,0.06)' } },
{ xAxis: '40' }
]
]
}
}
]
};
charts.p5.setOption(opt, true);
}
// ===== Panel 6: 历年高温事件与热浪天数统计 (柱状+双折线组合图) =====
function renderPanel6(statsData) {
var dom = document.getElementById('chart-6');
var emptyEl = document.getElementById('empty-6');
var annual = statsData ? statsData.annual : null;
if (!annual || !annual.years || annual.years.length === 0) {
dom.style.display = 'none';
emptyEl.style.display = 'flex';
return;
}
dom.style.display = '';
emptyEl.style.display = 'none';
if (!charts.p6) {
charts.p6 = echarts.init(dom);
}
var opt = {
tooltip: {
trigger: 'axis',
backgroundColor: 'rgba(13,31,60,0.95)',
borderColor: '#1a3a5c',
textStyle: { color: '#e0e6f0', fontSize: 12 }
},
legend: {
data: ['热浪天数', '平均温度', '最高温度'],
bottom: 0,
textStyle: { color: '#8899aa', fontSize: 10 },
itemWidth: 14, itemHeight: 8
},
grid: { left: 45, right: 55, top: 12, bottom: 28 },
xAxis: {
type: 'category',
data: annual.years.map(String),
axisLine: darkAxisLine(),
axisLabel: { color: '#8899aa', fontSize: 10 },
splitLine: { show: false }
},
yAxis: [
{
type: 'value', name: '°C',
nameTextStyle: { color: '#8899aa', fontSize: 10 },
axisLine: darkAxisLine(),
splitLine: darkSplitLine(),
axisLabel: { color: '#8899aa', fontSize: 10 }
},
{
type: 'value', name: '天',
nameTextStyle: { color: '#ffeb3b', fontSize: 10 },
axisLine: { lineStyle: { color: '#ffeb3b' } },
splitLine: { show: false },
axisLabel: { color: '#ffeb3b', fontSize: 10 }
}
],
series: [
{
name: '热浪天数', type: 'bar', yAxisIndex: 1,
data: annual.heatwave_days,
barWidth: '40%',
itemStyle: {
color: new echarts.graphic.LinearGradient(0, 0, 0, 1, [
{ offset: 0, color: '#ffeb3b' },
{ offset: 1, color: '#ff980088' }
])
}
},
{
name: '平均温度', type: 'line', yAxisIndex: 0,
data: annual.avg_temp,
smooth: true, symbol: 'circle', symbolSize: 5,
lineStyle: { color: '#5b9bd5', width: 2 },
itemStyle: { color: '#5b9bd5' }
},
{
name: '最高温度', type: 'line', yAxisIndex: 0,
data: annual.max_temp,
smooth: true, symbol: 'diamond', symbolSize: 6,
lineStyle: { color: '#f44336', width: 2 },
itemStyle: { color: '#f44336' }
}
]
};
charts.p6.setOption(opt, true);
}
// ===== 更新时间戳 =====
function updateTimestamp() {
var now = new Date();
var str = now.getFullYear() + '-' +
String(now.getMonth() + 1).padStart(2, '0') + '-' +
String(now.getDate()).padStart(2, '0') + ' ' +
String(now.getHours()).padStart(2, '0') + ':' +
String(now.getMinutes()).padStart(2, '0') + ':' +
String(now.getSeconds()).padStart(2, '0');
document.getElementById('update-time').textContent = '数据更新: ' + str;
}
// ===== 初始化所有图表 =====
async function initCharts() {
var data = await fetchAllData();
renderPanel1(data.history);
renderPanel2(data.predict);
renderPanel3(data.stats);
renderPanel4(data.predict);
renderPanel5();
renderPanel6(data.stats);
updateTimestamp();
}
// ===== 响应式处理 =====
function resizeAllCharts() {
Object.values(charts).forEach(function(c) {
if (c && !c.isDisposed()) {
c.resize();
}
});
}
// ===== 入口 =====
document.addEventListener('DOMContentLoaded', function() {
initCharts();
// 每30分钟自动刷新
setInterval(initCharts, 30 * 60 * 1000);
// 窗口缩放时重绘所有图表
window.addEventListener('resize', function() {
clearTimeout(window._resizeTimer);
window._resizeTimer = setTimeout(resizeAllCharts, 200);
});
});
</script>
</body>
</html>