feat: 添加 WeightedRandomSampler 解决类别不平衡

This commit is contained in:
2026-05-28 10:24:33 +08:00
parent 0d178b0d57
commit 106ce423d3
+13 -3
View File
@@ -14,7 +14,7 @@ import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from src.models.lstm_attention import HeatRiskPredictor
from src.utils.config import (
@@ -147,11 +147,21 @@ def train() -> HeatRiskPredictor:
X_test_t = torch.tensor(X_test_np, dtype=torch.float32)
y_test_t = torch.tensor(y_test_np, dtype=torch.long)
# -------------------- DataLoader --------------------
# -------------------- DataLoader (加权采样) --------------------
# 基于 y_short 的类别权重,解决极度不平衡问题
y_short_labels = y_train_np[:, 0]
class_counts = np.bincount(y_short_labels)
class_weights = 1.0 / class_counts
sample_weights = class_weights[y_short_labels]
sampler = WeightedRandomSampler(
weights=torch.from_numpy(sample_weights).float(),
num_samples=len(y_short_labels),
replacement=True,
)
train_loader = DataLoader(
TensorDataset(X_train_t, y_train_t),
batch_size=BATCH_SIZE,
shuffle=True,
sampler=sampler,
)
val_loader = DataLoader(
TensorDataset(X_val_t, y_val_t),