From 23f3f9e4fb2069523a6c815c3ef6036f147ba9c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Thu, 28 May 2026 12:23:42 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=B8=A9=E5=92=8C=E9=87=87=E6=A0=B7?= =?UTF-8?q?=E5=99=A8=20+=20=E7=BA=AF=20FocalLoss,=20batch=5Fsize=2032?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/train.py | 30 +++++++++++++++++------------- src/utils/config.py | 2 +- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/models/train.py b/src/models/train.py index a807929..0378e77 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -29,18 +29,15 @@ from src.utils.config import ( class FocalLoss(nn.Module): - """Focal Loss with class weights — 解决极度不平衡""" + """Focal Loss — 聚焦困难样本""" - def __init__(self, alpha: float = 0.75, gamma: float = 3.0, - class_weight: torch.Tensor | None = None): + def __init__(self, alpha: float = 0.5, gamma: float = 2.0): super().__init__() self.alpha = alpha self.gamma = gamma - self.class_weight = class_weight def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - ce = F.cross_entropy(logits, targets, reduction="none", - weight=self.class_weight) + ce = F.cross_entropy(logits, targets, reduction="none") pt = torch.exp(-ce) focal = self.alpha * (1 - pt) ** self.gamma * ce return focal.mean() @@ -150,11 +147,22 @@ def train() -> HeatRiskPredictor: X_test_t = torch.tensor(X_test_np, dtype=torch.float32) y_test_t = torch.tensor(y_test_np, dtype=torch.long) - # -------------------- DataLoader -------------------- + # -------------------- DataLoader (加权采样, 温和过采样) -------------------- + from torch.utils.data import WeightedRandomSampler + y_short = y_train_np[:, 0] + class_counts = np.bincount(y_short) + # 温和权重: 比 sqrt 更弱, 避免 overcorrect + weights = 1.0 / np.sqrt(class_counts) + sample_w = weights[y_short] + sampler = WeightedRandomSampler( + torch.from_numpy(sample_w).float(), + num_samples=len(y_short), + replacement=True, + ) train_loader = DataLoader( TensorDataset(X_train_t, y_train_t), batch_size=BATCH_SIZE, - shuffle=True, + sampler=sampler, ) val_loader = DataLoader( TensorDataset(X_val_t, y_val_t), @@ -168,11 +176,7 @@ def train() -> HeatRiskPredictor: print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}") # -------------------- 损失、优化器、调度器 -------------------- - # 手动温和类别权重: [低, 中, 高, 严重] - class_weights_tensor = torch.tensor( - [1.0, 3.0, 5.0, 8.0], dtype=torch.float32 - ).to(device) - focal_loss = FocalLoss(class_weight=class_weights_tensor) + focal_loss = FocalLoss() optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5) diff --git a/src/utils/config.py b/src/utils/config.py index aa3c600..080f02b 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -34,7 +34,7 @@ ERA5_VARIABLES = [ # 模型配置 LOOKBACK_DAYS = 14 -BATCH_SIZE = 64 +BATCH_SIZE = 32 LEARNING_RATE = 1e-3 MAX_EPOCHS = 50 EARLY_STOP_PATIENCE = 15