fix: 温和采样器 + 纯 FocalLoss, batch_size 32

This commit is contained in:
2026-05-28 12:23:42 +08:00
parent 13f553da5d
commit 23f3f9e4fb
2 changed files with 18 additions and 14 deletions
+17 -13
View File
@@ -29,18 +29,15 @@ from src.utils.config import (
class FocalLoss(nn.Module): class FocalLoss(nn.Module):
"""Focal Loss with class weights — 解决极度不平衡""" """Focal Loss — 聚焦困难样本"""
def __init__(self, alpha: float = 0.75, gamma: float = 3.0, def __init__(self, alpha: float = 0.5, gamma: float = 2.0):
class_weight: torch.Tensor | None = None):
super().__init__() super().__init__()
self.alpha = alpha self.alpha = alpha
self.gamma = gamma self.gamma = gamma
self.class_weight = class_weight
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
ce = F.cross_entropy(logits, targets, reduction="none", ce = F.cross_entropy(logits, targets, reduction="none")
weight=self.class_weight)
pt = torch.exp(-ce) pt = torch.exp(-ce)
focal = self.alpha * (1 - pt) ** self.gamma * ce focal = self.alpha * (1 - pt) ** self.gamma * ce
return focal.mean() return focal.mean()
@@ -150,11 +147,22 @@ def train() -> HeatRiskPredictor:
X_test_t = torch.tensor(X_test_np, dtype=torch.float32) X_test_t = torch.tensor(X_test_np, dtype=torch.float32)
y_test_t = torch.tensor(y_test_np, dtype=torch.long) y_test_t = torch.tensor(y_test_np, dtype=torch.long)
# -------------------- DataLoader -------------------- # -------------------- DataLoader (加权采样, 温和过采样) --------------------
from torch.utils.data import WeightedRandomSampler
y_short = y_train_np[:, 0]
class_counts = np.bincount(y_short)
# 温和权重: 比 sqrt 更弱, 避免 overcorrect
weights = 1.0 / np.sqrt(class_counts)
sample_w = weights[y_short]
sampler = WeightedRandomSampler(
torch.from_numpy(sample_w).float(),
num_samples=len(y_short),
replacement=True,
)
train_loader = DataLoader( train_loader = DataLoader(
TensorDataset(X_train_t, y_train_t), TensorDataset(X_train_t, y_train_t),
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
shuffle=True, sampler=sampler,
) )
val_loader = DataLoader( val_loader = DataLoader(
TensorDataset(X_val_t, y_val_t), TensorDataset(X_val_t, y_val_t),
@@ -168,11 +176,7 @@ def train() -> HeatRiskPredictor:
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}") print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
# -------------------- 损失、优化器、调度器 -------------------- # -------------------- 损失、优化器、调度器 --------------------
# 手动温和类别权重: [低, 中, 高, 严重] focal_loss = FocalLoss()
class_weights_tensor = torch.tensor(
[1.0, 3.0, 5.0, 8.0], dtype=torch.float32
).to(device)
focal_loss = FocalLoss(class_weight=class_weights_tensor)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5) scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
+1 -1
View File
@@ -34,7 +34,7 @@ ERA5_VARIABLES = [
# 模型配置 # 模型配置
LOOKBACK_DAYS = 14 LOOKBACK_DAYS = 14
BATCH_SIZE = 64 BATCH_SIZE = 32
LEARNING_RATE = 1e-3 LEARNING_RATE = 1e-3
MAX_EPOCHS = 50 MAX_EPOCHS = 50
EARLY_STOP_PATIENCE = 15 EARLY_STOP_PATIENCE = 15