From fcfa2c35c8bbd18ad2579f3fde2c35b5a72cce6c Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Fri, 24 Apr 2026 14:54:20 +0100 Subject: [PATCH] Sheep training flock of 10 fix? --- .../shepherd_dog_rl/shepherd_dog_rl.py | 22 +- training/herding_env.py | 38 ++-- training/smoke_test.py | 198 ++++++++++++++++++ training/train.py | 117 ++++++++++- 4 files changed, 342 insertions(+), 33 deletions(-) create mode 100644 training/smoke_test.py diff --git a/controllers/shepherd_dog_rl/shepherd_dog_rl.py b/controllers/shepherd_dog_rl/shepherd_dog_rl.py index d94c574..1219878 100644 --- a/controllers/shepherd_dog_rl/shepherd_dog_rl.py +++ b/controllers/shepherd_dog_rl/shepherd_dog_rl.py @@ -84,23 +84,25 @@ def build_obs(dog_pos: np.ndarray, d_from_com = np.linalg.norm(active_pos - com, axis=1) sorted_idx = np.argsort(d_from_com)[::-1] radius = float(d_from_com[sorted_idx[0]]) - far = active_pos[sorted_idx[0]] - second_far_dist = float(d_from_com[sorted_idx[1]]) if len(sorted_idx) > 1 else 0.0 + def nth(n): + return active_pos[sorted_idx[n]] if len(sorted_idx) > n else com + far1, far2, far3 = nth(0), nth(1), nth(2) else: com = PEN_CENTER.copy() - radius = second_far_dist = 0.0 - far = PEN_CENTER.copy() + radius = 0.0 + far1 = far2 = far3 = PEN_CENTER.copy() frac_active = n_active / max(n_sheep, 1) return np.array([ dog_pos[0] / FIELD, dog_pos[1] / FIELD, - (com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D, - (far[0] - dog_pos[0]) / D, (far[1] - dog_pos[1]) / D, - (PEN_CENTER[0] - com[0]) / D, (PEN_CENTER[1] - com[1]) / D, - (PEN_CENTER[0] - far[0]) / D, (PEN_CENTER[1] - far[1]) / D, - radius / D, - second_far_dist / D, + (com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D, + (far1[0] - dog_pos[0]) / D, (far1[1] - dog_pos[1]) / D, + (far2[0] - dog_pos[0]) / D, (far2[1] - dog_pos[1]) / D, + (far3[0] - dog_pos[0]) / D, (far3[1] - dog_pos[1]) / D, + (PEN_CENTER[0] - com[0]) / D, (PEN_CENTER[1] - com[1]) / D, + (PEN_CENTER[0] - far1[0]) / D, (PEN_CENTER[1] - far1[1]) / D, + radius / D, frac_active, ], dtype=np.float32) diff --git a/training/herding_env.py b/training/herding_env.py index c20ff0c..52a0e7b 100644 --- a/training/herding_env.py +++ b/training/herding_env.py @@ -56,7 +56,7 @@ class HerdingEnv(gym.Env): W_DRIVE = 2.0 # progress: COM moved toward pen (only when compact) W_COLLECT = 4.0 # progress: radius shrank (2× stronger when scattered) W_ALIGN = 0.5 # position: dog on anti-pen side of COM - W_COMPACT_BONUS = 0.1 # per-step bonus for staying compact (sustained signal) + W_COMPACT_BONUS = 0.0 # disabled: 0.1/step over 4000 steps = 400 >> W_COMPLETE=100 W_PEN_BONUS = 10.0 # per sheep penned W_COMPLETE = 100.0 # all sheep penned W_STEP_COST = 0.002 # time penalty @@ -72,11 +72,11 @@ class HerdingEnv(gym.Env): self.render_mode = render_mode self.random_n_sheep = random_n_sheep # if True, randomise n_sheep each reset - # Fixed 13-dim observation regardless of n_sheep: - # dog_pos(2) + rel_com(2) + rel_far(2) + com_to_pen(2) - # + far_to_pen(2) + radius(1) + second_far_dist(1) + frac_penned(1) + # Fixed 17-dim observation regardless of n_sheep: + # dog_pos(2) + rel_com(2) + rel_far1(2) + rel_far2(2) + rel_far3(2) + # + com_to_pen(2) + far1_to_pen(2) + radius(1) + frac_penned(1) self.observation_space = spaces.Box( - low=-np.inf, high=np.inf, shape=(13,), dtype=np.float32 + low=-np.inf, high=np.inf, shape=(17,), dtype=np.float32 ) # Action: desired velocity (vx, vy) ∈ [-1, 1]², scaled by DOG_SPEED @@ -269,29 +269,25 @@ class HerdingEnv(gym.Env): pts = self.sheep_pos[:self.n_sheep][active_mask] dists = np.linalg.norm(pts - com, axis=1) sorted_idx = np.argsort(dists)[::-1] # farthest first - far = pts[sorted_idx[0]] - # 2nd farthest — if only 1 active sheep, reuse the same position - far2 = pts[sorted_idx[1]] if len(sorted_idx) > 1 else far - second_far_dist = float(dists[sorted_idx[1]]) if len(sorted_idx) > 1 else 0.0 + # Top-3 stragglers; pad with COM when fewer active sheep exist + def nth(n): + return pts[sorted_idx[n]] if len(sorted_idx) > n else com + far1, far2, far3 = nth(0), nth(1), nth(2) else: - far = far2 = self.PEN_CENTER.copy() - second_far_dist = 0.0 + far1 = far2 = far3 = self.PEN_CENTER.copy() S = self.FIELD D = 2 * self.FIELD return np.array([ self.dog_pos[0] / S, self.dog_pos[1] / S, - (com[0] - self.dog_pos[0]) / D, - (com[1] - self.dog_pos[1]) / D, - (far[0] - self.dog_pos[0]) / D, - (far[1] - self.dog_pos[1]) / D, - (self.PEN_CENTER[0] - com[0]) / D, - (self.PEN_CENTER[1] - com[1]) / D, - (self.PEN_CENTER[0] - far[0]) / D, - (self.PEN_CENTER[1] - far[1]) / D, - radius / D, - second_far_dist / D, # replaced mean_disp: 2nd farthest sheep from COM + (com[0] - self.dog_pos[0]) / D, (com[1] - self.dog_pos[1]) / D, + (far1[0] - self.dog_pos[0]) / D, (far1[1] - self.dog_pos[1]) / D, + (far2[0] - self.dog_pos[0]) / D, (far2[1] - self.dog_pos[1]) / D, + (far3[0] - self.dog_pos[0]) / D, (far3[1] - self.dog_pos[1]) / D, + (self.PEN_CENTER[0] - com[0]) / D, (self.PEN_CENTER[1] - com[1]) / D, + (self.PEN_CENTER[0] - far1[0]) / D, (self.PEN_CENTER[1] - far1[1]) / D, + radius / D, active_mask.sum() / self.n_sheep, ], dtype=np.float32) diff --git a/training/smoke_test.py b/training/smoke_test.py new file mode 100644 index 0000000..7ae92dc --- /dev/null +++ b/training/smoke_test.py @@ -0,0 +1,198 @@ +""" +Quick sanity check before committing to a full 15M-step training run. + +Trains 1 sheep for 500k steps (~5 min), then 3 sheep for 500k steps. +If both pass, the obs/reward setup is sound and full training is worth running. +If either fails, abort and fix before wasting 15M steps. + +Usage: + python smoke_test.py # fresh run + python smoke_test.py --render # watch episodes after each stage +""" + +import argparse +import os +import sys +import numpy as np +from copy import deepcopy + +from stable_baselines3 import PPO +from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize + +from herding_env import HerdingEnv + + +COMPACT_RADIUS = 5.0 +PASS_THRESHOLD = 0.60 # success rate required to pass each stage + + +def make_env(n_sheep, seed, max_steps=2000): + def _init(): + env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps) + env.reset(seed=seed) + return env + return _init + + +def classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success): + if success: + return "SUCCESS" + if min(ep_radius) > COMPACT_RADIUS: + return "NEVER_COMPACT" + first_compact = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS) + if min(ep_com_dist[first_compact:]) > 3.0: + return "COMPACT_CANT_DRIVE" + if n_penned == 0: + return "DROVE_NO_SHEEP" + return f"PARTIAL_{n_penned}of{n_sheep}" + + +def run_episodes(model, eval_env, n_episodes=30, max_steps=2000, render=False): + """Run N deterministic episodes; return failure mode counts and success rate.""" + failure_counts = {} + successes = 0 + + for ep in range(n_episodes): + obs = eval_env.reset() + done = False + ep_radius, ep_com_dist = [], [] + n_penned = 0 + n_sheep = 1 + + while not done: + action, _ = model.predict(obs, deterministic=True) + obs, _, dones, infos = eval_env.step(action) + done = dones[0] + + inner = eval_env.envs[0] + com, radius, _ = inner._flock_stats() + com_dist = float(np.linalg.norm(com - inner.PEN_CENTER)) + ep_radius.append(radius) + ep_com_dist.append(com_dist) + + if render and ep == 0: + inner.render() + + info = infos[0] + n_penned = info.get("n_penned", 0) + n_sheep = info.get("n_sheep", 1) + success = n_penned == n_sheep + successes += int(success) + mode = classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success) + failure_counts[mode] = failure_counts.get(mode, 0) + 1 + + success_rate = successes / n_episodes + return success_rate, failure_counts + + +def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None): + """Train one stage; return (model, vecnorm).""" + train_env = SubprocVecEnv([make_env(n_sheep, i) for i in range(n_envs)]) + + if prev_vecnorm is not None: + vn = deepcopy(prev_vecnorm) + vn.set_venv(train_env) + vn.training = True + vn.norm_reward = True + else: + vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0) + + if prev_model is not None: + model = PPO.load(prev_model, env=vn, + learning_rate=3e-4, n_steps=2048, batch_size=256, + n_epochs=10, gamma=0.995, gae_lambda=0.95, + clip_range=0.2, ent_coef=0.005, vf_coef=0.5, + max_grad_norm=0.5) + else: + model = PPO( + "MlpPolicy", vn, + learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10, + gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.005, + vf_coef=0.5, max_grad_norm=0.5, + policy_kwargs=dict(net_arch=[256, 256]), + verbose=1, + ) + + model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None)) + return model, vn + + +def make_eval_env(model, vecnorm, n_sheep, max_steps=2000): + raw = DummyVecEnv([make_env(n_sheep, seed=9999, max_steps=max_steps)]) + vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False) + vn.obs_rms = deepcopy(vecnorm.obs_rms) + vn.ret_rms = deepcopy(vecnorm.ret_rms) + return vn + + +def report(n_sheep, success_rate, failure_counts, n_episodes): + print(f"\n{'='*52}") + print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})") + print(f" {'─'*48}") + for mode, cnt in sorted(failure_counts.items(), key=lambda x: -x[1]): + bar = "█" * cnt + print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}") + print(f"{'='*52}") + + passed = success_rate >= PASS_THRESHOLD + if passed: + print(f" ✓ PASS (threshold {PASS_THRESHOLD*100:.0f}%)") + else: + dominant = max(failure_counts, key=failure_counts.get) + print(f" ✗ FAIL — dominant: {dominant}") + if dominant == "NEVER_COMPACT": + print(" Dog can't compact flock. Check W_COLLECT, obs contains straggler positions?") + elif dominant == "COMPACT_CANT_DRIVE": + print(" Flock compacts but dog doesn't drive to pen. Check alignment reward / W_DRIVE.") + elif dominant.startswith("PARTIAL"): + print(" Flock splits near pen. Dog loses stragglers at the end.") + print() + return passed + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--steps", type=int, default=500_000, + help="Steps per smoke-test stage (default 500k)") + p.add_argument("--n-envs", type=int, default=4) + p.add_argument("--episodes", type=int, default=30, + help="Validation episodes per stage") + p.add_argument("--render", action="store_true") + args = p.parse_args() + + stages = [(1, args.steps), (3, args.steps)] + + model, vn = None, None + all_passed = True + + for n_sheep, steps in stages: + print(f"\n{'#'*52}") + print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps") + print(f"{'#'*52}") + + model, vn = train_stage(n_sheep, steps, args.n_envs, model, vn) + + eval_env = make_eval_env(model, vn, n_sheep) + success_rate, failure_counts = run_episodes( + model, eval_env, args.episodes, render=args.render + ) + eval_env.close() + + passed = report(n_sheep, success_rate, failure_counts, args.episodes) + if not passed: + all_passed = False + print(" Aborting smoke test — fix the issue above before full training.") + sys.exit(1) + + if all_passed: + print("\n All smoke-test stages passed.") + print(" Ready for full curriculum training:") + print() + print(" python train.py --curriculum --steps-per-stage 1500000 \\") + print(" --total-steps 15000000 --n-sheep 1 --max-sheep 10 \\") + print(" --n-envs 8 --run-dir runs/ppo_v2") + print() + + +if __name__ == "__main__": + main() diff --git a/training/train.py b/training/train.py index bd52050..9b6fd29 100644 --- a/training/train.py +++ b/training/train.py @@ -19,6 +19,7 @@ Usage examples import argparse import os +from copy import deepcopy import numpy as np from stable_baselines3 import PPO @@ -28,10 +29,25 @@ from stable_baselines3.common.callbacks import ( CheckpointCallback, EvalCallback, ) -from stable_baselines3.common.vec_env import SubprocVecEnv, VecNormalize +from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize from herding_env import HerdingEnv +COMPACT_RADIUS = HerdingEnv.DRIVE_GATE_RADIUS + + +def _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success): + if success: + return "SUCCESS" + if min(ep_radius) > COMPACT_RADIUS: + return "NEVER_COMPACT" + first = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS) + if min(ep_com_dist[first:]) > 3.0: + return "COMPACT_CANT_DRIVE" + if n_penned == 0: + return "DROVE_NO_SHEEP" + return f"PARTIAL_{n_penned}of{n_sheep}" + # --------------------------------------------------------------------------- # Curriculum callback @@ -101,6 +117,96 @@ class CurriculumCallback(BaseCallback): return True +# --------------------------------------------------------------------------- +# Diagnostic callback — failure-mode breakdown every diag_freq steps +# --------------------------------------------------------------------------- + +class DiagnosticCallback(BaseCallback): + """ + Every diag_freq env steps: spin up a temporary eval env, run n_episodes + deterministic episodes, and print a failure-mode breakdown. + Aborts training (returns False) if the dominant failure mode hasn't + changed after two consecutive checks at the same n_sheep — a sign that + training has stalled and further steps are wasted. + """ + + def __init__(self, diag_freq: int = 500_000, n_episodes: int = 20, + max_steps: int = 2000, verbose: int = 1): + super().__init__(verbose) + self.diag_freq = diag_freq + self.n_episodes = n_episodes + self.max_steps = max_steps + self._last_diag = 0 + self._prev_dominant = None # (n_sheep, mode) from last check + self._stall_count = 0 + + def _on_step(self) -> bool: + if self.num_timesteps - self._last_diag < self.diag_freq: + return True + self._last_diag = self.num_timesteps + + n_sheep = self.training_env.get_attr("n_sheep")[0] + + # Build a temporary single-env with copied VecNorm stats + raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep, + max_steps=self.max_steps)]) + vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False) + vn.obs_rms = deepcopy(self.training_env.obs_rms) + vn.ret_rms = deepcopy(self.training_env.ret_rms) + + failure_counts = {} + successes = 0 + + for _ in range(self.n_episodes): + obs = vn.reset() + done = False + ep_radius, ep_com_dist = [], [] + n_penned = 0 + + while not done: + action, _ = self.model.predict(obs, deterministic=True) + obs, _, dones, infos = vn.step(action) + done = dones[0] + inner = vn.envs[0] + com, radius, _ = inner._flock_stats() + ep_radius.append(radius) + ep_com_dist.append( + float(np.linalg.norm(com - inner.PEN_CENTER)) + ) + + n_penned = infos[0].get("n_penned", 0) + success = n_penned == n_sheep + successes += int(success) + mode = _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success) + failure_counts[mode] = failure_counts.get(mode, 0) + 1 + + vn.close() + + success_rate = successes / self.n_episodes + dominant = max(failure_counts, key=failure_counts.get) + + if self.verbose: + print(f"\n[Diag @ {self.num_timesteps:,} | n_sheep={n_sheep} | " + f"success={success_rate*100:.0f}%]") + for m, c in sorted(failure_counts.items(), key=lambda x: -x[1]): + print(f" {m:<26} {c}/{self.n_episodes}") + + # Stall detection: same dominant failure at same n_sheep twice in a row + key = (n_sheep, dominant) + if key == self._prev_dominant and dominant != "SUCCESS": + self._stall_count += 1 + if self._stall_count >= 2: + print(f"\n[Diag] STALL DETECTED — '{dominant}' on {n_sheep} sheep " + f"for {self._stall_count} consecutive checks. " + f"Aborting training early.") + return False + else: + self._stall_count = 0 + self._prev_dominant = key + + return True + + # --------------------------------------------------------------------------- # Environment factory # --------------------------------------------------------------------------- @@ -141,6 +247,8 @@ def parse_args(): p.add_argument("--save-freq", type=int, default=100_000) p.add_argument("--eval-freq", type=int, default=50_000) p.add_argument("--eval-eps", type=int, default=20) + p.add_argument("--diag-freq", type=int, default=500_000, + help="Run failure-mode diagnostics every N env steps") p.add_argument("--mixed", action="store_true", help="Randomise n_sheep each episode (consolidation pass, " "use with --resume after curriculum training)") @@ -193,7 +301,12 @@ def main(): deterministic=True, verbose=1, ) - callbacks = [checkpoint_cb, eval_cb] + diag_cb = DiagnosticCallback( + diag_freq=max(args.diag_freq // args.n_envs, 1), + n_episodes=20, + max_steps=args.max_steps, + ) + callbacks = [checkpoint_cb, eval_cb, diag_cb] if args.curriculum: cur_cb = CurriculumCallback(