fix: 改用类别加权 FocalLoss 替代 WeightedSampler,更快更稳定

This commit is contained in:
2026-05-28 11:16:18 +08:00
parent 106ce423d3
commit c0e2bdca72
+16 -17
View File
@@ -14,7 +14,7 @@ import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from torch.utils.data import DataLoader, TensorDataset
from src.models.lstm_attention import HeatRiskPredictor
from src.utils.config import (
@@ -29,15 +29,18 @@ from src.utils.config import (
class FocalLoss(nn.Module):
"""Focal Loss — 聚焦困难样本,缓解类别不平衡"""
"""Focal Loss with class weights — 解决极度不平衡"""
def __init__(self, alpha: float = 0.5, gamma: float = 2.0):
def __init__(self, alpha: float = 0.5, gamma: float = 2.0,
class_weight: torch.Tensor | None = None):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.class_weight = class_weight
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
ce = F.cross_entropy(logits, targets, reduction="none")
ce = F.cross_entropy(logits, targets, reduction="none",
weight=self.class_weight)
pt = torch.exp(-ce)
focal = self.alpha * (1 - pt) ** self.gamma * ce
return focal.mean()
@@ -147,21 +150,11 @@ def train() -> HeatRiskPredictor:
X_test_t = torch.tensor(X_test_np, dtype=torch.float32)
y_test_t = torch.tensor(y_test_np, dtype=torch.long)
# -------------------- DataLoader (加权采样) --------------------
# 基于 y_short 的类别权重,解决极度不平衡问题
y_short_labels = y_train_np[:, 0]
class_counts = np.bincount(y_short_labels)
class_weights = 1.0 / class_counts
sample_weights = class_weights[y_short_labels]
sampler = WeightedRandomSampler(
weights=torch.from_numpy(sample_weights).float(),
num_samples=len(y_short_labels),
replacement=True,
)
# -------------------- DataLoader --------------------
train_loader = DataLoader(
TensorDataset(X_train_t, y_train_t),
batch_size=BATCH_SIZE,
sampler=sampler,
shuffle=True,
)
val_loader = DataLoader(
TensorDataset(X_val_t, y_val_t),
@@ -175,7 +168,13 @@ def train() -> HeatRiskPredictor:
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
# -------------------- 损失、优化器、调度器 --------------------
focal_loss = FocalLoss()
# 基于训练集类别分布计算权重
class_counts = np.bincount(y_train_np[:, 0])
class_weights_tensor = torch.tensor(
1.0 / (class_counts + 1), dtype=torch.float32
).to(device)
class_weights_tensor = class_weights_tensor / class_weights_tensor.sum() * 4
focal_loss = FocalLoss(class_weight=class_weights_tensor)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)