diff --git a/pyproject.toml b/pyproject.toml index 336cb0f..a997dcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ dependencies = [ "netcdf4>=1.6", "cdsapi>=0.7", "torch>=2.1", + "torchvision>=0.19", + "torchaudio>=2.5", "pytorch-lightning>=2.1", "xgboost>=2.0", "scikit-learn>=1.3", @@ -19,4 +21,16 @@ dependencies = [ "jupyter>=1.0", "tqdm>=4.66", "scipy>=1.11", + "h5netcdf>=1.8.1", + "h5py>=3.16.0", ] + +[tool.uv.sources] +torch = { index = "pytorch-cu126" } +torchvision = { index = "pytorch-cu126" } +torchaudio = { index = "pytorch-cu126" } + +[[tool.uv.index]] +name = "pytorch-cu126" +url = "https://download.pytorch.org/whl/cu126" +explicit = true diff --git a/src/data/extract_zips.py b/src/data/extract_zips.py new file mode 100644 index 0000000..03698e2 --- /dev/null +++ b/src/data/extract_zips.py @@ -0,0 +1,50 @@ +"""批量解压 CDS 下载的 ZIP 格式 NetCDF 文件""" +import logging +import zipfile +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor + +from src.utils.config import DATA_RAW, CITIES + +logger = logging.getLogger(__name__) + + +def extract_one(nc_path: Path) -> bool: + """解压单个 ZIP 伪装的 .nc 文件,原地替换为真实 NetCDF""" + if nc_path.stat().st_size == 0: + return False + + with open(nc_path, "rb") as f: + header = f.read(4) + + if header[:2] == b"CDF": # 已经是真实 NetCDF + return True + if header[:2] != b"PK": # 不是 ZIP,跳过 + return False + + try: + with zipfile.ZipFile(nc_path) as zf: + nc_names = [n for n in zf.namelist() if n.endswith(".nc")] + if not nc_names: + return False + data = zf.read(nc_names[0]) + nc_path.write_bytes(data) + return True + except Exception: + return False + + +def extract_all(): + for city in CITIES: + city_dir = Path(DATA_RAW) / "era5" / city + files = sorted(city_dir.glob("*.nc")) + done = 0 + for f in files: + if extract_one(f): + done += 1 + logger.info("%s: %d/%d 已解压", city, done, len(files)) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") + extract_all() diff --git a/src/data/preprocess.py b/src/data/preprocess.py index 683bc65..01b8e76 100644 --- a/src/data/preprocess.py +++ b/src/data/preprocess.py @@ -534,7 +534,7 @@ def preprocess_all() -> None: if not saved_feature_cols: saved_feature_cols = feature_cols - npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz" + npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz" np.savez_compressed( npz_path, X=X, @@ -561,7 +561,7 @@ def preprocess_all() -> None: # 合并 NPZ all_X, all_y = [], [] for city_key in CITIES: - npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz" + npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz" if npz_path.exists(): data = np.load(npz_path, allow_pickle=True) all_X.append(data["X"])