diff --git a/src/models/train.py b/src/models/train.py index 028bc39..c4a1731 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -31,7 +31,7 @@ from src.utils.config import ( class FocalLoss(nn.Module): """Focal Loss — 聚焦困难样本,缓解类别不平衡""" - def __init__(self, alpha: float = 0.25, gamma: float = 2.0): + def __init__(self, alpha: float = 0.5, gamma: float = 2.0): super().__init__() self.alpha = alpha self.gamma = gamma @@ -247,7 +247,7 @@ def train() -> HeatRiskPredictor: scheduler.step(avg_val_loss) # ---- 打印进度 ---- - if epoch % 10 == 0 or epoch == 1: + if epoch % 5 == 0 or epoch == 1: lr_now = optimizer.param_groups[0]["lr"] print( f"Epoch {epoch:3d}/{MAX_EPOCHS} | " @@ -362,4 +362,13 @@ def train() -> HeatRiskPredictor: if __name__ == "__main__": + import logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.FileHandler("outputs/logs/training.log"), + logging.StreamHandler(), + ], + ) train() diff --git a/src/utils/config.py b/src/utils/config.py index cf7a79b..1b7e4f4 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -34,9 +34,9 @@ ERA5_VARIABLES = [ # 模型配置 LOOKBACK_DAYS = 14 -BATCH_SIZE = 32 +BATCH_SIZE = 16 LEARNING_RATE = 1e-3 -MAX_EPOCHS = 100 +MAX_EPOCHS = 200 EARLY_STOP_PATIENCE = 15 HIDDEN_DIM = 128 LSTM_LAYERS = 2