Files
TIR_PROJ/training/eval_per_sheep.py
T
2026-04-25 11:31:39 +01:00

110 lines
4.1 KiB
Python

"""
Load a saved run and evaluate the policy at every n_sheep from 1..N.
Tells you exactly where the curriculum stopped working.
Usage:
python eval_per_sheep.py --run-dir runs/ppo_v3
python eval_per_sheep.py --run-dir runs/ppo_v3 --max-sheep 10 --episodes 20
python eval_per_sheep.py --model runs/ppo_v3/final_model.zip \
--vecnorm runs/ppo_v3/vecnorm.pkl
"""
import argparse
import os
from copy import deepcopy
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
from train import _classify, COMPACT_RADIUS
def evaluate(model, vn_template, n_sheep, n_episodes, max_steps):
raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
failure = {}
successes = 0
act_mags, min_radii, min_dog_com, min_pen = [], [], [], []
for _ in range(n_episodes):
obs = vn.reset()
done = False
ep_radius, ep_com_dist, ep_dog_com, ep_act = [], [], [], []
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = vn.step(action)
done = dones[0]
inner = vn.envs[0]
com, radius, _ = inner._flock_stats()
ep_radius.append(radius)
ep_com_dist.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
ep_dog_com.append(float(np.linalg.norm(inner.dog_pos - com)))
ep_act.append(float(np.linalg.norm(action[0])))
npen = infos[0].get("n_penned", 0)
success = npen == n_sheep
successes += int(success)
mode = _classify(ep_radius, ep_com_dist, npen, n_sheep, success)
failure[mode] = failure.get(mode, 0) + 1
act_mags.extend(ep_act)
min_radii.append(min(ep_radius))
min_dog_com.append(min(ep_dog_com))
min_pen.append(min(ep_com_dist))
vn.close()
return {
"n_sheep": n_sheep,
"success_rate": successes / n_episodes,
"failure": failure,
"mean_action": float(np.mean(act_mags)),
"mean_min_radius": float(np.mean(min_radii)),
"mean_min_dog_com": float(np.mean(min_dog_com)),
"mean_min_pen": float(np.mean(min_pen)),
}
def main():
p = argparse.ArgumentParser()
p.add_argument("--run-dir", type=str, default=None)
p.add_argument("--model", type=str, default=None)
p.add_argument("--vecnorm", type=str, default=None)
p.add_argument("--max-sheep", type=int, default=10)
p.add_argument("--episodes", type=int, default=10)
p.add_argument("--max-steps", type=int, default=2000)
args = p.parse_args()
if args.run_dir:
model_path = os.path.join(args.run_dir, "final_model.zip")
if not os.path.exists(model_path):
model_path = os.path.join(args.run_dir, "best_model", "best_model.zip")
vn_path = os.path.join(args.run_dir, "vecnorm.pkl")
else:
model_path = args.model
vn_path = args.vecnorm
print(f"Loading model: {model_path}")
print(f"Loading vecnorm: {vn_path}\n")
model = PPO.load(model_path, device="cpu")
raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=1, max_steps=args.max_steps)])
vn_template = VecNormalize.load(vn_path, raw)
print(f"{'n_sheep':>7} {'success':>8} {'act':>6} {'min_r':>7} "
f"{'dog→com':>8} {'com→pen':>8} failure breakdown")
print("-" * 90)
for n in range(1, args.max_sheep + 1):
r = evaluate(model, vn_template, n, args.episodes, args.max_steps)
fb = " ".join(f"{m}={c}" for m, c in
sorted(r["failure"].items(), key=lambda x: -x[1]))
print(f"{n:>7d} {r['success_rate']*100:>6.0f}% "
f"{r['mean_action']:>6.2f} "
f"{r['mean_min_radius']:>6.2f}m "
f"{r['mean_min_dog_com']:>7.2f}m "
f"{r['mean_min_pen']:>7.2f}m {fb}")
if __name__ == "__main__":
main()