b32490ae03
修复 replay buffer 中 log_probs 的维度错误,从 (buffer_size, action_dim) 改为 buffer_size 修正训练时状态张量格式,从 (N, H, W, C) 转换为 (N, C, H, W) 更新 collect_rollout 返回观测值并修正 log_prob 计算 添加项目配置文件和训练曲线生成脚本
96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
"""Evaluation script for trained PPO agent."""
|
|
import sys
|
|
import os
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import torch
|
|
import numpy as np
|
|
import gymnasium as gym
|
|
from src.utils import make_env, get_device
|
|
from src.network import Actor, Critic
|
|
|
|
|
|
def evaluate(actor, env, num_episodes=10, device=torch.device("cpu")):
|
|
"""Evaluate actor and return average return."""
|
|
actor.eval()
|
|
returns = []
|
|
|
|
for ep in range(num_episodes):
|
|
obs, _ = env.reset()
|
|
obs = np.transpose(obs, (1, 2, 0)) # (C, H, W) -> (H, W, C) for storage
|
|
total_reward = 0
|
|
done = False
|
|
steps = 0
|
|
|
|
while not done and steps < 1000:
|
|
with torch.no_grad():
|
|
# Convert to tensor (B, C, H, W)
|
|
obs_t = torch.from_numpy(obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device)
|
|
mu, std = actor(obs_t)
|
|
# Sample action
|
|
dist = torch.distributions.Normal(mu, std)
|
|
action = dist.sample()
|
|
action = torch.clamp(action, -1, 1).squeeze(0).cpu().numpy()
|
|
|
|
obs, reward, terminated, truncated, _ = env.step(action)
|
|
# Convert to (C, H, W) format
|
|
obs = np.transpose(obs, (1, 2, 0))
|
|
total_reward += reward
|
|
done = terminated or truncated
|
|
steps += 1
|
|
|
|
returns.append(total_reward)
|
|
print(f"Episode {ep+1}/{num_episodes}: return={total_reward:.1f}, steps={steps}")
|
|
|
|
actor.train()
|
|
return np.mean(returns), np.std(returns)
|
|
|
|
|
|
def evaluate_render(actor, env, device):
|
|
"""Render and evaluate agent with visualization."""
|
|
actor.eval()
|
|
obs, _ = env.reset()
|
|
obs = np.transpose(obs, (1, 2, 0))
|
|
|
|
env.render_mode = "human"
|
|
done = False
|
|
total_reward = 0
|
|
|
|
while not done:
|
|
with torch.no_grad():
|
|
obs_t = torch.from_numpy(obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device)
|
|
mu, std = actor(obs_t)
|
|
dist = torch.distributions.Normal(mu, std)
|
|
action = dist.sample()
|
|
action = torch.clamp(action, -1, 1).squeeze(0).cpu().numpy()
|
|
|
|
obs, reward, terminated, truncated, _ = env.step(action)
|
|
obs = np.transpose(obs, (1, 2, 0))
|
|
total_reward += reward
|
|
done = terminated or truncated
|
|
env.render()
|
|
|
|
actor.train()
|
|
print(f"Final return: {total_reward:.1f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model", type=str, required=True, help="Path to trained model")
|
|
parser.add_argument("--episodes", type=int, default=5, help="Number of evaluation episodes")
|
|
args = parser.parse_args()
|
|
|
|
device = get_device()
|
|
env = make_env()
|
|
|
|
actor = Actor().to(device)
|
|
critic = Critic().to(device)
|
|
|
|
# Load model
|
|
checkpoint = torch.load(args.model, map_location=device, weights_only=False)
|
|
actor.load_state_dict(checkpoint["actor"])
|
|
print(f"Loaded model from {args.model}")
|
|
|
|
mean_return, std_return = evaluate(actor, env, num_episodes=args.episodes, device=device)
|
|
print(f"\nEvaluation: mean={mean_return:.2f}, std={std_return:.2f}") |