Sheep training flock of 10 fix?
This commit is contained in:
@@ -0,0 +1,316 @@
|
|||||||
|
"""
|
||||||
|
Single-episode visualization for the herding policy.
|
||||||
|
|
||||||
|
Outputs (all saved to --out-dir):
|
||||||
|
trajectory.png — full field view: dog path + every sheep path
|
||||||
|
timeseries.png — radius, per-sheep pen distance, action magnitude, reward
|
||||||
|
episode.gif — animated replay (slow enough to read)
|
||||||
|
|
||||||
|
Run with no model to watch a RANDOM policy (useful baseline):
|
||||||
|
python visualize.py --random --n-sheep 3 --out-dir vis_random/
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python visualize.py \\
|
||||||
|
--model runs/ppo_consolidation/final_model.zip \\
|
||||||
|
--vecnorm runs/ppo_consolidation/vecnorm.pkl \\
|
||||||
|
--n-sheep 3 --out-dir vis_out/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.patches as mpatches
|
||||||
|
import matplotlib.animation as animation
|
||||||
|
from matplotlib.collections import LineCollection
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||||
|
from herding_env import HerdingEnv
|
||||||
|
|
||||||
|
|
||||||
|
# ── colours ──────────────────────────────────────────────────────────────────
|
||||||
|
SHEEP_COLORS = [
|
||||||
|
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00",
|
||||||
|
"#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62",
|
||||||
|
]
|
||||||
|
DOG_COLOR = "#4e342e"
|
||||||
|
PEN_COLOR = "#ffe082"
|
||||||
|
FIELD_COLOR = "#dcedc8"
|
||||||
|
|
||||||
|
|
||||||
|
def make_env(n_sheep, max_steps, seed=42):
|
||||||
|
def _init():
|
||||||
|
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
|
||||||
|
env.reset(seed=seed)
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def run_episode(model, env, n_sheep, max_steps):
|
||||||
|
"""Run one deterministic episode; return recorded history."""
|
||||||
|
obs = env.reset()
|
||||||
|
inner = env.envs[0]
|
||||||
|
done = False
|
||||||
|
|
||||||
|
dog_xs, dog_ys = [], []
|
||||||
|
sheep_xs = [[] for _ in range(n_sheep)]
|
||||||
|
sheep_ys = [[] for _ in range(n_sheep)]
|
||||||
|
radii = []
|
||||||
|
pen_dists = [[] for _ in range(n_sheep)]
|
||||||
|
action_mags = []
|
||||||
|
rewards = []
|
||||||
|
penned_at = [None] * n_sheep # step when each sheep was penned
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
while not done:
|
||||||
|
if model is None:
|
||||||
|
action = env.action_space.sample()[np.newaxis]
|
||||||
|
else:
|
||||||
|
action, _ = model.predict(obs, deterministic=True)
|
||||||
|
|
||||||
|
obs, reward, dones, infos = env.step(action)
|
||||||
|
done = dones[0]
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
dx, dy = float(inner.dog_pos[0]), float(inner.dog_pos[1])
|
||||||
|
dog_xs.append(dx); dog_ys.append(dy)
|
||||||
|
|
||||||
|
com, radius, _ = inner._flock_stats()
|
||||||
|
radii.append(radius)
|
||||||
|
rewards.append(float(reward[0]))
|
||||||
|
|
||||||
|
act = action[0]
|
||||||
|
action_mags.append(float(np.linalg.norm(act)))
|
||||||
|
|
||||||
|
for i in range(n_sheep):
|
||||||
|
sx, sy = float(inner.sheep_pos[i][0]), float(inner.sheep_pos[i][1])
|
||||||
|
sheep_xs[i].append(sx)
|
||||||
|
sheep_ys[i].append(sy)
|
||||||
|
pen_dists[i].append(float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
|
||||||
|
if inner.penned[i] and penned_at[i] is None:
|
||||||
|
penned_at[i] = step
|
||||||
|
|
||||||
|
info = infos[0]
|
||||||
|
n_penned = info.get("n_penned", 0)
|
||||||
|
success = n_penned == n_sheep
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
dog_xs=dog_xs, dog_ys=dog_ys,
|
||||||
|
sheep_xs=sheep_xs, sheep_ys=sheep_ys,
|
||||||
|
radii=radii, pen_dists=pen_dists,
|
||||||
|
action_mags=action_mags, rewards=rewards,
|
||||||
|
penned_at=penned_at,
|
||||||
|
n_penned=n_penned, n_sheep=n_sheep,
|
||||||
|
success=success, steps=step,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── plot helpers ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def draw_field(ax):
|
||||||
|
ax.set_xlim(-16, 16); ax.set_ylim(-16, 16)
|
||||||
|
ax.set_aspect("equal"); ax.set_facecolor(FIELD_COLOR)
|
||||||
|
ax.add_patch(mpatches.Rectangle((-15,-15), 30, 30,
|
||||||
|
fill=False, edgecolor="#795548", lw=2))
|
||||||
|
ax.add_patch(mpatches.Rectangle((10,-15), 3, 7,
|
||||||
|
facecolor=PEN_COLOR, edgecolor="#795548", lw=2))
|
||||||
|
ax.text(11.5, -11.5, "pen", ha="center", va="center",
|
||||||
|
fontsize=8, color="#795548")
|
||||||
|
|
||||||
|
|
||||||
|
def faded_path(ax, xs, ys, color, lw=1.5, label=None):
|
||||||
|
"""Draw a path with alpha fading from start (transparent) to end (opaque)."""
|
||||||
|
n = len(xs)
|
||||||
|
if n < 2:
|
||||||
|
return
|
||||||
|
points = np.array([xs, ys]).T.reshape(-1, 1, 2)
|
||||||
|
segs = np.concatenate([points[:-1], points[1:]], axis=1)
|
||||||
|
alphas = np.linspace(0.15, 1.0, len(segs))
|
||||||
|
colors = [(*matplotlib.colors.to_rgb(color), a) for a in alphas]
|
||||||
|
lc = LineCollection(segs, colors=colors, linewidth=lw)
|
||||||
|
ax.add_collection(lc)
|
||||||
|
if label:
|
||||||
|
ax.plot([], [], color=color, lw=lw, label=label)
|
||||||
|
|
||||||
|
|
||||||
|
# ── main plots ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def plot_trajectory(hist, out_path):
|
||||||
|
fig, ax = plt.subplots(figsize=(7, 7))
|
||||||
|
draw_field(ax)
|
||||||
|
|
||||||
|
# Sheep paths
|
||||||
|
for i in range(hist["n_sheep"]):
|
||||||
|
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||||
|
xs, ys = hist["sheep_xs"][i], hist["sheep_ys"][i]
|
||||||
|
faded_path(ax, xs, ys, c, lw=1.2, label=f"sheep {i+1}")
|
||||||
|
ax.plot(xs[0], ys[0], "o", color=c, ms=7, zorder=4)
|
||||||
|
pa = hist["penned_at"][i]
|
||||||
|
end = pa if pa is not None else -1
|
||||||
|
ax.plot(xs[end], ys[end], "*", color=c, ms=11, zorder=5)
|
||||||
|
|
||||||
|
# Dog path
|
||||||
|
faded_path(ax, hist["dog_xs"], hist["dog_ys"], DOG_COLOR, lw=2.0, label="dog")
|
||||||
|
ax.plot(hist["dog_xs"][0], hist["dog_ys"][0], "s", color=DOG_COLOR, ms=10, zorder=5)
|
||||||
|
ax.plot(hist["dog_xs"][-1], hist["dog_ys"][-1], "D", color=DOG_COLOR, ms=10, zorder=5)
|
||||||
|
|
||||||
|
result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']} penned)"
|
||||||
|
ax.set_title(f"Trajectory — {result} — {hist['steps']} steps", fontsize=12)
|
||||||
|
ax.legend(loc="upper left", fontsize=8)
|
||||||
|
plt.tight_layout()
|
||||||
|
fig.savefig(out_path, dpi=120)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f" saved {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_timeseries(hist, out_path):
|
||||||
|
t = np.arange(hist["steps"])
|
||||||
|
fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
|
||||||
|
|
||||||
|
# 1. Flock radius
|
||||||
|
axes[0].plot(t, hist["radii"], color="steelblue")
|
||||||
|
axes[0].axhline(5.0, color="orange", ls="--", lw=1, label="compact threshold (5m)")
|
||||||
|
axes[0].set_ylabel("flock radius (m)")
|
||||||
|
axes[0].legend(fontsize=8)
|
||||||
|
axes[0].set_title("Flock radius — goal: get below 5m")
|
||||||
|
|
||||||
|
# 2. Per-sheep distance to pen
|
||||||
|
for i in range(hist["n_sheep"]):
|
||||||
|
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||||
|
axes[1].plot(t, hist["pen_dists"][i], color=c, lw=1, label=f"sheep {i+1}")
|
||||||
|
pa = hist["penned_at"][i]
|
||||||
|
if pa is not None:
|
||||||
|
axes[1].axvline(pa, color=c, ls=":", lw=1)
|
||||||
|
axes[1].set_ylabel("dist to pen (m)")
|
||||||
|
axes[1].legend(fontsize=7, ncol=min(hist["n_sheep"], 5))
|
||||||
|
axes[1].set_title("Per-sheep distance to pen — goal: all reach 0")
|
||||||
|
|
||||||
|
# 3. Action magnitude (how fast dog is moving)
|
||||||
|
axes[2].plot(t, hist["action_mags"], color="tomato", lw=1)
|
||||||
|
axes[2].axhline(1.0, color="gray", ls="--", lw=1, label="max")
|
||||||
|
axes[2].set_ylabel("action ||(vx,vy)||")
|
||||||
|
axes[2].set_ylim(0, 1.5)
|
||||||
|
axes[2].set_title("Dog action magnitude — 0=stopped, 1=full speed")
|
||||||
|
axes[2].legend(fontsize=8)
|
||||||
|
|
||||||
|
# 4. Reward per step
|
||||||
|
axes[3].plot(t, hist["rewards"], color="purple", lw=1, alpha=0.7)
|
||||||
|
axes[3].axhline(0, color="black", lw=0.5)
|
||||||
|
axes[3].set_ylabel("reward")
|
||||||
|
axes[3].set_xlabel("step")
|
||||||
|
axes[3].set_title("Reward per step")
|
||||||
|
|
||||||
|
result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']} penned)"
|
||||||
|
fig.suptitle(f"n_sheep={hist['n_sheep']} {result} {hist['steps']} steps", fontsize=13)
|
||||||
|
plt.tight_layout()
|
||||||
|
fig.savefig(out_path, dpi=120)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f" saved {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_gif(hist, out_path, fps=15, skip=5):
|
||||||
|
"""Animated replay, every `skip` steps."""
|
||||||
|
n = hist["n_sheep"]
|
||||||
|
idxs = list(range(0, hist["steps"], skip))
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(6, 6))
|
||||||
|
|
||||||
|
def _frame(k):
|
||||||
|
ax.clear()
|
||||||
|
draw_field(ax)
|
||||||
|
t = idxs[k]
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||||
|
s0 = max(0, t - 30)
|
||||||
|
ax.plot(hist["sheep_xs"][i][s0:t+1],
|
||||||
|
hist["sheep_ys"][i][s0:t+1],
|
||||||
|
color=c, lw=0.8, alpha=0.5)
|
||||||
|
color = "#ff69b4" if (hist["penned_at"][i] is not None
|
||||||
|
and t >= hist["penned_at"][i]) else c
|
||||||
|
ax.plot(hist["sheep_xs"][i][t], hist["sheep_ys"][i][t],
|
||||||
|
"o", color=color, ms=10, zorder=4,
|
||||||
|
markeredgecolor="#555", markeredgewidth=1)
|
||||||
|
|
||||||
|
s0 = max(0, t - 30)
|
||||||
|
ax.plot(hist["dog_xs"][s0:t+1], hist["dog_ys"][s0:t+1],
|
||||||
|
color=DOG_COLOR, lw=1.5, alpha=0.6)
|
||||||
|
ax.plot(hist["dog_xs"][t], hist["dog_ys"][t],
|
||||||
|
"s", color=DOG_COLOR, ms=13, zorder=5,
|
||||||
|
markeredgecolor="black", markeredgewidth=1.5)
|
||||||
|
|
||||||
|
r = hist["radii"][t]
|
||||||
|
ax.set_title(f"step {t}/{hist['steps']} radius={r:.1f}m "
|
||||||
|
f"penned={hist['n_penned'] if t==hist['steps']-1 else '?'}/{n}",
|
||||||
|
fontsize=10)
|
||||||
|
|
||||||
|
ani = animation.FuncAnimation(fig, _frame, frames=len(idxs), interval=1000//fps)
|
||||||
|
ani.save(out_path, writer="pillow", fps=fps)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f" saved {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# ── entry point ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--model", default=None, help="Model .zip (omit for random policy)")
|
||||||
|
p.add_argument("--vecnorm", default=None)
|
||||||
|
p.add_argument("--n-sheep", type=int, default=3)
|
||||||
|
p.add_argument("--max-steps", type=int, default=2000)
|
||||||
|
p.add_argument("--seed", type=int, default=42)
|
||||||
|
p.add_argument("--out-dir", default="vis_out")
|
||||||
|
p.add_argument("--random", action="store_true",
|
||||||
|
help="Use random policy (baseline comparison)")
|
||||||
|
p.add_argument("--gif-fps", type=int, default=15)
|
||||||
|
p.add_argument("--gif-skip", type=int, default=5,
|
||||||
|
help="Render every Nth step in the GIF")
|
||||||
|
p.add_argument("--no-gif", action="store_true")
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
os.makedirs(args.out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
raw = DummyVecEnv([make_env(args.n_sheep, args.max_steps, args.seed)])
|
||||||
|
|
||||||
|
if args.random or args.model is None:
|
||||||
|
print("Using RANDOM policy")
|
||||||
|
env = raw
|
||||||
|
model = None
|
||||||
|
else:
|
||||||
|
if args.vecnorm:
|
||||||
|
env = VecNormalize.load(args.vecnorm, raw)
|
||||||
|
env.training = False
|
||||||
|
env.norm_reward = False
|
||||||
|
else:
|
||||||
|
env = raw
|
||||||
|
model = PPO.load(args.model, env=env)
|
||||||
|
print(f"Loaded model: {args.model}")
|
||||||
|
|
||||||
|
print(f"Running episode n_sheep={args.n_sheep} seed={args.seed} ...")
|
||||||
|
hist = run_episode(model, env, args.n_sheep, args.max_steps)
|
||||||
|
|
||||||
|
result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']} penned)"
|
||||||
|
print(f"Episode done: {result} steps={hist['steps']}")
|
||||||
|
print(f" min radius : {min(hist['radii']):.2f} m")
|
||||||
|
print(f" mean reward: {np.mean(hist['rewards']):.4f}")
|
||||||
|
print(f" mean action: {np.mean(hist['action_mags']):.3f}")
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
plot_trajectory(hist, os.path.join(args.out_dir, "trajectory.png"))
|
||||||
|
plot_timeseries(hist, os.path.join(args.out_dir, "timeseries.png"))
|
||||||
|
if not args.no_gif:
|
||||||
|
save_gif(hist, os.path.join(args.out_dir, "episode.gif"),
|
||||||
|
fps=args.gif_fps, skip=args.gif_skip)
|
||||||
|
|
||||||
|
print(f"\nAll outputs saved to {args.out_dir}/")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user