Files
rl-atari/强化学习个人项目报告/train_parallel_improved.py
T
Serendipity 7dea00195e feat: 添加并行训练脚本和奖励塑形以改进PPO性能
引入并行环境训练脚本 train_parallel_improved.py,实现多进程并行数据收集
添加奖励塑形包装器,根据速度、赛道位置和完成圈数调整奖励信号
优化神经网络结构和训练参数,包括更大的rollout缓冲区
删除旧的tensorboard日志文件,创建新的训练运行记录
2026-05-01 09:26:39 +08:00

663 lines
21 KiB
Python

#!/usr/bin/env python3
"""Parallel training with reward shaping for CarRacing-v3 PPO."""
import os
import sys
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from collections import deque
from multiprocessing import Process, Queue, SimpleQueue, set_start_method
import gymnasium as gym
import cv2
class RewardShapingWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.steps_on_track = 0
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
self.steps_on_track = 0
return obs, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
shaped_reward = reward
if info.get("speed", 0) > 0.1:
shaped_reward += info["speed"] * 0.1
if not info.get("offtrack", False):
shaped_reward += 0.1
self.steps_on_track += 1
else:
shaped_reward -= 0.5
self.steps_on_track = 0
if info.get("lap_complete", False):
shaped_reward += 100
return obs, shaped_reward, terminated, truncated, info
class GrayScaleWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, obs):
gray = 0.299 * obs[:, :, 0] + 0.587 * obs[:, :, 1] + 0.114 * obs[:, :, 2]
return gray.astype(np.uint8)
class ResizeWrapper(gym.ObservationWrapper):
def __init__(self, env, size=(84, 84)):
super().__init__(env)
self.size = size
def observation(self, obs):
return cv2.resize(obs, self.size, interpolation=cv2.INTER_AREA)
class FrameStackWrapper(gym.ObservationWrapper):
def __init__(self, env, num_stack=4):
super().__init__(env)
self.num_stack = num_stack
self.frames = deque(maxlen=num_stack)
obs_shape = env.observation_space.shape
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(num_stack, *obs_shape[-2:]), dtype=np.uint8
)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
for _ in range(self.num_stack):
self.frames.append(obs)
return self._get_observation(), info
def observation(self, obs):
self.frames.append(obs)
return self._get_observation()
def _get_observation(self):
return np.stack(list(self.frames), axis=0)
def make_env(env_id="CarRacing-v3", gray_scale=True, resize=True, frame_stack=4):
env = gym.make(env_id, render_mode="rgb_array")
if resize:
env = ResizeWrapper(env, size=(84, 84))
if gray_scale:
env = GrayScaleWrapper(env)
if frame_stack > 1:
env = FrameStackWrapper(env, num_stack=frame_stack)
env = RewardShapingWrapper(env)
return env
def worker_loop(worker_id, action_queue, result_queue):
env = make_env()
obs, _ = env.reset()
obs = np.transpose(obs, (1, 2, 0))
while True:
try:
cmd, data = action_queue.get()
if cmd == 'step':
action = data
next_obs, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
next_obs_t = np.transpose(next_obs, (1, 2, 0))
if done:
next_obs_t, _ = env.reset()
next_obs_t = np.transpose(next_obs_t, (1, 2, 0))
result_queue.put((worker_id, obs, action, reward, done, next_obs_t))
obs = next_obs_t
elif cmd == 'reset':
obs, _ = env.reset()
obs = np.transpose(obs, (1, 2, 0))
result_queue.put((worker_id, obs))
elif cmd == 'close':
env.close()
break
except Exception:
break
class ParallelEnv:
def __init__(self, num_envs=4):
self.num_envs = num_envs
self.action_queues = []
self.result_queue = SimpleQueue()
self.processes = []
for i in range(num_envs):
action_queue = Queue(maxsize=1)
self.action_queues.append(action_queue)
p = Process(target=worker_loop, args=(i, action_queue, self.result_queue))
p.start()
self.processes.append(p)
for i in range(num_envs):
self.action_queues[i].put(('reset', None))
for _ in range(num_envs):
self.result_queue.get()
def reset(self):
for i in range(self.num_envs):
self.action_queues[i].put(('reset', None))
obs_list = []
for _ in range(self.num_envs):
worker_id, obs = self.result_queue.get()
obs_list.append(obs)
return np.array(obs_list)
def step(self, actions):
for i in range(self.num_envs):
self.action_queues[i].put(('step', actions[i]))
results = {}
for _ in range(self.num_envs):
item = self.result_queue.get()
results[item[0]] = item[1:]
obs_list = []
reward_list = []
done_list = []
next_obs_list = []
for i in range(self.num_envs):
data = results[i]
obs, action, reward, done = data[:4]
next_obs = data[4] if len(data) > 4 else obs
obs_list.append(obs)
reward_list.append(reward)
done_list.append(done)
next_obs_list.append(next_obs)
return np.array(next_obs_list), np.array(reward_list), np.array(done_list)
def close(self):
for i in range(self.num_envs):
try:
self.action_queues[i].put(('close', None))
except:
pass
time.sleep(0.5)
for p in self.processes:
if p.is_alive():
p.terminate()
p.join(timeout=1)
def get_device():
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("Using CPU")
return device
class Actor(nn.Module):
def __init__(self, state_shape=(84, 84, 4), action_dim=3):
super().__init__()
c, h, w = state_shape[2], state_shape[0], state_shape[1]
self.conv = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.LeakyReLU(0.2),
)
out_h = (h - 8) // 4 + 1
out_h = (out_h - 4) // 2 + 1
out_h = (out_h - 3) // 1 + 1
feat_size = 64 * out_h * out_h
self.fc = nn.Sequential(
nn.Linear(feat_size, 512),
nn.LeakyReLU(0.2),
)
self.mu_head = nn.Linear(512, action_dim)
self.log_std_head = nn.Linear(512, action_dim)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.orthogonal_(self.mu_head.weight, gain=0.01)
nn.init.orthogonal_(self.log_std_head.weight, gain=0.01)
def forward(self, x):
x = x / 255.0
x = self.conv(x)
x = x.flatten(1)
x = self.fc(x)
mu = torch.tanh(self.mu_head(x))
log_std = torch.clamp(self.log_std_head(x), -20, 2)
return mu, log_std.exp()
class Critic(nn.Module):
def __init__(self, state_shape=(84, 84, 4)):
super().__init__()
c, h, w = state_shape[2], state_shape[0], state_shape[1]
self.conv = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.LeakyReLU(0.2),
)
out_h = (h - 8) // 4 + 1
out_h = (out_h - 4) // 2 + 1
out_h = (out_h - 3) // 1 + 1
feat_size = 64 * out_h * out_h
self.fc = nn.Sequential(nn.Linear(feat_size, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1))
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = x / 255.0
x = self.conv(x)
x = x.flatten(1)
return self.fc(x)
class RolloutBuffer:
def __init__(self, buffer_size, state_shape, action_dim):
self.buffer_size = buffer_size
self.ptr = 0
self.size = 0
self.states = np.zeros((buffer_size, *state_shape), dtype=np.uint8)
self.actions = np.zeros((buffer_size, action_dim), dtype=np.float32)
self.rewards = np.zeros(buffer_size, dtype=np.float32)
self.dones = np.zeros(buffer_size, dtype=np.bool_)
self.values = np.zeros(buffer_size, dtype=np.float32)
self.log_probs = np.zeros(buffer_size, dtype=np.float32)
def add(self, state, action, reward, done, value, log_prob):
self.states[self.ptr] = state
self.actions[self.ptr] = action
self.rewards[self.ptr] = reward
self.dones[self.ptr] = done
self.values[self.ptr] = value
self.log_probs[self.ptr] = log_prob
self.ptr = (self.ptr + 1) % self.buffer_size
self.size = min(self.size + 1, self.buffer_size)
def compute_returns(self, last_value, gamma=0.99, gae_lambda=0.98):
advantages = np.zeros(self.size, dtype=np.float32)
last_gae = 0
for t in reversed(range(self.size)):
if t == self.size - 1:
next_value = last_value
else:
next_value = self.values[t + 1]
delta = self.rewards[t] + gamma * next_value * (1 - self.dones[t]) - self.values[t]
last_gae = delta + gamma * gae_lambda * (1 - self.dones[t]) * last_gae
advantages[t] = last_gae
returns = advantages + self.values[: self.size]
return returns, advantages
def get(self):
return (
self.states[: self.size],
self.actions[: self.size],
self.rewards[: self.size],
self.dones[: self.size],
self.values[: self.size],
self.log_probs[: self.size],
)
def reset(self):
self.ptr = 0
self.size = 0
class PPOTrainer:
def __init__(
self,
actor,
critic,
rollout_buffer,
device,
clip_eps=0.1,
gamma=0.99,
gae_lambda=0.98,
lr=3e-4,
ent_coef=0.005,
vf_coef=0.75,
max_grad_norm=0.5,
ppo_epochs=10,
mini_batch_size=128,
):
self.actor = actor
self.critic = critic
self.buffer = rollout_buffer
self.device = device
self.clip_eps = clip_eps
self.gamma = gamma
self.gae_lambda = gae_lambda
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.ppo_epochs = ppo_epochs
self.mini_batch_size = mini_batch_size
self.actor_optim = torch.optim.Adam(actor.parameters(), lr=lr, eps=1e-5)
self.critic_optim = torch.optim.Adam(critic.parameters(), lr=lr, eps=1e-5)
def update(self, last_value):
states, actions, rewards, dones, values, log_probs_old = self.buffer.get()
returns, advantages = self.buffer.compute_returns(last_value, self.gamma, self.gae_lambda)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
states_t = torch.from_numpy(states).float().permute(0, 3, 1, 2).to(self.device)
actions_t = torch.from_numpy(actions).float().to(self.device)
log_probs_old_t = torch.from_numpy(log_probs_old).float().to(self.device)
returns_t = torch.from_numpy(returns).float().to(self.device)
advantages_t = torch.from_numpy(advantages).float().to(self.device)
dataset = torch.utils.data.TensorDataset(
states_t, actions_t, log_probs_old_t, returns_t, advantages_t
)
loader = torch.utils.data.DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=True)
total_actor_loss = 0
total_critic_loss = 0
total_entropy = 0
count = 0
for _ in range(self.ppo_epochs):
for batch in loader:
s, a, log_pi_old, ret, adv = batch
mu, std = self.actor(s)
dist = torch.distributions.Normal(mu, std)
log_pi = dist.log_prob(a).sum(dim=-1)
entropy = dist.entropy().sum(dim=-1)
ratio = torch.exp(log_pi - log_pi_old)
surr1 = ratio * adv
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * adv
actor_loss = -torch.min(surr1, surr2).mean()
value = self.critic(s)
critic_loss = nn.MSELoss()(value.squeeze(), ret)
loss = actor_loss + self.vf_coef * critic_loss - self.ent_coef * entropy.mean()
self.actor_optim.zero_grad()
self.critic_optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
self.actor_optim.step()
self.critic_optim.step()
total_actor_loss += actor_loss.item()
total_critic_loss += critic_loss.item()
total_entropy += entropy.mean().item()
count += 1
avg_actor = total_actor_loss / count
avg_critic = total_critic_loss / count
avg_entropy = total_entropy / count
self.buffer.reset()
return avg_actor, avg_critic, avg_entropy
def collect_parallel_rollout(actor, critic, parallel_env, buffer, device, rollout_steps, num_envs):
obs_batch = parallel_env.reset()
episode_rewards = []
current_ep_rewards = [0.0] * num_envs
steps_per_env = rollout_steps // num_envs
for step in range(steps_per_env):
obs_t = torch.from_numpy(obs_batch).float().permute(0, 3, 1, 2).to(device)
with torch.no_grad():
mu, std = actor(obs_t)
dist = torch.distributions.Normal(mu, std)
action = dist.sample()
action = torch.clamp(action, -1, 1)
log_prob = dist.log_prob(action).sum(dim=-1)
value = critic(obs_t).squeeze(-1)
action_np = action.cpu().numpy()
log_prob_np = log_prob.cpu().numpy()
value_np = value.cpu().numpy()
next_obs_batch, reward_batch, done_batch = parallel_env.step(action_np)
for i in range(num_envs):
buffer.add(obs_batch[i], action_np[i], reward_batch[i],
done_batch[i], value_np[i], log_prob_np[i])
current_ep_rewards[i] += reward_batch[i]
if done_batch[i]:
episode_rewards.append(current_ep_rewards[i])
current_ep_rewards[i] = 0.0
obs_batch = next_obs_batch
return obs_batch, episode_rewards if episode_rewards else [sum(current_ep_rewards) / num_envs]
def train(
total_steps=2000000,
num_envs=4,
rollout_steps=4096,
eval_interval=10,
save_interval=50,
device=None,
):
if device is None:
device = get_device()
print(f"Creating {num_envs} parallel environments with reward shaping...")
parallel_env = ParallelEnv(num_envs=num_envs)
state_shape = (84, 84, 4)
action_dim = 3
actor = Actor(state_shape=state_shape, action_dim=action_dim).to(device)
critic = Critic(state_shape=state_shape).to(device)
buffer = RolloutBuffer(rollout_steps, state_shape, action_dim)
trainer = PPOTrainer(
actor=actor,
critic=critic,
rollout_buffer=buffer,
device=device,
clip_eps=0.1,
gamma=0.99,
gae_lambda=0.98,
lr=3e-4,
ent_coef=0.005,
vf_coef=0.75,
max_grad_norm=0.5,
ppo_epochs=10,
mini_batch_size=256,
)
log_dir = os.path.join("logs", "tensorboard", f"run_parallel_improved_{int(time.time())}")
writer = SummaryWriter(log_dir)
print(f"Training on {device} with {num_envs} parallel envs")
print(f"Log directory: {log_dir}")
print("Improvements: Parallel + Reward Shaping + Larger Rollout")
episode = 0
total_timesteps = 0
episode_rewards = []
best_eval = -float("inf")
while total_timesteps < total_steps:
obs_batch, batch_rewards = collect_parallel_rollout(
actor, critic, parallel_env, buffer, device, rollout_steps, num_envs
)
with torch.no_grad():
obs_t = torch.from_numpy(obs_batch).float().permute(0, 3, 1, 2).to(device)
last_value = critic(obs_t).mean().item()
actor_loss, critic_loss, entropy = trainer.update(last_value)
writer.add_scalar("Loss/Actor", actor_loss, total_timesteps)
writer.add_scalar("Loss/Critic", critic_loss, total_timesteps)
writer.add_scalar("Loss/Entropy", entropy, total_timesteps)
total_timesteps += rollout_steps
episode += 1
episode_rewards.extend(batch_rewards)
recent_rewards = episode_rewards[-50:] if len(episode_rewards) > 50 else episode_rewards
avg_reward = np.mean(recent_rewards)
mean_batch_reward = np.mean(batch_rewards)
writer.add_scalar("Reward/EpisodeMean", mean_batch_reward, total_timesteps)
writer.add_scalar("Reward/AvgLast50", avg_reward, total_timesteps)
print(f"Episode {episode}, steps {total_timesteps}, mean_reward={mean_batch_reward:.1f}, avg_50={avg_reward:.1f}")
if episode % eval_interval == 0:
eval_returns = []
for _ in range(5):
eval_env = make_env()
eval_obs, _ = eval_env.reset()
eval_obs = np.transpose(eval_obs, (1, 2, 0))
eval_reward = 0
done = False
while not done:
with torch.no_grad():
eval_obs_t = (
torch.from_numpy(eval_obs)
.float()
.unsqueeze(0)
.permute(0, 3, 1, 2)
.to(device)
)
mu, std = actor(eval_obs_t)
action = torch.clamp(mu, -1, 1).squeeze(0).cpu().numpy()
eval_obs, reward, terminated, truncated, _ = eval_env.step(action)
eval_obs = np.transpose(eval_obs, (1, 2, 0))
eval_reward += reward
done = terminated or truncated
eval_returns.append(eval_reward)
eval_env.close()
mean_eval = np.mean(eval_returns)
writer.add_scalar("Eval/MeanReturn", mean_eval, episode)
print(f" Eval: mean_return={mean_eval:.2f}")
if mean_eval > best_eval:
best_eval = mean_eval
os.makedirs("models", exist_ok=True)
torch.save(
{
"actor": actor.state_dict(),
"critic": critic.state_dict(),
"episode": episode,
"timesteps": total_timesteps,
"best_eval": best_eval,
},
os.path.join("models", "ppo_parallel_improved_best.pt"),
)
print(f" New best model saved! eval={best_eval:.2f}")
if episode % save_interval == 0:
os.makedirs("models", exist_ok=True)
torch.save(
{
"actor": actor.state_dict(),
"critic": critic.state_dict(),
"episode": episode,
"timesteps": total_timesteps,
},
os.path.join("models", f"ppo_parallel_improved_ep{episode}.pt"),
)
print(f" Saved model at episode {episode}")
os.makedirs("models", exist_ok=True)
torch.save(
{
"actor": actor.state_dict(),
"critic": critic.state_dict(),
"episode": episode,
"timesteps": total_timesteps,
"best_eval": best_eval,
},
os.path.join("models", "ppo_parallel_improved_final.pt"),
)
writer.close()
parallel_env.close()
print(f"Training complete! Total episodes: {episode}, Best eval: {best_eval:.2f}")
if __name__ == "__main__":
try:
set_start_method('fork')
except RuntimeError:
pass
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=2000000, help="Total training steps")
parser.add_argument("--num_envs", type=int, default=4, help="Number of parallel environments")
parser.add_argument("--rollout", type=int, default=4096, help="Rollout buffer size")
args = parser.parse_args()
device = get_device()
train(
total_steps=args.steps,
num_envs=args.num_envs,
rollout_steps=args.rollout,
device=device,
)