""" Episode-level diagnostics for the herding policy. Runs N episodes and for each one tracks: - flock radius over time - COM-to-pen distance over time - dog position over time - when (if ever) the flock first became compact - failure mode classification Then produces: 1. Console summary of failure modes 2. Per-episode time-series plots (radius + com_dist) 3. Optional rendered playback of the worst episodes Usage ----- python diagnose.py --model runs/ppo_consolidation/final_model.zip \ --vecnorm runs/ppo_consolidation/vecnorm.pkl \ --n-sheep 5 --episodes 20 # Watch the policy live (first episode rendered): python diagnose.py ... --render # Save plots to a directory instead of showing interactively: python diagnose.py ... --plot-dir debug_plots/ """ import argparse import os import numpy as np import matplotlib import matplotlib.pyplot as plt import matplotlib.patches as mpatches from stable_baselines3 import PPO from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize from herding_env import HerdingEnv # ── failure mode constants ──────────────────────────────────────────────────── COMPACT_RADIUS = 5.0 # must match DRIVE_GATE_RADIUS in herding_env.py def classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success): if success: return "SUCCESS" if min(ep_radius) > COMPACT_RADIUS: return "NEVER_COMPACT" # flock was always too scattered first_compact = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS) min_com_after = min(ep_com_dist[first_compact:]) pen_close = 3.0 # COM within 3m of pen counts as "got close" if min_com_after > pen_close: return "COMPACT_CANT_DRIVE" # compacted but never drove to pen if n_penned == 0: return "DROVE_NO_SHEEP" # got near pen, nothing went in return f"PARTIAL_{n_penned}of{n_sheep}" # some in, not all # ── main ───────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model", required=True) p.add_argument("--vecnorm", default=None) p.add_argument("--n-sheep", type=int, default=5) p.add_argument("--episodes", type=int, default=20) p.add_argument("--max-steps", type=int, default=4000) p.add_argument("--render", action="store_true", help="Show matplotlib animation of the first episode") p.add_argument("--plot-dir", default=None, help="Save time-series plots here (one per episode)") p.add_argument("--seed", type=int, default=0) return p.parse_args() def make_env(n_sheep, max_steps, render_mode=None): def _init(): return HerdingEnv(n_sheep=n_sheep, max_steps=max_steps, render_mode=render_mode) return _init def main(): args = parse_args() if args.plot_dir: os.makedirs(args.plot_dir, exist_ok=True) matplotlib.use("Agg") render_mode = "human" if args.render else None raw_env = DummyVecEnv([make_env(args.n_sheep, args.max_steps, render_mode)]) if args.vecnorm: env = VecNormalize.load(args.vecnorm, raw_env) env.training = False env.norm_reward = False else: env = raw_env model = PPO.load(args.model, env=env) failure_counts = {} all_ep_data = [] for ep in range(args.episodes): obs = env.reset() done = False step = 0 ep_radius = [] ep_com_dist = [] ep_dog_x = [] ep_dog_y = [] ep_n_penned = [] while not done: action, _ = model.predict(obs, deterministic=True) obs, _, dones, infos = env.step(action) done = dones[0] step += 1 inner = env.envs[0] if hasattr(env, "envs") else env.venv.envs[0] com, radius, _ = inner._flock_stats() com_dist = float(np.linalg.norm(com - inner.PEN_CENTER)) n_penned = int(inner.penned[:inner.n_sheep].sum()) ep_radius.append(radius) ep_com_dist.append(com_dist) ep_dog_x.append(float(inner.dog_pos[0])) ep_dog_y.append(float(inner.dog_pos[1])) ep_n_penned.append(n_penned) info = infos[0] n_pen = info.get("n_penned", 0) n_sheep = info.get("n_sheep", args.n_sheep) success = n_pen == n_sheep mode = classify_failure(ep_radius, ep_com_dist, n_pen, n_sheep, success) failure_counts[mode] = failure_counts.get(mode, 0) + 1 compact_step = next((i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS), None) min_radius = min(ep_radius) min_com_dist = min(ep_com_dist) print(f" ep {ep+1:>3} steps={step:>5} penned={n_pen}/{n_sheep}" f" min_r={min_radius:.1f}m" f" min_com={min_com_dist:.1f}m" f" compact@step={compact_step if compact_step is not None else 'NEVER'}" f" [{mode}]") all_ep_data.append(dict( ep=ep, radius=ep_radius, com_dist=ep_com_dist, dog_x=ep_dog_x, dog_y=ep_dog_y, n_penned=ep_n_penned, steps=step, mode=mode, success=success, )) # ── per-episode time-series plot ────────────────────────────────── if args.plot_dir or (not args.render and ep < 5): fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True) t = np.arange(len(ep_radius)) axes[0].plot(t, ep_radius, color="steelblue", label="flock radius (m)") axes[0].axhline(COMPACT_RADIUS, color="orange", linestyle="--", label=f"compact threshold ({COMPACT_RADIUS}m)") if compact_step is not None: axes[0].axvline(compact_step, color="green", linestyle=":", alpha=0.6, label=f"first compact (step {compact_step})") axes[0].set_ylabel("radius (m)") axes[0].legend(fontsize=8) axes[0].set_title(f"ep {ep+1} | n_sheep={n_sheep} | {mode}") axes[1].plot(t, ep_com_dist, color="tomato", label="COM-to-pen dist (m)") axes[1].set_ylabel("COM-to-pen (m)") axes[1].set_xlabel("step") axes[1].legend(fontsize=8) plt.tight_layout() if args.plot_dir: fig.savefig(os.path.join(args.plot_dir, f"ep{ep+1:03d}_{mode}.png"), dpi=100) plt.close(fig) else: plt.show(block=False) plt.pause(0.5) env.close() # ── summary ────────────────────────────────────────────────────────────── print("\n" + "=" * 55) print(f" Model : {args.model}") print(f" n_sheep : {args.n_sheep} episodes : {args.episodes}") print("-" * 55) total = sum(failure_counts.values()) for mode, cnt in sorted(failure_counts.items(), key=lambda x: -x[1]): bar = "█" * cnt print(f" {mode:<26} {cnt:>3}/{total} {bar}") print("-" * 55) never_compact = failure_counts.get("NEVER_COMPACT", 0) cant_drive = failure_counts.get("COMPACT_CANT_DRIVE", 0) partial = sum(v for k, v in failure_counts.items() if k.startswith("PARTIAL")) successes = failure_counts.get("SUCCESS", 0) print(f"\n Diagnosis:") if never_compact / total > 0.5: print(" ► COLLECT problem: dog rarely compacts the flock.") print(" → Phase-gate W_DRIVE, increase W_COLLECT, check alignment reward.") if cant_drive / total > 0.3: print(" ► DRIVE problem: flock compacts but doesn't reach pen.") print(" → Check dog alignment, pen direction, W_DRIVE magnitude.") if partial / total > 0.3: print(" ► PARTIAL problem: some sheep penned, stragglers remain.") print(" → Flock splits; need better straggler-chasing behavior.") if successes / total > 0.5: print(" ► Mostly working! Fine-tune for consistency.") print("=" * 55) if __name__ == "__main__": main()