feat: 添加并行训练脚本和奖励塑形以改进PPO性能
引入并行环境训练脚本 train_parallel_improved.py,实现多进程并行数据收集 添加奖励塑形包装器,根据速度、赛道位置和完成圈数调整奖励信号 优化神经网络结构和训练参数,包括更大的rollout缓冲区 删除旧的tensorboard日志文件,创建新的训练运行记录
This commit is contained in:
@@ -0,0 +1,662 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Parallel training with reward shaping for CarRacing-v3 PPO."""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from collections import deque
|
||||
from multiprocessing import Process, Queue, SimpleQueue, set_start_method
|
||||
import gymnasium as gym
|
||||
import cv2
|
||||
|
||||
|
||||
class RewardShapingWrapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.steps_on_track = 0
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
self.steps_on_track = 0
|
||||
return obs, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
shaped_reward = reward
|
||||
|
||||
if info.get("speed", 0) > 0.1:
|
||||
shaped_reward += info["speed"] * 0.1
|
||||
|
||||
if not info.get("offtrack", False):
|
||||
shaped_reward += 0.1
|
||||
self.steps_on_track += 1
|
||||
else:
|
||||
shaped_reward -= 0.5
|
||||
self.steps_on_track = 0
|
||||
|
||||
if info.get("lap_complete", False):
|
||||
shaped_reward += 100
|
||||
|
||||
return obs, shaped_reward, terminated, truncated, info
|
||||
|
||||
|
||||
class GrayScaleWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, obs):
|
||||
gray = 0.299 * obs[:, :, 0] + 0.587 * obs[:, :, 1] + 0.114 * obs[:, :, 2]
|
||||
return gray.astype(np.uint8)
|
||||
|
||||
|
||||
class ResizeWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, size=(84, 84)):
|
||||
super().__init__(env)
|
||||
self.size = size
|
||||
|
||||
def observation(self, obs):
|
||||
return cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
|
||||
class FrameStackWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, num_stack=4):
|
||||
super().__init__(env)
|
||||
self.num_stack = num_stack
|
||||
self.frames = deque(maxlen=num_stack)
|
||||
obs_shape = env.observation_space.shape
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=(num_stack, *obs_shape[-2:]), dtype=np.uint8
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
for _ in range(self.num_stack):
|
||||
self.frames.append(obs)
|
||||
return self._get_observation(), info
|
||||
|
||||
def observation(self, obs):
|
||||
self.frames.append(obs)
|
||||
return self._get_observation()
|
||||
|
||||
def _get_observation(self):
|
||||
return np.stack(list(self.frames), axis=0)
|
||||
|
||||
|
||||
def make_env(env_id="CarRacing-v3", gray_scale=True, resize=True, frame_stack=4):
|
||||
env = gym.make(env_id, render_mode="rgb_array")
|
||||
if resize:
|
||||
env = ResizeWrapper(env, size=(84, 84))
|
||||
if gray_scale:
|
||||
env = GrayScaleWrapper(env)
|
||||
if frame_stack > 1:
|
||||
env = FrameStackWrapper(env, num_stack=frame_stack)
|
||||
env = RewardShapingWrapper(env)
|
||||
return env
|
||||
|
||||
|
||||
def worker_loop(worker_id, action_queue, result_queue):
|
||||
env = make_env()
|
||||
obs, _ = env.reset()
|
||||
obs = np.transpose(obs, (1, 2, 0))
|
||||
|
||||
while True:
|
||||
try:
|
||||
cmd, data = action_queue.get()
|
||||
|
||||
if cmd == 'step':
|
||||
action = data
|
||||
next_obs, reward, terminated, truncated, _ = env.step(action)
|
||||
done = terminated or truncated
|
||||
next_obs_t = np.transpose(next_obs, (1, 2, 0))
|
||||
|
||||
if done:
|
||||
next_obs_t, _ = env.reset()
|
||||
next_obs_t = np.transpose(next_obs_t, (1, 2, 0))
|
||||
|
||||
result_queue.put((worker_id, obs, action, reward, done, next_obs_t))
|
||||
obs = next_obs_t
|
||||
|
||||
elif cmd == 'reset':
|
||||
obs, _ = env.reset()
|
||||
obs = np.transpose(obs, (1, 2, 0))
|
||||
result_queue.put((worker_id, obs))
|
||||
|
||||
elif cmd == 'close':
|
||||
env.close()
|
||||
break
|
||||
|
||||
except Exception:
|
||||
break
|
||||
|
||||
|
||||
class ParallelEnv:
|
||||
def __init__(self, num_envs=4):
|
||||
self.num_envs = num_envs
|
||||
self.action_queues = []
|
||||
self.result_queue = SimpleQueue()
|
||||
self.processes = []
|
||||
|
||||
for i in range(num_envs):
|
||||
action_queue = Queue(maxsize=1)
|
||||
self.action_queues.append(action_queue)
|
||||
p = Process(target=worker_loop, args=(i, action_queue, self.result_queue))
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
for i in range(num_envs):
|
||||
self.action_queues[i].put(('reset', None))
|
||||
|
||||
for _ in range(num_envs):
|
||||
self.result_queue.get()
|
||||
|
||||
def reset(self):
|
||||
for i in range(self.num_envs):
|
||||
self.action_queues[i].put(('reset', None))
|
||||
|
||||
obs_list = []
|
||||
for _ in range(self.num_envs):
|
||||
worker_id, obs = self.result_queue.get()
|
||||
obs_list.append(obs)
|
||||
|
||||
return np.array(obs_list)
|
||||
|
||||
def step(self, actions):
|
||||
for i in range(self.num_envs):
|
||||
self.action_queues[i].put(('step', actions[i]))
|
||||
|
||||
results = {}
|
||||
for _ in range(self.num_envs):
|
||||
item = self.result_queue.get()
|
||||
results[item[0]] = item[1:]
|
||||
|
||||
obs_list = []
|
||||
reward_list = []
|
||||
done_list = []
|
||||
next_obs_list = []
|
||||
|
||||
for i in range(self.num_envs):
|
||||
data = results[i]
|
||||
obs, action, reward, done = data[:4]
|
||||
next_obs = data[4] if len(data) > 4 else obs
|
||||
obs_list.append(obs)
|
||||
reward_list.append(reward)
|
||||
done_list.append(done)
|
||||
next_obs_list.append(next_obs)
|
||||
|
||||
return np.array(next_obs_list), np.array(reward_list), np.array(done_list)
|
||||
|
||||
def close(self):
|
||||
for i in range(self.num_envs):
|
||||
try:
|
||||
self.action_queues[i].put(('close', None))
|
||||
except:
|
||||
pass
|
||||
|
||||
time.sleep(0.5)
|
||||
for p in self.processes:
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=1)
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("Using CPU")
|
||||
return device
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
def __init__(self, state_shape=(84, 84, 4), action_dim=3):
|
||||
super().__init__()
|
||||
c, h, w = state_shape[2], state_shape[0], state_shape[1]
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
|
||||
out_h = (h - 8) // 4 + 1
|
||||
out_h = (out_h - 4) // 2 + 1
|
||||
out_h = (out_h - 3) // 1 + 1
|
||||
feat_size = 64 * out_h * out_h
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_size, 512),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
self.mu_head = nn.Linear(512, action_dim)
|
||||
self.log_std_head = nn.Linear(512, action_dim)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
nn.init.orthogonal_(self.mu_head.weight, gain=0.01)
|
||||
nn.init.orthogonal_(self.log_std_head.weight, gain=0.01)
|
||||
|
||||
def forward(self, x):
|
||||
x = x / 255.0
|
||||
x = self.conv(x)
|
||||
x = x.flatten(1)
|
||||
x = self.fc(x)
|
||||
mu = torch.tanh(self.mu_head(x))
|
||||
log_std = torch.clamp(self.log_std_head(x), -20, 2)
|
||||
return mu, log_std.exp()
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self, state_shape=(84, 84, 4)):
|
||||
super().__init__()
|
||||
c, h, w = state_shape[2], state_shape[0], state_shape[1]
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
|
||||
out_h = (h - 8) // 4 + 1
|
||||
out_h = (out_h - 4) // 2 + 1
|
||||
out_h = (out_h - 3) // 1 + 1
|
||||
feat_size = 64 * out_h * out_h
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(feat_size, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1))
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = x / 255.0
|
||||
x = self.conv(x)
|
||||
x = x.flatten(1)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class RolloutBuffer:
|
||||
def __init__(self, buffer_size, state_shape, action_dim):
|
||||
self.buffer_size = buffer_size
|
||||
self.ptr = 0
|
||||
self.size = 0
|
||||
|
||||
self.states = np.zeros((buffer_size, *state_shape), dtype=np.uint8)
|
||||
self.actions = np.zeros((buffer_size, action_dim), dtype=np.float32)
|
||||
self.rewards = np.zeros(buffer_size, dtype=np.float32)
|
||||
self.dones = np.zeros(buffer_size, dtype=np.bool_)
|
||||
self.values = np.zeros(buffer_size, dtype=np.float32)
|
||||
self.log_probs = np.zeros(buffer_size, dtype=np.float32)
|
||||
|
||||
def add(self, state, action, reward, done, value, log_prob):
|
||||
self.states[self.ptr] = state
|
||||
self.actions[self.ptr] = action
|
||||
self.rewards[self.ptr] = reward
|
||||
self.dones[self.ptr] = done
|
||||
self.values[self.ptr] = value
|
||||
self.log_probs[self.ptr] = log_prob
|
||||
self.ptr = (self.ptr + 1) % self.buffer_size
|
||||
self.size = min(self.size + 1, self.buffer_size)
|
||||
|
||||
def compute_returns(self, last_value, gamma=0.99, gae_lambda=0.98):
|
||||
advantages = np.zeros(self.size, dtype=np.float32)
|
||||
last_gae = 0
|
||||
|
||||
for t in reversed(range(self.size)):
|
||||
if t == self.size - 1:
|
||||
next_value = last_value
|
||||
else:
|
||||
next_value = self.values[t + 1]
|
||||
|
||||
delta = self.rewards[t] + gamma * next_value * (1 - self.dones[t]) - self.values[t]
|
||||
last_gae = delta + gamma * gae_lambda * (1 - self.dones[t]) * last_gae
|
||||
advantages[t] = last_gae
|
||||
|
||||
returns = advantages + self.values[: self.size]
|
||||
return returns, advantages
|
||||
|
||||
def get(self):
|
||||
return (
|
||||
self.states[: self.size],
|
||||
self.actions[: self.size],
|
||||
self.rewards[: self.size],
|
||||
self.dones[: self.size],
|
||||
self.values[: self.size],
|
||||
self.log_probs[: self.size],
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.ptr = 0
|
||||
self.size = 0
|
||||
|
||||
|
||||
class PPOTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
actor,
|
||||
critic,
|
||||
rollout_buffer,
|
||||
device,
|
||||
clip_eps=0.1,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.98,
|
||||
lr=3e-4,
|
||||
ent_coef=0.005,
|
||||
vf_coef=0.75,
|
||||
max_grad_norm=0.5,
|
||||
ppo_epochs=10,
|
||||
mini_batch_size=128,
|
||||
):
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.buffer = rollout_buffer
|
||||
self.device = device
|
||||
self.clip_eps = clip_eps
|
||||
self.gamma = gamma
|
||||
self.gae_lambda = gae_lambda
|
||||
self.ent_coef = ent_coef
|
||||
self.vf_coef = vf_coef
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.ppo_epochs = ppo_epochs
|
||||
self.mini_batch_size = mini_batch_size
|
||||
|
||||
self.actor_optim = torch.optim.Adam(actor.parameters(), lr=lr, eps=1e-5)
|
||||
self.critic_optim = torch.optim.Adam(critic.parameters(), lr=lr, eps=1e-5)
|
||||
|
||||
def update(self, last_value):
|
||||
states, actions, rewards, dones, values, log_probs_old = self.buffer.get()
|
||||
|
||||
returns, advantages = self.buffer.compute_returns(last_value, self.gamma, self.gae_lambda)
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
states_t = torch.from_numpy(states).float().permute(0, 3, 1, 2).to(self.device)
|
||||
actions_t = torch.from_numpy(actions).float().to(self.device)
|
||||
log_probs_old_t = torch.from_numpy(log_probs_old).float().to(self.device)
|
||||
returns_t = torch.from_numpy(returns).float().to(self.device)
|
||||
advantages_t = torch.from_numpy(advantages).float().to(self.device)
|
||||
|
||||
dataset = torch.utils.data.TensorDataset(
|
||||
states_t, actions_t, log_probs_old_t, returns_t, advantages_t
|
||||
)
|
||||
loader = torch.utils.data.DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=True)
|
||||
|
||||
total_actor_loss = 0
|
||||
total_critic_loss = 0
|
||||
total_entropy = 0
|
||||
count = 0
|
||||
|
||||
for _ in range(self.ppo_epochs):
|
||||
for batch in loader:
|
||||
s, a, log_pi_old, ret, adv = batch
|
||||
|
||||
mu, std = self.actor(s)
|
||||
dist = torch.distributions.Normal(mu, std)
|
||||
log_pi = dist.log_prob(a).sum(dim=-1)
|
||||
entropy = dist.entropy().sum(dim=-1)
|
||||
|
||||
ratio = torch.exp(log_pi - log_pi_old)
|
||||
surr1 = ratio * adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * adv
|
||||
actor_loss = -torch.min(surr1, surr2).mean()
|
||||
|
||||
value = self.critic(s)
|
||||
critic_loss = nn.MSELoss()(value.squeeze(), ret)
|
||||
|
||||
loss = actor_loss + self.vf_coef * critic_loss - self.ent_coef * entropy.mean()
|
||||
|
||||
self.actor_optim.zero_grad()
|
||||
self.critic_optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
||||
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
||||
self.actor_optim.step()
|
||||
self.critic_optim.step()
|
||||
|
||||
total_actor_loss += actor_loss.item()
|
||||
total_critic_loss += critic_loss.item()
|
||||
total_entropy += entropy.mean().item()
|
||||
count += 1
|
||||
|
||||
avg_actor = total_actor_loss / count
|
||||
avg_critic = total_critic_loss / count
|
||||
avg_entropy = total_entropy / count
|
||||
|
||||
self.buffer.reset()
|
||||
return avg_actor, avg_critic, avg_entropy
|
||||
|
||||
|
||||
def collect_parallel_rollout(actor, critic, parallel_env, buffer, device, rollout_steps, num_envs):
|
||||
obs_batch = parallel_env.reset()
|
||||
episode_rewards = []
|
||||
current_ep_rewards = [0.0] * num_envs
|
||||
|
||||
steps_per_env = rollout_steps // num_envs
|
||||
|
||||
for step in range(steps_per_env):
|
||||
obs_t = torch.from_numpy(obs_batch).float().permute(0, 3, 1, 2).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
mu, std = actor(obs_t)
|
||||
dist = torch.distributions.Normal(mu, std)
|
||||
action = dist.sample()
|
||||
action = torch.clamp(action, -1, 1)
|
||||
log_prob = dist.log_prob(action).sum(dim=-1)
|
||||
value = critic(obs_t).squeeze(-1)
|
||||
|
||||
action_np = action.cpu().numpy()
|
||||
log_prob_np = log_prob.cpu().numpy()
|
||||
value_np = value.cpu().numpy()
|
||||
|
||||
next_obs_batch, reward_batch, done_batch = parallel_env.step(action_np)
|
||||
|
||||
for i in range(num_envs):
|
||||
buffer.add(obs_batch[i], action_np[i], reward_batch[i],
|
||||
done_batch[i], value_np[i], log_prob_np[i])
|
||||
current_ep_rewards[i] += reward_batch[i]
|
||||
|
||||
if done_batch[i]:
|
||||
episode_rewards.append(current_ep_rewards[i])
|
||||
current_ep_rewards[i] = 0.0
|
||||
|
||||
obs_batch = next_obs_batch
|
||||
|
||||
return obs_batch, episode_rewards if episode_rewards else [sum(current_ep_rewards) / num_envs]
|
||||
|
||||
|
||||
def train(
|
||||
total_steps=2000000,
|
||||
num_envs=4,
|
||||
rollout_steps=4096,
|
||||
eval_interval=10,
|
||||
save_interval=50,
|
||||
device=None,
|
||||
):
|
||||
if device is None:
|
||||
device = get_device()
|
||||
|
||||
print(f"Creating {num_envs} parallel environments with reward shaping...")
|
||||
parallel_env = ParallelEnv(num_envs=num_envs)
|
||||
|
||||
state_shape = (84, 84, 4)
|
||||
action_dim = 3
|
||||
|
||||
actor = Actor(state_shape=state_shape, action_dim=action_dim).to(device)
|
||||
critic = Critic(state_shape=state_shape).to(device)
|
||||
|
||||
buffer = RolloutBuffer(rollout_steps, state_shape, action_dim)
|
||||
|
||||
trainer = PPOTrainer(
|
||||
actor=actor,
|
||||
critic=critic,
|
||||
rollout_buffer=buffer,
|
||||
device=device,
|
||||
clip_eps=0.1,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.98,
|
||||
lr=3e-4,
|
||||
ent_coef=0.005,
|
||||
vf_coef=0.75,
|
||||
max_grad_norm=0.5,
|
||||
ppo_epochs=10,
|
||||
mini_batch_size=256,
|
||||
)
|
||||
|
||||
log_dir = os.path.join("logs", "tensorboard", f"run_parallel_improved_{int(time.time())}")
|
||||
writer = SummaryWriter(log_dir)
|
||||
|
||||
print(f"Training on {device} with {num_envs} parallel envs")
|
||||
print(f"Log directory: {log_dir}")
|
||||
print("Improvements: Parallel + Reward Shaping + Larger Rollout")
|
||||
|
||||
episode = 0
|
||||
total_timesteps = 0
|
||||
episode_rewards = []
|
||||
best_eval = -float("inf")
|
||||
|
||||
while total_timesteps < total_steps:
|
||||
obs_batch, batch_rewards = collect_parallel_rollout(
|
||||
actor, critic, parallel_env, buffer, device, rollout_steps, num_envs
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs_batch).float().permute(0, 3, 1, 2).to(device)
|
||||
last_value = critic(obs_t).mean().item()
|
||||
|
||||
actor_loss, critic_loss, entropy = trainer.update(last_value)
|
||||
|
||||
writer.add_scalar("Loss/Actor", actor_loss, total_timesteps)
|
||||
writer.add_scalar("Loss/Critic", critic_loss, total_timesteps)
|
||||
writer.add_scalar("Loss/Entropy", entropy, total_timesteps)
|
||||
|
||||
total_timesteps += rollout_steps
|
||||
episode += 1
|
||||
|
||||
episode_rewards.extend(batch_rewards)
|
||||
recent_rewards = episode_rewards[-50:] if len(episode_rewards) > 50 else episode_rewards
|
||||
avg_reward = np.mean(recent_rewards)
|
||||
mean_batch_reward = np.mean(batch_rewards)
|
||||
|
||||
writer.add_scalar("Reward/EpisodeMean", mean_batch_reward, total_timesteps)
|
||||
writer.add_scalar("Reward/AvgLast50", avg_reward, total_timesteps)
|
||||
|
||||
print(f"Episode {episode}, steps {total_timesteps}, mean_reward={mean_batch_reward:.1f}, avg_50={avg_reward:.1f}")
|
||||
|
||||
if episode % eval_interval == 0:
|
||||
eval_returns = []
|
||||
for _ in range(5):
|
||||
eval_env = make_env()
|
||||
eval_obs, _ = eval_env.reset()
|
||||
eval_obs = np.transpose(eval_obs, (1, 2, 0))
|
||||
eval_reward = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
with torch.no_grad():
|
||||
eval_obs_t = (
|
||||
torch.from_numpy(eval_obs)
|
||||
.float()
|
||||
.unsqueeze(0)
|
||||
.permute(0, 3, 1, 2)
|
||||
.to(device)
|
||||
)
|
||||
mu, std = actor(eval_obs_t)
|
||||
action = torch.clamp(mu, -1, 1).squeeze(0).cpu().numpy()
|
||||
eval_obs, reward, terminated, truncated, _ = eval_env.step(action)
|
||||
eval_obs = np.transpose(eval_obs, (1, 2, 0))
|
||||
eval_reward += reward
|
||||
done = terminated or truncated
|
||||
|
||||
eval_returns.append(eval_reward)
|
||||
eval_env.close()
|
||||
|
||||
mean_eval = np.mean(eval_returns)
|
||||
writer.add_scalar("Eval/MeanReturn", mean_eval, episode)
|
||||
print(f" Eval: mean_return={mean_eval:.2f}")
|
||||
|
||||
if mean_eval > best_eval:
|
||||
best_eval = mean_eval
|
||||
os.makedirs("models", exist_ok=True)
|
||||
torch.save(
|
||||
{
|
||||
"actor": actor.state_dict(),
|
||||
"critic": critic.state_dict(),
|
||||
"episode": episode,
|
||||
"timesteps": total_timesteps,
|
||||
"best_eval": best_eval,
|
||||
},
|
||||
os.path.join("models", "ppo_parallel_improved_best.pt"),
|
||||
)
|
||||
print(f" New best model saved! eval={best_eval:.2f}")
|
||||
|
||||
if episode % save_interval == 0:
|
||||
os.makedirs("models", exist_ok=True)
|
||||
torch.save(
|
||||
{
|
||||
"actor": actor.state_dict(),
|
||||
"critic": critic.state_dict(),
|
||||
"episode": episode,
|
||||
"timesteps": total_timesteps,
|
||||
},
|
||||
os.path.join("models", f"ppo_parallel_improved_ep{episode}.pt"),
|
||||
)
|
||||
print(f" Saved model at episode {episode}")
|
||||
|
||||
os.makedirs("models", exist_ok=True)
|
||||
torch.save(
|
||||
{
|
||||
"actor": actor.state_dict(),
|
||||
"critic": critic.state_dict(),
|
||||
"episode": episode,
|
||||
"timesteps": total_timesteps,
|
||||
"best_eval": best_eval,
|
||||
},
|
||||
os.path.join("models", "ppo_parallel_improved_final.pt"),
|
||||
)
|
||||
|
||||
writer.close()
|
||||
parallel_env.close()
|
||||
print(f"Training complete! Total episodes: {episode}, Best eval: {best_eval:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
set_start_method('fork')
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--steps", type=int, default=2000000, help="Total training steps")
|
||||
parser.add_argument("--num_envs", type=int, default=4, help="Number of parallel environments")
|
||||
parser.add_argument("--rollout", type=int, default=4096, help="Rollout buffer size")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = get_device()
|
||||
train(
|
||||
total_steps=args.steps,
|
||||
num_envs=args.num_envs,
|
||||
rollout_steps=args.rollout,
|
||||
device=device,
|
||||
)
|
||||
Reference in New Issue
Block a user