Sheep training flock _ improver
This commit is contained in:
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Load a saved run and evaluate the policy at every n_sheep from 1..N.
|
||||
Tells you exactly where the curriculum stopped working.
|
||||
|
||||
Usage:
|
||||
python eval_per_sheep.py --run-dir runs/ppo_v3
|
||||
python eval_per_sheep.py --run-dir runs/ppo_v3 --max-sheep 10 --episodes 20
|
||||
python eval_per_sheep.py --model runs/ppo_v3/final_model.zip \
|
||||
--vecnorm runs/ppo_v3/vecnorm.pkl
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
from train import _classify, COMPACT_RADIUS
|
||||
|
||||
|
||||
def evaluate(model, vn_template, n_sheep, n_episodes, max_steps):
|
||||
raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)])
|
||||
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
vn.obs_rms = deepcopy(vn_template.obs_rms)
|
||||
vn.ret_rms = deepcopy(vn_template.ret_rms)
|
||||
|
||||
failure = {}
|
||||
successes = 0
|
||||
act_mags, min_radii, min_dog_com, min_pen = [], [], [], []
|
||||
|
||||
for _ in range(n_episodes):
|
||||
obs = vn.reset()
|
||||
done = False
|
||||
ep_radius, ep_com_dist, ep_dog_com, ep_act = [], [], [], []
|
||||
while not done:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, _, dones, infos = vn.step(action)
|
||||
done = dones[0]
|
||||
inner = vn.envs[0]
|
||||
com, radius, _ = inner._flock_stats()
|
||||
ep_radius.append(radius)
|
||||
ep_com_dist.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
|
||||
ep_dog_com.append(float(np.linalg.norm(inner.dog_pos - com)))
|
||||
ep_act.append(float(np.linalg.norm(action[0])))
|
||||
npen = infos[0].get("n_penned", 0)
|
||||
success = npen == n_sheep
|
||||
successes += int(success)
|
||||
mode = _classify(ep_radius, ep_com_dist, npen, n_sheep, success)
|
||||
failure[mode] = failure.get(mode, 0) + 1
|
||||
act_mags.extend(ep_act)
|
||||
min_radii.append(min(ep_radius))
|
||||
min_dog_com.append(min(ep_dog_com))
|
||||
min_pen.append(min(ep_com_dist))
|
||||
vn.close()
|
||||
|
||||
return {
|
||||
"n_sheep": n_sheep,
|
||||
"success_rate": successes / n_episodes,
|
||||
"failure": failure,
|
||||
"mean_action": float(np.mean(act_mags)),
|
||||
"mean_min_radius": float(np.mean(min_radii)),
|
||||
"mean_min_dog_com": float(np.mean(min_dog_com)),
|
||||
"mean_min_pen": float(np.mean(min_pen)),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--run-dir", type=str, default=None)
|
||||
p.add_argument("--model", type=str, default=None)
|
||||
p.add_argument("--vecnorm", type=str, default=None)
|
||||
p.add_argument("--max-sheep", type=int, default=10)
|
||||
p.add_argument("--episodes", type=int, default=10)
|
||||
p.add_argument("--max-steps", type=int, default=2000)
|
||||
args = p.parse_args()
|
||||
|
||||
if args.run_dir:
|
||||
model_path = os.path.join(args.run_dir, "final_model.zip")
|
||||
if not os.path.exists(model_path):
|
||||
model_path = os.path.join(args.run_dir, "best_model", "best_model.zip")
|
||||
vn_path = os.path.join(args.run_dir, "vecnorm.pkl")
|
||||
else:
|
||||
model_path = args.model
|
||||
vn_path = args.vecnorm
|
||||
|
||||
print(f"Loading model: {model_path}")
|
||||
print(f"Loading vecnorm: {vn_path}\n")
|
||||
model = PPO.load(model_path, device="cpu")
|
||||
raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=1, max_steps=args.max_steps)])
|
||||
vn_template = VecNormalize.load(vn_path, raw)
|
||||
|
||||
print(f"{'n_sheep':>7} {'success':>8} {'act':>6} {'min_r':>7} "
|
||||
f"{'dog→com':>8} {'com→pen':>8} failure breakdown")
|
||||
print("-" * 90)
|
||||
for n in range(1, args.max_sheep + 1):
|
||||
r = evaluate(model, vn_template, n, args.episodes, args.max_steps)
|
||||
fb = " ".join(f"{m}={c}" for m, c in
|
||||
sorted(r["failure"].items(), key=lambda x: -x[1]))
|
||||
print(f"{n:>7d} {r['success_rate']*100:>6.0f}% "
|
||||
f"{r['mean_action']:>6.2f} "
|
||||
f"{r['mean_min_radius']:>6.2f}m "
|
||||
f"{r['mean_min_dog_com']:>7.2f}m "
|
||||
f"{r['mean_min_pen']:>7.2f}m {fb}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user