fix: 手动温和权重 [1,3,5,8] + gamma 3.0
This commit is contained in:
+3
-5
@@ -31,7 +31,7 @@ from src.utils.config import (
|
|||||||
class FocalLoss(nn.Module):
|
class FocalLoss(nn.Module):
|
||||||
"""Focal Loss with class weights — 解决极度不平衡"""
|
"""Focal Loss with class weights — 解决极度不平衡"""
|
||||||
|
|
||||||
def __init__(self, alpha: float = 0.5, gamma: float = 2.0,
|
def __init__(self, alpha: float = 0.75, gamma: float = 3.0,
|
||||||
class_weight: torch.Tensor | None = None):
|
class_weight: torch.Tensor | None = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
@@ -168,12 +168,10 @@ def train() -> HeatRiskPredictor:
|
|||||||
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
|
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
# -------------------- 损失、优化器、调度器 --------------------
|
# -------------------- 损失、优化器、调度器 --------------------
|
||||||
# 基于训练集类别分布计算权重 (sqrt inverse freq, 温和加权)
|
# 手动温和类别权重: [低, 中, 高, 严重]
|
||||||
class_counts = np.bincount(y_train_np[:, 0])
|
|
||||||
class_weights_tensor = torch.tensor(
|
class_weights_tensor = torch.tensor(
|
||||||
1.0 / np.sqrt(class_counts + 1), dtype=torch.float32
|
[1.0, 3.0, 5.0, 8.0], dtype=torch.float32
|
||||||
).to(device)
|
).to(device)
|
||||||
class_weights_tensor = class_weights_tensor / class_weights_tensor.sum() * 4
|
|
||||||
focal_loss = FocalLoss(class_weight=class_weights_tensor)
|
focal_loss = FocalLoss(class_weight=class_weights_tensor)
|
||||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
|
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
|
||||||
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
|
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
|
||||||
|
|||||||
Reference in New Issue
Block a user