fix: 统一 NPZ 命名格式,添加 NetCDF ZIP 解压工具和 h5netcdf/h5py 依赖
This commit is contained in:
@@ -10,6 +10,8 @@ dependencies = [
|
|||||||
"netcdf4>=1.6",
|
"netcdf4>=1.6",
|
||||||
"cdsapi>=0.7",
|
"cdsapi>=0.7",
|
||||||
"torch>=2.1",
|
"torch>=2.1",
|
||||||
|
"torchvision>=0.19",
|
||||||
|
"torchaudio>=2.5",
|
||||||
"pytorch-lightning>=2.1",
|
"pytorch-lightning>=2.1",
|
||||||
"xgboost>=2.0",
|
"xgboost>=2.0",
|
||||||
"scikit-learn>=1.3",
|
"scikit-learn>=1.3",
|
||||||
@@ -19,4 +21,16 @@ dependencies = [
|
|||||||
"jupyter>=1.0",
|
"jupyter>=1.0",
|
||||||
"tqdm>=4.66",
|
"tqdm>=4.66",
|
||||||
"scipy>=1.11",
|
"scipy>=1.11",
|
||||||
|
"h5netcdf>=1.8.1",
|
||||||
|
"h5py>=3.16.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = { index = "pytorch-cu126" }
|
||||||
|
torchvision = { index = "pytorch-cu126" }
|
||||||
|
torchaudio = { index = "pytorch-cu126" }
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cu126"
|
||||||
|
url = "https://download.pytorch.org/whl/cu126"
|
||||||
|
explicit = true
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
"""批量解压 CDS 下载的 ZIP 格式 NetCDF 文件"""
|
||||||
|
import logging
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from src.utils.config import DATA_RAW, CITIES
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_one(nc_path: Path) -> bool:
|
||||||
|
"""解压单个 ZIP 伪装的 .nc 文件,原地替换为真实 NetCDF"""
|
||||||
|
if nc_path.stat().st_size == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
with open(nc_path, "rb") as f:
|
||||||
|
header = f.read(4)
|
||||||
|
|
||||||
|
if header[:2] == b"CDF": # 已经是真实 NetCDF
|
||||||
|
return True
|
||||||
|
if header[:2] != b"PK": # 不是 ZIP,跳过
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(nc_path) as zf:
|
||||||
|
nc_names = [n for n in zf.namelist() if n.endswith(".nc")]
|
||||||
|
if not nc_names:
|
||||||
|
return False
|
||||||
|
data = zf.read(nc_names[0])
|
||||||
|
nc_path.write_bytes(data)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def extract_all():
|
||||||
|
for city in CITIES:
|
||||||
|
city_dir = Path(DATA_RAW) / "era5" / city
|
||||||
|
files = sorted(city_dir.glob("*.nc"))
|
||||||
|
done = 0
|
||||||
|
for f in files:
|
||||||
|
if extract_one(f):
|
||||||
|
done += 1
|
||||||
|
logger.info("%s: %d/%d 已解压", city, done, len(files))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||||
|
extract_all()
|
||||||
@@ -534,7 +534,7 @@ def preprocess_all() -> None:
|
|||||||
if not saved_feature_cols:
|
if not saved_feature_cols:
|
||||||
saved_feature_cols = feature_cols
|
saved_feature_cols = feature_cols
|
||||||
|
|
||||||
npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz"
|
npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz"
|
||||||
np.savez_compressed(
|
np.savez_compressed(
|
||||||
npz_path,
|
npz_path,
|
||||||
X=X,
|
X=X,
|
||||||
@@ -561,7 +561,7 @@ def preprocess_all() -> None:
|
|||||||
# 合并 NPZ
|
# 合并 NPZ
|
||||||
all_X, all_y = [], []
|
all_X, all_y = [], []
|
||||||
for city_key in CITIES:
|
for city_key in CITIES:
|
||||||
npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz"
|
npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz"
|
||||||
if npz_path.exists():
|
if npz_path.exists():
|
||||||
data = np.load(npz_path, allow_pickle=True)
|
data = np.load(npz_path, allow_pickle=True)
|
||||||
all_X.append(data["X"])
|
all_X.append(data["X"])
|
||||||
|
|||||||
Reference in New Issue
Block a user