From 0d178b0d574c72a80fae966a5cc5aede627cca11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Thu, 28 May 2026 09:51:03 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=20=E2=80=94=20batch=5Fsize=2016,=20alpha=200?= =?UTF-8?q?.5,=20max=5Fepochs=20200,=20=E6=97=A5=E5=BF=97=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/train.py | 13 +++++++++++-- src/utils/config.py | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/models/train.py b/src/models/train.py index 028bc39..c4a1731 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -31,7 +31,7 @@ from src.utils.config import ( class FocalLoss(nn.Module): """Focal Loss — 聚焦困难样本,缓解类别不平衡""" - def __init__(self, alpha: float = 0.25, gamma: float = 2.0): + def __init__(self, alpha: float = 0.5, gamma: float = 2.0): super().__init__() self.alpha = alpha self.gamma = gamma @@ -247,7 +247,7 @@ def train() -> HeatRiskPredictor: scheduler.step(avg_val_loss) # ---- 打印进度 ---- - if epoch % 10 == 0 or epoch == 1: + if epoch % 5 == 0 or epoch == 1: lr_now = optimizer.param_groups[0]["lr"] print( f"Epoch {epoch:3d}/{MAX_EPOCHS} | " @@ -362,4 +362,13 @@ def train() -> HeatRiskPredictor: if __name__ == "__main__": + import logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.FileHandler("outputs/logs/training.log"), + logging.StreamHandler(), + ], + ) train() diff --git a/src/utils/config.py b/src/utils/config.py index cf7a79b..1b7e4f4 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -34,9 +34,9 @@ ERA5_VARIABLES = [ # 模型配置 LOOKBACK_DAYS = 14 -BATCH_SIZE = 32 +BATCH_SIZE = 16 LEARNING_RATE = 1e-3 -MAX_EPOCHS = 100 +MAX_EPOCHS = 200 EARLY_STOP_PATIENCE = 15 HIDDEN_DIM = 128 LSTM_LAYERS = 2