Files
2026-04-22 23:34:58 +01:00

144 lines
4.9 KiB
Python

"""
Evaluation script for a trained herding policy.
Runs N episodes and reports the three project metrics:
1. Success rate — fraction of episodes where all sheep are penned
2. Time-to-pen — mean steps across successful episodes (per sheep)
3. Flock dispersion — mean pairwise distance among active sheep, averaged
over all timesteps (lower = tighter herding)
Usage
-----
python evaluate.py --model runs/ppo_herding/best_model/best_model.zip \
--vecnorm runs/ppo_herding/vecnorm.pkl \
--n-sheep 5 --episodes 100
Add --render to watch the first episode in a matplotlib window.
"""
import argparse
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
def make_single_env(n_sheep: int, max_steps: int, render_mode: str = None):
def _init():
return HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
render_mode=render_mode)
return _init
def pairwise_mean(positions: np.ndarray, n_active: int) -> float:
"""Mean pairwise distance among the first n_active sheep."""
if n_active < 2:
return 0.0
pts = positions[:n_active]
dists = []
for i in range(n_active):
for j in range(i + 1, n_active):
dists.append(float(np.linalg.norm(pts[i] - pts[j])))
return float(np.mean(dists))
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", required=True,
help="Path to saved model .zip")
p.add_argument("--vecnorm", default=None,
help="Path to VecNormalize stats .pkl (optional)")
p.add_argument("--n-sheep", type=int, default=1)
p.add_argument("--episodes", type=int, default=50)
p.add_argument("--max-steps", type=int, default=2000)
p.add_argument("--render", action="store_true",
help="Render first episode in matplotlib")
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def main():
args = parse_args()
render_mode = "human" if args.render else None
raw_env = DummyVecEnv([make_single_env(args.n_sheep, args.max_steps,
render_mode)])
if args.vecnorm:
env = VecNormalize.load(args.vecnorm, raw_env)
env.training = False
env.norm_reward = False
else:
env = raw_env
model = PPO.load(args.model, env=env)
successes = []
steps_to_pen = [] # steps for successful episodes
dispersions = [] # per-episode mean flock dispersion
for ep in range(args.episodes):
obs = env.reset()
done = False
ep_steps = 0
ep_dispersion = []
first_ep = ep == 0
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = env.step(action)
done = dones[0]
ep_steps += 1
# Access the underlying HerdingEnv for dispersion calculation
inner = env.envs[0] if hasattr(env, "envs") else env.venv.envs[0]
if not inner.penned[:inner.n_sheep].all():
ep_dispersion.append(
pairwise_mean(inner.sheep_pos, inner.n_sheep)
)
if first_ep and render_mode == "human":
pass # render() is called inside step()
info = infos[0]
n_penned = info.get("n_penned", 0)
n_sheep = info.get("n_sheep", args.n_sheep)
success = n_penned == n_sheep
successes.append(int(success))
if success:
steps_to_pen.append(ep_steps / n_sheep)
if ep_dispersion:
dispersions.append(float(np.mean(ep_dispersion)))
if (ep + 1) % 10 == 0:
print(f" Episode {ep + 1:>4}/{args.episodes} "
f"success={int(success)} steps={ep_steps}")
env.close()
# -----------------------------------------------------------------------
# Report
# -----------------------------------------------------------------------
success_rate = float(np.mean(successes))
mean_ttp = float(np.mean(steps_to_pen)) if steps_to_pen else float("nan")
mean_disp = float(np.mean(dispersions)) if dispersions else float("nan")
print("\n" + "=" * 50)
print(f" Model : {args.model}")
print(f" Sheep : {args.n_sheep}")
print(f" Episodes : {args.episodes}")
print("-" * 50)
print(f" Success rate : {success_rate * 100:.1f}%"
f" ({sum(successes)}/{args.episodes})")
print(f" Time-to-pen : {mean_ttp:.1f} steps/sheep"
f" (successful episodes only)")
print(f" Flock dispersion: {mean_disp:.2f} m"
f" (mean pairwise distance while active)")
print("=" * 50)
if __name__ == "__main__":
main()