From 106ce423d3881b1040e6e5614c846c5dfdd97457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Thu, 28 May 2026 10:24:33 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20WeightedRandomSamp?= =?UTF-8?q?ler=20=E8=A7=A3=E5=86=B3=E7=B1=BB=E5=88=AB=E4=B8=8D=E5=B9=B3?= =?UTF-8?q?=E8=A1=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/train.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/models/train.py b/src/models/train.py index c4a1731..68793b7 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from sklearn.metrics import accuracy_score, f1_score from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler from src.models.lstm_attention import HeatRiskPredictor from src.utils.config import ( @@ -147,11 +147,21 @@ def train() -> HeatRiskPredictor: X_test_t = torch.tensor(X_test_np, dtype=torch.float32) y_test_t = torch.tensor(y_test_np, dtype=torch.long) - # -------------------- DataLoader -------------------- + # -------------------- DataLoader (加权采样) -------------------- + # 基于 y_short 的类别权重,解决极度不平衡问题 + y_short_labels = y_train_np[:, 0] + class_counts = np.bincount(y_short_labels) + class_weights = 1.0 / class_counts + sample_weights = class_weights[y_short_labels] + sampler = WeightedRandomSampler( + weights=torch.from_numpy(sample_weights).float(), + num_samples=len(y_short_labels), + replacement=True, + ) train_loader = DataLoader( TensorDataset(X_train_t, y_train_t), batch_size=BATCH_SIZE, - shuffle=True, + sampler=sampler, ) val_loader = DataLoader( TensorDataset(X_val_t, y_val_t),