diff --git a/src/models/train.py b/src/models/train.py index b2adfc0..a807929 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -31,7 +31,7 @@ from src.utils.config import ( class FocalLoss(nn.Module): """Focal Loss with class weights — 解决极度不平衡""" - def __init__(self, alpha: float = 0.5, gamma: float = 2.0, + def __init__(self, alpha: float = 0.75, gamma: float = 3.0, class_weight: torch.Tensor | None = None): super().__init__() self.alpha = alpha @@ -168,12 +168,10 @@ def train() -> HeatRiskPredictor: print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}") # -------------------- 损失、优化器、调度器 -------------------- - # 基于训练集类别分布计算权重 (sqrt inverse freq, 温和加权) - class_counts = np.bincount(y_train_np[:, 0]) + # 手动温和类别权重: [低, 中, 高, 严重] class_weights_tensor = torch.tensor( - 1.0 / np.sqrt(class_counts + 1), dtype=torch.float32 + [1.0, 3.0, 5.0, 8.0], dtype=torch.float32 ).to(device) - class_weights_tensor = class_weights_tensor / class_weights_tensor.sum() * 4 focal_loss = FocalLoss(class_weight=class_weights_tensor) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)