perf: 优化训练配置 — batch_size 16, alpha 0.5, max_epochs 200, 日志文件输出
This commit is contained in:
+11
-2
@@ -31,7 +31,7 @@ from src.utils.config import (
|
||||
class FocalLoss(nn.Module):
|
||||
"""Focal Loss — 聚焦困难样本,缓解类别不平衡"""
|
||||
|
||||
def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
|
||||
def __init__(self, alpha: float = 0.5, gamma: float = 2.0):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
@@ -247,7 +247,7 @@ def train() -> HeatRiskPredictor:
|
||||
scheduler.step(avg_val_loss)
|
||||
|
||||
# ---- 打印进度 ----
|
||||
if epoch % 10 == 0 or epoch == 1:
|
||||
if epoch % 5 == 0 or epoch == 1:
|
||||
lr_now = optimizer.param_groups[0]["lr"]
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{MAX_EPOCHS} | "
|
||||
@@ -362,4 +362,13 @@ def train() -> HeatRiskPredictor:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler("outputs/logs/training.log"),
|
||||
logging.StreamHandler(),
|
||||
],
|
||||
)
|
||||
train()
|
||||
|
||||
+2
-2
@@ -34,9 +34,9 @@ ERA5_VARIABLES = [
|
||||
|
||||
# 模型配置
|
||||
LOOKBACK_DAYS = 14
|
||||
BATCH_SIZE = 32
|
||||
BATCH_SIZE = 16
|
||||
LEARNING_RATE = 1e-3
|
||||
MAX_EPOCHS = 100
|
||||
MAX_EPOCHS = 200
|
||||
EARLY_STOP_PATIENCE = 15
|
||||
HIDDEN_DIM = 128
|
||||
LSTM_LAYERS = 2
|
||||
|
||||
Reference in New Issue
Block a user