"""KL-regularised PPO fine-tune of a behaviour-cloned policy. The PPO-from-scratch and unregularised PPO-fine-tune-of-BC versions we tried earlier failed for the standard reasons (sparse pen reward, long horizons, exploration noise destroying BC weights). The fix is to anchor the policy to its BC initialisation with a KL penalty in the loss — the policy is free to refine the BC mean within a trust-region-like ball around the reference, and the dense-enough per-step reward signal does the rest. Pipeline -------- 1. Load ``bc_v3`` weights into both the trainable policy and a frozen reference ``ref_policy``. 2. Initialise the policy's log_std to a small fixed value (≈ −1.5) and disable its gradient — exploration noise stays small so PPO updates don't blow up the BC mean before reward can stabilise. 3. Override ``PPO.train()`` to add ``β · KL(π ‖ π_ref)`` to the loss each minibatch. 4. Train for ~1–3 M timesteps with a low LR (5e-5). Output: ``runs/rl_v1/policy.zip`` — same SB3 format as bc_v3, loadable by the dog controller's ``HERDING_MODE=rl`` path. Usage:: python -m training.train_ppo \\ --bc training/runs/bc_v3 \\ --out training/runs/rl_v1 \\ --total-timesteps 2000000 """ from __future__ import annotations import argparse import os import sys from pathlib import Path _HERE = os.path.dirname(os.path.abspath(__file__)) _PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..")) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) import numpy as np import torch as th import torch.nn.functional as F from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv from herding.obs import OBS_DIM from training.herding_env import HerdingEnv # -------------------------------------------------------------------- # Env factory # -------------------------------------------------------------------- def _make_env(rank: int, seed: int, frame_stack: int): def _thunk(): env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack) env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned")) return env return _thunk # -------------------------------------------------------------------- # KL-regularised PPO # -------------------------------------------------------------------- class KLPPO(PPO): """PPO with an extra KL-to-reference penalty in the policy loss. Subclasses SB3's PPO and overrides ``train()`` only to add a single line for the KL term — everything else (rollout buffer, clipped surrogate, value loss, entropy bonus) is unchanged. """ def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs): super().__init__(*args, **kwargs) # ref_policy is set after construction (caller can build it # from the BC checkpoint once `self.policy` exists). self.ref_policy = ref_policy if self.ref_policy is not None: self.ref_policy.set_training_mode(False) for p in self.ref_policy.parameters(): p.requires_grad = False self.kl_coef = kl_coef def train(self) -> None: # Copied from stable_baselines3.ppo.PPO.train (v2.x), with the # KL-to-reference term added. Keeping the structure intact so # behavioural parity with stock PPO is obvious. self.policy.set_training_mode(True) self._update_learning_rate(self.policy.optimizer) clip_range = self.clip_range(self._current_progress_remaining) if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf(self._current_progress_remaining) entropy_losses, pg_losses, value_losses, kl_losses = [], [], [], [] clip_fractions = [] continue_training = True for epoch in range(self.n_epochs): approx_kl_divs = [] for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions if isinstance(self.action_space, th.distributions.Categorical.__bases__): actions = rollout_data.actions.long().flatten() values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions) values = values.flatten() advantages = rollout_data.advantages if self.normalize_advantage and len(advantages) > 1: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) ratio = th.exp(log_prob - rollout_data.old_log_prob) policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() pg_losses.append(policy_loss.item()) clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: values_pred = values else: values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf) value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) if entropy is None: entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) # --- KL-to-reference term ---------------------------- # Both policies are diagonal Gaussian (ActorCriticPolicy). # KL(π ‖ π_ref) per-action-dim; sum over the action axis # to get total KL per sample, then mean over batch. # Computed on the rollout's observations so the penalty # reflects what the agent actually saw. if self.ref_policy is None: raise RuntimeError("KLPPO.train called without ref_policy") with th.no_grad(): ref_dist = self.ref_policy.get_distribution(rollout_data.observations) pi_dist = self.policy.get_distribution(rollout_data.observations) kl_div = th.distributions.kl.kl_divergence( pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean() kl_losses.append(kl_div.item()) # ---------------------------------------------------- loss = (policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + self.kl_coef * kl_div) with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: continue_training = False if self.verbose >= 1: print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") break self.policy.optimizer.zero_grad() loss.backward() th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() self._n_updates += 1 if not continue_training: break explained_var = self._explained_variance() self.logger.record("train/entropy_loss", float(np.mean(entropy_losses))) self.logger.record("train/policy_gradient_loss", float(np.mean(pg_losses))) self.logger.record("train/value_loss", float(np.mean(value_losses))) self.logger.record("train/kl_to_reference", float(np.mean(kl_losses))) self.logger.record("train/approx_kl", float(np.mean(approx_kl_divs))) self.logger.record("train/clip_fraction", float(np.mean(clip_fractions))) self.logger.record("train/explained_variance", float(explained_var)) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) def _explained_variance(self) -> float: # SB3 doesn't expose this as a method; replicate the computation. y_pred = self.rollout_buffer.values.flatten() y_true = self.rollout_buffer.returns.flatten() var_y = np.var(y_true) return float("nan") if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y # -------------------------------------------------------------------- # Main # -------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--bc", default="training/runs/bc_v3", help="Directory containing the BC initialisation (policy.zip).") parser.add_argument("--out", default="training/runs/rl_v1", help="Where to save the fine-tuned policy.") parser.add_argument("--total-timesteps", type=int, default=2_000_000) parser.add_argument("--n-envs", type=int, default=8) parser.add_argument("--learning-rate", type=float, default=5e-5, help="Low LR keeps PPO close to the BC mean.") parser.add_argument("--kl-coef", type=float, default=0.05, help="KL-to-reference penalty coefficient.") parser.add_argument("--log-std", type=float, default=-1.5, help="Initial (and frozen) log_std. σ ≈ exp(-1.5) ≈ 0.22.") parser.add_argument("--freeze-log-std", action="store_true", default=True, help="Keep log_std fixed; only the policy mean updates.") parser.add_argument("--n-steps", type=int, default=2048, help="Steps per rollout per env.") parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--n-epochs", type=int, default=10) parser.add_argument("--gamma", type=float, default=0.995) parser.add_argument("--gae-lambda", type=float, default=0.95) parser.add_argument("--clip-range", type=float, default=0.1, help="Tight clip range — keep updates conservative.") parser.add_argument("--ent-coef", type=float, default=0.0) parser.add_argument("--target-kl", type=float, default=0.02, help="SB3's per-batch KL early stop; safety belt.") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--device", default="cpu") args = parser.parse_args() bc_zip = Path(args.bc) / "policy.zip" if not bc_zip.exists(): raise SystemExit( f"BC checkpoint not found at {bc_zip}. Train bc_v3 first with " f"`python -m training.bc_pretrain`." ) out = Path(args.out) out.mkdir(parents=True, exist_ok=True) (out / "checkpoints").mkdir(exist_ok=True) (out / "best").mkdir(exist_ok=True) # --- Inspect BC obs dim → infer frame_stack --- ref_only = PPO.load(str(bc_zip), device=args.device) obs_dim = int(ref_only.observation_space.shape[0]) if obs_dim % OBS_DIM != 0: raise SystemExit(f"BC obs dim {obs_dim} is not a multiple of {OBS_DIM}.") frame_stack = obs_dim // OBS_DIM print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}") # --- Vectorised envs (match BC obs space) --- env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)] venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns) eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)]) # --- Trainable policy: load BC weights, then bolt onto PPO --- # Trick: instantiate a PPO with the right env (so the policy # network is constructed at the correct obs/action shape), then # copy BC weights into it. model = KLPPO( "MlpPolicy", venv, ref_policy=None, # filled in below kl_coef=args.kl_coef, learning_rate=args.learning_rate, n_steps=args.n_steps, batch_size=args.batch_size, n_epochs=args.n_epochs, gamma=args.gamma, gae_lambda=args.gae_lambda, clip_range=args.clip_range, ent_coef=args.ent_coef, target_kl=args.target_kl, policy_kwargs=dict( net_arch=dict(pi=[512, 512], vf=[512, 512]), log_std_init=args.log_std, ), verbose=1, seed=args.seed, device=args.device, tensorboard_log=str(out / "tb"), ) # --- Load BC weights into both `model.policy` and `ref_policy` --- bc_state = ref_only.policy.state_dict() # Strict=False because the value head may not have been trained in # BC — that's fine, PPO will train it from scratch. missing, unexpected = model.policy.load_state_dict(bc_state, strict=False) print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}") # Build a separate reference policy with identical architecture and # the BC weights, frozen. ref_policy = type(model.policy)( observation_space=model.observation_space, action_space=model.action_space, lr_schedule=lambda _: 0.0, net_arch=dict(pi=[512, 512], vf=[512, 512]), log_std_init=args.log_std, ).to(args.device) ref_policy.load_state_dict(bc_state, strict=False) model.ref_policy = ref_policy model.ref_policy.set_training_mode(False) for p in model.ref_policy.parameters(): p.requires_grad = False # Align both policies' log_std. BC was trained with log_std≈0.5 # (σ≈1.65), which would make the KL term huge from a std mismatch # rather than the mean drift we actually care about. Force both to # the same small value so KL measures only how far the policy mean # has drifted from the BC mean. with th.no_grad(): model.policy.log_std.fill_(args.log_std) model.ref_policy.log_std.fill_(args.log_std) if args.freeze_log_std: model.policy.log_std.requires_grad = False print(f"[rl] log_std frozen at {args.log_std} (σ ≈ {np.exp(args.log_std):.3f})") # --- Callbacks --- ckpt_cb = CheckpointCallback( save_freq=max(1, 50_000 // args.n_envs), save_path=str(out / "checkpoints"), name_prefix="ppo", ) eval_cb = EvalCallback( eval_venv, best_model_save_path=str(out / "best"), log_path=str(out / "evals"), eval_freq=max(1, 20_000 // args.n_envs), n_eval_episodes=5, deterministic=True, ) print(f"[rl] training: total_timesteps={args.total_timesteps} " f"n_envs={args.n_envs} lr={args.learning_rate} kl_coef={args.kl_coef}") model.learn(total_timesteps=args.total_timesteps, callback=[ckpt_cb, eval_cb], progress_bar=True) # --- Save final checkpoint in the SB3 zip the controller expects --- model.save(out / "policy.zip") print(f"[rl] saved fine-tuned policy → {out/'policy.zip'}") if __name__ == "__main__": main()