fix: sqrt 逆频率权重 + 每 epoch 打印
This commit is contained in:
+3
-3
@@ -168,10 +168,10 @@ def train() -> HeatRiskPredictor:
|
||||
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# -------------------- 损失、优化器、调度器 --------------------
|
||||
# 基于训练集类别分布计算权重
|
||||
# 基于训练集类别分布计算权重 (sqrt inverse freq, 温和加权)
|
||||
class_counts = np.bincount(y_train_np[:, 0])
|
||||
class_weights_tensor = torch.tensor(
|
||||
1.0 / (class_counts + 1), dtype=torch.float32
|
||||
1.0 / np.sqrt(class_counts + 1), dtype=torch.float32
|
||||
).to(device)
|
||||
class_weights_tensor = class_weights_tensor / class_weights_tensor.sum() * 4
|
||||
focal_loss = FocalLoss(class_weight=class_weights_tensor)
|
||||
@@ -256,7 +256,7 @@ def train() -> HeatRiskPredictor:
|
||||
scheduler.step(avg_val_loss)
|
||||
|
||||
# ---- 打印进度 ----
|
||||
if epoch % 5 == 0 or epoch == 1:
|
||||
if True: # print every epoch
|
||||
lr_now = optimizer.param_groups[0]["lr"]
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{MAX_EPOCHS} | "
|
||||
|
||||
Reference in New Issue
Block a user