fix: 温和采样器 + 纯 FocalLoss, batch_size 32
This commit is contained in:
+17
-13
@@ -29,18 +29,15 @@ from src.utils.config import (
|
|||||||
|
|
||||||
|
|
||||||
class FocalLoss(nn.Module):
|
class FocalLoss(nn.Module):
|
||||||
"""Focal Loss with class weights — 解决极度不平衡"""
|
"""Focal Loss — 聚焦困难样本"""
|
||||||
|
|
||||||
def __init__(self, alpha: float = 0.75, gamma: float = 3.0,
|
def __init__(self, alpha: float = 0.5, gamma: float = 2.0):
|
||||||
class_weight: torch.Tensor | None = None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.class_weight = class_weight
|
|
||||||
|
|
||||||
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||||
ce = F.cross_entropy(logits, targets, reduction="none",
|
ce = F.cross_entropy(logits, targets, reduction="none")
|
||||||
weight=self.class_weight)
|
|
||||||
pt = torch.exp(-ce)
|
pt = torch.exp(-ce)
|
||||||
focal = self.alpha * (1 - pt) ** self.gamma * ce
|
focal = self.alpha * (1 - pt) ** self.gamma * ce
|
||||||
return focal.mean()
|
return focal.mean()
|
||||||
@@ -150,11 +147,22 @@ def train() -> HeatRiskPredictor:
|
|||||||
X_test_t = torch.tensor(X_test_np, dtype=torch.float32)
|
X_test_t = torch.tensor(X_test_np, dtype=torch.float32)
|
||||||
y_test_t = torch.tensor(y_test_np, dtype=torch.long)
|
y_test_t = torch.tensor(y_test_np, dtype=torch.long)
|
||||||
|
|
||||||
# -------------------- DataLoader --------------------
|
# -------------------- DataLoader (加权采样, 温和过采样) --------------------
|
||||||
|
from torch.utils.data import WeightedRandomSampler
|
||||||
|
y_short = y_train_np[:, 0]
|
||||||
|
class_counts = np.bincount(y_short)
|
||||||
|
# 温和权重: 比 sqrt 更弱, 避免 overcorrect
|
||||||
|
weights = 1.0 / np.sqrt(class_counts)
|
||||||
|
sample_w = weights[y_short]
|
||||||
|
sampler = WeightedRandomSampler(
|
||||||
|
torch.from_numpy(sample_w).float(),
|
||||||
|
num_samples=len(y_short),
|
||||||
|
replacement=True,
|
||||||
|
)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
TensorDataset(X_train_t, y_train_t),
|
TensorDataset(X_train_t, y_train_t),
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
shuffle=True,
|
sampler=sampler,
|
||||||
)
|
)
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
TensorDataset(X_val_t, y_val_t),
|
TensorDataset(X_val_t, y_val_t),
|
||||||
@@ -168,11 +176,7 @@ def train() -> HeatRiskPredictor:
|
|||||||
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
|
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
# -------------------- 损失、优化器、调度器 --------------------
|
# -------------------- 损失、优化器、调度器 --------------------
|
||||||
# 手动温和类别权重: [低, 中, 高, 严重]
|
focal_loss = FocalLoss()
|
||||||
class_weights_tensor = torch.tensor(
|
|
||||||
[1.0, 3.0, 5.0, 8.0], dtype=torch.float32
|
|
||||||
).to(device)
|
|
||||||
focal_loss = FocalLoss(class_weight=class_weights_tensor)
|
|
||||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
|
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
|
||||||
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
|
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -34,7 +34,7 @@ ERA5_VARIABLES = [
|
|||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
LOOKBACK_DAYS = 14
|
LOOKBACK_DAYS = 14
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 32
|
||||||
LEARNING_RATE = 1e-3
|
LEARNING_RATE = 1e-3
|
||||||
MAX_EPOCHS = 50
|
MAX_EPOCHS = 50
|
||||||
EARLY_STOP_PATIENCE = 15
|
EARLY_STOP_PATIENCE = 15
|
||||||
|
|||||||
Reference in New Issue
Block a user