Files
rl-atari/强化学习个人项目报告/src/evaluate.py
T
Serendipity b32490ae03 fix(ppo): 修正日志概率维度与状态张量格式
修复 replay buffer 中 log_probs 的维度错误,从 (buffer_size, action_dim) 改为 buffer_size
修正训练时状态张量格式,从 (N, H, W, C) 转换为 (N, C, H, W)
更新 collect_rollout 返回观测值并修正 log_prob 计算
添加项目配置文件和训练曲线生成脚本
2026-04-30 20:30:40 +08:00

96 lines
3.1 KiB
Python

"""Evaluation script for trained PPO agent."""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import numpy as np
import gymnasium as gym
from src.utils import make_env, get_device
from src.network import Actor, Critic
def evaluate(actor, env, num_episodes=10, device=torch.device("cpu")):
"""Evaluate actor and return average return."""
actor.eval()
returns = []
for ep in range(num_episodes):
obs, _ = env.reset()
obs = np.transpose(obs, (1, 2, 0)) # (C, H, W) -> (H, W, C) for storage
total_reward = 0
done = False
steps = 0
while not done and steps < 1000:
with torch.no_grad():
# Convert to tensor (B, C, H, W)
obs_t = torch.from_numpy(obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device)
mu, std = actor(obs_t)
# Sample action
dist = torch.distributions.Normal(mu, std)
action = dist.sample()
action = torch.clamp(action, -1, 1).squeeze(0).cpu().numpy()
obs, reward, terminated, truncated, _ = env.step(action)
# Convert to (C, H, W) format
obs = np.transpose(obs, (1, 2, 0))
total_reward += reward
done = terminated or truncated
steps += 1
returns.append(total_reward)
print(f"Episode {ep+1}/{num_episodes}: return={total_reward:.1f}, steps={steps}")
actor.train()
return np.mean(returns), np.std(returns)
def evaluate_render(actor, env, device):
"""Render and evaluate agent with visualization."""
actor.eval()
obs, _ = env.reset()
obs = np.transpose(obs, (1, 2, 0))
env.render_mode = "human"
done = False
total_reward = 0
while not done:
with torch.no_grad():
obs_t = torch.from_numpy(obs).float().unsqueeze(0).permute(0, 3, 1, 2).to(device)
mu, std = actor(obs_t)
dist = torch.distributions.Normal(mu, std)
action = dist.sample()
action = torch.clamp(action, -1, 1).squeeze(0).cpu().numpy()
obs, reward, terminated, truncated, _ = env.step(action)
obs = np.transpose(obs, (1, 2, 0))
total_reward += reward
done = terminated or truncated
env.render()
actor.train()
print(f"Final return: {total_reward:.1f}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True, help="Path to trained model")
parser.add_argument("--episodes", type=int, default=5, help="Number of evaluation episodes")
args = parser.parse_args()
device = get_device()
env = make_env()
actor = Actor().to(device)
critic = Critic().to(device)
# Load model
checkpoint = torch.load(args.model, map_location=device, weights_only=False)
actor.load_state_dict(checkpoint["actor"])
print(f"Loaded model from {args.model}")
mean_return, std_return = evaluate(actor, env, num_episodes=args.episodes, device=device)
print(f"\nEvaluation: mean={mean_return:.2f}, std={std_return:.2f}")