Sheep training flock of 10 fix?
This commit is contained in:
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Single-episode visualization for the herding policy.
|
||||
|
||||
Outputs (all saved to --out-dir):
|
||||
trajectory.png — full field view: dog path + every sheep path
|
||||
timeseries.png — radius, per-sheep pen distance, action magnitude, reward
|
||||
episode.gif — animated replay (slow enough to read)
|
||||
|
||||
Run with no model to watch a RANDOM policy (useful baseline):
|
||||
python visualize.py --random --n-sheep 3 --out-dir vis_random/
|
||||
|
||||
Usage:
|
||||
python visualize.py \\
|
||||
--model runs/ppo_consolidation/final_model.zip \\
|
||||
--vecnorm runs/ppo_consolidation/vecnorm.pkl \\
|
||||
--n-sheep 3 --out-dir vis_out/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.animation as animation
|
||||
from matplotlib.collections import LineCollection
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||
from herding_env import HerdingEnv
|
||||
|
||||
|
||||
# ── colours ──────────────────────────────────────────────────────────────────
|
||||
SHEEP_COLORS = [
|
||||
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00",
|
||||
"#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62",
|
||||
]
|
||||
DOG_COLOR = "#4e342e"
|
||||
PEN_COLOR = "#ffe082"
|
||||
FIELD_COLOR = "#dcedc8"
|
||||
|
||||
|
||||
def make_env(n_sheep, max_steps, seed=42):
|
||||
def _init():
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
|
||||
def run_episode(model, env, n_sheep, max_steps):
|
||||
"""Run one deterministic episode; return recorded history."""
|
||||
obs = env.reset()
|
||||
inner = env.envs[0]
|
||||
done = False
|
||||
|
||||
dog_xs, dog_ys = [], []
|
||||
sheep_xs = [[] for _ in range(n_sheep)]
|
||||
sheep_ys = [[] for _ in range(n_sheep)]
|
||||
radii = []
|
||||
pen_dists = [[] for _ in range(n_sheep)]
|
||||
action_mags = []
|
||||
rewards = []
|
||||
penned_at = [None] * n_sheep # step when each sheep was penned
|
||||
|
||||
step = 0
|
||||
while not done:
|
||||
if model is None:
|
||||
action = env.action_space.sample()[np.newaxis]
|
||||
else:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
|
||||
obs, reward, dones, infos = env.step(action)
|
||||
done = dones[0]
|
||||
step += 1
|
||||
|
||||
dx, dy = float(inner.dog_pos[0]), float(inner.dog_pos[1])
|
||||
dog_xs.append(dx); dog_ys.append(dy)
|
||||
|
||||
com, radius, _ = inner._flock_stats()
|
||||
radii.append(radius)
|
||||
rewards.append(float(reward[0]))
|
||||
|
||||
act = action[0]
|
||||
action_mags.append(float(np.linalg.norm(act)))
|
||||
|
||||
for i in range(n_sheep):
|
||||
sx, sy = float(inner.sheep_pos[i][0]), float(inner.sheep_pos[i][1])
|
||||
sheep_xs[i].append(sx)
|
||||
sheep_ys[i].append(sy)
|
||||
pen_dists[i].append(float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
|
||||
if inner.penned[i] and penned_at[i] is None:
|
||||
penned_at[i] = step
|
||||
|
||||
info = infos[0]
|
||||
n_penned = info.get("n_penned", 0)
|
||||
success = n_penned == n_sheep
|
||||
|
||||
return dict(
|
||||
dog_xs=dog_xs, dog_ys=dog_ys,
|
||||
sheep_xs=sheep_xs, sheep_ys=sheep_ys,
|
||||
radii=radii, pen_dists=pen_dists,
|
||||
action_mags=action_mags, rewards=rewards,
|
||||
penned_at=penned_at,
|
||||
n_penned=n_penned, n_sheep=n_sheep,
|
||||
success=success, steps=step,
|
||||
)
|
||||
|
||||
|
||||
# ── plot helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def draw_field(ax):
|
||||
ax.set_xlim(-16, 16); ax.set_ylim(-16, 16)
|
||||
ax.set_aspect("equal"); ax.set_facecolor(FIELD_COLOR)
|
||||
ax.add_patch(mpatches.Rectangle((-15,-15), 30, 30,
|
||||
fill=False, edgecolor="#795548", lw=2))
|
||||
ax.add_patch(mpatches.Rectangle((10,-15), 3, 7,
|
||||
facecolor=PEN_COLOR, edgecolor="#795548", lw=2))
|
||||
ax.text(11.5, -11.5, "pen", ha="center", va="center",
|
||||
fontsize=8, color="#795548")
|
||||
|
||||
|
||||
def faded_path(ax, xs, ys, color, lw=1.5, label=None):
|
||||
"""Draw a path with alpha fading from start (transparent) to end (opaque)."""
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return
|
||||
points = np.array([xs, ys]).T.reshape(-1, 1, 2)
|
||||
segs = np.concatenate([points[:-1], points[1:]], axis=1)
|
||||
alphas = np.linspace(0.15, 1.0, len(segs))
|
||||
colors = [(*matplotlib.colors.to_rgb(color), a) for a in alphas]
|
||||
lc = LineCollection(segs, colors=colors, linewidth=lw)
|
||||
ax.add_collection(lc)
|
||||
if label:
|
||||
ax.plot([], [], color=color, lw=lw, label=label)
|
||||
|
||||
|
||||
# ── main plots ────────────────────────────────────────────────────────────────
|
||||
|
||||
def plot_trajectory(hist, out_path):
|
||||
fig, ax = plt.subplots(figsize=(7, 7))
|
||||
draw_field(ax)
|
||||
|
||||
# Sheep paths
|
||||
for i in range(hist["n_sheep"]):
|
||||
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||
xs, ys = hist["sheep_xs"][i], hist["sheep_ys"][i]
|
||||
faded_path(ax, xs, ys, c, lw=1.2, label=f"sheep {i+1}")
|
||||
ax.plot(xs[0], ys[0], "o", color=c, ms=7, zorder=4)
|
||||
pa = hist["penned_at"][i]
|
||||
end = pa if pa is not None else -1
|
||||
ax.plot(xs[end], ys[end], "*", color=c, ms=11, zorder=5)
|
||||
|
||||
# Dog path
|
||||
faded_path(ax, hist["dog_xs"], hist["dog_ys"], DOG_COLOR, lw=2.0, label="dog")
|
||||
ax.plot(hist["dog_xs"][0], hist["dog_ys"][0], "s", color=DOG_COLOR, ms=10, zorder=5)
|
||||
ax.plot(hist["dog_xs"][-1], hist["dog_ys"][-1], "D", color=DOG_COLOR, ms=10, zorder=5)
|
||||
|
||||
result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']} penned)"
|
||||
ax.set_title(f"Trajectory — {result} — {hist['steps']} steps", fontsize=12)
|
||||
ax.legend(loc="upper left", fontsize=8)
|
||||
plt.tight_layout()
|
||||
fig.savefig(out_path, dpi=120)
|
||||
plt.close(fig)
|
||||
print(f" saved {out_path}")
|
||||
|
||||
|
||||
def plot_timeseries(hist, out_path):
|
||||
t = np.arange(hist["steps"])
|
||||
fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
|
||||
|
||||
# 1. Flock radius
|
||||
axes[0].plot(t, hist["radii"], color="steelblue")
|
||||
axes[0].axhline(5.0, color="orange", ls="--", lw=1, label="compact threshold (5m)")
|
||||
axes[0].set_ylabel("flock radius (m)")
|
||||
axes[0].legend(fontsize=8)
|
||||
axes[0].set_title("Flock radius — goal: get below 5m")
|
||||
|
||||
# 2. Per-sheep distance to pen
|
||||
for i in range(hist["n_sheep"]):
|
||||
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||
axes[1].plot(t, hist["pen_dists"][i], color=c, lw=1, label=f"sheep {i+1}")
|
||||
pa = hist["penned_at"][i]
|
||||
if pa is not None:
|
||||
axes[1].axvline(pa, color=c, ls=":", lw=1)
|
||||
axes[1].set_ylabel("dist to pen (m)")
|
||||
axes[1].legend(fontsize=7, ncol=min(hist["n_sheep"], 5))
|
||||
axes[1].set_title("Per-sheep distance to pen — goal: all reach 0")
|
||||
|
||||
# 3. Action magnitude (how fast dog is moving)
|
||||
axes[2].plot(t, hist["action_mags"], color="tomato", lw=1)
|
||||
axes[2].axhline(1.0, color="gray", ls="--", lw=1, label="max")
|
||||
axes[2].set_ylabel("action ||(vx,vy)||")
|
||||
axes[2].set_ylim(0, 1.5)
|
||||
axes[2].set_title("Dog action magnitude — 0=stopped, 1=full speed")
|
||||
axes[2].legend(fontsize=8)
|
||||
|
||||
# 4. Reward per step
|
||||
axes[3].plot(t, hist["rewards"], color="purple", lw=1, alpha=0.7)
|
||||
axes[3].axhline(0, color="black", lw=0.5)
|
||||
axes[3].set_ylabel("reward")
|
||||
axes[3].set_xlabel("step")
|
||||
axes[3].set_title("Reward per step")
|
||||
|
||||
result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']} penned)"
|
||||
fig.suptitle(f"n_sheep={hist['n_sheep']} {result} {hist['steps']} steps", fontsize=13)
|
||||
plt.tight_layout()
|
||||
fig.savefig(out_path, dpi=120)
|
||||
plt.close(fig)
|
||||
print(f" saved {out_path}")
|
||||
|
||||
|
||||
def save_gif(hist, out_path, fps=15, skip=5):
|
||||
"""Animated replay, every `skip` steps."""
|
||||
n = hist["n_sheep"]
|
||||
idxs = list(range(0, hist["steps"], skip))
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 6))
|
||||
|
||||
def _frame(k):
|
||||
ax.clear()
|
||||
draw_field(ax)
|
||||
t = idxs[k]
|
||||
|
||||
for i in range(n):
|
||||
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||
s0 = max(0, t - 30)
|
||||
ax.plot(hist["sheep_xs"][i][s0:t+1],
|
||||
hist["sheep_ys"][i][s0:t+1],
|
||||
color=c, lw=0.8, alpha=0.5)
|
||||
color = "#ff69b4" if (hist["penned_at"][i] is not None
|
||||
and t >= hist["penned_at"][i]) else c
|
||||
ax.plot(hist["sheep_xs"][i][t], hist["sheep_ys"][i][t],
|
||||
"o", color=color, ms=10, zorder=4,
|
||||
markeredgecolor="#555", markeredgewidth=1)
|
||||
|
||||
s0 = max(0, t - 30)
|
||||
ax.plot(hist["dog_xs"][s0:t+1], hist["dog_ys"][s0:t+1],
|
||||
color=DOG_COLOR, lw=1.5, alpha=0.6)
|
||||
ax.plot(hist["dog_xs"][t], hist["dog_ys"][t],
|
||||
"s", color=DOG_COLOR, ms=13, zorder=5,
|
||||
markeredgecolor="black", markeredgewidth=1.5)
|
||||
|
||||
r = hist["radii"][t]
|
||||
ax.set_title(f"step {t}/{hist['steps']} radius={r:.1f}m "
|
||||
f"penned={hist['n_penned'] if t==hist['steps']-1 else '?'}/{n}",
|
||||
fontsize=10)
|
||||
|
||||
ani = animation.FuncAnimation(fig, _frame, frames=len(idxs), interval=1000//fps)
|
||||
ani.save(out_path, writer="pillow", fps=fps)
|
||||
plt.close(fig)
|
||||
print(f" saved {out_path}")
|
||||
|
||||
|
||||
# ── entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model", default=None, help="Model .zip (omit for random policy)")
|
||||
p.add_argument("--vecnorm", default=None)
|
||||
p.add_argument("--n-sheep", type=int, default=3)
|
||||
p.add_argument("--max-steps", type=int, default=2000)
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
p.add_argument("--out-dir", default="vis_out")
|
||||
p.add_argument("--random", action="store_true",
|
||||
help="Use random policy (baseline comparison)")
|
||||
p.add_argument("--gif-fps", type=int, default=15)
|
||||
p.add_argument("--gif-skip", type=int, default=5,
|
||||
help="Render every Nth step in the GIF")
|
||||
p.add_argument("--no-gif", action="store_true")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
raw = DummyVecEnv([make_env(args.n_sheep, args.max_steps, args.seed)])
|
||||
|
||||
if args.random or args.model is None:
|
||||
print("Using RANDOM policy")
|
||||
env = raw
|
||||
model = None
|
||||
else:
|
||||
if args.vecnorm:
|
||||
env = VecNormalize.load(args.vecnorm, raw)
|
||||
env.training = False
|
||||
env.norm_reward = False
|
||||
else:
|
||||
env = raw
|
||||
model = PPO.load(args.model, env=env)
|
||||
print(f"Loaded model: {args.model}")
|
||||
|
||||
print(f"Running episode n_sheep={args.n_sheep} seed={args.seed} ...")
|
||||
hist = run_episode(model, env, args.n_sheep, args.max_steps)
|
||||
|
||||
result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']} penned)"
|
||||
print(f"Episode done: {result} steps={hist['steps']}")
|
||||
print(f" min radius : {min(hist['radii']):.2f} m")
|
||||
print(f" mean reward: {np.mean(hist['rewards']):.4f}")
|
||||
print(f" mean action: {np.mean(hist['action_mags']):.3f}")
|
||||
|
||||
env.close()
|
||||
|
||||
plot_trajectory(hist, os.path.join(args.out_dir, "trajectory.png"))
|
||||
plot_timeseries(hist, os.path.join(args.out_dir, "timeseries.png"))
|
||||
if not args.no_gif:
|
||||
save_gif(hist, os.path.join(args.out_dir, "episode.gif"),
|
||||
fps=args.gif_fps, skip=args.gif_skip)
|
||||
|
||||
print(f"\nAll outputs saved to {args.out_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user