fix: 温和采样器 + 纯 FocalLoss, batch_size 32
This commit is contained in:
+17
-13
@@ -29,18 +29,15 @@ from src.utils.config import (
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""Focal Loss with class weights — 解决极度不平衡"""
|
||||
"""Focal Loss — 聚焦困难样本"""
|
||||
|
||||
def __init__(self, alpha: float = 0.75, gamma: float = 3.0,
|
||||
class_weight: torch.Tensor | None = None):
|
||||
def __init__(self, alpha: float = 0.5, gamma: float = 2.0):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.class_weight = class_weight
|
||||
|
||||
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
ce = F.cross_entropy(logits, targets, reduction="none",
|
||||
weight=self.class_weight)
|
||||
ce = F.cross_entropy(logits, targets, reduction="none")
|
||||
pt = torch.exp(-ce)
|
||||
focal = self.alpha * (1 - pt) ** self.gamma * ce
|
||||
return focal.mean()
|
||||
@@ -150,11 +147,22 @@ def train() -> HeatRiskPredictor:
|
||||
X_test_t = torch.tensor(X_test_np, dtype=torch.float32)
|
||||
y_test_t = torch.tensor(y_test_np, dtype=torch.long)
|
||||
|
||||
# -------------------- DataLoader --------------------
|
||||
# -------------------- DataLoader (加权采样, 温和过采样) --------------------
|
||||
from torch.utils.data import WeightedRandomSampler
|
||||
y_short = y_train_np[:, 0]
|
||||
class_counts = np.bincount(y_short)
|
||||
# 温和权重: 比 sqrt 更弱, 避免 overcorrect
|
||||
weights = 1.0 / np.sqrt(class_counts)
|
||||
sample_w = weights[y_short]
|
||||
sampler = WeightedRandomSampler(
|
||||
torch.from_numpy(sample_w).float(),
|
||||
num_samples=len(y_short),
|
||||
replacement=True,
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
TensorDataset(X_train_t, y_train_t),
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
sampler=sampler,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
TensorDataset(X_val_t, y_val_t),
|
||||
@@ -168,11 +176,7 @@ def train() -> HeatRiskPredictor:
|
||||
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# -------------------- 损失、优化器、调度器 --------------------
|
||||
# 手动温和类别权重: [低, 中, 高, 严重]
|
||||
class_weights_tensor = torch.tensor(
|
||||
[1.0, 3.0, 5.0, 8.0], dtype=torch.float32
|
||||
).to(device)
|
||||
focal_loss = FocalLoss(class_weight=class_weights_tensor)
|
||||
focal_loss = FocalLoss()
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
|
||||
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
|
||||
|
||||
|
||||
+1
-1
@@ -34,7 +34,7 @@ ERA5_VARIABLES = [
|
||||
|
||||
# 模型配置
|
||||
LOOKBACK_DAYS = 14
|
||||
BATCH_SIZE = 64
|
||||
BATCH_SIZE = 32
|
||||
LEARNING_RATE = 1e-3
|
||||
MAX_EPOCHS = 50
|
||||
EARLY_STOP_PATIENCE = 15
|
||||
|
||||
Reference in New Issue
Block a user