Files
rl-atari/强化学习个人项目报告/train.py
T
Serendipity d353133b31 feat: 添加强化学习项目报告及重构课程作业报告代码结构
- 新增强化学习个人项目报告,包含基于PyTorch从零实现的PPO算法
- 重构课程作业报告代码结构,提取运行时路径管理和notebook执行逻辑到独立模块
- 更新依赖文件requirements.txt,添加强化学习相关依赖
- 简化模型比较结果表格,仅保留基线逻辑回归模型数据
2026-04-30 16:54:41 +08:00

193 lines
6.3 KiB
Python

"""Main training script for PPO on CarRacing-v3."""
import os
import time
import argparse
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from src.network import Actor, Critic
from src.replay_buffer import RolloutBuffer
from src.trainer import PPOTrainer
from src.utils import make_env, get_device
def collect_rollout(actor, critic, env, buffer, device, rollout_steps):
"""Collect rollout data."""
obs, _ = env.reset()
# Convert to (C, H, W) format for storage
obs = np.transpose(obs, (1, 2, 0))
for step in range(rollout_steps):
with torch.no_grad():
# Convert to (B, C, H, W)
obs_t = torch.from_numpy(obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device)
mu, std = actor(obs_t)
dist = torch.distributions.Normal(mu, std)
action = dist.sample()
action = torch.clamp(action, -1, 1)
log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True)
value = critic(obs_t).squeeze(0).item()
action_np = action.squeeze(0).cpu().numpy()
log_prob_np = log_prob.squeeze(0).cpu().numpy()
next_obs, reward, terminated, truncated, _ = env.step(action_np)
done = terminated or truncated
# Convert next_obs to (C, H, W) for storage
next_obs_stored = np.transpose(next_obs, (1, 2, 0))
buffer.add(obs.copy(), action_np, reward, done, value, log_prob_np)
obs = next_obs_stored
if done:
obs, _ = env.reset()
obs = np.transpose(obs, (1, 2, 0))
def train(
total_steps=500000,
rollout_steps=2048,
eval_interval=10,
save_interval=50,
device=None,
):
"""Main training loop."""
if device is None:
device = get_device()
env = make_env()
eval_env = make_env()
state_shape = (84, 84, 4)
action_dim = 3
actor = Actor(state_shape=state_shape, action_dim=action_dim).to(device)
critic = Critic(state_shape=state_shape).to(device)
buffer = RolloutBuffer(
buffer_size=rollout_steps,
state_shape=state_shape,
action_dim=action_dim,
)
trainer = PPOTrainer(
actor=actor,
critic=critic,
rollout_buffer=buffer,
device=device,
clip_eps=0.2,
gamma=0.99,
gae_lambda=0.95,
lr=3e-4,
ent_coef=0.01,
vf_coef=0.5,
max_grad_norm=0.5,
ppo_epochs=4,
mini_batch_size=64,
)
# TensorBoard
log_dir = os.path.join("logs", "tensorboard", f"run_{int(time.time())}")
writer = SummaryWriter(log_dir)
print(f"Training on {device}")
print(f"Log directory: {log_dir}")
episode = 0
total_timesteps = 0
episode_rewards = []
recent_rewards = []
while total_timesteps < total_steps:
# Collect rollout
collect_rollout(actor, critic, env, buffer, device, rollout_steps)
# Get last value for GAE
with torch.no_grad():
obs_t = torch.from_numpy(obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device)
last_value = critic(obs_t).squeeze(0).item()
# PPO update
actor_loss, critic_loss, entropy = trainer.update(last_value)
# Logging
writer.add_scalar("Loss/Actor", actor_loss, total_timesteps)
writer.add_scalar("Loss/Critic", critic_loss, total_timesteps)
writer.add_scalar("Loss/Entropy", entropy, total_timesteps)
total_timesteps += rollout_steps
episode += 1
# Estimate episode reward from buffer
ep_reward = buffer.rewards[:buffer.size].sum()
episode_rewards.append(ep_reward)
recent_rewards.append(ep_reward)
# Running average of last 10 episodes
avg_reward = np.mean(recent_rewards[-10:]) if len(recent_rewards) >= 10 else np.mean(recent_rewards)
writer.add_scalar("Reward/Episode", ep_reward, total_timesteps)
writer.add_scalar("Reward/AvgLast10", avg_reward, total_timesteps)
print(f"Episode {episode}, steps {total_timesteps}, ep_reward={ep_reward:.1f}, avg_10={avg_reward:.1f}")
# Evaluation
if episode % eval_interval == 0:
eval_returns = []
for _ in range(5):
eval_obs, _ = eval_env.reset()
eval_obs = np.transpose(eval_obs, (1, 2, 0))
eval_reward = 0
done = False
while not done:
with torch.no_grad():
eval_obs_t = torch.from_numpy(eval_obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device)
mu, std = actor(eval_obs_t)
action = torch.clamp(mu, -1, 1).squeeze(0).cpu().numpy()
eval_obs, reward, terminated, truncated, _ = eval_env.step(action)
eval_obs = np.transpose(eval_obs, (1, 2, 0))
eval_reward += reward
done = terminated or truncated
eval_returns.append(eval_reward)
mean_eval = np.mean(eval_returns)
writer.add_scalar("Eval/MeanReturn", mean_eval, episode)
print(f" Eval: mean_return={mean_eval:.2f}")
# Save model
if episode % save_interval == 0:
os.makedirs("models", exist_ok=True)
torch.save({
"actor": actor.state_dict(),
"critic": critic.state_dict(),
"episode": episode,
"timesteps": total_timesteps,
}, os.path.join("models", f"ppo_carracing_ep{episode}.pt"))
print(f" Saved model at episode {episode}")
# Save final model
os.makedirs("models", exist_ok=True)
torch.save({
"actor": actor.state_dict(),
"critic": critic.state_dict(),
"episode": episode,
"timesteps": total_timesteps,
}, os.path.join("models", "ppo_carracing_final.pt"))
writer.close()
print(f"Training complete! Total episodes: {episode}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=500000, help="Total training steps")
parser.add_argument("--rollout", type=int, default=2048, help="Rollout buffer size")
args = parser.parse_args()
device = get_device()
train(total_steps=args.steps, rollout_steps=args.rollout, device=device)