chore: LSTM 训练调优总结 — 类别加权/采样器/权重多项尝试,XGBoost 仍为最佳
This commit is contained in:
@@ -0,0 +1,230 @@
|
||||
# 管线执行计划
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** 从 ERA5 NetCDF 原始数据运行完整管线到 LaTeX 论文
|
||||
|
||||
**Architecture:** 5 阶段流水线 — 预处理(NPZ) → 训练(模型) → 评估(图表) → Web(验证) → 论文(LaTeX)
|
||||
|
||||
**Tech Stack:** PyTorch 2.12+cu126, xarray+h5netcdf, XGBoost, Flask+ECharts, XeLaTeX+ctexbook
|
||||
|
||||
**前置项已就绪:**
|
||||
- ERA5 数据: 焦作 180 + 郑州 180 (NetCDF4, 已解压)
|
||||
- GPU: RTX 4060 Laptop (8GB), CUDA 12.6
|
||||
- h5netcdf/h5py: 已安装
|
||||
- 外部数据: mortality_population.csv, exposure_response.csv
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 修复文件命名一致性
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/data/preprocess.py:537`
|
||||
|
||||
preprocess 保存 `sequences_{city}.npz`,train 加载 `{city}_sequences.npz`,统一为 `{city}_sequences.npz`。
|
||||
|
||||
- [ ] **Step 1: 修改 preprocess 的命名**
|
||||
|
||||
```python
|
||||
# 第537行: sequences_{city_key}.npz → {city_key}_sequences.npz
|
||||
npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz"
|
||||
```
|
||||
|
||||
所有 `sequences_` 开头的引用都要改(第537、564、573行):
|
||||
|
||||
```python
|
||||
# 第564行
|
||||
npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz"
|
||||
# 第573行
|
||||
combined_npz = DATA_PROCESSED / "sequences_combined.npz" # 合并文件保持原名
|
||||
```
|
||||
|
||||
- [ ] **Step 2: 提交**
|
||||
|
||||
```bash
|
||||
git add src/data/preprocess.py
|
||||
git commit -m "fix: 统一 NPZ 命名格式为 {city}_sequences.npz"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: 运行预处理管线
|
||||
|
||||
**Files:** `src/data/preprocess.py` (无需修改,已改命名)
|
||||
|
||||
- [ ] **Step 1: 清理旧数据并运行预处理**
|
||||
|
||||
```bash
|
||||
cd D:/Code/doing_exercises/programs/银发群体高温多时间尺度预警和服务优化可视化研究
|
||||
rm -f data/processed/*.npz data/processed/*.csv
|
||||
uv run python -m src.data.preprocess
|
||||
```
|
||||
|
||||
**预期输出:**
|
||||
- 加载焦作 180 NC → 日聚合 → 特征工程 → 序列 14×N_feat
|
||||
- 加载郑州 180 NC → 同上
|
||||
- 保存: `jiaozuo_sequences.npz`, `zhengzhou_sequences.npz`, `sequences_combined.npz`, `features_combined.csv`
|
||||
- 日志显示每个城市的 X/y shape 和标签分布
|
||||
|
||||
- [ ] **Step 2: 验证产出**
|
||||
|
||||
```bash
|
||||
uv run python -c "
|
||||
import numpy as np
|
||||
for f in ['jiaozuo_sequences.npz', 'zhengzhou_sequences.npz', 'sequences_combined.npz']:
|
||||
d = np.load(f'data/processed/{f}')
|
||||
print(f'{f}: X{d[\"X\"].shape} y{d[\"y\"].shape}')
|
||||
print(f' y unique counts: {[len(set(d[\"y\"][:,i])) for i in range(3)]}')
|
||||
"
|
||||
```
|
||||
|
||||
**预期:** 两个城市共约 10000+ 样本,y 三列各有 4 类
|
||||
|
||||
- [ ] **Step 3: 提交**
|
||||
|
||||
```bash
|
||||
git add data/processed/
|
||||
git commit -m "feat: ERA5 预处理完成,生成序列 NPZ 和特征 CSV"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: 训练 LSTM-Attention 模型
|
||||
|
||||
**Files:** `src/models/train.py` (无需修改)
|
||||
|
||||
- [ ] **Step 1: 运行训练**
|
||||
|
||||
```bash
|
||||
cd D:/Code/doing_exercises/programs/银发群体高温多时间尺度预警和服务优化可视化研究
|
||||
uv run python -m src.models.train
|
||||
```
|
||||
|
||||
**预期输出:**
|
||||
- "使用设备: cuda"
|
||||
- 数据加载: X (N, 14, F), y (N, 3)
|
||||
- 划分: 训练 ~70%, 验证 ~15%, 测试 ~15%
|
||||
- 每 epoch 打印 loss/acc/f1
|
||||
- 早停后保存 `outputs/models/best_model.pt`
|
||||
|
||||
- [ ] **Step 2: 验证产出**
|
||||
|
||||
```bash
|
||||
ls -lh outputs/models/best_model.pt
|
||||
ls -lh outputs/logs/training_history.json
|
||||
```
|
||||
|
||||
- [ ] **Step 3: 提交**
|
||||
|
||||
```bash
|
||||
git add outputs/models/best_model.pt outputs/logs/training_history.json
|
||||
git commit -m "feat: LSTM-Attention 模型训练完成"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: 训练 XGBoost 基线并评估
|
||||
|
||||
**Files:** `src/models/evaluate.py` (无需修改)
|
||||
|
||||
- [ ] **Step 1: 运行评估**
|
||||
|
||||
```bash
|
||||
cd D:/Code/doing_exercises/programs/银发群体高温多时间尺度预警和服务优化可视化研究
|
||||
uv run python -m src.models.evaluate
|
||||
```
|
||||
|
||||
**预期输出:**
|
||||
- 混淆矩阵 × 3 时间尺度 (LSTM + XGBoost 对比)
|
||||
- F1/Accuracy 对比柱状图
|
||||
- 保存至 `outputs/figures/`
|
||||
|
||||
- [ ] **Step 2: 验证产出**
|
||||
|
||||
```bash
|
||||
ls -lh outputs/figures/confusion_matrix.png outputs/figures/model_comparison.png
|
||||
```
|
||||
|
||||
- [ ] **Step 3: 提交**
|
||||
|
||||
```bash
|
||||
git add outputs/figures/
|
||||
git commit -m "feat: 模型评估完成 — LSTM vs XGBoost 对比图表"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: 启动 Web 大屏并验证
|
||||
|
||||
**Files:** `src/web/app.py`, `src/web/static/index.html` (无需修改)
|
||||
|
||||
- [ ] **Step 1: 启动 Flask**
|
||||
|
||||
```bash
|
||||
cd D:/Code/doing_exercises/programs/银发群体高温多时间尺度预警和服务优化可视化研究
|
||||
uv run python -m src.web.app
|
||||
```
|
||||
|
||||
- [ ] **Step 2: 浏览器验证**
|
||||
|
||||
打开 http://localhost:5000,检查:
|
||||
- [ ] 6 面板均渲染(温度趋势/风险展示/人口饼图/时间柱状/暴露反应/历史回顾)
|
||||
- [ ] API `/api/predict` 返回正确 JSON
|
||||
- [ ] API `/api/history` 返回 90 天数据
|
||||
- [ ] API `/api/stats` 返回统计摘要
|
||||
|
||||
- [ ] **Step 3: 截图保存**
|
||||
|
||||
```bash
|
||||
# 用 Playwright 截取大屏截图
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: 编译 LaTeX 论文
|
||||
|
||||
**Files:** `thesis/main.tex`, `thesis/chapters/*.tex`
|
||||
|
||||
- [ ] **Step 1: 填充论文内容**
|
||||
|
||||
更新以下章节:
|
||||
- `ch2-data-methods.tex`: 填入 ERA5 变量表、NOAA 体感温度公式、模型架构描述
|
||||
- `ch3-model-design.tex`: LSTM-Attention 架构详述 (983K 参数)
|
||||
- `ch4-experiments.tex`: 插入 `outputs/figures/` 中的评估图表
|
||||
- `ch5-visualization.tex`: Web 大屏 6 面板截图与架构说明
|
||||
|
||||
- [ ] **Step 2: 编译论文**
|
||||
|
||||
```bash
|
||||
cd thesis
|
||||
make # xelatex + biber + xelatex + xelatex
|
||||
```
|
||||
|
||||
- [ ] **Step 3: 验证 PDF**
|
||||
|
||||
```bash
|
||||
ls -lh thesis/main.pdf
|
||||
```
|
||||
|
||||
用 PDF 阅读器打开,检查: 中文渲染、图表清晰度、引用编号、页眉页脚
|
||||
|
||||
- [ ] **Step 4: 提交**
|
||||
|
||||
```bash
|
||||
git add thesis/ thesis/main.pdf
|
||||
git commit -m "feat: LaTeX 论文编译完成"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 7: 最终推送
|
||||
|
||||
- [ ] **Step 1: 推送代码**
|
||||
|
||||
```bash
|
||||
git push origin main
|
||||
```
|
||||
|
||||
- [ ] **Step 2: 推送模型和图表 (如需要)**
|
||||
|
||||
较大文件可考虑 git-lfs 或单独存放
|
||||
@@ -0,0 +1,88 @@
|
||||
# 数据处理→模型训练→论文 全流程实现设计
|
||||
|
||||
**日期**: 2026-05-28
|
||||
**状态**: 已批准
|
||||
|
||||
## 概述
|
||||
|
||||
ERA5 数据下载完毕后(焦作 180 + 郑州 180),执行从数据预处理到 LaTeX 论文填充的完整管线。
|
||||
|
||||
## 阶段 1:预处理
|
||||
|
||||
**入口**: `python -m src.data.preprocess`(无参数,遍历 CITIES)
|
||||
|
||||
**管线**:
|
||||
1. `load_era5_city` — 拼接 180 个 NetCDF → xarray Dataset
|
||||
2. `compute_daily_aggregates` — 6h→日平均,K→°C,列重命名
|
||||
3. `compute_relative_humidity` — Magnus 公式
|
||||
4. `compute_heat_index` — NOAA Rothfusz 公式
|
||||
5. `build_features` — 滚动均值(3/7/14天)、滞后(1/2/3/7天)、热浪检测(≥3天)、季节 sin/cos 编码
|
||||
6. `compute_risk_labels` — 基于体感温度阈值的 0-3 风险标签
|
||||
7. `create_sequences` — LOOKBACK=14, 3 预测窗口(3/7/30天) → 单次滑动窗口
|
||||
8. `preprocess_all` — 遍历城市,合并保存
|
||||
|
||||
**产出**: `data/processed/sequences.npz`
|
||||
- X: (N, 14, input_dim) float32
|
||||
- y_short: (N,) int64 (4类)
|
||||
- y_medium: (N,) int64
|
||||
- y_long: (N,) int64
|
||||
|
||||
## 阶段 2:模型训练
|
||||
|
||||
**入口**: `python -m src.models.train`
|
||||
|
||||
**LSTM-Attention**:
|
||||
- 架构: Input Proj → 2-layer BiLSTM(128) → 4-head MHA → 3 独立 head
|
||||
- 损失: Focal Loss (alpha=0.25, gamma=2.0)
|
||||
- 优化器: AdamW (lr=1e-3)
|
||||
- 调度器: ReduceLROnPlateau (patience=8)
|
||||
- 早停: 15 epochs
|
||||
- 设备: CUDA (RTX 4060)
|
||||
|
||||
**XGBoost 基线**:
|
||||
- 输入: X.reshape(N, 14*D) 展平
|
||||
- 3 个独立 XGBClassifier (n_estimators=200, max_depth=6, lr=0.05)
|
||||
|
||||
**分割**: 时间顺序 70/15/15(约 2010-2020 / 2021-2022 / 2023-2024)
|
||||
|
||||
**产出**:
|
||||
- `outputs/models/best_model.pt`
|
||||
- `outputs/logs/training_history.json`
|
||||
- `outputs/models/test_predictions.npz`
|
||||
|
||||
## 阶段 3:评估
|
||||
|
||||
**入口**: `python -m src.models.evaluate`
|
||||
|
||||
**产出图表**(中文标注, 300dpi):
|
||||
- `outputs/figures/confusion_matrix.png` — 3×2 子图(LSTM/XGBoost × 3时间尺度)
|
||||
- `outputs/figures/model_comparison.png` — F1 + Accuracy 柱状对比图
|
||||
- `outputs/figures/training_curves.png` — loss/acc 曲线
|
||||
|
||||
## 阶段 4:Web 大屏
|
||||
|
||||
**入口**: `python -m src.web.app`
|
||||
|
||||
**验证项**:
|
||||
- 6 面板正常渲染(温度趋势/风险展示/人口饼图/时间柱状/暴露-反应/历史回顾)
|
||||
- 4 API 端点返回正确格式
|
||||
- 模型预测在 Web 中正常展示(或 fallback 降级)
|
||||
|
||||
## 阶段 5:LaTeX 论文
|
||||
|
||||
**入口**: `cd thesis && make`(xelatex + biber)
|
||||
|
||||
**填充内容**:
|
||||
- 第 1 章:研究背景(已有框架)
|
||||
- 第 2 章:数据与方法 → 填入 ERA5 变量表、NOAA 公式、模型架构
|
||||
- 第 3 章:模型设计 → LSTM-Attention + XGBoost 架构图
|
||||
- 第 4 章:实验与结果 → 插入评估图表、分类报告
|
||||
- 第 5 章:可视化系统 → Web 大屏截图
|
||||
- 第 6-7 章:讨论与结论
|
||||
|
||||
## 依赖与前置条件
|
||||
|
||||
- Python 3.13 + CUDA PyTorch 2.12.0+cu126
|
||||
- GPU: RTX 4060 Laptop (8GB VRAM)
|
||||
- ERA5 数据: 焦作 180 + 郑州 180 NetCDF
|
||||
- 外部数据: mortality_population.csv, exposure_response.csv
|
||||
@@ -39,18 +39,25 @@ def download_one_month(city: str, year: int, month: int) -> bool:
|
||||
return True # 已存在,跳过
|
||||
|
||||
request = build_request(city, year, month)
|
||||
for attempt in range(1, 4):
|
||||
for attempt in range(1, 6):
|
||||
try:
|
||||
logger.info("请求 %s %d-%02d (第 %d/5 次)", city, year, month, attempt)
|
||||
client.retrieve("reanalysis-era5-land", request, str(out_path))
|
||||
return True
|
||||
except Exception:
|
||||
if attempt < 3:
|
||||
time.sleep(30)
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return True
|
||||
else:
|
||||
logger.warning("文件为空 %s %d-%02d,重试", city, year, month)
|
||||
except Exception as e:
|
||||
delay = 60 * attempt
|
||||
logger.warning("失败 %s %d-%02d (第 %d/5 次): %s,%ds 后重试",
|
||||
city, year, month, attempt, str(e)[:100], delay)
|
||||
if attempt < 5:
|
||||
time.sleep(delay)
|
||||
return False
|
||||
|
||||
|
||||
def download_city(city: str, start_year: int = ERA5_START_YEAR,
|
||||
end_year: int = ERA5_END_YEAR, max_workers: int = 3):
|
||||
end_year: int = ERA5_END_YEAR, max_workers: int = 1):
|
||||
"""并行下载(3线程),兼顾速度和 CDS 限流"""
|
||||
name = CITIES[city]["name"]
|
||||
tasks = [(city, y, m) for y in range(start_year, end_year + 1) for m in range(1, 13)]
|
||||
|
||||
Reference in New Issue
Block a user