Sheep training flock of 10 fix?

This commit is contained in:
Johnny Fernandes
2026-04-24 10:58:36 +01:00
parent 4189cc8dba
commit 17eb25864e
3 changed files with 280 additions and 40 deletions
+223
View File
@@ -0,0 +1,223 @@
"""
Episode-level diagnostics for the herding policy.
Runs N episodes and for each one tracks:
- flock radius over time
- COM-to-pen distance over time
- dog position over time
- when (if ever) the flock first became compact
- failure mode classification
Then produces:
1. Console summary of failure modes
2. Per-episode time-series plots (radius + com_dist)
3. Optional rendered playback of the worst episodes
Usage
-----
python diagnose.py --model runs/ppo_consolidation/final_model.zip \
--vecnorm runs/ppo_consolidation/vecnorm.pkl \
--n-sheep 5 --episodes 20
# Watch the policy live (first episode rendered):
python diagnose.py ... --render
# Save plots to a directory instead of showing interactively:
python diagnose.py ... --plot-dir debug_plots/
"""
import argparse
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
# ── failure mode constants ────────────────────────────────────────────────────
COMPACT_RADIUS = 5.0 # must match DRIVE_GATE_RADIUS in herding_env.py
def classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success):
if success:
return "SUCCESS"
if min(ep_radius) > COMPACT_RADIUS:
return "NEVER_COMPACT" # flock was always too scattered
first_compact = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
min_com_after = min(ep_com_dist[first_compact:])
pen_close = 3.0 # COM within 3m of pen counts as "got close"
if min_com_after > pen_close:
return "COMPACT_CANT_DRIVE" # compacted but never drove to pen
if n_penned == 0:
return "DROVE_NO_SHEEP" # got near pen, nothing went in
return f"PARTIAL_{n_penned}of{n_sheep}" # some in, not all
# ── main ─────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", required=True)
p.add_argument("--vecnorm", default=None)
p.add_argument("--n-sheep", type=int, default=5)
p.add_argument("--episodes", type=int, default=20)
p.add_argument("--max-steps", type=int, default=4000)
p.add_argument("--render", action="store_true",
help="Show matplotlib animation of the first episode")
p.add_argument("--plot-dir", default=None,
help="Save time-series plots here (one per episode)")
p.add_argument("--seed", type=int, default=0)
return p.parse_args()
def make_env(n_sheep, max_steps, render_mode=None):
def _init():
return HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
render_mode=render_mode)
return _init
def main():
args = parse_args()
if args.plot_dir:
os.makedirs(args.plot_dir, exist_ok=True)
matplotlib.use("Agg")
render_mode = "human" if args.render else None
raw_env = DummyVecEnv([make_env(args.n_sheep, args.max_steps, render_mode)])
if args.vecnorm:
env = VecNormalize.load(args.vecnorm, raw_env)
env.training = False
env.norm_reward = False
else:
env = raw_env
model = PPO.load(args.model, env=env)
failure_counts = {}
all_ep_data = []
for ep in range(args.episodes):
obs = env.reset()
done = False
step = 0
ep_radius = []
ep_com_dist = []
ep_dog_x = []
ep_dog_y = []
ep_n_penned = []
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = env.step(action)
done = dones[0]
step += 1
inner = env.envs[0] if hasattr(env, "envs") else env.venv.envs[0]
com, radius, _ = inner._flock_stats()
com_dist = float(np.linalg.norm(com - inner.PEN_CENTER))
n_penned = int(inner.penned[:inner.n_sheep].sum())
ep_radius.append(radius)
ep_com_dist.append(com_dist)
ep_dog_x.append(float(inner.dog_pos[0]))
ep_dog_y.append(float(inner.dog_pos[1]))
ep_n_penned.append(n_penned)
info = infos[0]
n_pen = info.get("n_penned", 0)
n_sheep = info.get("n_sheep", args.n_sheep)
success = n_pen == n_sheep
mode = classify_failure(ep_radius, ep_com_dist, n_pen, n_sheep, success)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
compact_step = next((i for i, r in enumerate(ep_radius)
if r <= COMPACT_RADIUS), None)
min_radius = min(ep_radius)
min_com_dist = min(ep_com_dist)
print(f" ep {ep+1:>3} steps={step:>5} penned={n_pen}/{n_sheep}"
f" min_r={min_radius:.1f}m"
f" min_com={min_com_dist:.1f}m"
f" compact@step={compact_step if compact_step is not None else 'NEVER'}"
f" [{mode}]")
all_ep_data.append(dict(
ep=ep, radius=ep_radius, com_dist=ep_com_dist,
dog_x=ep_dog_x, dog_y=ep_dog_y, n_penned=ep_n_penned,
steps=step, mode=mode, success=success,
))
# ── per-episode time-series plot ──────────────────────────────────
if args.plot_dir or (not args.render and ep < 5):
fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
t = np.arange(len(ep_radius))
axes[0].plot(t, ep_radius, color="steelblue", label="flock radius (m)")
axes[0].axhline(COMPACT_RADIUS, color="orange", linestyle="--",
label=f"compact threshold ({COMPACT_RADIUS}m)")
if compact_step is not None:
axes[0].axvline(compact_step, color="green", linestyle=":",
alpha=0.6, label=f"first compact (step {compact_step})")
axes[0].set_ylabel("radius (m)")
axes[0].legend(fontsize=8)
axes[0].set_title(f"ep {ep+1} | n_sheep={n_sheep} | {mode}")
axes[1].plot(t, ep_com_dist, color="tomato", label="COM-to-pen dist (m)")
axes[1].set_ylabel("COM-to-pen (m)")
axes[1].set_xlabel("step")
axes[1].legend(fontsize=8)
plt.tight_layout()
if args.plot_dir:
fig.savefig(os.path.join(args.plot_dir, f"ep{ep+1:03d}_{mode}.png"),
dpi=100)
plt.close(fig)
else:
plt.show(block=False)
plt.pause(0.5)
env.close()
# ── summary ──────────────────────────────────────────────────────────────
print("\n" + "=" * 55)
print(f" Model : {args.model}")
print(f" n_sheep : {args.n_sheep} episodes : {args.episodes}")
print("-" * 55)
total = sum(failure_counts.values())
for mode, cnt in sorted(failure_counts.items(), key=lambda x: -x[1]):
bar = "" * cnt
print(f" {mode:<26} {cnt:>3}/{total} {bar}")
print("-" * 55)
never_compact = failure_counts.get("NEVER_COMPACT", 0)
cant_drive = failure_counts.get("COMPACT_CANT_DRIVE", 0)
partial = sum(v for k, v in failure_counts.items() if k.startswith("PARTIAL"))
successes = failure_counts.get("SUCCESS", 0)
print(f"\n Diagnosis:")
if never_compact / total > 0.5:
print(" ► COLLECT problem: dog rarely compacts the flock.")
print(" → Phase-gate W_DRIVE, increase W_COLLECT, check alignment reward.")
if cant_drive / total > 0.3:
print(" ► DRIVE problem: flock compacts but doesn't reach pen.")
print(" → Check dog alignment, pen direction, W_DRIVE magnitude.")
if partial / total > 0.3:
print(" ► PARTIAL problem: some sheep penned, stragglers remain.")
print(" → Flock splits; need better straggler-chasing behavior.")
if successes / total > 0.5:
print(" ► Mostly working! Fine-tune for consistency.")
print("=" * 55)
if __name__ == "__main__":
main()