From 1a6f5b07aab65fac42714b5bd9734aeb9e315e01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Thu, 28 May 2026 11:37:43 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20sqrt=20=E9=80=86=E9=A2=91=E7=8E=87?= =?UTF-8?q?=E6=9D=83=E9=87=8D=20+=20=E6=AF=8F=20epoch=20=E6=89=93=E5=8D=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/models/train.py b/src/models/train.py index 18f9744..b2adfc0 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -168,10 +168,10 @@ def train() -> HeatRiskPredictor: print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}") # -------------------- 损失、优化器、调度器 -------------------- - # 基于训练集类别分布计算权重 + # 基于训练集类别分布计算权重 (sqrt inverse freq, 温和加权) class_counts = np.bincount(y_train_np[:, 0]) class_weights_tensor = torch.tensor( - 1.0 / (class_counts + 1), dtype=torch.float32 + 1.0 / np.sqrt(class_counts + 1), dtype=torch.float32 ).to(device) class_weights_tensor = class_weights_tensor / class_weights_tensor.sum() * 4 focal_loss = FocalLoss(class_weight=class_weights_tensor) @@ -256,7 +256,7 @@ def train() -> HeatRiskPredictor: scheduler.step(avg_val_loss) # ---- 打印进度 ---- - if epoch % 5 == 0 or epoch == 1: + if True: # print every epoch lr_now = optimizer.param_groups[0]["lr"] print( f"Epoch {epoch:3d}/{MAX_EPOCHS} | "