diff --git a/src/models/train.py b/src/models/train.py index 18f9744..b2adfc0 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -168,10 +168,10 @@ def train() -> HeatRiskPredictor: print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}") # -------------------- 损失、优化器、调度器 -------------------- - # 基于训练集类别分布计算权重 + # 基于训练集类别分布计算权重 (sqrt inverse freq, 温和加权) class_counts = np.bincount(y_train_np[:, 0]) class_weights_tensor = torch.tensor( - 1.0 / (class_counts + 1), dtype=torch.float32 + 1.0 / np.sqrt(class_counts + 1), dtype=torch.float32 ).to(device) class_weights_tensor = class_weights_tensor / class_weights_tensor.sum() * 4 focal_loss = FocalLoss(class_weight=class_weights_tensor) @@ -256,7 +256,7 @@ def train() -> HeatRiskPredictor: scheduler.step(avg_val_loss) # ---- 打印进度 ---- - if epoch % 5 == 0 or epoch == 1: + if True: # print every epoch lr_now = optimizer.param_groups[0]["lr"] print( f"Epoch {epoch:3d}/{MAX_EPOCHS} | "