diff --git a/src/models/train.py b/src/models/train.py index 68793b7..18f9744 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from sklearn.metrics import accuracy_score, f1_score from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler +from torch.utils.data import DataLoader, TensorDataset from src.models.lstm_attention import HeatRiskPredictor from src.utils.config import ( @@ -29,15 +29,18 @@ from src.utils.config import ( class FocalLoss(nn.Module): - """Focal Loss — 聚焦困难样本,缓解类别不平衡""" + """Focal Loss with class weights — 解决极度不平衡""" - def __init__(self, alpha: float = 0.5, gamma: float = 2.0): + def __init__(self, alpha: float = 0.5, gamma: float = 2.0, + class_weight: torch.Tensor | None = None): super().__init__() self.alpha = alpha self.gamma = gamma + self.class_weight = class_weight def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - ce = F.cross_entropy(logits, targets, reduction="none") + ce = F.cross_entropy(logits, targets, reduction="none", + weight=self.class_weight) pt = torch.exp(-ce) focal = self.alpha * (1 - pt) ** self.gamma * ce return focal.mean() @@ -147,21 +150,11 @@ def train() -> HeatRiskPredictor: X_test_t = torch.tensor(X_test_np, dtype=torch.float32) y_test_t = torch.tensor(y_test_np, dtype=torch.long) - # -------------------- DataLoader (加权采样) -------------------- - # 基于 y_short 的类别权重,解决极度不平衡问题 - y_short_labels = y_train_np[:, 0] - class_counts = np.bincount(y_short_labels) - class_weights = 1.0 / class_counts - sample_weights = class_weights[y_short_labels] - sampler = WeightedRandomSampler( - weights=torch.from_numpy(sample_weights).float(), - num_samples=len(y_short_labels), - replacement=True, - ) + # -------------------- DataLoader -------------------- train_loader = DataLoader( TensorDataset(X_train_t, y_train_t), batch_size=BATCH_SIZE, - sampler=sampler, + shuffle=True, ) val_loader = DataLoader( TensorDataset(X_val_t, y_val_t), @@ -175,7 +168,13 @@ def train() -> HeatRiskPredictor: print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}") # -------------------- 损失、优化器、调度器 -------------------- - focal_loss = FocalLoss() + # 基于训练集类别分布计算权重 + class_counts = np.bincount(y_train_np[:, 0]) + class_weights_tensor = torch.tensor( + 1.0 / (class_counts + 1), dtype=torch.float32 + ).to(device) + class_weights_tensor = class_weights_tensor / class_weights_tensor.sum() * 4 + focal_loss = FocalLoss(class_weight=class_weights_tensor) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)