""" PPO training script for the herding task. Usage examples -------------- # Proper 5-sheep curriculum, 1 M steps per stage: python train.py --curriculum --steps-per-stage 1000000 --total-steps 5000000 # Success-rate curriculum (advances when 70 % success over 100 episodes): python train.py --curriculum --threshold 0.70 # Resume from checkpoint at stage 3: python train.py --resume runs/ppo_herding/ckpt_3000000_steps.zip --n-sheep 3 \ --curriculum --steps-per-stage 1000000 --total-steps 5000000 # Quick smoke-test: python train.py --n-envs 1 --total-steps 50000 """ import argparse import os from copy import deepcopy import numpy as np from stable_baselines3 import PPO from stable_baselines3.common.callbacks import ( BaseCallback, CallbackList, CheckpointCallback, EvalCallback, ) from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize from herding_env import HerdingEnv COMPACT_RADIUS = 5.0 def _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success): if success: return "SUCCESS" if min(ep_radius) > COMPACT_RADIUS: return "NEVER_COMPACT" first = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS) if min(ep_com_dist[first:]) > 3.0: return "COMPACT_CANT_DRIVE" if n_penned == 0: return "DROVE_NO_SHEEP" return f"PARTIAL_{n_penned}of{n_sheep}" # --------------------------------------------------------------------------- # Curriculum callback # --------------------------------------------------------------------------- class CurriculumCallback(BaseCallback): """ Advances n_sheep on both training and eval envs. Two modes (mutually exclusive): steps_per_stage — advance every N environment steps regardless of success rate (recommended for reliability). threshold — advance when rolling success rate exceeds this value (requires the policy to actually reach the threshold). """ def __init__(self, start_sheep: int, max_sheep: int, eval_env=None, steps_per_stage: int = None, threshold: float = 0.75, window: int = 100, min_episodes: int = 50, verbose: int = 1): super().__init__(verbose) self.max_sheep = max_sheep self.eval_env = eval_env self.steps_per_stage = steps_per_stage self.threshold = threshold self.window = window self.min_episodes = min_episodes self._cur_sheep = start_sheep self._successes = [] self._stage_start = 0 def _advance(self): self._cur_sheep += 1 self.training_env.env_method("set_n_sheep", self._cur_sheep) if self.eval_env is not None: self.eval_env.env_method("set_n_sheep", self._cur_sheep) self._stage_start = self.num_timesteps self._successes.clear() if self.verbose: print(f"\n[Curriculum] → {self._cur_sheep} sheep " f"at step {self.num_timesteps:,}\n") def _on_step(self) -> bool: if self._cur_sheep >= self.max_sheep: return True if self.steps_per_stage is not None: # Time-based: advance every steps_per_stage env steps if self.num_timesteps - self._stage_start >= self.steps_per_stage: self._advance() else: # Success-rate based for info, done in zip(self.locals["infos"], self.locals["dones"]): if done: truncated = info.get("TimeLimit.truncated", False) self._successes.append(0 if truncated else 1) if len(self._successes) > self.window: self._successes.pop(0) if (len(self._successes) >= self.min_episodes and np.mean(self._successes) >= self.threshold): self._advance() return True # --------------------------------------------------------------------------- # Diagnostic callback — failure-mode breakdown every diag_freq steps # --------------------------------------------------------------------------- class DiagnosticCallback(BaseCallback): """ Every diag_freq env steps: spin up a temporary eval env, run n_episodes deterministic episodes, and print a failure-mode breakdown. Aborts training (returns False) if the dominant failure mode hasn't changed after two consecutive checks at the same n_sheep — a sign that training has stalled and further steps are wasted. """ def __init__(self, diag_freq: int = 500_000, n_episodes: int = 20, max_steps: int = 2000, verbose: int = 1): super().__init__(verbose) self.diag_freq = diag_freq self.n_episodes = n_episodes self.max_steps = max_steps self._last_diag = 0 self._prev_dominant = None # (n_sheep, mode) from last check self._stall_count = 0 def _on_step(self) -> bool: if self.num_timesteps - self._last_diag < self.diag_freq: return True self._last_diag = self.num_timesteps n_sheep = self.training_env.get_attr("n_sheep")[0] # Build a temporary single-env with copied VecNorm stats raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep, max_steps=self.max_steps)]) vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False) vn.obs_rms = deepcopy(self.training_env.obs_rms) vn.ret_rms = deepcopy(self.training_env.ret_rms) failure_counts = {} successes = 0 for _ in range(self.n_episodes): obs = vn.reset() done = False ep_radius, ep_com_dist = [], [] n_penned = 0 while not done: action, _ = self.model.predict(obs, deterministic=True) obs, _, dones, infos = vn.step(action) done = dones[0] inner = vn.envs[0] com, radius, _ = inner._flock_stats() ep_radius.append(radius) ep_com_dist.append( float(np.linalg.norm(com - inner.PEN_CENTER)) ) n_penned = infos[0].get("n_penned", 0) success = n_penned == n_sheep successes += int(success) mode = _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success) failure_counts[mode] = failure_counts.get(mode, 0) + 1 vn.close() success_rate = successes / self.n_episodes dominant = max(failure_counts, key=failure_counts.get) if self.verbose: print(f"\n[Diag @ {self.num_timesteps:,} | n_sheep={n_sheep} | " f"success={success_rate*100:.0f}%]") for m, c in sorted(failure_counts.items(), key=lambda x: -x[1]): print(f" {m:<26} {c}/{self.n_episodes}") # Stall detection: same dominant failure at same n_sheep twice in a row key = (n_sheep, dominant) if key == self._prev_dominant and dominant != "SUCCESS": self._stall_count += 1 if self._stall_count >= 2: print(f"\n[Diag] STALL DETECTED — '{dominant}' on {n_sheep} sheep " f"for {self._stall_count} consecutive checks. " f"Aborting training early.") return False else: self._stall_count = 0 self._prev_dominant = key return True # --------------------------------------------------------------------------- # Environment factory # --------------------------------------------------------------------------- def make_env(n_sheep: int, seed: int, max_steps: int, random_n_sheep: bool = False): def _init(): env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps, random_n_sheep=random_n_sheep) env.reset(seed=seed) return env return _init # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def parse_args(): p = argparse.ArgumentParser() p.add_argument("--n-sheep", type=int, default=1, help="Starting sheep count") p.add_argument("--max-sheep", type=int, default=5, help="Final sheep count for curriculum") p.add_argument("--n-envs", type=int, default=8, help="Parallel training environments") p.add_argument("--total-steps", type=int, default=5_000_000) p.add_argument("--max-steps", type=int, default=2000, help="Episode step limit") p.add_argument("--curriculum", action="store_true", help="Enable curriculum advancement") p.add_argument("--steps-per-stage", type=int, default=None, help="Advance curriculum every N steps (overrides --threshold)") p.add_argument("--threshold", type=float, default=0.75, help="Success-rate threshold to advance (used without --steps-per-stage)") p.add_argument("--resume", type=str, default=None, help="Checkpoint .zip to resume from") p.add_argument("--run-dir", type=str, default="runs/ppo_herding") p.add_argument("--save-freq", type=int, default=100_000) p.add_argument("--eval-freq", type=int, default=50_000) p.add_argument("--eval-eps", type=int, default=20) p.add_argument("--diag-freq", type=int, default=500_000, help="Run failure-mode diagnostics every N env steps") p.add_argument("--mixed", action="store_true", help="Randomise n_sheep each episode (consolidation pass, " "use with --resume after curriculum training)") return p.parse_args() def main(): args = parse_args() os.makedirs(args.run_dir, exist_ok=True) ckpt_dir = os.path.join(args.run_dir, "checkpoints") best_dir = os.path.join(args.run_dir, "best_model") norm_path = os.path.join(args.run_dir, "vecnorm.pkl") os.makedirs(ckpt_dir, exist_ok=True) # Training envs train_env = SubprocVecEnv([ make_env(args.n_sheep, seed=i, max_steps=args.max_steps, random_n_sheep=args.mixed) for i in range(args.n_envs) ]) if args.resume and os.path.exists(norm_path): train_env = VecNormalize.load(norm_path, train_env) train_env.training = True train_env.norm_reward = True else: train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0) # Eval env — starts at same difficulty, advances with curriculum callback eval_env = SubprocVecEnv([ make_env(args.n_sheep, seed=1000 + i, max_steps=args.max_steps) for i in range(2) ]) eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, clip_obs=10.0, training=False) # Callbacks checkpoint_cb = CheckpointCallback( save_freq=max(args.save_freq // args.n_envs, 1), save_path=ckpt_dir, name_prefix="ckpt", save_vecnormalize=True, ) eval_cb = EvalCallback( eval_env, best_model_save_path=best_dir, log_path=args.run_dir, eval_freq=max(args.eval_freq // args.n_envs, 1), n_eval_episodes=args.eval_eps, deterministic=True, verbose=1, ) diag_cb = DiagnosticCallback( diag_freq=max(args.diag_freq // args.n_envs, 1), n_episodes=20, max_steps=args.max_steps, ) callbacks = [checkpoint_cb, eval_cb, diag_cb] if args.curriculum: cur_cb = CurriculumCallback( start_sheep=args.n_sheep, max_sheep=args.max_sheep, eval_env=eval_env, steps_per_stage=args.steps_per_stage, threshold=args.threshold, ) callbacks.append(cur_cb) callback_list = CallbackList(callbacks) # Model ppo_kwargs = dict( policy = "MlpPolicy", env = train_env, learning_rate = 3e-4, n_steps = 2048, batch_size = 256, n_epochs = 10, gamma = 0.995, gae_lambda = 0.95, clip_range = 0.2, ent_coef = 0.02, vf_coef = 0.5, max_grad_norm = 0.5, policy_kwargs = dict(net_arch=[256, 256]), tensorboard_log = args.run_dir, verbose = 1, ) if args.resume: print(f"Resuming from {args.resume}") model = PPO.load(args.resume, env=train_env, **{ k: v for k, v in ppo_kwargs.items() if k not in ("policy", "env") }) else: model = PPO(**ppo_kwargs) model.learn( total_timesteps=args.total_steps, callback=callback_list, reset_num_timesteps=args.resume is None, tb_log_name="ppo", ) model.save(os.path.join(args.run_dir, "final_model")) train_env.save(norm_path) print(f"\nTraining complete. Artefacts saved to {args.run_dir}/") if __name__ == "__main__": main()