""" All visualization for the herding policy: trajectory plots, timeseries plots, success-rate bar chart, and animated GIFs. Used both by train.py (auto-rendered after each curriculum stage) and as a CLI to render a fresh episode against a saved model. CLI usage: python viz.py --run-dir runs/v1 --n-sheep 5 python viz.py --run-dir runs/v1 --n-sheep 10 --no-gif python viz.py --model runs/v1/final_model.zip --vecnorm runs/v1/vecnorm.pkl \\ --n-sheep 3 --out-dir vis_v1_3sheep """ import argparse import os import json from copy import deepcopy import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches import matplotlib.animation as animation from matplotlib.collections import LineCollection import numpy as np from stable_baselines3 import PPO from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize from herding_env import HerdingEnv # ── Palette ────────────────────────────────────────────────────────────────── SHEEP_COLORS = [ "#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62", ] DOG_COLOR = "#4e342e" # ── Common drawing primitives ──────────────────────────────────────────────── def draw_field(ax): ax.set_xlim(-16, 16) ax.set_ylim(-16, 16) ax.set_aspect("equal") ax.set_facecolor("#dcedc8") ax.add_patch(mpatches.Rectangle( (-15, -15), 30, 30, fill=False, edgecolor="#795548", lw=2)) ax.add_patch(mpatches.Rectangle( (10, -15), 3, 7, facecolor="#ffe082", edgecolor="#795548", lw=2)) ax.text(11.5, -11.5, "pen", ha="center", va="center", fontsize=8, color="#795548") def faded_path(ax, xs, ys, color, lw=1.5, label=None): n = len(xs) if n < 2: return points = np.array([xs, ys]).T.reshape(-1, 1, 2) segs = np.concatenate([points[:-1], points[1:]], axis=1) alphas = np.linspace(0.15, 1.0, len(segs)) colors = [(*matplotlib.colors.to_rgb(color), a) for a in alphas] ax.add_collection(LineCollection(segs, colors=colors, linewidth=lw)) if label: ax.plot([], [], color=color, lw=lw, label=label) # ── Episode rollout ────────────────────────────────────────────────────────── def make_eval_env(n_sheep, seed, max_steps, reward_cfg=None): def _init(): env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps, reward_cfg=reward_cfg) env.reset(seed=seed) return env return _init def run_and_record(model, vn_template, n_sheep, max_steps, reward_cfg=None, seed=42): """Run one deterministic episode and return full trajectory history.""" raw = DummyVecEnv([make_eval_env(n_sheep, seed, max_steps, reward_cfg)]) vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False) vn.obs_rms = deepcopy(vn_template.obs_rms) vn.ret_rms = deepcopy(vn_template.ret_rms) obs = vn.reset() inner = vn.envs[0] done = False dog_xs, dog_ys = [], [] sheep_xs = [[] for _ in range(n_sheep)] sheep_ys = [[] for _ in range(n_sheep)] sheep_penned = [[] for _ in range(n_sheep)] radii = [] pen_dists = [[] for _ in range(n_sheep)] action_mags = [] rewards = [] penned_at = [None] * n_sheep step = 0 while not done: action, _ = model.predict(obs, deterministic=True) obs, reward, dones, infos = vn.step(action) done = dones[0] step += 1 dog_xs.append(float(inner.dog_pos[0])) dog_ys.append(float(inner.dog_pos[1])) com, radius, _ = inner._flock_stats() radii.append(radius) rewards.append(float(reward[0])) action_mags.append(float(np.linalg.norm(action[0]))) for i in range(n_sheep): sheep_xs[i].append(float(inner.sheep_pos[i][0])) sheep_ys[i].append(float(inner.sheep_pos[i][1])) sheep_penned[i].append(bool(inner.penned[i])) pen_dists[i].append( float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER))) if inner.penned[i] and penned_at[i] is None: penned_at[i] = step n_penned = infos[0].get("n_penned", 0) vn.close() return dict( dog_xs=dog_xs, dog_ys=dog_ys, sheep_xs=sheep_xs, sheep_ys=sheep_ys, sheep_penned=sheep_penned, radii=radii, pen_dists=pen_dists, action_mags=action_mags, rewards=rewards, penned_at=penned_at, n_penned=n_penned, n_sheep=n_sheep, success=n_penned == n_sheep, steps=step, ) # ── Static plots ───────────────────────────────────────────────────────────── def plot_trajectory(hist, out_path): fig, ax = plt.subplots(figsize=(7, 7)) draw_field(ax) for i in range(hist["n_sheep"]): c = SHEEP_COLORS[i % len(SHEEP_COLORS)] xs, ys = hist["sheep_xs"][i], hist["sheep_ys"][i] faded_path(ax, xs, ys, c, lw=1.2, label=f"sheep {i+1}") ax.plot(xs[0], ys[0], "o", color=c, ms=7, zorder=4) end = hist["penned_at"][i] if hist["penned_at"][i] is not None else -1 ax.plot(xs[end], ys[end], "*", color=c, ms=11, zorder=5) faded_path(ax, hist["dog_xs"], hist["dog_ys"], DOG_COLOR, lw=2.0, label="dog") ax.plot(hist["dog_xs"][0], hist["dog_ys"][0], "s", color=DOG_COLOR, ms=10, zorder=5) ax.plot(hist["dog_xs"][-1], hist["dog_ys"][-1], "D", color=DOG_COLOR, ms=10, zorder=5) result = ("SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})") ax.set_title(f"n={hist['n_sheep']} {result} {hist['steps']} steps", fontsize=12) ax.legend(loc="upper left", fontsize=8) plt.tight_layout() fig.savefig(out_path, dpi=120) plt.close(fig) def plot_timeseries(hist, out_path): t = np.arange(hist["steps"]) fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True) axes[0].plot(t, hist["radii"], color="steelblue") axes[0].axhline(5.0, color="orange", ls="--", lw=1, label="compact (5m)") axes[0].set_ylabel("flock radius (m)") axes[0].legend(fontsize=8) axes[0].set_title("Flock radius") for i in range(hist["n_sheep"]): c = SHEEP_COLORS[i % len(SHEEP_COLORS)] axes[1].plot(t, hist["pen_dists"][i], color=c, lw=1, label=f"sheep {i+1}") if hist["penned_at"][i] is not None: axes[1].axvline(hist["penned_at"][i], color=c, ls=":", lw=1) axes[1].set_ylabel("dist to pen (m)") axes[1].legend(fontsize=7, ncol=min(hist["n_sheep"], 5)) axes[1].set_title("Per-sheep distance to pen") axes[2].plot(t, hist["action_mags"], color="tomato", lw=1) axes[2].axhline(1.0, color="gray", ls="--", lw=1, label="max") axes[2].set_ylabel("action ||(vx,vy)||") axes[2].set_ylim(0, 1.5) axes[2].set_title("Dog action magnitude") axes[2].legend(fontsize=8) axes[3].plot(t, hist["rewards"], color="purple", lw=1, alpha=0.7) axes[3].axhline(0, color="black", lw=0.5) axes[3].set_ylabel("reward") axes[3].set_xlabel("step") axes[3].set_title("Reward per step") result = ("SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})") fig.suptitle(f"n_sheep={hist['n_sheep']} {result} {hist['steps']} steps", fontsize=13) plt.tight_layout() fig.savefig(out_path, dpi=120) plt.close(fig) def plot_success_rate(stage_results, out_path): fig, ax = plt.subplots(figsize=(8, 4)) ns = [r["n_sheep"] for r in stage_results] srs = [r["sr"] * 100 for r in stage_results] bars = ax.bar(ns, srs, color="steelblue", edgecolor="white") ax.set_xlabel("Sheep count") ax.set_ylabel("Success rate (%)") ax.set_ylim(0, 105) ax.axhline(90, color="orange", ls="--", lw=1, label="90% target") for bar, sr in zip(bars, srs): ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 1, f"{sr:.0f}%", ha="center", fontsize=9) ax.legend() ax.set_title("Evaluation success rate per sheep count") plt.tight_layout() fig.savefig(out_path, dpi=120) plt.close(fig) # ── Animated GIF ───────────────────────────────────────────────────────────── def save_episode_gif(hist, out_path, fps=20, skip=3): """Render hist as an animated GIF. `skip` keeps every Nth frame (smaller file).""" n_sheep = hist["n_sheep"] frames = list(range(0, hist["steps"], max(1, skip))) if frames[-1] != hist["steps"] - 1: frames.append(hist["steps"] - 1) fig, ax = plt.subplots(figsize=(6, 6)) draw_field(ax) title = ax.text(0, 16.5, "", ha="center", fontsize=11) dog_marker, = ax.plot([], [], "s", color=DOG_COLOR, ms=12, markeredgecolor="black", markeredgewidth=1.5, zorder=5) sheep_markers = [] for i in range(n_sheep): c = SHEEP_COLORS[i % len(SHEEP_COLORS)] m, = ax.plot([], [], "o", color=c, ms=10, markeredgecolor="#333", markeredgewidth=1, zorder=4) sheep_markers.append(m) dog_trail, = ax.plot([], [], color=DOG_COLOR, lw=1.0, alpha=0.5) def update(k): title.set_text( f"n={n_sheep} step {k+1}/{hist['steps']} " f"penned {sum(hist['sheep_penned'][i][k] for i in range(n_sheep))}/{n_sheep}") dog_marker.set_data([hist["dog_xs"][k]], [hist["dog_ys"][k]]) dog_trail.set_data(hist["dog_xs"][:k+1], hist["dog_ys"][:k+1]) for i, m in enumerate(sheep_markers): m.set_data([hist["sheep_xs"][i][k]], [hist["sheep_ys"][i][k]]) penned = hist["sheep_penned"][i][k] m.set_color("deeppink" if penned else SHEEP_COLORS[i % len(SHEEP_COLORS)]) return [title, dog_marker, dog_trail, *sheep_markers] anim = animation.FuncAnimation( fig, update, frames=frames, interval=1000 / fps, blit=False) anim.save(out_path, writer=animation.PillowWriter(fps=fps), dpi=80) plt.close(fig) # ── CLI ────────────────────────────────────────────────────────────────────── def _resolve_paths(args): if args.run_dir: model_path = os.path.join(args.run_dir, "final_model.zip") vn_path = os.path.join(args.run_dir, "vecnorm.pkl") cfg_path = os.path.join(args.run_dir, "config.json") else: model_path = args.model vn_path = args.vecnorm cfg_path = args.config return model_path, vn_path, cfg_path def main(): p = argparse.ArgumentParser( description="Render trajectory + timeseries + GIF for a saved policy.") p.add_argument("--run-dir", type=str, default=None, help="Run directory containing final_model.zip + vecnorm.pkl + config.json") p.add_argument("--model", type=str, default=None) p.add_argument("--vecnorm", type=str, default=None) p.add_argument("--config", type=str, default=None) p.add_argument("--n-sheep", type=int, default=3) p.add_argument("--seed", type=int, default=42) p.add_argument("--max-steps", type=int, default=2500) p.add_argument("--out-dir", type=str, default=None) p.add_argument("--no-gif", action="store_true", help="Skip the animated GIF (PNG-only is faster).") p.add_argument("--gif-fps", type=int, default=20) p.add_argument("--gif-skip", type=int, default=3) args = p.parse_args() model_path, vn_path, cfg_path = _resolve_paths(args) if not (model_path and vn_path): p.error("either --run-dir or both --model and --vecnorm are required") rcfg = None if cfg_path and os.path.exists(cfg_path): with open(cfg_path) as f: cfg = json.load(f) rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)} out_dir = args.out_dir or os.path.join( os.path.dirname(os.path.abspath(model_path)), f"vis_{args.n_sheep}s") os.makedirs(out_dir, exist_ok=True) print(f"Loading model: {model_path}") print(f"Loading vecnorm: {vn_path}") model = PPO.load(model_path, device="cpu") raw = DummyVecEnv([make_eval_env(args.n_sheep, args.seed, args.max_steps, rcfg)]) vn = VecNormalize.load(vn_path, raw) print(f"Rolling out n_sheep={args.n_sheep} (seed={args.seed})...") hist = run_and_record(model, vn, args.n_sheep, args.max_steps, reward_cfg=rcfg, seed=args.seed) result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})" print(f" {result} in {hist['steps']} steps") plot_trajectory(hist, os.path.join(out_dir, "trajectory.png")) plot_timeseries(hist, os.path.join(out_dir, "timeseries.png")) print(f" saved trajectory.png + timeseries.png to {out_dir}/") if not args.no_gif: gif_path = os.path.join(out_dir, "episode.gif") print(f" rendering GIF (fps={args.gif_fps}, skip={args.gif_skip})...") save_episode_gif(hist, gif_path, fps=args.gif_fps, skip=args.gif_skip) print(f" saved {gif_path}") if __name__ == "__main__": main()