"""KL-regularised PPO fine-tune of a behaviour-cloned policy. The trainable policy is initialised from ``runs/bc/policy.zip``. A frozen copy of the same weights becomes the reference; each PPO loss gets an extra ``β · KL(π ‖ π_ref)`` term so the policy can only move within a trust region around BC. ``log_std`` is fixed small to keep exploration tight. Output: ``runs/rl/policy.zip`` — same SB3 format as the BC checkpoint, loadable by ``HERDING_MODE=rl`` in the dog controller. Usage:: python -m training.rl.train \\ --bc training/runs/bc \\ --out training/runs/rl \\ --total-timesteps 2000000 """ from __future__ import annotations import argparse import os from pathlib import Path # Early CLI pre-parse for --world so geometry is configured before any # herding.* / training.* import binds geometry constants. Matches the # pattern used by training.bc.collect and training.eval. _pre_argv = [a for a in os.sys.argv[1:]] _pre_world = None for i, a in enumerate(_pre_argv): if a == "--world" and i + 1 < len(_pre_argv): _pre_world = _pre_argv[i + 1] break if a.startswith("--world="): _pre_world = a.split("=", 1)[1] break if _pre_world is not None: from herding.world.geometry import configure as _geo_configure _geo_configure(_pre_world) os.environ["HERDING_WORLD"] = _pre_world import numpy as np import torch as th import torch.nn.functional as F from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv from herding.perception.obs import OBS_DIM from training.herding_env import HerdingEnv # -------------------------------------------------------------------- # Env factory # -------------------------------------------------------------------- def _make_env(rank: int, seed: int, frame_stack: int, drive_mode: str = "differential", difficulty: float = 1.0, max_n_sheep: int = 10): def _thunk(): env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack, drive_mode=drive_mode, difficulty=difficulty, max_n_sheep=max_n_sheep) env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned")) return env return _thunk # -------------------------------------------------------------------- # KL-regularised PPO # -------------------------------------------------------------------- class KLPPO(PPO): """PPO with an extra KL-to-reference penalty in the policy loss. Overrides only ``train()``; rollout buffer, clipped surrogate, value loss and entropy bonus are unchanged from stock SB3 PPO. """ def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs): super().__init__(*args, **kwargs) self.ref_policy = ref_policy if self.ref_policy is not None: self.ref_policy.set_training_mode(False) for p in self.ref_policy.parameters(): p.requires_grad = False self.kl_coef = kl_coef def train(self) -> None: # Stock SB3 PPO.train() structure with the KL-to-reference term # added inside the inner minibatch loop. self.policy.set_training_mode(True) self._update_learning_rate(self.policy.optimizer) clip_range = self.clip_range(self._current_progress_remaining) if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf(self._current_progress_remaining) entropy_losses, pg_losses, value_losses, kl_losses = [], [], [], [] clip_fractions = [] continue_training = True for epoch in range(self.n_epochs): approx_kl_divs = [] for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions if isinstance(self.action_space, th.distributions.Categorical.__bases__): actions = rollout_data.actions.long().flatten() values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions) values = values.flatten() advantages = rollout_data.advantages if self.normalize_advantage and len(advantages) > 1: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) ratio = th.exp(log_prob - rollout_data.old_log_prob) policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() pg_losses.append(policy_loss.item()) clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: values_pred = values else: values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf) value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) if entropy is None: entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) # KL-to-reference: closed-form KL between two diagonal # Gaussians, summed over the action axis, mean over batch. if self.ref_policy is None: raise RuntimeError("KLPPO.train called without ref_policy") with th.no_grad(): ref_dist = self.ref_policy.get_distribution(rollout_data.observations) pi_dist = self.policy.get_distribution(rollout_data.observations) kl_div = th.distributions.kl.kl_divergence( pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean() kl_losses.append(kl_div.item()) loss = (policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + self.kl_coef * kl_div) with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: continue_training = False if self.verbose >= 1: print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") break self.policy.optimizer.zero_grad() loss.backward() th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() self._n_updates += 1 if not continue_training: break explained_var = self._explained_variance() self.logger.record("train/entropy_loss", float(np.mean(entropy_losses))) self.logger.record("train/policy_gradient_loss", float(np.mean(pg_losses))) self.logger.record("train/value_loss", float(np.mean(value_losses))) self.logger.record("train/kl_to_reference", float(np.mean(kl_losses))) self.logger.record("train/approx_kl", float(np.mean(approx_kl_divs))) self.logger.record("train/clip_fraction", float(np.mean(clip_fractions))) self.logger.record("train/explained_variance", float(explained_var)) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) def _explained_variance(self) -> float: y_pred = self.rollout_buffer.values.flatten() y_true = self.rollout_buffer.returns.flatten() var_y = np.var(y_true) return float("nan") if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y # -------------------------------------------------------------------- # Main # -------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--bc", default="training/runs/bc", help="Directory containing the BC initialisation.") parser.add_argument("--out", default="training/runs/rl", help="Where to save the fine-tuned policy.") parser.add_argument("--total-timesteps", type=int, default=2_000_000) parser.add_argument("--n-envs", type=int, default=8) parser.add_argument("--learning-rate", type=float, default=5e-5) parser.add_argument("--kl-coef", type=float, default=0.05, help="Coefficient of the KL-to-reference penalty.") parser.add_argument("--log-std", type=float, default=-1.5, help="Initial (and frozen) log_std for exploration.") parser.add_argument("--freeze-log-std", action="store_true", default=True) parser.add_argument("--n-steps", type=int, default=2048) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--n-epochs", type=int, default=10) parser.add_argument("--gamma", type=float, default=0.995) parser.add_argument("--gae-lambda", type=float, default=0.95) parser.add_argument("--clip-range", type=float, default=0.1) parser.add_argument("--ent-coef", type=float, default=0.0) parser.add_argument("--target-kl", type=float, default=0.02, help="SB3 per-batch KL early-stop guard.") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--device", default="cpu") parser.add_argument("--drive-mode", default=None, choices=["differential", "mecanum"], help="Drive mode. If not set, inferred from " "BC action dimension (2→differential, 3→mecanum).") parser.add_argument("--imitate-weight", type=float, default=None, help="Override env.W_IMITATE (e.g. 0.0 to drop " "Strömbom imitation during fine-tune).") parser.add_argument("--time-weight", type=float, default=None, help="Override env.W_TIME (e.g. -0.1 for a " "per-step time penalty).") parser.add_argument("--difficulty", type=float, default=1.0, help="HerdingEnv difficulty for PPO rollouts. " "Must match eval (1.0) to avoid train/eval " "distribution mismatch.") parser.add_argument("--max-n-sheep", type=int, default=10, help="Upper bound on flock size sampled each reset.") parser.add_argument("--world", default=None, choices=["field", "field_round"], help="World shape. If not set, uses HERDING_WORLD " "env var or defaults to 'field'.") args = parser.parse_args() # --world was already honoured in the early pre-parse above; here we # just sanity-check that the final argparse view agrees. if args.world is not None: from herding.world.geometry import FIELD_SHAPE as _CURRENT_SHAPE if args.world != _CURRENT_SHAPE: print(f"[rl] WARNING: --world={args.world} but geometry is " f"'{_CURRENT_SHAPE}'. File a bug.") bc_zip = Path(args.bc) / "policy.zip" if not bc_zip.exists(): raise SystemExit( f"BC checkpoint not found at {bc_zip}. Train bc first with " f"`python -m training.bc.pretrain`." ) out = Path(args.out) out.mkdir(parents=True, exist_ok=True) (out / "checkpoints").mkdir(exist_ok=True) (out / "best").mkdir(exist_ok=True) # Infer frame_stack from the BC checkpoint's obs space. ref_only = PPO.load(str(bc_zip), device=args.device) obs_dim = int(ref_only.observation_space.shape[0]) if obs_dim % OBS_DIM != 0: raise SystemExit(f"BC obs dim {obs_dim} is not a multiple of {OBS_DIM}.") frame_stack = obs_dim // OBS_DIM print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}") # Infer drive mode from BC action dim if not explicitly set. bc_action_dim = int(ref_only.action_space.shape[0]) if args.drive_mode is not None: drive_mode = args.drive_mode elif bc_action_dim == 3: drive_mode = "mecanum" else: drive_mode = "differential" print(f"[rl] drive_mode={drive_mode} (BC action_dim={bc_action_dim})") env_fns = [_make_env(i, args.seed, frame_stack, drive_mode, difficulty=args.difficulty, max_n_sheep=args.max_n_sheep) for i in range(args.n_envs)] venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns) eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack, drive_mode, difficulty=args.difficulty, max_n_sheep=args.max_n_sheep)]) print(f"[rl] difficulty={args.difficulty} max_n_sheep={args.max_n_sheep}") # Reward-shaping overrides (broadcast to every env instance). def _broadcast(method: str, value): for v in (venv, eval_venv): try: v.env_method(method, value) except AttributeError: v.venv.env_method(method, value) if args.imitate_weight is not None: _broadcast("set_imitate_weight", args.imitate_weight) print(f"[rl] W_IMITATE overridden to {args.imitate_weight}") if args.time_weight is not None: _broadcast("set_time_weight", args.time_weight) print(f"[rl] W_TIME overridden to {args.time_weight}") # Build a fresh KLPPO at the right obs/action shape, then copy BC # weights into both the trainable policy and the frozen reference. model = KLPPO( "MlpPolicy", venv, ref_policy=None, # filled in below kl_coef=args.kl_coef, learning_rate=args.learning_rate, n_steps=args.n_steps, batch_size=args.batch_size, n_epochs=args.n_epochs, gamma=args.gamma, gae_lambda=args.gae_lambda, clip_range=args.clip_range, ent_coef=args.ent_coef, target_kl=args.target_kl, policy_kwargs=dict( net_arch=dict(pi=[512, 512], vf=[512, 512]), log_std_init=args.log_std, ), verbose=1, seed=args.seed, device=args.device, tensorboard_log=str(out / "tb"), ) # strict=False — the BC value head wasn't trained; PPO trains it. bc_state = ref_only.policy.state_dict() missing, unexpected = model.policy.load_state_dict(bc_state, strict=False) print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}") ref_policy = type(model.policy)( observation_space=model.observation_space, action_space=model.action_space, lr_schedule=lambda _: 0.0, net_arch=dict(pi=[512, 512], vf=[512, 512]), log_std_init=args.log_std, ).to(args.device) ref_policy.load_state_dict(bc_state, strict=False) model.ref_policy = ref_policy model.ref_policy.set_training_mode(False) for p in model.ref_policy.parameters(): p.requires_grad = False # Force both policies to the same log_std so the KL term measures # mean drift only, not a std mismatch carried over from BC. with th.no_grad(): model.policy.log_std.fill_(args.log_std) model.ref_policy.log_std.fill_(args.log_std) if args.freeze_log_std: model.policy.log_std.requires_grad = False print(f"[rl] log_std frozen at {args.log_std} (σ ≈ {np.exp(args.log_std):.3f})") ckpt_cb = CheckpointCallback( save_freq=max(1, 50_000 // args.n_envs), save_path=str(out / "checkpoints"), name_prefix="ppo", ) # EvalCallback writes /best_model.zip on every new best # eval reward. We send it straight to ``out/`` and rename to # ``policy.zip`` after training so the deployed file lives at the # canonical path. eval_cb = EvalCallback( eval_venv, best_model_save_path=str(out), log_path=str(out / "evals"), eval_freq=max(1, 20_000 // args.n_envs), n_eval_episodes=5, deterministic=True, ) print(f"[rl] training: total_timesteps={args.total_timesteps} " f"n_envs={args.n_envs} lr={args.learning_rate} kl_coef={args.kl_coef}") model.learn(total_timesteps=args.total_timesteps, callback=[ckpt_cb, eval_cb], progress_bar=True) # Save the end-of-training state for debugging convergence behaviour. model.save(out / "final.zip") # Promote the EvalCallback's best-by-eval-reward snapshot to the # canonical ``policy.zip`` (what the controller loads). Fall back # to the final state if eval never recorded a "best". import shutil best_zip = out / "best_model.zip" policy_zip = out / "policy.zip" if best_zip.exists(): if policy_zip.exists(): policy_zip.unlink() best_zip.rename(policy_zip) print(f"[rl] best snapshot → {policy_zip} (final state kept at {out/'final.zip'})") else: shutil.copy(out / "final.zip", policy_zip) print(f"[rl] no best snapshot recorded; using final → {policy_zip}") if __name__ == "__main__": main()