perf: 优化训练配置 — batch_size 16, alpha 0.5, max_epochs 200, 日志文件输出
This commit is contained in:
+11
-2
@@ -31,7 +31,7 @@ from src.utils.config import (
|
|||||||
class FocalLoss(nn.Module):
|
class FocalLoss(nn.Module):
|
||||||
"""Focal Loss — 聚焦困难样本,缓解类别不平衡"""
|
"""Focal Loss — 聚焦困难样本,缓解类别不平衡"""
|
||||||
|
|
||||||
def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
|
def __init__(self, alpha: float = 0.5, gamma: float = 2.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
@@ -247,7 +247,7 @@ def train() -> HeatRiskPredictor:
|
|||||||
scheduler.step(avg_val_loss)
|
scheduler.step(avg_val_loss)
|
||||||
|
|
||||||
# ---- 打印进度 ----
|
# ---- 打印进度 ----
|
||||||
if epoch % 10 == 0 or epoch == 1:
|
if epoch % 5 == 0 or epoch == 1:
|
||||||
lr_now = optimizer.param_groups[0]["lr"]
|
lr_now = optimizer.param_groups[0]["lr"]
|
||||||
print(
|
print(
|
||||||
f"Epoch {epoch:3d}/{MAX_EPOCHS} | "
|
f"Epoch {epoch:3d}/{MAX_EPOCHS} | "
|
||||||
@@ -362,4 +362,13 @@ def train() -> HeatRiskPredictor:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("outputs/logs/training.log"),
|
||||||
|
logging.StreamHandler(),
|
||||||
|
],
|
||||||
|
)
|
||||||
train()
|
train()
|
||||||
|
|||||||
+2
-2
@@ -34,9 +34,9 @@ ERA5_VARIABLES = [
|
|||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
LOOKBACK_DAYS = 14
|
LOOKBACK_DAYS = 14
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 16
|
||||||
LEARNING_RATE = 1e-3
|
LEARNING_RATE = 1e-3
|
||||||
MAX_EPOCHS = 100
|
MAX_EPOCHS = 200
|
||||||
EARLY_STOP_PATIENCE = 15
|
EARLY_STOP_PATIENCE = 15
|
||||||
HIDDEN_DIM = 128
|
HIDDEN_DIM = 128
|
||||||
LSTM_LAYERS = 2
|
LSTM_LAYERS = 2
|
||||||
|
|||||||
Reference in New Issue
Block a user