fix: 统一 NPZ 命名格式,添加 NetCDF ZIP 解压工具和 h5netcdf/h5py 依赖
This commit is contained in:
@@ -10,6 +10,8 @@ dependencies = [
|
||||
"netcdf4>=1.6",
|
||||
"cdsapi>=0.7",
|
||||
"torch>=2.1",
|
||||
"torchvision>=0.19",
|
||||
"torchaudio>=2.5",
|
||||
"pytorch-lightning>=2.1",
|
||||
"xgboost>=2.0",
|
||||
"scikit-learn>=1.3",
|
||||
@@ -19,4 +21,16 @@ dependencies = [
|
||||
"jupyter>=1.0",
|
||||
"tqdm>=4.66",
|
||||
"scipy>=1.11",
|
||||
"h5netcdf>=1.8.1",
|
||||
"h5py>=3.16.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cu126" }
|
||||
torchvision = { index = "pytorch-cu126" }
|
||||
torchaudio = { index = "pytorch-cu126" }
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu126"
|
||||
url = "https://download.pytorch.org/whl/cu126"
|
||||
explicit = true
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
"""批量解压 CDS 下载的 ZIP 格式 NetCDF 文件"""
|
||||
import logging
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from src.utils.config import DATA_RAW, CITIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_one(nc_path: Path) -> bool:
|
||||
"""解压单个 ZIP 伪装的 .nc 文件,原地替换为真实 NetCDF"""
|
||||
if nc_path.stat().st_size == 0:
|
||||
return False
|
||||
|
||||
with open(nc_path, "rb") as f:
|
||||
header = f.read(4)
|
||||
|
||||
if header[:2] == b"CDF": # 已经是真实 NetCDF
|
||||
return True
|
||||
if header[:2] != b"PK": # 不是 ZIP,跳过
|
||||
return False
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(nc_path) as zf:
|
||||
nc_names = [n for n in zf.namelist() if n.endswith(".nc")]
|
||||
if not nc_names:
|
||||
return False
|
||||
data = zf.read(nc_names[0])
|
||||
nc_path.write_bytes(data)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def extract_all():
|
||||
for city in CITIES:
|
||||
city_dir = Path(DATA_RAW) / "era5" / city
|
||||
files = sorted(city_dir.glob("*.nc"))
|
||||
done = 0
|
||||
for f in files:
|
||||
if extract_one(f):
|
||||
done += 1
|
||||
logger.info("%s: %d/%d 已解压", city, done, len(files))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||
extract_all()
|
||||
@@ -534,7 +534,7 @@ def preprocess_all() -> None:
|
||||
if not saved_feature_cols:
|
||||
saved_feature_cols = feature_cols
|
||||
|
||||
npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz"
|
||||
npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz"
|
||||
np.savez_compressed(
|
||||
npz_path,
|
||||
X=X,
|
||||
@@ -561,7 +561,7 @@ def preprocess_all() -> None:
|
||||
# 合并 NPZ
|
||||
all_X, all_y = [], []
|
||||
for city_key in CITIES:
|
||||
npz_path = DATA_PROCESSED / f"sequences_{city_key}.npz"
|
||||
npz_path = DATA_PROCESSED / f"{city_key}_sequences.npz"
|
||||
if npz_path.exists():
|
||||
data = np.load(npz_path, allow_pickle=True)
|
||||
all_X.append(data["X"])
|
||||
|
||||
Reference in New Issue
Block a user