From e0426bf320669f288467700146a81a418c7d6923 Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Fri, 24 Apr 2026 16:46:02 +0100 Subject: [PATCH] Sheep training flock of 10 fix? --- training/visualize.py | 316 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 316 insertions(+) create mode 100644 training/visualize.py diff --git a/training/visualize.py b/training/visualize.py new file mode 100644 index 0000000..ea616e9 --- /dev/null +++ b/training/visualize.py @@ -0,0 +1,316 @@ +""" +Single-episode visualization for the herding policy. + +Outputs (all saved to --out-dir): + trajectory.png — full field view: dog path + every sheep path + timeseries.png — radius, per-sheep pen distance, action magnitude, reward + episode.gif — animated replay (slow enough to read) + +Run with no model to watch a RANDOM policy (useful baseline): + python visualize.py --random --n-sheep 3 --out-dir vis_random/ + +Usage: + python visualize.py \\ + --model runs/ppo_consolidation/final_model.zip \\ + --vecnorm runs/ppo_consolidation/vecnorm.pkl \\ + --n-sheep 3 --out-dir vis_out/ +""" + +import argparse +import os +import math +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import matplotlib.animation as animation +from matplotlib.collections import LineCollection +from stable_baselines3 import PPO +from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize +from herding_env import HerdingEnv + + +# ── colours ────────────────────────────────────────────────────────────────── +SHEEP_COLORS = [ + "#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", + "#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62", +] +DOG_COLOR = "#4e342e" +PEN_COLOR = "#ffe082" +FIELD_COLOR = "#dcedc8" + + +def make_env(n_sheep, max_steps, seed=42): + def _init(): + env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps) + env.reset(seed=seed) + return env + return _init + + +def run_episode(model, env, n_sheep, max_steps): + """Run one deterministic episode; return recorded history.""" + obs = env.reset() + inner = env.envs[0] + done = False + + dog_xs, dog_ys = [], [] + sheep_xs = [[] for _ in range(n_sheep)] + sheep_ys = [[] for _ in range(n_sheep)] + radii = [] + pen_dists = [[] for _ in range(n_sheep)] + action_mags = [] + rewards = [] + penned_at = [None] * n_sheep # step when each sheep was penned + + step = 0 + while not done: + if model is None: + action = env.action_space.sample()[np.newaxis] + else: + action, _ = model.predict(obs, deterministic=True) + + obs, reward, dones, infos = env.step(action) + done = dones[0] + step += 1 + + dx, dy = float(inner.dog_pos[0]), float(inner.dog_pos[1]) + dog_xs.append(dx); dog_ys.append(dy) + + com, radius, _ = inner._flock_stats() + radii.append(radius) + rewards.append(float(reward[0])) + + act = action[0] + action_mags.append(float(np.linalg.norm(act))) + + for i in range(n_sheep): + sx, sy = float(inner.sheep_pos[i][0]), float(inner.sheep_pos[i][1]) + sheep_xs[i].append(sx) + sheep_ys[i].append(sy) + pen_dists[i].append(float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER))) + if inner.penned[i] and penned_at[i] is None: + penned_at[i] = step + + info = infos[0] + n_penned = info.get("n_penned", 0) + success = n_penned == n_sheep + + return dict( + dog_xs=dog_xs, dog_ys=dog_ys, + sheep_xs=sheep_xs, sheep_ys=sheep_ys, + radii=radii, pen_dists=pen_dists, + action_mags=action_mags, rewards=rewards, + penned_at=penned_at, + n_penned=n_penned, n_sheep=n_sheep, + success=success, steps=step, + ) + + +# ── plot helpers ───────────────────────────────────────────────────────────── + +def draw_field(ax): + ax.set_xlim(-16, 16); ax.set_ylim(-16, 16) + ax.set_aspect("equal"); ax.set_facecolor(FIELD_COLOR) + ax.add_patch(mpatches.Rectangle((-15,-15), 30, 30, + fill=False, edgecolor="#795548", lw=2)) + ax.add_patch(mpatches.Rectangle((10,-15), 3, 7, + facecolor=PEN_COLOR, edgecolor="#795548", lw=2)) + ax.text(11.5, -11.5, "pen", ha="center", va="center", + fontsize=8, color="#795548") + + +def faded_path(ax, xs, ys, color, lw=1.5, label=None): + """Draw a path with alpha fading from start (transparent) to end (opaque).""" + n = len(xs) + if n < 2: + return + points = np.array([xs, ys]).T.reshape(-1, 1, 2) + segs = np.concatenate([points[:-1], points[1:]], axis=1) + alphas = np.linspace(0.15, 1.0, len(segs)) + colors = [(*matplotlib.colors.to_rgb(color), a) for a in alphas] + lc = LineCollection(segs, colors=colors, linewidth=lw) + ax.add_collection(lc) + if label: + ax.plot([], [], color=color, lw=lw, label=label) + + +# ── main plots ──────────────────────────────────────────────────────────────── + +def plot_trajectory(hist, out_path): + fig, ax = plt.subplots(figsize=(7, 7)) + draw_field(ax) + + # Sheep paths + for i in range(hist["n_sheep"]): + c = SHEEP_COLORS[i % len(SHEEP_COLORS)] + xs, ys = hist["sheep_xs"][i], hist["sheep_ys"][i] + faded_path(ax, xs, ys, c, lw=1.2, label=f"sheep {i+1}") + ax.plot(xs[0], ys[0], "o", color=c, ms=7, zorder=4) + pa = hist["penned_at"][i] + end = pa if pa is not None else -1 + ax.plot(xs[end], ys[end], "*", color=c, ms=11, zorder=5) + + # Dog path + faded_path(ax, hist["dog_xs"], hist["dog_ys"], DOG_COLOR, lw=2.0, label="dog") + ax.plot(hist["dog_xs"][0], hist["dog_ys"][0], "s", color=DOG_COLOR, ms=10, zorder=5) + ax.plot(hist["dog_xs"][-1], hist["dog_ys"][-1], "D", color=DOG_COLOR, ms=10, zorder=5) + + result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']} penned)" + ax.set_title(f"Trajectory — {result} — {hist['steps']} steps", fontsize=12) + ax.legend(loc="upper left", fontsize=8) + plt.tight_layout() + fig.savefig(out_path, dpi=120) + plt.close(fig) + print(f" saved {out_path}") + + +def plot_timeseries(hist, out_path): + t = np.arange(hist["steps"]) + fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True) + + # 1. Flock radius + axes[0].plot(t, hist["radii"], color="steelblue") + axes[0].axhline(5.0, color="orange", ls="--", lw=1, label="compact threshold (5m)") + axes[0].set_ylabel("flock radius (m)") + axes[0].legend(fontsize=8) + axes[0].set_title("Flock radius — goal: get below 5m") + + # 2. Per-sheep distance to pen + for i in range(hist["n_sheep"]): + c = SHEEP_COLORS[i % len(SHEEP_COLORS)] + axes[1].plot(t, hist["pen_dists"][i], color=c, lw=1, label=f"sheep {i+1}") + pa = hist["penned_at"][i] + if pa is not None: + axes[1].axvline(pa, color=c, ls=":", lw=1) + axes[1].set_ylabel("dist to pen (m)") + axes[1].legend(fontsize=7, ncol=min(hist["n_sheep"], 5)) + axes[1].set_title("Per-sheep distance to pen — goal: all reach 0") + + # 3. Action magnitude (how fast dog is moving) + axes[2].plot(t, hist["action_mags"], color="tomato", lw=1) + axes[2].axhline(1.0, color="gray", ls="--", lw=1, label="max") + axes[2].set_ylabel("action ||(vx,vy)||") + axes[2].set_ylim(0, 1.5) + axes[2].set_title("Dog action magnitude — 0=stopped, 1=full speed") + axes[2].legend(fontsize=8) + + # 4. Reward per step + axes[3].plot(t, hist["rewards"], color="purple", lw=1, alpha=0.7) + axes[3].axhline(0, color="black", lw=0.5) + axes[3].set_ylabel("reward") + axes[3].set_xlabel("step") + axes[3].set_title("Reward per step") + + result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']} penned)" + fig.suptitle(f"n_sheep={hist['n_sheep']} {result} {hist['steps']} steps", fontsize=13) + plt.tight_layout() + fig.savefig(out_path, dpi=120) + plt.close(fig) + print(f" saved {out_path}") + + +def save_gif(hist, out_path, fps=15, skip=5): + """Animated replay, every `skip` steps.""" + n = hist["n_sheep"] + idxs = list(range(0, hist["steps"], skip)) + + fig, ax = plt.subplots(figsize=(6, 6)) + + def _frame(k): + ax.clear() + draw_field(ax) + t = idxs[k] + + for i in range(n): + c = SHEEP_COLORS[i % len(SHEEP_COLORS)] + s0 = max(0, t - 30) + ax.plot(hist["sheep_xs"][i][s0:t+1], + hist["sheep_ys"][i][s0:t+1], + color=c, lw=0.8, alpha=0.5) + color = "#ff69b4" if (hist["penned_at"][i] is not None + and t >= hist["penned_at"][i]) else c + ax.plot(hist["sheep_xs"][i][t], hist["sheep_ys"][i][t], + "o", color=color, ms=10, zorder=4, + markeredgecolor="#555", markeredgewidth=1) + + s0 = max(0, t - 30) + ax.plot(hist["dog_xs"][s0:t+1], hist["dog_ys"][s0:t+1], + color=DOG_COLOR, lw=1.5, alpha=0.6) + ax.plot(hist["dog_xs"][t], hist["dog_ys"][t], + "s", color=DOG_COLOR, ms=13, zorder=5, + markeredgecolor="black", markeredgewidth=1.5) + + r = hist["radii"][t] + ax.set_title(f"step {t}/{hist['steps']} radius={r:.1f}m " + f"penned={hist['n_penned'] if t==hist['steps']-1 else '?'}/{n}", + fontsize=10) + + ani = animation.FuncAnimation(fig, _frame, frames=len(idxs), interval=1000//fps) + ani.save(out_path, writer="pillow", fps=fps) + plt.close(fig) + print(f" saved {out_path}") + + +# ── entry point ─────────────────────────────────────────────────────────────── + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--model", default=None, help="Model .zip (omit for random policy)") + p.add_argument("--vecnorm", default=None) + p.add_argument("--n-sheep", type=int, default=3) + p.add_argument("--max-steps", type=int, default=2000) + p.add_argument("--seed", type=int, default=42) + p.add_argument("--out-dir", default="vis_out") + p.add_argument("--random", action="store_true", + help="Use random policy (baseline comparison)") + p.add_argument("--gif-fps", type=int, default=15) + p.add_argument("--gif-skip", type=int, default=5, + help="Render every Nth step in the GIF") + p.add_argument("--no-gif", action="store_true") + return p.parse_args() + + +def main(): + args = parse_args() + os.makedirs(args.out_dir, exist_ok=True) + + raw = DummyVecEnv([make_env(args.n_sheep, args.max_steps, args.seed)]) + + if args.random or args.model is None: + print("Using RANDOM policy") + env = raw + model = None + else: + if args.vecnorm: + env = VecNormalize.load(args.vecnorm, raw) + env.training = False + env.norm_reward = False + else: + env = raw + model = PPO.load(args.model, env=env) + print(f"Loaded model: {args.model}") + + print(f"Running episode n_sheep={args.n_sheep} seed={args.seed} ...") + hist = run_episode(model, env, args.n_sheep, args.max_steps) + + result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']} penned)" + print(f"Episode done: {result} steps={hist['steps']}") + print(f" min radius : {min(hist['radii']):.2f} m") + print(f" mean reward: {np.mean(hist['rewards']):.4f}") + print(f" mean action: {np.mean(hist['action_mags']):.3f}") + + env.close() + + plot_trajectory(hist, os.path.join(args.out_dir, "trajectory.png")) + plot_timeseries(hist, os.path.join(args.out_dir, "timeseries.png")) + if not args.no_gif: + save_gif(hist, os.path.join(args.out_dir, "episode.gif"), + fps=args.gif_fps, skip=args.gif_skip) + + print(f"\nAll outputs saved to {args.out_dir}/") + + +if __name__ == "__main__": + main()