From bf9fe902d9f4f1ab166e8d6877eeb2f06258f34e Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Fri, 24 Apr 2026 17:49:42 +0100 Subject: [PATCH] Sheep training flock of 10 fix? --- .../shepherd_dog_rl/shepherd_dog_rl.py | 6 +- training/herding_env.py | 10 +- training/smoke_test.py | 92 ++++++++++++++++++- 3 files changed, 98 insertions(+), 10 deletions(-) diff --git a/controllers/shepherd_dog_rl/shepherd_dog_rl.py b/controllers/shepherd_dog_rl/shepherd_dog_rl.py index 1219878..c015c06 100644 --- a/controllers/shepherd_dog_rl/shepherd_dog_rl.py +++ b/controllers/shepherd_dog_rl/shepherd_dog_rl.py @@ -97,9 +97,9 @@ def build_obs(dog_pos: np.ndarray, return np.array([ dog_pos[0] / FIELD, dog_pos[1] / FIELD, (com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D, - (far1[0] - dog_pos[0]) / D, (far1[1] - dog_pos[1]) / D, - (far2[0] - dog_pos[0]) / D, (far2[1] - dog_pos[1]) / D, - (far3[0] - dog_pos[0]) / D, (far3[1] - dog_pos[1]) / D, + (far1[0] - com[0]) / D, (far1[1] - com[1]) / D, + (far2[0] - com[0]) / D, (far2[1] - com[1]) / D, + (far3[0] - com[0]) / D, (far3[1] - com[1]) / D, (PEN_CENTER[0] - com[0]) / D, (PEN_CENTER[1] - com[1]) / D, (PEN_CENTER[0] - far1[0]) / D, (PEN_CENTER[1] - far1[1]) / D, radius / D, diff --git a/training/herding_env.py b/training/herding_env.py index 034bb17..4eba6e6 100644 --- a/training/herding_env.py +++ b/training/herding_env.py @@ -279,12 +279,16 @@ class HerdingEnv(gym.Env): S = self.FIELD D = 2 * self.FIELD + # far1/far2/far3 expressed relative to COM, not dog. + # For 1 sheep: far1-COM = far2-COM = far3-COM = [0,0] → cleanly ignorable. + # For 3+ sheep: non-zero vectors tell the dog where each straggler is + # within the group, without conflicting with weights trained on 1 sheep. return np.array([ self.dog_pos[0] / S, self.dog_pos[1] / S, (com[0] - self.dog_pos[0]) / D, (com[1] - self.dog_pos[1]) / D, - (far1[0] - self.dog_pos[0]) / D, (far1[1] - self.dog_pos[1]) / D, - (far2[0] - self.dog_pos[0]) / D, (far2[1] - self.dog_pos[1]) / D, - (far3[0] - self.dog_pos[0]) / D, (far3[1] - self.dog_pos[1]) / D, + (far1[0] - com[0]) / D, (far1[1] - com[1]) / D, + (far2[0] - com[0]) / D, (far2[1] - com[1]) / D, + (far3[0] - com[0]) / D, (far3[1] - com[1]) / D, (self.PEN_CENTER[0] - com[0]) / D, (self.PEN_CENTER[1] - com[1]) / D, (self.PEN_CENTER[0] - far1[0]) / D, (self.PEN_CENTER[1] - far1[1]) / D, radius / D, diff --git a/training/smoke_test.py b/training/smoke_test.py index 99413fa..6c929c5 100644 --- a/training/smoke_test.py +++ b/training/smoke_test.py @@ -16,6 +16,12 @@ import sys import numpy as np from copy import deepcopy +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.collections import LineCollection + from stable_baselines3 import PPO from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize @@ -148,6 +154,83 @@ def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THR return passed +SHEEP_COLORS = ["#e41a1c","#377eb8","#4daf4a","#984ea3","#ff7f00", + "#a65628","#f781bf","#999999","#66c2a5","#fc8d62"] + +def _save_smoke_vis(model, vn, n_sheep, save_dir, seed=42, max_steps=2000): + """Run one episode and save trajectory + timeseries PNGs.""" + from copy import deepcopy + raw = DummyVecEnv([make_env(n_sheep, max_steps, seed)]) + env = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False) + env.obs_rms = deepcopy(vn.obs_rms) + env.ret_rms = deepcopy(vn.ret_rms) + + obs = env.reset() + inner = env.envs[0] + dog_xs, dog_ys = [], [] + sheep_xs = [[] for _ in range(n_sheep)] + sheep_ys = [[] for _ in range(n_sheep)] + radii, action_mags, rewards = [], [], [] + pen_dists = [[] for _ in range(n_sheep)] + done = False + + while not done: + action, _ = model.predict(obs, deterministic=True) + obs, reward, dones, _ = env.step(action) + done = dones[0] + dog_xs.append(float(inner.dog_pos[0])); dog_ys.append(float(inner.dog_pos[1])) + com, radius, _ = inner._flock_stats() + radii.append(radius) + rewards.append(float(reward[0])) + action_mags.append(float(np.linalg.norm(action[0]))) + for i in range(n_sheep): + sheep_xs[i].append(float(inner.sheep_pos[i][0])) + sheep_ys[i].append(float(inner.sheep_pos[i][1])) + pen_dists[i].append(float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER))) + env.close() + + steps = len(dog_xs) + # Trajectory + fig, ax = plt.subplots(figsize=(6,6)) + ax.set_xlim(-16,16); ax.set_ylim(-16,16); ax.set_aspect("equal") + ax.set_facecolor("#dcedc8") + ax.add_patch(mpatches.Rectangle((-15,-15),30,30,fill=False,edgecolor="#795548",lw=2)) + ax.add_patch(mpatches.Rectangle((10,-15),3,7,facecolor="#ffe082",edgecolor="#795548",lw=2)) + ax.text(11.5,-11.5,"pen",ha="center",va="center",fontsize=8,color="#795548") + for i in range(n_sheep): + c = SHEEP_COLORS[i % len(SHEEP_COLORS)] + ax.plot(sheep_xs[i], sheep_ys[i], color=c, lw=1, alpha=0.6, label=f"sheep {i+1}") + ax.plot(sheep_xs[i][0], sheep_ys[i][0], "o", color=c, ms=7) + ax.plot(sheep_xs[i][-1], sheep_ys[i][-1], "*", color=c, ms=10) + ax.plot(dog_xs, dog_ys, color="#4e342e", lw=1.5, label="dog", alpha=0.8) + ax.plot(dog_xs[0], dog_ys[0], "s", color="#4e342e", ms=9) + ax.plot(dog_xs[-1], dog_ys[-1], "D", color="#4e342e", ms=9) + ax.set_title(f"n_sheep={n_sheep} {steps} steps min_r={min(radii):.1f}m") + ax.legend(fontsize=7, loc="upper left") + plt.tight_layout() + fig.savefig(os.path.join(save_dir, "trajectory.png"), dpi=100) + plt.close(fig) + + # Timeseries + t = np.arange(steps) + fig, axes = plt.subplots(4,1,figsize=(10,8),sharex=True) + axes[0].plot(t, radii, color="steelblue"); axes[0].axhline(5,color="orange",ls="--",lw=1) + axes[0].set_ylabel("radius (m)"); axes[0].set_title("Flock radius (orange=5m threshold)") + for i in range(n_sheep): + axes[1].plot(t, pen_dists[i], color=SHEEP_COLORS[i%len(SHEEP_COLORS)], lw=1, label=f"sheep {i+1}") + axes[1].set_ylabel("pen dist (m)"); axes[1].set_title("Per-sheep distance to pen"); axes[1].legend(fontsize=7) + axes[2].plot(t, action_mags, color="tomato", lw=1, alpha=0.8) + axes[2].axhline(1.0,color="gray",ls="--",lw=1); axes[2].set_ylim(0,1.5) + axes[2].set_ylabel("action mag"); axes[2].set_title("Dog action magnitude (0=stopped)") + axes[3].plot(t, rewards, color="purple", lw=1, alpha=0.7); axes[3].axhline(0,color="black",lw=0.5) + axes[3].set_ylabel("reward"); axes[3].set_xlabel("step"); axes[3].set_title("Reward per step") + fig.suptitle(f"Smoke stage n_sheep={n_sheep}", fontsize=12) + plt.tight_layout() + fig.savefig(os.path.join(save_dir, "timeseries.png"), dpi=100) + plt.close(fig) + print(f" Viz saved to {save_dir}/trajectory.png + timeseries.png") + + def main(): p = argparse.ArgumentParser() p.add_argument("--steps", type=int, default=500_000, @@ -158,10 +241,10 @@ def main(): p.add_argument("--render", action="store_true") args = p.parse_args() - # Stage 1 (1 sheep, 500k): fast sanity check — obs/reward structurally correct? - # Stage 2 (3 sheep, 1.5M): real test at curriculum pace — if it fails here, - # there is a genuine problem worth fixing before committing to 15M steps. - stages = [(1, args.steps, 0.60), (3, args.steps * 3, 0.40)] + # 1 sheep (500k): sanity check — obs/reward structurally correct? + # 2 sheep (1M): first multi-agent step — gradual transfer + # 3 sheep (1.5M): real multi-sheep test at curriculum pace + stages = [(1, args.steps, 0.60), (2, args.steps * 2, 0.40), (3, args.steps * 3, 0.35)] model, vn = None, None all_passed = True @@ -184,6 +267,7 @@ def main(): model.save(os.path.join(save_dir, "model")) vn.save(os.path.join(save_dir, "vecnorm.pkl")) print(f" Model saved to {save_dir}/") + _save_smoke_vis(model, vn, n_sheep, save_dir) passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold) if not passed: