fix: sqrt 逆频率权重 + 每 epoch 打印

This commit is contained in:
2026-05-28 11:37:43 +08:00
parent 305f70b9de
commit 1a6f5b07aa
+3 -3
View File
@@ -168,10 +168,10 @@ def train() -> HeatRiskPredictor:
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
# -------------------- 损失、优化器、调度器 --------------------
# 基于训练集类别分布计算权重
# 基于训练集类别分布计算权重 (sqrt inverse freq, 温和加权)
class_counts = np.bincount(y_train_np[:, 0])
class_weights_tensor = torch.tensor(
1.0 / (class_counts + 1), dtype=torch.float32
1.0 / np.sqrt(class_counts + 1), dtype=torch.float32
).to(device)
class_weights_tensor = class_weights_tensor / class_weights_tensor.sum() * 4
focal_loss = FocalLoss(class_weight=class_weights_tensor)
@@ -256,7 +256,7 @@ def train() -> HeatRiskPredictor:
scheduler.step(avg_val_loss)
# ---- 打印进度 ----
if epoch % 5 == 0 or epoch == 1:
if True: # print every epoch
lr_now = optimizer.param_groups[0]["lr"]
print(
f"Epoch {epoch:3d}/{MAX_EPOCHS} | "