refactor(agent): 重命名 train_step 为 step_count 以提高可读性
- 将 agent.py 中的 train_step 变量重命名为 step_count,使其含义更清晰 - 更新所有相关引用,包括 epsilon 衰减和目标网络更新逻辑 - 同步修改模型保存和加载时的键名 - 修复多个源文件末尾的换行符问题
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -63,7 +63,7 @@ class DQNAgent:
|
|||||||
self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr)
|
self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr)
|
||||||
|
|
||||||
# 训练步数
|
# 训练步数
|
||||||
self.train_step = 0
|
self.step_count = 0
|
||||||
|
|
||||||
# 训练统计
|
# 训练统计
|
||||||
self.loss_history = []
|
self.loss_history = []
|
||||||
@@ -98,9 +98,9 @@ class DQNAgent:
|
|||||||
|
|
||||||
def update_epsilon(self):
|
def update_epsilon(self):
|
||||||
"""更新ε值(线性衰减)"""
|
"""更新ε值(线性衰减)"""
|
||||||
if self.train_step < self.epsilon_decay_steps:
|
if self.step_count < self.epsilon_decay_steps:
|
||||||
self.epsilon = self.epsilon_start - (self.epsilon_start - self.epsilon_end) * \
|
self.epsilon = self.epsilon_start - (self.epsilon_start - self.epsilon_end) * \
|
||||||
(self.train_step / self.epsilon_decay_steps)
|
(self.step_count / self.epsilon_decay_steps)
|
||||||
else:
|
else:
|
||||||
self.epsilon = self.epsilon_end
|
self.epsilon = self.epsilon_end
|
||||||
|
|
||||||
@@ -147,8 +147,8 @@ class DQNAgent:
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
# 更新目标网络
|
# 更新目标网络
|
||||||
self.train_step += 1
|
self.step_count += 1
|
||||||
if self.train_step % self.target_update_freq == 0:
|
if self.step_count % self.target_update_freq == 0:
|
||||||
self.target_network.load_state_dict(self.q_network.state_dict())
|
self.target_network.load_state_dict(self.q_network.state_dict())
|
||||||
|
|
||||||
# 更新ε
|
# 更新ε
|
||||||
@@ -168,7 +168,7 @@ class DQNAgent:
|
|||||||
'q_network': self.q_network.state_dict(),
|
'q_network': self.q_network.state_dict(),
|
||||||
'target_network': self.target_network.state_dict(),
|
'target_network': self.target_network.state_dict(),
|
||||||
'optimizer': self.optimizer.state_dict(),
|
'optimizer': self.optimizer.state_dict(),
|
||||||
'train_step': self.train_step,
|
'step_count': self.step_count,
|
||||||
'epsilon': self.epsilon,
|
'epsilon': self.epsilon,
|
||||||
}, path)
|
}, path)
|
||||||
print(f"模型已保存到: {path}")
|
print(f"模型已保存到: {path}")
|
||||||
@@ -179,6 +179,6 @@ class DQNAgent:
|
|||||||
self.q_network.load_state_dict(checkpoint['q_network'])
|
self.q_network.load_state_dict(checkpoint['q_network'])
|
||||||
self.target_network.load_state_dict(checkpoint['target_network'])
|
self.target_network.load_state_dict(checkpoint['target_network'])
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
self.train_step = checkpoint['train_step']
|
self.step_count = checkpoint['step_count']
|
||||||
self.epsilon = checkpoint['epsilon']
|
self.epsilon = checkpoint['epsilon']
|
||||||
print(f"模型已从 {path} 加载")
|
print(f"模型已从 {path} 加载")
|
||||||
|
|||||||
@@ -132,4 +132,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -147,4 +147,4 @@ class DuelingQNetwork(nn.Module):
|
|||||||
# Q = V(s) + A(s,a) - mean(A(s,a))
|
# Q = V(s) + A(s,a) - mean(A(s,a))
|
||||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
return q_values
|
return q_values
|
||||||
|
|||||||
@@ -159,4 +159,4 @@ class PrioritizedReplayBuffer:
|
|||||||
self.max_priority = max(self.max_priority, priorities.max())
|
self.max_priority = max(self.max_priority, priorities.max())
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.size
|
return self.size
|
||||||
|
|||||||
@@ -163,4 +163,4 @@ class DQNTrainer:
|
|||||||
|
|
||||||
rewards.append(episode_reward)
|
rewards.append(episode_reward)
|
||||||
|
|
||||||
return np.mean(rewards)
|
return np.mean(rewards)
|
||||||
|
|||||||
@@ -197,4 +197,4 @@ def preprocess_obs(obs):
|
|||||||
"""确保观测格式正确"""
|
"""确保观测格式正确"""
|
||||||
if len(obs.shape) == 2:
|
if len(obs.shape) == 2:
|
||||||
obs = np.expand_dims(obs, axis=0)
|
obs = np.expand_dims(obs, axis=0)
|
||||||
return obs
|
return obs
|
||||||
|
|||||||
Reference in New Issue
Block a user