feat: 添加并行训练脚本和奖励塑形以改进PPO性能
引入并行环境训练脚本 train_parallel_improved.py,实现多进程并行数据收集 添加奖励塑形包装器,根据速度、赛道位置和完成圈数调整奖励信号 优化神经网络结构和训练参数,包括更大的rollout缓冲区 删除旧的tensorboard日志文件,创建新的训练运行记录
This commit is contained in:
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,662 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Parallel training with reward shaping for CarRacing-v3 PPO."""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from collections import deque
|
||||
from multiprocessing import Process, Queue, SimpleQueue, set_start_method
|
||||
import gymnasium as gym
|
||||
import cv2
|
||||
|
||||
|
||||
class RewardShapingWrapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.steps_on_track = 0
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
self.steps_on_track = 0
|
||||
return obs, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
shaped_reward = reward
|
||||
|
||||
if info.get("speed", 0) > 0.1:
|
||||
shaped_reward += info["speed"] * 0.1
|
||||
|
||||
if not info.get("offtrack", False):
|
||||
shaped_reward += 0.1
|
||||
self.steps_on_track += 1
|
||||
else:
|
||||
shaped_reward -= 0.5
|
||||
self.steps_on_track = 0
|
||||
|
||||
if info.get("lap_complete", False):
|
||||
shaped_reward += 100
|
||||
|
||||
return obs, shaped_reward, terminated, truncated, info
|
||||
|
||||
|
||||
class GrayScaleWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, obs):
|
||||
gray = 0.299 * obs[:, :, 0] + 0.587 * obs[:, :, 1] + 0.114 * obs[:, :, 2]
|
||||
return gray.astype(np.uint8)
|
||||
|
||||
|
||||
class ResizeWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, size=(84, 84)):
|
||||
super().__init__(env)
|
||||
self.size = size
|
||||
|
||||
def observation(self, obs):
|
||||
return cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
|
||||
class FrameStackWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, num_stack=4):
|
||||
super().__init__(env)
|
||||
self.num_stack = num_stack
|
||||
self.frames = deque(maxlen=num_stack)
|
||||
obs_shape = env.observation_space.shape
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=(num_stack, *obs_shape[-2:]), dtype=np.uint8
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
for _ in range(self.num_stack):
|
||||
self.frames.append(obs)
|
||||
return self._get_observation(), info
|
||||
|
||||
def observation(self, obs):
|
||||
self.frames.append(obs)
|
||||
return self._get_observation()
|
||||
|
||||
def _get_observation(self):
|
||||
return np.stack(list(self.frames), axis=0)
|
||||
|
||||
|
||||
def make_env(env_id="CarRacing-v3", gray_scale=True, resize=True, frame_stack=4):
|
||||
env = gym.make(env_id, render_mode="rgb_array")
|
||||
if resize:
|
||||
env = ResizeWrapper(env, size=(84, 84))
|
||||
if gray_scale:
|
||||
env = GrayScaleWrapper(env)
|
||||
if frame_stack > 1:
|
||||
env = FrameStackWrapper(env, num_stack=frame_stack)
|
||||
env = RewardShapingWrapper(env)
|
||||
return env
|
||||
|
||||
|
||||
def worker_loop(worker_id, action_queue, result_queue):
|
||||
env = make_env()
|
||||
obs, _ = env.reset()
|
||||
obs = np.transpose(obs, (1, 2, 0))
|
||||
|
||||
while True:
|
||||
try:
|
||||
cmd, data = action_queue.get()
|
||||
|
||||
if cmd == 'step':
|
||||
action = data
|
||||
next_obs, reward, terminated, truncated, _ = env.step(action)
|
||||
done = terminated or truncated
|
||||
next_obs_t = np.transpose(next_obs, (1, 2, 0))
|
||||
|
||||
if done:
|
||||
next_obs_t, _ = env.reset()
|
||||
next_obs_t = np.transpose(next_obs_t, (1, 2, 0))
|
||||
|
||||
result_queue.put((worker_id, obs, action, reward, done, next_obs_t))
|
||||
obs = next_obs_t
|
||||
|
||||
elif cmd == 'reset':
|
||||
obs, _ = env.reset()
|
||||
obs = np.transpose(obs, (1, 2, 0))
|
||||
result_queue.put((worker_id, obs))
|
||||
|
||||
elif cmd == 'close':
|
||||
env.close()
|
||||
break
|
||||
|
||||
except Exception:
|
||||
break
|
||||
|
||||
|
||||
class ParallelEnv:
|
||||
def __init__(self, num_envs=4):
|
||||
self.num_envs = num_envs
|
||||
self.action_queues = []
|
||||
self.result_queue = SimpleQueue()
|
||||
self.processes = []
|
||||
|
||||
for i in range(num_envs):
|
||||
action_queue = Queue(maxsize=1)
|
||||
self.action_queues.append(action_queue)
|
||||
p = Process(target=worker_loop, args=(i, action_queue, self.result_queue))
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
for i in range(num_envs):
|
||||
self.action_queues[i].put(('reset', None))
|
||||
|
||||
for _ in range(num_envs):
|
||||
self.result_queue.get()
|
||||
|
||||
def reset(self):
|
||||
for i in range(self.num_envs):
|
||||
self.action_queues[i].put(('reset', None))
|
||||
|
||||
obs_list = []
|
||||
for _ in range(self.num_envs):
|
||||
worker_id, obs = self.result_queue.get()
|
||||
obs_list.append(obs)
|
||||
|
||||
return np.array(obs_list)
|
||||
|
||||
def step(self, actions):
|
||||
for i in range(self.num_envs):
|
||||
self.action_queues[i].put(('step', actions[i]))
|
||||
|
||||
results = {}
|
||||
for _ in range(self.num_envs):
|
||||
item = self.result_queue.get()
|
||||
results[item[0]] = item[1:]
|
||||
|
||||
obs_list = []
|
||||
reward_list = []
|
||||
done_list = []
|
||||
next_obs_list = []
|
||||
|
||||
for i in range(self.num_envs):
|
||||
data = results[i]
|
||||
obs, action, reward, done = data[:4]
|
||||
next_obs = data[4] if len(data) > 4 else obs
|
||||
obs_list.append(obs)
|
||||
reward_list.append(reward)
|
||||
done_list.append(done)
|
||||
next_obs_list.append(next_obs)
|
||||
|
||||
return np.array(next_obs_list), np.array(reward_list), np.array(done_list)
|
||||
|
||||
def close(self):
|
||||
for i in range(self.num_envs):
|
||||
try:
|
||||
self.action_queues[i].put(('close', None))
|
||||
except:
|
||||
pass
|
||||
|
||||
time.sleep(0.5)
|
||||
for p in self.processes:
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=1)
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("Using CPU")
|
||||
return device
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
def __init__(self, state_shape=(84, 84, 4), action_dim=3):
|
||||
super().__init__()
|
||||
c, h, w = state_shape[2], state_shape[0], state_shape[1]
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
|
||||
out_h = (h - 8) // 4 + 1
|
||||
out_h = (out_h - 4) // 2 + 1
|
||||
out_h = (out_h - 3) // 1 + 1
|
||||
feat_size = 64 * out_h * out_h
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_size, 512),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
self.mu_head = nn.Linear(512, action_dim)
|
||||
self.log_std_head = nn.Linear(512, action_dim)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
nn.init.orthogonal_(self.mu_head.weight, gain=0.01)
|
||||
nn.init.orthogonal_(self.log_std_head.weight, gain=0.01)
|
||||
|
||||
def forward(self, x):
|
||||
x = x / 255.0
|
||||
x = self.conv(x)
|
||||
x = x.flatten(1)
|
||||
x = self.fc(x)
|
||||
mu = torch.tanh(self.mu_head(x))
|
||||
log_std = torch.clamp(self.log_std_head(x), -20, 2)
|
||||
return mu, log_std.exp()
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self, state_shape=(84, 84, 4)):
|
||||
super().__init__()
|
||||
c, h, w = state_shape[2], state_shape[0], state_shape[1]
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
|
||||
out_h = (h - 8) // 4 + 1
|
||||
out_h = (out_h - 4) // 2 + 1
|
||||
out_h = (out_h - 3) // 1 + 1
|
||||
feat_size = 64 * out_h * out_h
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(feat_size, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1))
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = x / 255.0
|
||||
x = self.conv(x)
|
||||
x = x.flatten(1)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class RolloutBuffer:
|
||||
def __init__(self, buffer_size, state_shape, action_dim):
|
||||
self.buffer_size = buffer_size
|
||||
self.ptr = 0
|
||||
self.size = 0
|
||||
|
||||
self.states = np.zeros((buffer_size, *state_shape), dtype=np.uint8)
|
||||
self.actions = np.zeros((buffer_size, action_dim), dtype=np.float32)
|
||||
self.rewards = np.zeros(buffer_size, dtype=np.float32)
|
||||
self.dones = np.zeros(buffer_size, dtype=np.bool_)
|
||||
self.values = np.zeros(buffer_size, dtype=np.float32)
|
||||
self.log_probs = np.zeros(buffer_size, dtype=np.float32)
|
||||
|
||||
def add(self, state, action, reward, done, value, log_prob):
|
||||
self.states[self.ptr] = state
|
||||
self.actions[self.ptr] = action
|
||||
self.rewards[self.ptr] = reward
|
||||
self.dones[self.ptr] = done
|
||||
self.values[self.ptr] = value
|
||||
self.log_probs[self.ptr] = log_prob
|
||||
self.ptr = (self.ptr + 1) % self.buffer_size
|
||||
self.size = min(self.size + 1, self.buffer_size)
|
||||
|
||||
def compute_returns(self, last_value, gamma=0.99, gae_lambda=0.98):
|
||||
advantages = np.zeros(self.size, dtype=np.float32)
|
||||
last_gae = 0
|
||||
|
||||
for t in reversed(range(self.size)):
|
||||
if t == self.size - 1:
|
||||
next_value = last_value
|
||||
else:
|
||||
next_value = self.values[t + 1]
|
||||
|
||||
delta = self.rewards[t] + gamma * next_value * (1 - self.dones[t]) - self.values[t]
|
||||
last_gae = delta + gamma * gae_lambda * (1 - self.dones[t]) * last_gae
|
||||
advantages[t] = last_gae
|
||||
|
||||
returns = advantages + self.values[: self.size]
|
||||
return returns, advantages
|
||||
|
||||
def get(self):
|
||||
return (
|
||||
self.states[: self.size],
|
||||
self.actions[: self.size],
|
||||
self.rewards[: self.size],
|
||||
self.dones[: self.size],
|
||||
self.values[: self.size],
|
||||
self.log_probs[: self.size],
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.ptr = 0
|
||||
self.size = 0
|
||||
|
||||
|
||||
class PPOTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
actor,
|
||||
critic,
|
||||
rollout_buffer,
|
||||
device,
|
||||
clip_eps=0.1,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.98,
|
||||
lr=3e-4,
|
||||
ent_coef=0.005,
|
||||
vf_coef=0.75,
|
||||
max_grad_norm=0.5,
|
||||
ppo_epochs=10,
|
||||
mini_batch_size=128,
|
||||
):
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.buffer = rollout_buffer
|
||||
self.device = device
|
||||
self.clip_eps = clip_eps
|
||||
self.gamma = gamma
|
||||
self.gae_lambda = gae_lambda
|
||||
self.ent_coef = ent_coef
|
||||
self.vf_coef = vf_coef
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.ppo_epochs = ppo_epochs
|
||||
self.mini_batch_size = mini_batch_size
|
||||
|
||||
self.actor_optim = torch.optim.Adam(actor.parameters(), lr=lr, eps=1e-5)
|
||||
self.critic_optim = torch.optim.Adam(critic.parameters(), lr=lr, eps=1e-5)
|
||||
|
||||
def update(self, last_value):
|
||||
states, actions, rewards, dones, values, log_probs_old = self.buffer.get()
|
||||
|
||||
returns, advantages = self.buffer.compute_returns(last_value, self.gamma, self.gae_lambda)
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
states_t = torch.from_numpy(states).float().permute(0, 3, 1, 2).to(self.device)
|
||||
actions_t = torch.from_numpy(actions).float().to(self.device)
|
||||
log_probs_old_t = torch.from_numpy(log_probs_old).float().to(self.device)
|
||||
returns_t = torch.from_numpy(returns).float().to(self.device)
|
||||
advantages_t = torch.from_numpy(advantages).float().to(self.device)
|
||||
|
||||
dataset = torch.utils.data.TensorDataset(
|
||||
states_t, actions_t, log_probs_old_t, returns_t, advantages_t
|
||||
)
|
||||
loader = torch.utils.data.DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=True)
|
||||
|
||||
total_actor_loss = 0
|
||||
total_critic_loss = 0
|
||||
total_entropy = 0
|
||||
count = 0
|
||||
|
||||
for _ in range(self.ppo_epochs):
|
||||
for batch in loader:
|
||||
s, a, log_pi_old, ret, adv = batch
|
||||
|
||||
mu, std = self.actor(s)
|
||||
dist = torch.distributions.Normal(mu, std)
|
||||
log_pi = dist.log_prob(a).sum(dim=-1)
|
||||
entropy = dist.entropy().sum(dim=-1)
|
||||
|
||||
ratio = torch.exp(log_pi - log_pi_old)
|
||||
surr1 = ratio * adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * adv
|
||||
actor_loss = -torch.min(surr1, surr2).mean()
|
||||
|
||||
value = self.critic(s)
|
||||
critic_loss = nn.MSELoss()(value.squeeze(), ret)
|
||||
|
||||
loss = actor_loss + self.vf_coef * critic_loss - self.ent_coef * entropy.mean()
|
||||
|
||||
self.actor_optim.zero_grad()
|
||||
self.critic_optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
||||
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
||||
self.actor_optim.step()
|
||||
self.critic_optim.step()
|
||||
|
||||
total_actor_loss += actor_loss.item()
|
||||
total_critic_loss += critic_loss.item()
|
||||
total_entropy += entropy.mean().item()
|
||||
count += 1
|
||||
|
||||
avg_actor = total_actor_loss / count
|
||||
avg_critic = total_critic_loss / count
|
||||
avg_entropy = total_entropy / count
|
||||
|
||||
self.buffer.reset()
|
||||
return avg_actor, avg_critic, avg_entropy
|
||||
|
||||
|
||||
def collect_parallel_rollout(actor, critic, parallel_env, buffer, device, rollout_steps, num_envs):
|
||||
obs_batch = parallel_env.reset()
|
||||
episode_rewards = []
|
||||
current_ep_rewards = [0.0] * num_envs
|
||||
|
||||
steps_per_env = rollout_steps // num_envs
|
||||
|
||||
for step in range(steps_per_env):
|
||||
obs_t = torch.from_numpy(obs_batch).float().permute(0, 3, 1, 2).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
mu, std = actor(obs_t)
|
||||
dist = torch.distributions.Normal(mu, std)
|
||||
action = dist.sample()
|
||||
action = torch.clamp(action, -1, 1)
|
||||
log_prob = dist.log_prob(action).sum(dim=-1)
|
||||
value = critic(obs_t).squeeze(-1)
|
||||
|
||||
action_np = action.cpu().numpy()
|
||||
log_prob_np = log_prob.cpu().numpy()
|
||||
value_np = value.cpu().numpy()
|
||||
|
||||
next_obs_batch, reward_batch, done_batch = parallel_env.step(action_np)
|
||||
|
||||
for i in range(num_envs):
|
||||
buffer.add(obs_batch[i], action_np[i], reward_batch[i],
|
||||
done_batch[i], value_np[i], log_prob_np[i])
|
||||
current_ep_rewards[i] += reward_batch[i]
|
||||
|
||||
if done_batch[i]:
|
||||
episode_rewards.append(current_ep_rewards[i])
|
||||
current_ep_rewards[i] = 0.0
|
||||
|
||||
obs_batch = next_obs_batch
|
||||
|
||||
return obs_batch, episode_rewards if episode_rewards else [sum(current_ep_rewards) / num_envs]
|
||||
|
||||
|
||||
def train(
|
||||
total_steps=2000000,
|
||||
num_envs=4,
|
||||
rollout_steps=4096,
|
||||
eval_interval=10,
|
||||
save_interval=50,
|
||||
device=None,
|
||||
):
|
||||
if device is None:
|
||||
device = get_device()
|
||||
|
||||
print(f"Creating {num_envs} parallel environments with reward shaping...")
|
||||
parallel_env = ParallelEnv(num_envs=num_envs)
|
||||
|
||||
state_shape = (84, 84, 4)
|
||||
action_dim = 3
|
||||
|
||||
actor = Actor(state_shape=state_shape, action_dim=action_dim).to(device)
|
||||
critic = Critic(state_shape=state_shape).to(device)
|
||||
|
||||
buffer = RolloutBuffer(rollout_steps, state_shape, action_dim)
|
||||
|
||||
trainer = PPOTrainer(
|
||||
actor=actor,
|
||||
critic=critic,
|
||||
rollout_buffer=buffer,
|
||||
device=device,
|
||||
clip_eps=0.1,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.98,
|
||||
lr=3e-4,
|
||||
ent_coef=0.005,
|
||||
vf_coef=0.75,
|
||||
max_grad_norm=0.5,
|
||||
ppo_epochs=10,
|
||||
mini_batch_size=256,
|
||||
)
|
||||
|
||||
log_dir = os.path.join("logs", "tensorboard", f"run_parallel_improved_{int(time.time())}")
|
||||
writer = SummaryWriter(log_dir)
|
||||
|
||||
print(f"Training on {device} with {num_envs} parallel envs")
|
||||
print(f"Log directory: {log_dir}")
|
||||
print("Improvements: Parallel + Reward Shaping + Larger Rollout")
|
||||
|
||||
episode = 0
|
||||
total_timesteps = 0
|
||||
episode_rewards = []
|
||||
best_eval = -float("inf")
|
||||
|
||||
while total_timesteps < total_steps:
|
||||
obs_batch, batch_rewards = collect_parallel_rollout(
|
||||
actor, critic, parallel_env, buffer, device, rollout_steps, num_envs
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs_batch).float().permute(0, 3, 1, 2).to(device)
|
||||
last_value = critic(obs_t).mean().item()
|
||||
|
||||
actor_loss, critic_loss, entropy = trainer.update(last_value)
|
||||
|
||||
writer.add_scalar("Loss/Actor", actor_loss, total_timesteps)
|
||||
writer.add_scalar("Loss/Critic", critic_loss, total_timesteps)
|
||||
writer.add_scalar("Loss/Entropy", entropy, total_timesteps)
|
||||
|
||||
total_timesteps += rollout_steps
|
||||
episode += 1
|
||||
|
||||
episode_rewards.extend(batch_rewards)
|
||||
recent_rewards = episode_rewards[-50:] if len(episode_rewards) > 50 else episode_rewards
|
||||
avg_reward = np.mean(recent_rewards)
|
||||
mean_batch_reward = np.mean(batch_rewards)
|
||||
|
||||
writer.add_scalar("Reward/EpisodeMean", mean_batch_reward, total_timesteps)
|
||||
writer.add_scalar("Reward/AvgLast50", avg_reward, total_timesteps)
|
||||
|
||||
print(f"Episode {episode}, steps {total_timesteps}, mean_reward={mean_batch_reward:.1f}, avg_50={avg_reward:.1f}")
|
||||
|
||||
if episode % eval_interval == 0:
|
||||
eval_returns = []
|
||||
for _ in range(5):
|
||||
eval_env = make_env()
|
||||
eval_obs, _ = eval_env.reset()
|
||||
eval_obs = np.transpose(eval_obs, (1, 2, 0))
|
||||
eval_reward = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
with torch.no_grad():
|
||||
eval_obs_t = (
|
||||
torch.from_numpy(eval_obs)
|
||||
.float()
|
||||
.unsqueeze(0)
|
||||
.permute(0, 3, 1, 2)
|
||||
.to(device)
|
||||
)
|
||||
mu, std = actor(eval_obs_t)
|
||||
action = torch.clamp(mu, -1, 1).squeeze(0).cpu().numpy()
|
||||
eval_obs, reward, terminated, truncated, _ = eval_env.step(action)
|
||||
eval_obs = np.transpose(eval_obs, (1, 2, 0))
|
||||
eval_reward += reward
|
||||
done = terminated or truncated
|
||||
|
||||
eval_returns.append(eval_reward)
|
||||
eval_env.close()
|
||||
|
||||
mean_eval = np.mean(eval_returns)
|
||||
writer.add_scalar("Eval/MeanReturn", mean_eval, episode)
|
||||
print(f" Eval: mean_return={mean_eval:.2f}")
|
||||
|
||||
if mean_eval > best_eval:
|
||||
best_eval = mean_eval
|
||||
os.makedirs("models", exist_ok=True)
|
||||
torch.save(
|
||||
{
|
||||
"actor": actor.state_dict(),
|
||||
"critic": critic.state_dict(),
|
||||
"episode": episode,
|
||||
"timesteps": total_timesteps,
|
||||
"best_eval": best_eval,
|
||||
},
|
||||
os.path.join("models", "ppo_parallel_improved_best.pt"),
|
||||
)
|
||||
print(f" New best model saved! eval={best_eval:.2f}")
|
||||
|
||||
if episode % save_interval == 0:
|
||||
os.makedirs("models", exist_ok=True)
|
||||
torch.save(
|
||||
{
|
||||
"actor": actor.state_dict(),
|
||||
"critic": critic.state_dict(),
|
||||
"episode": episode,
|
||||
"timesteps": total_timesteps,
|
||||
},
|
||||
os.path.join("models", f"ppo_parallel_improved_ep{episode}.pt"),
|
||||
)
|
||||
print(f" Saved model at episode {episode}")
|
||||
|
||||
os.makedirs("models", exist_ok=True)
|
||||
torch.save(
|
||||
{
|
||||
"actor": actor.state_dict(),
|
||||
"critic": critic.state_dict(),
|
||||
"episode": episode,
|
||||
"timesteps": total_timesteps,
|
||||
"best_eval": best_eval,
|
||||
},
|
||||
os.path.join("models", "ppo_parallel_improved_final.pt"),
|
||||
)
|
||||
|
||||
writer.close()
|
||||
parallel_env.close()
|
||||
print(f"Training complete! Total episodes: {episode}, Best eval: {best_eval:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
set_start_method('fork')
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--steps", type=int, default=2000000, help="Total training steps")
|
||||
parser.add_argument("--num_envs", type=int, default=4, help="Number of parallel environments")
|
||||
parser.add_argument("--rollout", type=int, default=4096, help="Rollout buffer size")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = get_device()
|
||||
train(
|
||||
total_steps=args.steps,
|
||||
num_envs=args.num_envs,
|
||||
rollout_steps=args.rollout,
|
||||
device=device,
|
||||
)
|
||||
Reference in New Issue
Block a user