"""PPO trainer for the shepherd-dog policy — EXPERIMENTAL. The deliverable pipeline is `bc_pretrain.py` (see ``training/README.md``). This script is kept in the tree because it implements: * PPO from scratch with curriculum over flock size + spawn area, and * PPO fine-tune of a behavior-cloned policy. Both ran into stability issues in our setting (long-horizon credit assignment for sparse pen reward, BC-degradation under PPO exploration noise). The abstractions are reusable for follow-up work — e.g. KL-regularised fine-tune with a frozen reference policy — so we leave the code in place. Usage (PPO from scratch):: python -m training.train_ppo \ --config training/configs/ppo_default.yaml \ --out-dir training/runs/ppo_scratch Usage (PPO fine-tune of BC):: python -m training.train_ppo \ --resume training/runs/bc_flock/policy.zip \ --out-dir training/runs/bc_ppo \ --no-vecnorm --no-curriculum --imitate-weight 0 \ --difficulty 1.0 --log-std -1.5 --learning-rate 5e-5 \ --total-timesteps 3000000 """ from __future__ import annotations import argparse import os import sys from pathlib import Path import yaml _HERE = os.path.dirname(os.path.abspath(__file__)) _PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..")) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) import numpy as np import torch as th from stable_baselines3 import PPO from stable_baselines3.common.callbacks import ( BaseCallback, CheckpointCallback, EvalCallback, ) from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import ( DummyVecEnv, SubprocVecEnv, VecNormalize, ) from training.herding_env import HerdingEnv # -------------------------------------------------------------------------- # Env factories # -------------------------------------------------------------------------- def _make_env(rank: int, seed: int = 0): def _thunk(): env = HerdingEnv(seed=seed + rank) env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned")) return env return _thunk # -------------------------------------------------------------------------- # Curriculum callback # -------------------------------------------------------------------------- class CurriculumCallback(BaseCallback): """Drive the env's flock-size + state-space difficulty curriculum. Schedule entries: {step, max_n_sheep, difficulty}. The largest entry whose step <= num_timesteps wins; both knobs update together. """ def __init__(self, schedule, vec_envs, verbose: int = 0): super().__init__(verbose) self.schedule = sorted(schedule, key=lambda d: d["step"]) # Accept a list of envs so the eval env tracks training difficulty. self.vec_envs = vec_envs if isinstance(vec_envs, (list, tuple)) else [vec_envs] self._last_n = None self._last_d = None def _call(self, method, value): for v in self.vec_envs: try: v.env_method(method, value) except AttributeError: v.venv.env_method(method, value) def _on_step(self) -> bool: t = self.num_timesteps n = self.schedule[0]["max_n_sheep"] d = self.schedule[0].get("difficulty", 1.0) for entry in self.schedule: if t >= entry["step"]: n = entry["max_n_sheep"] d = entry.get("difficulty", 1.0) if n != self._last_n: self._call("set_max_n_sheep", n) self._last_n = n if d != self._last_d: self._call("set_difficulty", d) self._last_d = d if self.verbose: print(f"[curriculum] t={t} → max_n_sheep={n} difficulty={d}") return True # -------------------------------------------------------------------------- # Main # -------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default=os.path.join(_HERE, "configs", "ppo_default.yaml")) parser.add_argument("--out-dir", default=os.path.join(_HERE, "runs", "latest")) parser.add_argument("--n-envs", type=int, default=None, help="Override config n_envs.") parser.add_argument("--total-timesteps", type=int, default=None, help="Override config total_timesteps.") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--resume", type=str, default=None, help="Path to a SB3 zip to resume from.") # SB3 recommends CPU for MlpPolicy — GPU helps CNN policies, not MLPs # of this size. Override with --device cuda if you really want it. parser.add_argument("--device", default="cpu") parser.add_argument("--no-vecnorm", action="store_true", help="Disable VecNormalize wrapper. Required when " "resuming from a BC-pretrained policy that " "wasn't trained under it.") parser.add_argument("--no-curriculum", action="store_true", help="Skip curriculum callback (resumed policy is " "already competent across the distribution).") parser.add_argument("--imitate-weight", type=float, default=None, help="Override env W_IMITATE. Set to 0 to disable " "Strömbom imitation reward.") parser.add_argument("--difficulty", type=float, default=None, help="Override env difficulty (0=easy, 1=hard). " "Used in BC fine-tune to skip easy curriculum.") parser.add_argument("--log-std", type=float, default=None, help="Override the policy's log_std after load. " "BC trained with std≈1.6 (log_std=0.5) which " "is too noisy for fine-tune. Use -1.5 (std≈0.22) " "to keep PPO close to the BC mean while still " "exploring locally.") parser.add_argument("--learning-rate", type=float, default=None, help="Override config learning rate. For BC " "fine-tune, 5e-5 is much safer than the 3e-4 " "default.") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) n_envs = args.n_envs or cfg["n_envs"] total_timesteps = args.total_timesteps or cfg["total_timesteps"] out = Path(args.out_dir) out.mkdir(parents=True, exist_ok=True) (out / "checkpoints").mkdir(exist_ok=True) (out / "best").mkdir(exist_ok=True) (out / "evals").mkdir(exist_ok=True) print(f"[train] out={out} n_envs={n_envs} total={total_timesteps} device={args.device}") # --- Train env (vectorised, optionally normalised) --- env_fns = [_make_env(i, seed=args.seed) for i in range(n_envs)] venv = SubprocVecEnv(env_fns) if n_envs > 1 else DummyVecEnv(env_fns) eval_venv = DummyVecEnv([_make_env(99, seed=args.seed + 999)]) if not args.no_vecnorm: venv = VecNormalize(venv, norm_obs=True, norm_reward=False, clip_obs=10.0) eval_venv = VecNormalize(eval_venv, norm_obs=True, norm_reward=False, clip_obs=10.0, training=False) eval_venv.obs_rms = venv.obs_rms else: print("[train] VecNormalize disabled (resumed policy was trained without it).") # Apply env-level overrides (used by BC fine-tune to disable Strömbom # imitation and start at full deployment difficulty). def _env_call(method, value): for v in (venv, eval_venv): try: v.env_method(method, value) except AttributeError: v.venv.env_method(method, value) if args.imitate_weight is not None: _env_call("set_imitate_weight", args.imitate_weight) print(f"[train] W_IMITATE overridden to {args.imitate_weight}") if args.difficulty is not None: _env_call("set_difficulty", args.difficulty) print(f"[train] difficulty pinned to {args.difficulty}") # --- Model --- policy_kwargs = dict( net_arch=dict(pi=cfg["net_arch_pi"], vf=cfg["net_arch_vf"]), log_std_init=cfg.get("log_std_init", 0.0), ) if args.resume: print(f"[train] resuming from {args.resume}") custom_objects = {} if args.learning_rate is not None: custom_objects["learning_rate"] = args.learning_rate model = PPO.load(args.resume, env=venv, device=args.device, tensorboard_log=str(out / "tb"), custom_objects=custom_objects or None) if args.log_std is not None: import torch as _th with _th.no_grad(): model.policy.log_std.fill_(args.log_std) print(f"[train] log_std overridden to {args.log_std} " f"(std≈{2.71828 ** args.log_std:.2f})") if args.learning_rate is not None: print(f"[train] learning_rate overridden to {args.learning_rate}") else: model = PPO( cfg["policy"], venv, learning_rate=cfg["learning_rate"], n_steps=cfg["n_steps"], batch_size=cfg["batch_size"], n_epochs=cfg["n_epochs"], gamma=cfg["gamma"], gae_lambda=cfg["gae_lambda"], clip_range=cfg["clip_range"], ent_coef=cfg["ent_coef"], vf_coef=cfg["vf_coef"], max_grad_norm=cfg["max_grad_norm"], target_kl=cfg.get("target_kl"), policy_kwargs=policy_kwargs, tensorboard_log=str(out / "tb"), seed=args.seed, device=args.device, verbose=1, ) # --- Callbacks --- ckpt_cb = CheckpointCallback( save_freq=max(1, cfg["checkpoint_freq"] // n_envs), save_path=str(out / "checkpoints"), name_prefix="ppo", save_vecnormalize=True, ) eval_cb = EvalCallback( eval_venv, best_model_save_path=str(out / "best"), log_path=str(out / "evals"), eval_freq=max(1, cfg["eval_freq"] // n_envs), n_eval_episodes=cfg["n_eval_episodes"], deterministic=True, ) callbacks = [ckpt_cb, eval_cb] if not args.no_curriculum and "curriculum" in cfg and cfg["curriculum"]: callbacks.append(CurriculumCallback( cfg["curriculum"], [venv, eval_venv], verbose=1, )) elif args.no_curriculum: print("[train] curriculum disabled — env knobs left at their current values.") # --- Train --- model.learn(total_timesteps=total_timesteps, callback=callbacks, progress_bar=True) # --- Save final model + VecNormalize stats --- model.save(out / "final.zip") venv.save(str(out / "vecnormalize.pkl")) # The EvalCallback already wrote best_model.zip into out/best/ — drop the # VecNormalize stats next to it for the controller to pick up. venv.save(str(out / "best" / "vecnormalize.pkl")) print(f"[train] done. saved to {out}") if __name__ == "__main__": main()