""" PPO training script for the herding task. Usage examples -------------- # Start fresh with curriculum (1 → 5 sheep): python train.py --curriculum # Resume from checkpoint, skip directly to 3 sheep: python train.py --resume runs/ppo_herding/ckpt_200000_steps.zip --n-sheep 3 # Quick smoke-test (no curriculum, single env): python train.py --n-envs 1 --total-steps 50000 """ import argparse import os import numpy as np from stable_baselines3 import PPO from stable_baselines3.common.callbacks import ( BaseCallback, CallbackList, CheckpointCallback, EvalCallback, ) from stable_baselines3.common.vec_env import SubprocVecEnv, VecNormalize from herding_env import HerdingEnv # --------------------------------------------------------------------------- # Curriculum callback # --------------------------------------------------------------------------- class CurriculumCallback(BaseCallback): """ Advances the curriculum (number of active sheep) when the rolling mean episode success rate exceeds a threshold. Success = episode terminated (all sheep penned) rather than truncated. """ THRESHOLD = 0.75 # success rate to graduate WINDOW = 100 # episodes to average over MIN_EPISODES = 50 # don't graduate before seeing this many episodes def __init__(self, start_sheep: int, max_sheep: int, verbose: int = 1): super().__init__(verbose) self.max_sheep = max_sheep self._successes = [] self._cur_sheep = start_sheep def _on_step(self) -> bool: for info, done in zip(self.locals["infos"], self.locals["dones"]): if done: truncated = info.get("TimeLimit.truncated", False) self._successes.append(0 if truncated else 1) if len(self._successes) > self.WINDOW: self._successes.pop(0) if (self._cur_sheep < self.max_sheep and len(self._successes) >= self.MIN_EPISODES and np.mean(self._successes) >= self.THRESHOLD): self._cur_sheep += 1 self.training_env.env_method("set_n_sheep", self._cur_sheep) self._successes.clear() if self.verbose: print(f"\n[Curriculum] Advanced to {self._cur_sheep} sheep " f"at step {self.num_timesteps}\n") return True # --------------------------------------------------------------------------- # Environment factory # --------------------------------------------------------------------------- def make_env(n_sheep: int, seed: int, max_steps: int): def _init(): env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps) env.reset(seed=seed) return env return _init # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def parse_args(): p = argparse.ArgumentParser() p.add_argument("--n-sheep", type=int, default=1, help="Starting number of sheep (or fixed count if no curriculum)") p.add_argument("--max-sheep", type=int, default=5, help="Maximum sheep for curriculum (ignored without --curriculum)") p.add_argument("--n-envs", type=int, default=8, help="Number of parallel environments") p.add_argument("--total-steps", type=int, default=5_000_000, help="Total environment steps to train for") p.add_argument("--max-steps", type=int, default=2000, help="Episode step limit inside each env") p.add_argument("--curriculum", action="store_true", help="Enable automatic curriculum advancement") p.add_argument("--resume", type=str, default=None, help="Path to a .zip checkpoint to resume training from") p.add_argument("--run-dir", type=str, default="runs/ppo_herding", help="Output directory for checkpoints and logs") p.add_argument("--save-freq", type=int, default=100_000, help="Checkpoint every N steps (per-env, not total)") p.add_argument("--eval-freq", type=int, default=50_000, help="Evaluate every N steps") p.add_argument("--eval-eps", type=int, default=20, help="Episodes per evaluation run") return p.parse_args() def main(): args = parse_args() os.makedirs(args.run_dir, exist_ok=True) ckpt_dir = os.path.join(args.run_dir, "checkpoints") best_dir = os.path.join(args.run_dir, "best_model") norm_path = os.path.join(args.run_dir, "vecnorm.pkl") os.makedirs(ckpt_dir, exist_ok=True) # Training envs train_env = SubprocVecEnv([ make_env(args.n_sheep, seed=i, max_steps=args.max_steps) for i in range(args.n_envs) ]) if args.resume and os.path.exists(norm_path): train_env = VecNormalize.load(norm_path, train_env) train_env.training = True train_env.norm_reward = True else: train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0) # Eval env (no reward normalisation, deterministic) eval_env = SubprocVecEnv([ make_env(args.n_sheep, seed=1000 + i, max_steps=args.max_steps) for i in range(2) ]) eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, clip_obs=10.0, training=False) # Callbacks checkpoint_cb = CheckpointCallback( save_freq=max(args.save_freq // args.n_envs, 1), save_path=ckpt_dir, name_prefix="ckpt", save_vecnormalize=True, ) eval_cb = EvalCallback( eval_env, best_model_save_path=best_dir, log_path=args.run_dir, eval_freq=max(args.eval_freq // args.n_envs, 1), n_eval_episodes=args.eval_eps, deterministic=True, verbose=1, ) callbacks = [checkpoint_cb, eval_cb] if args.curriculum: callbacks.append(CurriculumCallback(start_sheep=args.n_sheep, max_sheep=args.max_sheep)) callback_list = CallbackList(callbacks) # Model ppo_kwargs = dict( policy = "MlpPolicy", env = train_env, learning_rate = 3e-4, n_steps = 2048, batch_size = 256, n_epochs = 10, gamma = 0.995, gae_lambda = 0.95, clip_range = 0.2, ent_coef = 0.005, vf_coef = 0.5, max_grad_norm = 0.5, policy_kwargs = dict(net_arch=[256, 256]), tensorboard_log = args.run_dir, verbose = 1, ) if args.resume: print(f"Resuming from {args.resume}") model = PPO.load(args.resume, env=train_env, **{ k: v for k, v in ppo_kwargs.items() if k not in ("policy", "env") }) else: model = PPO(**ppo_kwargs) model.learn( total_timesteps=args.total_steps, callback=callback_list, reset_num_timesteps=args.resume is None, tb_log_name="ppo", ) # Save final artefacts model.save(os.path.join(args.run_dir, "final_model")) train_env.save(norm_path) print(f"\nTraining complete. Artefacts saved to {args.run_dir}/") if __name__ == "__main__": main()