refactor(agent): 重命名 train_step 为 step_count 以提高可读性

- 将 agent.py 中的 train_step 变量重命名为 step_count,使其含义更清晰
- 更新所有相关引用,包括 epsilon 衰减和目标网络更新逻辑
- 同步修改模型保存和加载时的键名
- 修复多个源文件末尾的换行符问题
This commit is contained in:
2026-05-01 10:19:14 +08:00
parent e8b51240f9
commit 85f912483a
10 changed files with 13 additions and 13 deletions
@@ -63,7 +63,7 @@ class DQNAgent:
self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr) self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr)
# 训练步数 # 训练步数
self.train_step = 0 self.step_count = 0
# 训练统计 # 训练统计
self.loss_history = [] self.loss_history = []
@@ -98,9 +98,9 @@ class DQNAgent:
def update_epsilon(self): def update_epsilon(self):
"""更新ε值(线性衰减)""" """更新ε值(线性衰减)"""
if self.train_step < self.epsilon_decay_steps: if self.step_count < self.epsilon_decay_steps:
self.epsilon = self.epsilon_start - (self.epsilon_start - self.epsilon_end) * \ self.epsilon = self.epsilon_start - (self.epsilon_start - self.epsilon_end) * \
(self.train_step / self.epsilon_decay_steps) (self.step_count / self.epsilon_decay_steps)
else: else:
self.epsilon = self.epsilon_end self.epsilon = self.epsilon_end
@@ -147,8 +147,8 @@ class DQNAgent:
self.optimizer.step() self.optimizer.step()
# 更新目标网络 # 更新目标网络
self.train_step += 1 self.step_count += 1
if self.train_step % self.target_update_freq == 0: if self.step_count % self.target_update_freq == 0:
self.target_network.load_state_dict(self.q_network.state_dict()) self.target_network.load_state_dict(self.q_network.state_dict())
# 更新ε # 更新ε
@@ -168,7 +168,7 @@ class DQNAgent:
'q_network': self.q_network.state_dict(), 'q_network': self.q_network.state_dict(),
'target_network': self.target_network.state_dict(), 'target_network': self.target_network.state_dict(),
'optimizer': self.optimizer.state_dict(), 'optimizer': self.optimizer.state_dict(),
'train_step': self.train_step, 'step_count': self.step_count,
'epsilon': self.epsilon, 'epsilon': self.epsilon,
}, path) }, path)
print(f"模型已保存到: {path}") print(f"模型已保存到: {path}")
@@ -179,6 +179,6 @@ class DQNAgent:
self.q_network.load_state_dict(checkpoint['q_network']) self.q_network.load_state_dict(checkpoint['q_network'])
self.target_network.load_state_dict(checkpoint['target_network']) self.target_network.load_state_dict(checkpoint['target_network'])
self.optimizer.load_state_dict(checkpoint['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer'])
self.train_step = checkpoint['train_step'] self.step_count = checkpoint['step_count']
self.epsilon = checkpoint['epsilon'] self.epsilon = checkpoint['epsilon']
print(f"模型已从 {path} 加载") print(f"模型已从 {path} 加载")