144 lines
4.9 KiB
Python
144 lines
4.9 KiB
Python
"""
|
|
Evaluation script for a trained herding policy.
|
|
|
|
Runs N episodes and reports the three project metrics:
|
|
1. Success rate — fraction of episodes where all sheep are penned
|
|
2. Time-to-pen — mean steps across successful episodes (per sheep)
|
|
3. Flock dispersion — mean pairwise distance among active sheep, averaged
|
|
over all timesteps (lower = tighter herding)
|
|
|
|
Usage
|
|
-----
|
|
python evaluate.py --model runs/ppo_herding/best_model/best_model.zip \
|
|
--vecnorm runs/ppo_herding/vecnorm.pkl \
|
|
--n-sheep 5 --episodes 100
|
|
|
|
Add --render to watch the first episode in a matplotlib window.
|
|
"""
|
|
|
|
import argparse
|
|
|
|
import numpy as np
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
|
|
|
from herding_env import HerdingEnv
|
|
|
|
|
|
def make_single_env(n_sheep: int, max_steps: int, render_mode: str = None):
|
|
def _init():
|
|
return HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
|
render_mode=render_mode)
|
|
return _init
|
|
|
|
|
|
def pairwise_mean(positions: np.ndarray, n_active: int) -> float:
|
|
"""Mean pairwise distance among the first n_active sheep."""
|
|
if n_active < 2:
|
|
return 0.0
|
|
pts = positions[:n_active]
|
|
dists = []
|
|
for i in range(n_active):
|
|
for j in range(i + 1, n_active):
|
|
dists.append(float(np.linalg.norm(pts[i] - pts[j])))
|
|
return float(np.mean(dists))
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--model", required=True,
|
|
help="Path to saved model .zip")
|
|
p.add_argument("--vecnorm", default=None,
|
|
help="Path to VecNormalize stats .pkl (optional)")
|
|
p.add_argument("--n-sheep", type=int, default=1)
|
|
p.add_argument("--episodes", type=int, default=50)
|
|
p.add_argument("--max-steps", type=int, default=2000)
|
|
p.add_argument("--render", action="store_true",
|
|
help="Render first episode in matplotlib")
|
|
p.add_argument("--seed", type=int, default=42)
|
|
return p.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
render_mode = "human" if args.render else None
|
|
raw_env = DummyVecEnv([make_single_env(args.n_sheep, args.max_steps,
|
|
render_mode)])
|
|
if args.vecnorm:
|
|
env = VecNormalize.load(args.vecnorm, raw_env)
|
|
env.training = False
|
|
env.norm_reward = False
|
|
else:
|
|
env = raw_env
|
|
|
|
model = PPO.load(args.model, env=env)
|
|
|
|
successes = []
|
|
steps_to_pen = [] # steps for successful episodes
|
|
dispersions = [] # per-episode mean flock dispersion
|
|
|
|
for ep in range(args.episodes):
|
|
obs = env.reset()
|
|
done = False
|
|
ep_steps = 0
|
|
ep_dispersion = []
|
|
first_ep = ep == 0
|
|
|
|
while not done:
|
|
action, _ = model.predict(obs, deterministic=True)
|
|
obs, _, dones, infos = env.step(action)
|
|
done = dones[0]
|
|
ep_steps += 1
|
|
|
|
# Access the underlying HerdingEnv for dispersion calculation
|
|
inner = env.envs[0] if hasattr(env, "envs") else env.venv.envs[0]
|
|
if not inner.penned[:inner.n_sheep].all():
|
|
ep_dispersion.append(
|
|
pairwise_mean(inner.sheep_pos, inner.n_sheep)
|
|
)
|
|
|
|
if first_ep and render_mode == "human":
|
|
pass # render() is called inside step()
|
|
|
|
info = infos[0]
|
|
n_penned = info.get("n_penned", 0)
|
|
n_sheep = info.get("n_sheep", args.n_sheep)
|
|
success = n_penned == n_sheep
|
|
|
|
successes.append(int(success))
|
|
if success:
|
|
steps_to_pen.append(ep_steps / n_sheep)
|
|
if ep_dispersion:
|
|
dispersions.append(float(np.mean(ep_dispersion)))
|
|
|
|
if (ep + 1) % 10 == 0:
|
|
print(f" Episode {ep + 1:>4}/{args.episodes} "
|
|
f"success={int(success)} steps={ep_steps}")
|
|
|
|
env.close()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Report
|
|
# -----------------------------------------------------------------------
|
|
success_rate = float(np.mean(successes))
|
|
mean_ttp = float(np.mean(steps_to_pen)) if steps_to_pen else float("nan")
|
|
mean_disp = float(np.mean(dispersions)) if dispersions else float("nan")
|
|
|
|
print("\n" + "=" * 50)
|
|
print(f" Model : {args.model}")
|
|
print(f" Sheep : {args.n_sheep}")
|
|
print(f" Episodes : {args.episodes}")
|
|
print("-" * 50)
|
|
print(f" Success rate : {success_rate * 100:.1f}%"
|
|
f" ({sum(successes)}/{args.episodes})")
|
|
print(f" Time-to-pen : {mean_ttp:.1f} steps/sheep"
|
|
f" (successful episodes only)")
|
|
print(f" Flock dispersion: {mean_disp:.2f} m"
|
|
f" (mean pairwise distance while active)")
|
|
print("=" * 50)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|