diff --git a/强化学习个人项目报告/logs/tensorboard/run_improved_1777558107/events.out.tfevents.1777558107.LHY.48412.0 b/强化学习个人项目报告/logs/tensorboard/run_improved_1777558107/events.out.tfevents.1777558107.LHY.48412.0 deleted file mode 100644 index fbaf3ae..0000000 Binary files a/强化学习个人项目报告/logs/tensorboard/run_improved_1777558107/events.out.tfevents.1777558107.LHY.48412.0 and /dev/null differ diff --git a/强化学习个人项目报告/logs/tensorboard/run_improved_1777564565/events.out.tfevents.1777564565.LHY.35368.0 b/强化学习个人项目报告/logs/tensorboard/run_improved_1777564565/events.out.tfevents.1777564565.LHY.35368.0 deleted file mode 100644 index 7151853..0000000 Binary files a/强化学习个人项目报告/logs/tensorboard/run_improved_1777564565/events.out.tfevents.1777564565.LHY.35368.0 and /dev/null differ diff --git a/强化学习个人项目报告/logs/tensorboard/run_improved_1777564802/events.out.tfevents.1777564802.LHY.39456.0 b/强化学习个人项目报告/logs/tensorboard/run_improved_1777564802/events.out.tfevents.1777564802.LHY.39456.0 deleted file mode 100644 index b7002a1..0000000 Binary files a/强化学习个人项目报告/logs/tensorboard/run_improved_1777564802/events.out.tfevents.1777564802.LHY.39456.0 and /dev/null differ diff --git a/强化学习个人项目报告/logs/tensorboard/run_improved_1777568942/events.out.tfevents.1777568942.LHY.41354.0 b/强化学习个人项目报告/logs/tensorboard/run_improved_1777568942/events.out.tfevents.1777568942.LHY.41354.0 new file mode 100644 index 0000000..6739c81 Binary files /dev/null and b/强化学习个人项目报告/logs/tensorboard/run_improved_1777568942/events.out.tfevents.1777568942.LHY.41354.0 differ diff --git a/强化学习个人项目报告/logs/tensorboard/run_parallel_1777561439/events.out.tfevents.1777561439.LHY.16748.0 b/强化学习个人项目报告/logs/tensorboard/run_parallel_1777561439/events.out.tfevents.1777561439.LHY.16748.0 deleted file mode 100644 index d449f8f..0000000 Binary files a/强化学习个人项目报告/logs/tensorboard/run_parallel_1777561439/events.out.tfevents.1777561439.LHY.16748.0 and /dev/null differ diff --git a/强化学习个人项目报告/logs/tensorboard/run_parallel_1777566661/events.out.tfevents.1777566661.LHY.3190.0 b/强化学习个人项目报告/logs/tensorboard/run_parallel_1777566661/events.out.tfevents.1777566661.LHY.3190.0 deleted file mode 100644 index f7f00ac..0000000 Binary files a/强化学习个人项目报告/logs/tensorboard/run_parallel_1777566661/events.out.tfevents.1777566661.LHY.3190.0 and /dev/null differ diff --git a/强化学习个人项目报告/logs/tensorboard/run_parallel_improved_1777570209/events.out.tfevents.1777570209.LHY.74809.0 b/强化学习个人项目报告/logs/tensorboard/run_parallel_improved_1777570209/events.out.tfevents.1777570209.LHY.74809.0 new file mode 100644 index 0000000..7d2c869 Binary files /dev/null and b/强化学习个人项目报告/logs/tensorboard/run_parallel_improved_1777570209/events.out.tfevents.1777570209.LHY.74809.0 differ diff --git a/强化学习个人项目报告/logs/tensorboard/run_parallel_improved_1777570331/events.out.tfevents.1777570331.LHY.79036.0 b/强化学习个人项目报告/logs/tensorboard/run_parallel_improved_1777570331/events.out.tfevents.1777570331.LHY.79036.0 new file mode 100644 index 0000000..96f2025 Binary files /dev/null and b/强化学习个人项目报告/logs/tensorboard/run_parallel_improved_1777570331/events.out.tfevents.1777570331.LHY.79036.0 differ diff --git a/强化学习个人项目报告/models/ppo_parallel_improved_best.pt b/强化学习个人项目报告/models/ppo_parallel_improved_best.pt new file mode 100644 index 0000000..4bc6273 Binary files /dev/null and b/强化学习个人项目报告/models/ppo_parallel_improved_best.pt differ diff --git a/强化学习个人项目报告/models/ppo_parallel_improved_ep100.pt b/强化学习个人项目报告/models/ppo_parallel_improved_ep100.pt new file mode 100644 index 0000000..7591de4 Binary files /dev/null and b/强化学习个人项目报告/models/ppo_parallel_improved_ep100.pt differ diff --git a/强化学习个人项目报告/models/ppo_parallel_improved_ep50.pt b/强化学习个人项目报告/models/ppo_parallel_improved_ep50.pt new file mode 100644 index 0000000..1c84c0d Binary files /dev/null and b/强化学习个人项目报告/models/ppo_parallel_improved_ep50.pt differ diff --git a/强化学习个人项目报告/train_parallel_improved.py b/强化学习个人项目报告/train_parallel_improved.py new file mode 100644 index 0000000..143c219 --- /dev/null +++ b/强化学习个人项目报告/train_parallel_improved.py @@ -0,0 +1,662 @@ +#!/usr/bin/env python3 +"""Parallel training with reward shaping for CarRacing-v3 PPO.""" +import os +import sys +import time +import argparse +import numpy as np +import torch +import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter +from collections import deque +from multiprocessing import Process, Queue, SimpleQueue, set_start_method +import gymnasium as gym +import cv2 + + +class RewardShapingWrapper(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + self.steps_on_track = 0 + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + self.steps_on_track = 0 + return obs, info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + done = terminated or truncated + + shaped_reward = reward + + if info.get("speed", 0) > 0.1: + shaped_reward += info["speed"] * 0.1 + + if not info.get("offtrack", False): + shaped_reward += 0.1 + self.steps_on_track += 1 + else: + shaped_reward -= 0.5 + self.steps_on_track = 0 + + if info.get("lap_complete", False): + shaped_reward += 100 + + return obs, shaped_reward, terminated, truncated, info + + +class GrayScaleWrapper(gym.ObservationWrapper): + def __init__(self, env): + super().__init__(env) + + def observation(self, obs): + gray = 0.299 * obs[:, :, 0] + 0.587 * obs[:, :, 1] + 0.114 * obs[:, :, 2] + return gray.astype(np.uint8) + + +class ResizeWrapper(gym.ObservationWrapper): + def __init__(self, env, size=(84, 84)): + super().__init__(env) + self.size = size + + def observation(self, obs): + return cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA) + + +class FrameStackWrapper(gym.ObservationWrapper): + def __init__(self, env, num_stack=4): + super().__init__(env) + self.num_stack = num_stack + self.frames = deque(maxlen=num_stack) + obs_shape = env.observation_space.shape + self.observation_space = gym.spaces.Box( + low=0, high=255, shape=(num_stack, *obs_shape[-2:]), dtype=np.uint8 + ) + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + for _ in range(self.num_stack): + self.frames.append(obs) + return self._get_observation(), info + + def observation(self, obs): + self.frames.append(obs) + return self._get_observation() + + def _get_observation(self): + return np.stack(list(self.frames), axis=0) + + +def make_env(env_id="CarRacing-v3", gray_scale=True, resize=True, frame_stack=4): + env = gym.make(env_id, render_mode="rgb_array") + if resize: + env = ResizeWrapper(env, size=(84, 84)) + if gray_scale: + env = GrayScaleWrapper(env) + if frame_stack > 1: + env = FrameStackWrapper(env, num_stack=frame_stack) + env = RewardShapingWrapper(env) + return env + + +def worker_loop(worker_id, action_queue, result_queue): + env = make_env() + obs, _ = env.reset() + obs = np.transpose(obs, (1, 2, 0)) + + while True: + try: + cmd, data = action_queue.get() + + if cmd == 'step': + action = data + next_obs, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + next_obs_t = np.transpose(next_obs, (1, 2, 0)) + + if done: + next_obs_t, _ = env.reset() + next_obs_t = np.transpose(next_obs_t, (1, 2, 0)) + + result_queue.put((worker_id, obs, action, reward, done, next_obs_t)) + obs = next_obs_t + + elif cmd == 'reset': + obs, _ = env.reset() + obs = np.transpose(obs, (1, 2, 0)) + result_queue.put((worker_id, obs)) + + elif cmd == 'close': + env.close() + break + + except Exception: + break + + +class ParallelEnv: + def __init__(self, num_envs=4): + self.num_envs = num_envs + self.action_queues = [] + self.result_queue = SimpleQueue() + self.processes = [] + + for i in range(num_envs): + action_queue = Queue(maxsize=1) + self.action_queues.append(action_queue) + p = Process(target=worker_loop, args=(i, action_queue, self.result_queue)) + p.start() + self.processes.append(p) + + for i in range(num_envs): + self.action_queues[i].put(('reset', None)) + + for _ in range(num_envs): + self.result_queue.get() + + def reset(self): + for i in range(self.num_envs): + self.action_queues[i].put(('reset', None)) + + obs_list = [] + for _ in range(self.num_envs): + worker_id, obs = self.result_queue.get() + obs_list.append(obs) + + return np.array(obs_list) + + def step(self, actions): + for i in range(self.num_envs): + self.action_queues[i].put(('step', actions[i])) + + results = {} + for _ in range(self.num_envs): + item = self.result_queue.get() + results[item[0]] = item[1:] + + obs_list = [] + reward_list = [] + done_list = [] + next_obs_list = [] + + for i in range(self.num_envs): + data = results[i] + obs, action, reward, done = data[:4] + next_obs = data[4] if len(data) > 4 else obs + obs_list.append(obs) + reward_list.append(reward) + done_list.append(done) + next_obs_list.append(next_obs) + + return np.array(next_obs_list), np.array(reward_list), np.array(done_list) + + def close(self): + for i in range(self.num_envs): + try: + self.action_queues[i].put(('close', None)) + except: + pass + + time.sleep(0.5) + for p in self.processes: + if p.is_alive(): + p.terminate() + p.join(timeout=1) + + +def get_device(): + if torch.cuda.is_available(): + device = torch.device("cuda") + print(f"Using GPU: {torch.cuda.get_device_name(0)}") + else: + device = torch.device("cpu") + print("Using CPU") + return device + + +class Actor(nn.Module): + def __init__(self, state_shape=(84, 84, 4), action_dim=3): + super().__init__() + c, h, w = state_shape[2], state_shape[0], state_shape[1] + + self.conv = nn.Sequential( + nn.Conv2d(c, 32, kernel_size=8, stride=4), + nn.LeakyReLU(0.2), + nn.BatchNorm2d(32), + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.LeakyReLU(0.2), + nn.BatchNorm2d(64), + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.LeakyReLU(0.2), + ) + + out_h = (h - 8) // 4 + 1 + out_h = (out_h - 4) // 2 + 1 + out_h = (out_h - 3) // 1 + 1 + feat_size = 64 * out_h * out_h + + self.fc = nn.Sequential( + nn.Linear(feat_size, 512), + nn.LeakyReLU(0.2), + ) + self.mu_head = nn.Linear(512, action_dim) + self.log_std_head = nn.Linear(512, action_dim) + + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + nn.init.orthogonal_(self.mu_head.weight, gain=0.01) + nn.init.orthogonal_(self.log_std_head.weight, gain=0.01) + + def forward(self, x): + x = x / 255.0 + x = self.conv(x) + x = x.flatten(1) + x = self.fc(x) + mu = torch.tanh(self.mu_head(x)) + log_std = torch.clamp(self.log_std_head(x), -20, 2) + return mu, log_std.exp() + + +class Critic(nn.Module): + def __init__(self, state_shape=(84, 84, 4)): + super().__init__() + c, h, w = state_shape[2], state_shape[0], state_shape[1] + + self.conv = nn.Sequential( + nn.Conv2d(c, 32, kernel_size=8, stride=4), + nn.LeakyReLU(0.2), + nn.BatchNorm2d(32), + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.LeakyReLU(0.2), + nn.BatchNorm2d(64), + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.LeakyReLU(0.2), + ) + + out_h = (h - 8) // 4 + 1 + out_h = (out_h - 4) // 2 + 1 + out_h = (out_h - 3) // 1 + 1 + feat_size = 64 * out_h * out_h + + self.fc = nn.Sequential(nn.Linear(feat_size, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1)) + + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = x / 255.0 + x = self.conv(x) + x = x.flatten(1) + return self.fc(x) + + +class RolloutBuffer: + def __init__(self, buffer_size, state_shape, action_dim): + self.buffer_size = buffer_size + self.ptr = 0 + self.size = 0 + + self.states = np.zeros((buffer_size, *state_shape), dtype=np.uint8) + self.actions = np.zeros((buffer_size, action_dim), dtype=np.float32) + self.rewards = np.zeros(buffer_size, dtype=np.float32) + self.dones = np.zeros(buffer_size, dtype=np.bool_) + self.values = np.zeros(buffer_size, dtype=np.float32) + self.log_probs = np.zeros(buffer_size, dtype=np.float32) + + def add(self, state, action, reward, done, value, log_prob): + self.states[self.ptr] = state + self.actions[self.ptr] = action + self.rewards[self.ptr] = reward + self.dones[self.ptr] = done + self.values[self.ptr] = value + self.log_probs[self.ptr] = log_prob + self.ptr = (self.ptr + 1) % self.buffer_size + self.size = min(self.size + 1, self.buffer_size) + + def compute_returns(self, last_value, gamma=0.99, gae_lambda=0.98): + advantages = np.zeros(self.size, dtype=np.float32) + last_gae = 0 + + for t in reversed(range(self.size)): + if t == self.size - 1: + next_value = last_value + else: + next_value = self.values[t + 1] + + delta = self.rewards[t] + gamma * next_value * (1 - self.dones[t]) - self.values[t] + last_gae = delta + gamma * gae_lambda * (1 - self.dones[t]) * last_gae + advantages[t] = last_gae + + returns = advantages + self.values[: self.size] + return returns, advantages + + def get(self): + return ( + self.states[: self.size], + self.actions[: self.size], + self.rewards[: self.size], + self.dones[: self.size], + self.values[: self.size], + self.log_probs[: self.size], + ) + + def reset(self): + self.ptr = 0 + self.size = 0 + + +class PPOTrainer: + def __init__( + self, + actor, + critic, + rollout_buffer, + device, + clip_eps=0.1, + gamma=0.99, + gae_lambda=0.98, + lr=3e-4, + ent_coef=0.005, + vf_coef=0.75, + max_grad_norm=0.5, + ppo_epochs=10, + mini_batch_size=128, + ): + self.actor = actor + self.critic = critic + self.buffer = rollout_buffer + self.device = device + self.clip_eps = clip_eps + self.gamma = gamma + self.gae_lambda = gae_lambda + self.ent_coef = ent_coef + self.vf_coef = vf_coef + self.max_grad_norm = max_grad_norm + self.ppo_epochs = ppo_epochs + self.mini_batch_size = mini_batch_size + + self.actor_optim = torch.optim.Adam(actor.parameters(), lr=lr, eps=1e-5) + self.critic_optim = torch.optim.Adam(critic.parameters(), lr=lr, eps=1e-5) + + def update(self, last_value): + states, actions, rewards, dones, values, log_probs_old = self.buffer.get() + + returns, advantages = self.buffer.compute_returns(last_value, self.gamma, self.gae_lambda) + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + states_t = torch.from_numpy(states).float().permute(0, 3, 1, 2).to(self.device) + actions_t = torch.from_numpy(actions).float().to(self.device) + log_probs_old_t = torch.from_numpy(log_probs_old).float().to(self.device) + returns_t = torch.from_numpy(returns).float().to(self.device) + advantages_t = torch.from_numpy(advantages).float().to(self.device) + + dataset = torch.utils.data.TensorDataset( + states_t, actions_t, log_probs_old_t, returns_t, advantages_t + ) + loader = torch.utils.data.DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=True) + + total_actor_loss = 0 + total_critic_loss = 0 + total_entropy = 0 + count = 0 + + for _ in range(self.ppo_epochs): + for batch in loader: + s, a, log_pi_old, ret, adv = batch + + mu, std = self.actor(s) + dist = torch.distributions.Normal(mu, std) + log_pi = dist.log_prob(a).sum(dim=-1) + entropy = dist.entropy().sum(dim=-1) + + ratio = torch.exp(log_pi - log_pi_old) + surr1 = ratio * adv + surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * adv + actor_loss = -torch.min(surr1, surr2).mean() + + value = self.critic(s) + critic_loss = nn.MSELoss()(value.squeeze(), ret) + + loss = actor_loss + self.vf_coef * critic_loss - self.ent_coef * entropy.mean() + + self.actor_optim.zero_grad() + self.critic_optim.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) + nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) + self.actor_optim.step() + self.critic_optim.step() + + total_actor_loss += actor_loss.item() + total_critic_loss += critic_loss.item() + total_entropy += entropy.mean().item() + count += 1 + + avg_actor = total_actor_loss / count + avg_critic = total_critic_loss / count + avg_entropy = total_entropy / count + + self.buffer.reset() + return avg_actor, avg_critic, avg_entropy + + +def collect_parallel_rollout(actor, critic, parallel_env, buffer, device, rollout_steps, num_envs): + obs_batch = parallel_env.reset() + episode_rewards = [] + current_ep_rewards = [0.0] * num_envs + + steps_per_env = rollout_steps // num_envs + + for step in range(steps_per_env): + obs_t = torch.from_numpy(obs_batch).float().permute(0, 3, 1, 2).to(device) + + with torch.no_grad(): + mu, std = actor(obs_t) + dist = torch.distributions.Normal(mu, std) + action = dist.sample() + action = torch.clamp(action, -1, 1) + log_prob = dist.log_prob(action).sum(dim=-1) + value = critic(obs_t).squeeze(-1) + + action_np = action.cpu().numpy() + log_prob_np = log_prob.cpu().numpy() + value_np = value.cpu().numpy() + + next_obs_batch, reward_batch, done_batch = parallel_env.step(action_np) + + for i in range(num_envs): + buffer.add(obs_batch[i], action_np[i], reward_batch[i], + done_batch[i], value_np[i], log_prob_np[i]) + current_ep_rewards[i] += reward_batch[i] + + if done_batch[i]: + episode_rewards.append(current_ep_rewards[i]) + current_ep_rewards[i] = 0.0 + + obs_batch = next_obs_batch + + return obs_batch, episode_rewards if episode_rewards else [sum(current_ep_rewards) / num_envs] + + +def train( + total_steps=2000000, + num_envs=4, + rollout_steps=4096, + eval_interval=10, + save_interval=50, + device=None, +): + if device is None: + device = get_device() + + print(f"Creating {num_envs} parallel environments with reward shaping...") + parallel_env = ParallelEnv(num_envs=num_envs) + + state_shape = (84, 84, 4) + action_dim = 3 + + actor = Actor(state_shape=state_shape, action_dim=action_dim).to(device) + critic = Critic(state_shape=state_shape).to(device) + + buffer = RolloutBuffer(rollout_steps, state_shape, action_dim) + + trainer = PPOTrainer( + actor=actor, + critic=critic, + rollout_buffer=buffer, + device=device, + clip_eps=0.1, + gamma=0.99, + gae_lambda=0.98, + lr=3e-4, + ent_coef=0.005, + vf_coef=0.75, + max_grad_norm=0.5, + ppo_epochs=10, + mini_batch_size=256, + ) + + log_dir = os.path.join("logs", "tensorboard", f"run_parallel_improved_{int(time.time())}") + writer = SummaryWriter(log_dir) + + print(f"Training on {device} with {num_envs} parallel envs") + print(f"Log directory: {log_dir}") + print("Improvements: Parallel + Reward Shaping + Larger Rollout") + + episode = 0 + total_timesteps = 0 + episode_rewards = [] + best_eval = -float("inf") + + while total_timesteps < total_steps: + obs_batch, batch_rewards = collect_parallel_rollout( + actor, critic, parallel_env, buffer, device, rollout_steps, num_envs + ) + + with torch.no_grad(): + obs_t = torch.from_numpy(obs_batch).float().permute(0, 3, 1, 2).to(device) + last_value = critic(obs_t).mean().item() + + actor_loss, critic_loss, entropy = trainer.update(last_value) + + writer.add_scalar("Loss/Actor", actor_loss, total_timesteps) + writer.add_scalar("Loss/Critic", critic_loss, total_timesteps) + writer.add_scalar("Loss/Entropy", entropy, total_timesteps) + + total_timesteps += rollout_steps + episode += 1 + + episode_rewards.extend(batch_rewards) + recent_rewards = episode_rewards[-50:] if len(episode_rewards) > 50 else episode_rewards + avg_reward = np.mean(recent_rewards) + mean_batch_reward = np.mean(batch_rewards) + + writer.add_scalar("Reward/EpisodeMean", mean_batch_reward, total_timesteps) + writer.add_scalar("Reward/AvgLast50", avg_reward, total_timesteps) + + print(f"Episode {episode}, steps {total_timesteps}, mean_reward={mean_batch_reward:.1f}, avg_50={avg_reward:.1f}") + + if episode % eval_interval == 0: + eval_returns = [] + for _ in range(5): + eval_env = make_env() + eval_obs, _ = eval_env.reset() + eval_obs = np.transpose(eval_obs, (1, 2, 0)) + eval_reward = 0 + done = False + + while not done: + with torch.no_grad(): + eval_obs_t = ( + torch.from_numpy(eval_obs) + .float() + .unsqueeze(0) + .permute(0, 3, 1, 2) + .to(device) + ) + mu, std = actor(eval_obs_t) + action = torch.clamp(mu, -1, 1).squeeze(0).cpu().numpy() + eval_obs, reward, terminated, truncated, _ = eval_env.step(action) + eval_obs = np.transpose(eval_obs, (1, 2, 0)) + eval_reward += reward + done = terminated or truncated + + eval_returns.append(eval_reward) + eval_env.close() + + mean_eval = np.mean(eval_returns) + writer.add_scalar("Eval/MeanReturn", mean_eval, episode) + print(f" Eval: mean_return={mean_eval:.2f}") + + if mean_eval > best_eval: + best_eval = mean_eval + os.makedirs("models", exist_ok=True) + torch.save( + { + "actor": actor.state_dict(), + "critic": critic.state_dict(), + "episode": episode, + "timesteps": total_timesteps, + "best_eval": best_eval, + }, + os.path.join("models", "ppo_parallel_improved_best.pt"), + ) + print(f" New best model saved! eval={best_eval:.2f}") + + if episode % save_interval == 0: + os.makedirs("models", exist_ok=True) + torch.save( + { + "actor": actor.state_dict(), + "critic": critic.state_dict(), + "episode": episode, + "timesteps": total_timesteps, + }, + os.path.join("models", f"ppo_parallel_improved_ep{episode}.pt"), + ) + print(f" Saved model at episode {episode}") + + os.makedirs("models", exist_ok=True) + torch.save( + { + "actor": actor.state_dict(), + "critic": critic.state_dict(), + "episode": episode, + "timesteps": total_timesteps, + "best_eval": best_eval, + }, + os.path.join("models", "ppo_parallel_improved_final.pt"), + ) + + writer.close() + parallel_env.close() + print(f"Training complete! Total episodes: {episode}, Best eval: {best_eval:.2f}") + + +if __name__ == "__main__": + try: + set_start_method('fork') + except RuntimeError: + pass + + parser = argparse.ArgumentParser() + parser.add_argument("--steps", type=int, default=2000000, help="Total training steps") + parser.add_argument("--num_envs", type=int, default=4, help="Number of parallel environments") + parser.add_argument("--rollout", type=int, default=4096, help="Rollout buffer size") + args = parser.parse_args() + + device = get_device() + train( + total_steps=args.steps, + num_envs=args.num_envs, + rollout_steps=args.rollout, + device=device, + )