From a2c9ba4863f164c115eddf6646c3383f42373a9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Thu, 28 May 2026 08:26:18 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E7=BB=9F=E4=B8=80=20NPZ=20=E5=91=BD?= =?UTF-8?q?=E5=90=8D=E6=A0=BC=E5=BC=8F=EF=BC=8C=E6=B7=BB=E5=8A=A0=20NetCDF?= =?UTF-8?q?=20ZIP=20=E8=A7=A3=E5=8E=8B=E5=B7=A5=E5=85=B7=E5=92=8C=20h5netc?= =?UTF-8?q?df/h5py=20=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 14 +++++++++++ src/data/extract_zips.py | 50 ++++++++++++++++++++++++++++++++++++++++ src/data/preprocess.py | 4 ++-- 3 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 src/data/extract_zips.py diff --git a/pyproject.toml b/pyproject.toml index 336cb0f..a997dcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ dependencies = [ "netcdf4>=1.6", "cdsapi>=0.7", "torch>=2.1", + "torchvision>=0.19", + "torchaudio>=2.5", "pytorch-lightning>=2.1", "xgboost>=2.0", "scikit-learn>=1.3", @@ -19,4 +21,16 @@ dependencies = [ "jupyter>=1.0", "tqdm>=4.66", "scipy>=1.11", + "h5netcdf>=1.8.1", + "h5py>=3.16.0", ] + +[tool.uv.sources] +torch = { index = "pytorch-cu126" } +torchvision = { index = "pytorch-cu126" } +torchaudio = { index = "pytorch-cu126" } + +[[tool.uv.index]] +name = "pytorch-cu126" +url = "https://download.pytorch.org/whl/cu126" +explicit = true diff --git a/src/data/extract_zips.py b/src/data/extract_zips.py new file mode 100644 index 0000000..03698e2 --- /dev/null +++ b/src/data/extract_zips.py @@ -0,0 +1,50 @@ +"""批量解压 CDS 下载的 ZIP 格式 NetCDF 文件""" +import logging +import zipfile +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor + +from src.utils.config import DATA_RAW, CITIES + +logger = logging.getLogger(__name__) + + +def extract_one(nc_path: Path) -> bool: + """解压单个 ZIP 伪装的 .nc 文件,原地替换为真实 NetCDF""" + if nc_path.stat().st_size == 0: + return False + + with open(nc_path, "rb") as f: + header = f.read(4) + + if header[:2] == b"CDF": # 已经是真实 NetCDF + return True + if header[:2] != b"PK": # 不是 ZIP,跳过 + return False + + try: + with zipfile.ZipFile(nc_path) as zf: + nc_names = [n for n in zf.namelist() if n.endswith(".nc")] + if not nc_names: + return False + data = zf.read(nc_names[0]) + nc_path.write_bytes(data) + return True + except Exception: + return False + + +def extract_all(): + for city in CITIES: + city_dir = Path(DATA_RAW) / "era5" / city + files = sorted(city_dir.glob("*.nc")) + done = 0 + for f in files: + if extract_one(f): + done += 1 + logger.info("%s: %d/%d 已解压", city, done, len(files)) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") + extract_all() diff --git a/src/data/preprocess.py b/src/data/preprocess.py index 683bc65..01b8e76 100644 --- a/src/data/preprocess.py +++ b/src/data/preprocess.py @@ -534,7 +534,7 @@ def preprocess_all() -> None: if not saved_feature_cols: saved_feature_cols = feature_cols - npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz" + npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz" np.savez_compressed( npz_path, X=X, @@ -561,7 +561,7 @@ def preprocess_all() -> None: # 合并 NPZ all_X, all_y = [], [] for city_key in CITIES: - npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz" + npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz" if npz_path.exists(): data = np.load(npz_path, allow_pickle=True) all_X.append(data["X"])