diff --git a/强化学习个人项目报告(Atari 游戏方向)/models/dqn_best.pt b/强化学习个人项目报告(Atari 游戏方向)/models/dqn_best.pt new file mode 100644 index 0000000..eb409d1 Binary files /dev/null and b/强化学习个人项目报告(Atari 游戏方向)/models/dqn_best.pt differ diff --git a/强化学习个人项目报告(Atari 游戏方向)/models/dqn_final.pt b/强化学习个人项目报告(Atari 游戏方向)/models/dqn_final.pt new file mode 100644 index 0000000..fc6c1bb Binary files /dev/null and b/强化学习个人项目报告(Atari 游戏方向)/models/dqn_final.pt differ diff --git a/强化学习个人项目报告(Atari 游戏方向)/models/dqn_step_100000.pt b/强化学习个人项目报告(Atari 游戏方向)/models/dqn_step_100000.pt new file mode 100644 index 0000000..6a37ce0 Binary files /dev/null and b/强化学习个人项目报告(Atari 游戏方向)/models/dqn_step_100000.pt differ diff --git a/强化学习个人项目报告(Atari 游戏方向)/models/dqn_step_50000.pt b/强化学习个人项目报告(Atari 游戏方向)/models/dqn_step_50000.pt new file mode 100644 index 0000000..2a8a6b4 Binary files /dev/null and b/强化学习个人项目报告(Atari 游戏方向)/models/dqn_step_50000.pt differ diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/agent.py b/强化学习个人项目报告(Atari 游戏方向)/src/agent.py index 0ea23e6..a276a76 100644 --- a/强化学习个人项目报告(Atari 游戏方向)/src/agent.py +++ b/强化学习个人项目报告(Atari 游戏方向)/src/agent.py @@ -63,7 +63,7 @@ class DQNAgent: self.optimizer = torch.optim.Adam(q_network.parameters(), lr=lr) # 训练步数 - self.train_step = 0 + self.step_count = 0 # 训练统计 self.loss_history = [] @@ -98,9 +98,9 @@ class DQNAgent: def update_epsilon(self): """更新ε值(线性衰减)""" - if self.train_step < self.epsilon_decay_steps: + if self.step_count < self.epsilon_decay_steps: self.epsilon = self.epsilon_start - (self.epsilon_start - self.epsilon_end) * \ - (self.train_step / self.epsilon_decay_steps) + (self.step_count / self.epsilon_decay_steps) else: self.epsilon = self.epsilon_end @@ -147,8 +147,8 @@ class DQNAgent: self.optimizer.step() # 更新目标网络 - self.train_step += 1 - if self.train_step % self.target_update_freq == 0: + self.step_count += 1 + if self.step_count % self.target_update_freq == 0: self.target_network.load_state_dict(self.q_network.state_dict()) # 更新ε @@ -168,7 +168,7 @@ class DQNAgent: 'q_network': self.q_network.state_dict(), 'target_network': self.target_network.state_dict(), 'optimizer': self.optimizer.state_dict(), - 'train_step': self.train_step, + 'step_count': self.step_count, 'epsilon': self.epsilon, }, path) print(f"模型已保存到: {path}") @@ -179,6 +179,6 @@ class DQNAgent: self.q_network.load_state_dict(checkpoint['q_network']) self.target_network.load_state_dict(checkpoint['target_network']) self.optimizer.load_state_dict(checkpoint['optimizer']) - self.train_step = checkpoint['train_step'] + self.step_count = checkpoint['step_count'] self.epsilon = checkpoint['epsilon'] - print(f"模型已从 {path} 加载") \ No newline at end of file + print(f"模型已从 {path} 加载") diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/evaluate.py b/强化学习个人项目报告(Atari 游戏方向)/src/evaluate.py index 2ec8357..1ff01b1 100644 --- a/强化学习个人项目报告(Atari 游戏方向)/src/evaluate.py +++ b/强化学习个人项目报告(Atari 游戏方向)/src/evaluate.py @@ -132,4 +132,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/network.py b/强化学习个人项目报告(Atari 游戏方向)/src/network.py index 2ad6b33..f59ec28 100644 --- a/强化学习个人项目报告(Atari 游戏方向)/src/network.py +++ b/强化学习个人项目报告(Atari 游戏方向)/src/network.py @@ -147,4 +147,4 @@ class DuelingQNetwork(nn.Module): # Q = V(s) + A(s,a) - mean(A(s,a)) q_values = value + advantage - advantage.mean(dim=1, keepdim=True) - return q_values \ No newline at end of file + return q_values diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/replay_buffer.py b/强化学习个人项目报告(Atari 游戏方向)/src/replay_buffer.py index 8f28f6a..3d03dc0 100644 --- a/强化学习个人项目报告(Atari 游戏方向)/src/replay_buffer.py +++ b/强化学习个人项目报告(Atari 游戏方向)/src/replay_buffer.py @@ -159,4 +159,4 @@ class PrioritizedReplayBuffer: self.max_priority = max(self.max_priority, priorities.max()) def __len__(self): - return self.size \ No newline at end of file + return self.size diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/trainer.py b/强化学习个人项目报告(Atari 游戏方向)/src/trainer.py index 80bcac5..4704695 100644 --- a/强化学习个人项目报告(Atari 游戏方向)/src/trainer.py +++ b/强化学习个人项目报告(Atari 游戏方向)/src/trainer.py @@ -163,4 +163,4 @@ class DQNTrainer: rewards.append(episode_reward) - return np.mean(rewards) \ No newline at end of file + return np.mean(rewards) diff --git a/强化学习个人项目报告(Atari 游戏方向)/src/utils.py b/强化学习个人项目报告(Atari 游戏方向)/src/utils.py index d1dc9d4..bb79ee6 100644 --- a/强化学习个人项目报告(Atari 游戏方向)/src/utils.py +++ b/强化学习个人项目报告(Atari 游戏方向)/src/utils.py @@ -197,4 +197,4 @@ def preprocess_obs(obs): """确保观测格式正确""" if len(obs.shape) == 2: obs = np.expand_dims(obs, axis=0) - return obs \ No newline at end of file + return obs