fix: 统一 NPZ 命名格式,添加 NetCDF ZIP 解压工具和 h5netcdf/h5py 依赖

This commit is contained in:
2026-05-28 08:26:18 +08:00
parent 07468266b4
commit a2c9ba4863
3 changed files with 66 additions and 2 deletions
+14
View File
@@ -10,6 +10,8 @@ dependencies = [
"netcdf4>=1.6", "netcdf4>=1.6",
"cdsapi>=0.7", "cdsapi>=0.7",
"torch>=2.1", "torch>=2.1",
"torchvision>=0.19",
"torchaudio>=2.5",
"pytorch-lightning>=2.1", "pytorch-lightning>=2.1",
"xgboost>=2.0", "xgboost>=2.0",
"scikit-learn>=1.3", "scikit-learn>=1.3",
@@ -19,4 +21,16 @@ dependencies = [
"jupyter>=1.0", "jupyter>=1.0",
"tqdm>=4.66", "tqdm>=4.66",
"scipy>=1.11", "scipy>=1.11",
"h5netcdf>=1.8.1",
"h5py>=3.16.0",
] ]
[tool.uv.sources]
torch = { index = "pytorch-cu126" }
torchvision = { index = "pytorch-cu126" }
torchaudio = { index = "pytorch-cu126" }
[[tool.uv.index]]
name = "pytorch-cu126"
url = "https://download.pytorch.org/whl/cu126"
explicit = true
+50
View File
@@ -0,0 +1,50 @@
"""批量解压 CDS 下载的 ZIP 格式 NetCDF 文件"""
import logging
import zipfile
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from src.utils.config import DATA_RAW, CITIES
logger = logging.getLogger(__name__)
def extract_one(nc_path: Path) -> bool:
"""解压单个 ZIP 伪装的 .nc 文件,原地替换为真实 NetCDF"""
if nc_path.stat().st_size == 0:
return False
with open(nc_path, "rb") as f:
header = f.read(4)
if header[:2] == b"CDF": # 已经是真实 NetCDF
return True
if header[:2] != b"PK": # 不是 ZIP,跳过
return False
try:
with zipfile.ZipFile(nc_path) as zf:
nc_names = [n for n in zf.namelist() if n.endswith(".nc")]
if not nc_names:
return False
data = zf.read(nc_names[0])
nc_path.write_bytes(data)
return True
except Exception:
return False
def extract_all():
for city in CITIES:
city_dir = Path(DATA_RAW) / "era5" / city
files = sorted(city_dir.glob("*.nc"))
done = 0
for f in files:
if extract_one(f):
done += 1
logger.info("%s: %d/%d 已解压", city, done, len(files))
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
extract_all()
+2 -2
View File
@@ -534,7 +534,7 @@ def preprocess_all() -> None:
if not saved_feature_cols: if not saved_feature_cols:
saved_feature_cols = feature_cols saved_feature_cols = feature_cols
npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz" npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz"
np.savez_compressed( np.savez_compressed(
npz_path, npz_path,
X=X, X=X,
@@ -561,7 +561,7 @@ def preprocess_all() -> None:
# 合并 NPZ # 合并 NPZ
all_X, all_y = [], [] all_X, all_y = [], []
for city_key in CITIES: for city_key in CITIES:
npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz" npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz"
if npz_path.exists(): if npz_path.exists():
data = np.load(npz_path, allow_pickle=True) data = np.load(npz_path, allow_pickle=True)
all_X.append(data["X"]) all_X.append(data["X"])