"""Main training script for PPO on CarRacing-v3.""" import os import time import argparse import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from src.network import Actor, Critic from src.replay_buffer import RolloutBuffer from src.trainer import PPOTrainer from src.utils import make_env, get_device def collect_rollout(actor, critic, env, buffer, device, rollout_steps): """Collect rollout data.""" obs, _ = env.reset() # Convert to (C, H, W) format for storage obs = np.transpose(obs, (1, 2, 0)) for step in range(rollout_steps): with torch.no_grad(): # Convert to (B, C, H, W) obs_t = torch.from_numpy(obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device) mu, std = actor(obs_t) dist = torch.distributions.Normal(mu, std) action = dist.sample() action = torch.clamp(action, -1, 1) log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) value = critic(obs_t).squeeze(0).item() action_np = action.squeeze(0).cpu().numpy() log_prob_np = log_prob.squeeze(0).cpu().numpy().sum() next_obs, reward, terminated, truncated, _ = env.step(action_np) done = terminated or truncated # Convert next_obs to (C, H, W) for storage next_obs_stored = np.transpose(next_obs, (1, 2, 0)) buffer.add(obs.copy(), action_np, reward, done, value, log_prob_np) obs = next_obs_stored if done: obs, _ = env.reset() obs = np.transpose(obs, (1, 2, 0)) return obs def train( total_steps=500000, rollout_steps=2048, eval_interval=10, save_interval=50, device=None, ): """Main training loop.""" if device is None: device = get_device() env = make_env() eval_env = make_env() state_shape = (84, 84, 4) action_dim = 3 actor = Actor(state_shape=state_shape, action_dim=action_dim).to(device) critic = Critic(state_shape=state_shape).to(device) buffer = RolloutBuffer( buffer_size=rollout_steps, state_shape=state_shape, action_dim=action_dim, ) trainer = PPOTrainer( actor=actor, critic=critic, rollout_buffer=buffer, device=device, clip_eps=0.2, gamma=0.99, gae_lambda=0.95, lr=3e-4, ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, ppo_epochs=4, mini_batch_size=64, ) # TensorBoard log_dir = os.path.join("logs", "tensorboard", f"run_{int(time.time())}") writer = SummaryWriter(log_dir) print(f"Training on {device}") print(f"Log directory: {log_dir}") episode = 0 total_timesteps = 0 episode_rewards = [] recent_rewards = [] while total_timesteps < total_steps: obs = collect_rollout(actor, critic, env, buffer, device, rollout_steps) with torch.no_grad(): obs_t = torch.from_numpy(obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device) last_value = critic(obs_t).squeeze(0).item() # PPO update actor_loss, critic_loss, entropy = trainer.update(last_value) # Logging writer.add_scalar("Loss/Actor", actor_loss, total_timesteps) writer.add_scalar("Loss/Critic", critic_loss, total_timesteps) writer.add_scalar("Loss/Entropy", entropy, total_timesteps) total_timesteps += rollout_steps episode += 1 # Estimate episode reward from buffer ep_reward = buffer.rewards[:buffer.size].sum() episode_rewards.append(ep_reward) recent_rewards.append(ep_reward) # Running average of last 10 episodes avg_reward = np.mean(recent_rewards[-10:]) if len(recent_rewards) >= 10 else np.mean(recent_rewards) writer.add_scalar("Reward/Episode", ep_reward, total_timesteps) writer.add_scalar("Reward/AvgLast10", avg_reward, total_timesteps) print(f"Episode {episode}, steps {total_timesteps}, ep_reward={ep_reward:.1f}, avg_10={avg_reward:.1f}") # Evaluation if episode % eval_interval == 0: eval_returns = [] for _ in range(5): eval_obs, _ = eval_env.reset() eval_obs = np.transpose(eval_obs, (1, 2, 0)) eval_reward = 0 done = False while not done: with torch.no_grad(): eval_obs_t = torch.from_numpy(eval_obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device) mu, std = actor(eval_obs_t) action = torch.clamp(mu, -1, 1).squeeze(0).cpu().numpy() eval_obs, reward, terminated, truncated, _ = eval_env.step(action) eval_obs = np.transpose(eval_obs, (1, 2, 0)) eval_reward += reward done = terminated or truncated eval_returns.append(eval_reward) mean_eval = np.mean(eval_returns) writer.add_scalar("Eval/MeanReturn", mean_eval, episode) print(f" Eval: mean_return={mean_eval:.2f}") # Save model if episode % save_interval == 0: os.makedirs("models", exist_ok=True) torch.save({ "actor": actor.state_dict(), "critic": critic.state_dict(), "episode": episode, "timesteps": total_timesteps, }, os.path.join("models", f"ppo_carracing_ep{episode}.pt")) print(f" Saved model at episode {episode}") # Save final model os.makedirs("models", exist_ok=True) torch.save({ "actor": actor.state_dict(), "critic": critic.state_dict(), "episode": episode, "timesteps": total_timesteps, }, os.path.join("models", "ppo_carracing_final.pt")) writer.close() print(f"Training complete! Total episodes: {episode}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--steps", type=int, default=500000, help="Total training steps") parser.add_argument("--rollout", type=int, default=2048, help="Rollout buffer size") args = parser.parse_args() device = get_device() train(total_steps=args.steps, rollout_steps=args.rollout, device=device) def main(): parser = argparse.ArgumentParser() parser.add_argument("--steps", type=int, default=500000, help="Total training steps") parser.add_argument("--rollout", type=int, default=2048, help="Rollout buffer size") args = parser.parse_args() device = get_device() train(total_steps=args.steps, rollout_steps=args.rollout, device=device)