Checkpoint 4

This commit is contained in:
Johnny Fernandes
2026-05-11 00:42:52 +01:00
parent 2a6db038df
commit 6688325d89
26 changed files with 2018 additions and 503 deletions
+275 -206
View File
@@ -1,31 +1,33 @@
"""PPO trainer for the shepherd-dog policy — EXPERIMENTAL.
"""KL-regularised PPO fine-tune of a behaviour-cloned policy.
The deliverable pipeline is `bc_pretrain.py` (see ``training/README.md``).
This script is kept in the tree because it implements:
The PPO-from-scratch and unregularised PPO-fine-tune-of-BC versions
we tried earlier failed for the standard reasons (sparse pen reward,
long horizons, exploration noise destroying BC weights). The fix is
to anchor the policy to its BC initialisation with a KL penalty in
the loss — the policy is free to refine the BC mean within a
trust-region-like ball around the reference, and the dense-enough
per-step reward signal does the rest.
* PPO from scratch with curriculum over flock size + spawn area, and
* PPO fine-tune of a behavior-cloned policy.
Pipeline
--------
1. Load ``bc_v3`` weights into both the trainable policy and a frozen
reference ``ref_policy``.
2. Initialise the policy's log_std to a small fixed value (≈ 1.5)
and disable its gradient — exploration noise stays small so PPO
updates don't blow up the BC mean before reward can stabilise.
3. Override ``PPO.train()`` to add ``β · KL(π ‖ π_ref)`` to the loss
each minibatch.
4. Train for ~13 M timesteps with a low LR (5e-5).
Both ran into stability issues in our setting (long-horizon credit
assignment for sparse pen reward, BC-degradation under PPO exploration
noise). The abstractions are reusable for follow-up work — e.g.
KL-regularised fine-tune with a frozen reference policy — so we leave
the code in place.
Output: ``runs/rl_v1/policy.zip`` — same SB3 format as bc_v3, loadable
by the dog controller's ``HERDING_MODE=rl`` path.
Usage (PPO from scratch)::
Usage::
python -m training.train_ppo \
--config training/configs/ppo_default.yaml \
--out-dir training/runs/ppo_scratch
Usage (PPO fine-tune of BC)::
python -m training.train_ppo \
--resume training/runs/bc_flock/policy.zip \
--out-dir training/runs/bc_ppo \
--no-vecnorm --no-curriculum --imitate-weight 0 \
--difficulty 1.0 --log-std -1.5 --learning-rate 5e-5 \
--total-timesteps 3000000
python -m training.train_ppo \\
--bc training/runs/bc_v3 \\
--out training/runs/rl_v1 \\
--total-timesteps 2000000
"""
from __future__ import annotations
@@ -35,8 +37,6 @@ import os
import sys
from pathlib import Path
import yaml
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
@@ -44,236 +44,305 @@ if _PROJECT_ROOT not in sys.path:
import numpy as np
import torch as th
import torch.nn.functional as F
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import (
BaseCallback, CheckpointCallback, EvalCallback,
)
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import (
DummyVecEnv, SubprocVecEnv, VecNormalize,
)
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from herding.obs import OBS_DIM
from training.herding_env import HerdingEnv
# --------------------------------------------------------------------------
# Env factories
# --------------------------------------------------------------------------
# --------------------------------------------------------------------
# Env factory
# --------------------------------------------------------------------
def _make_env(rank: int, seed: int = 0):
def _make_env(rank: int, seed: int, frame_stack: int):
def _thunk():
env = HerdingEnv(seed=seed + rank)
env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack)
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
return env
return _thunk
# --------------------------------------------------------------------------
# Curriculum callback
# --------------------------------------------------------------------------
# --------------------------------------------------------------------
# KL-regularised PPO
# --------------------------------------------------------------------
class CurriculumCallback(BaseCallback):
"""Drive the env's flock-size + state-space difficulty curriculum.
class KLPPO(PPO):
"""PPO with an extra KL-to-reference penalty in the policy loss.
Schedule entries: {step, max_n_sheep, difficulty}. The largest entry
whose step <= num_timesteps wins; both knobs update together.
Subclasses SB3's PPO and overrides ``train()`` only to add a single
line for the KL term — everything else (rollout buffer, clipped
surrogate, value loss, entropy bonus) is unchanged.
"""
def __init__(self, schedule, vec_envs, verbose: int = 0):
super().__init__(verbose)
self.schedule = sorted(schedule, key=lambda d: d["step"])
# Accept a list of envs so the eval env tracks training difficulty.
self.vec_envs = vec_envs if isinstance(vec_envs, (list, tuple)) else [vec_envs]
self._last_n = None
self._last_d = None
def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs):
super().__init__(*args, **kwargs)
# ref_policy is set after construction (caller can build it
# from the BC checkpoint once `self.policy` exists).
self.ref_policy = ref_policy
if self.ref_policy is not None:
self.ref_policy.set_training_mode(False)
for p in self.ref_policy.parameters():
p.requires_grad = False
self.kl_coef = kl_coef
def _call(self, method, value):
for v in self.vec_envs:
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
def train(self) -> None:
# Copied from stable_baselines3.ppo.PPO.train (v2.x), with the
# KL-to-reference term added. Keeping the structure intact so
# behavioural parity with stock PPO is obvious.
self.policy.set_training_mode(True)
self._update_learning_rate(self.policy.optimizer)
clip_range = self.clip_range(self._current_progress_remaining)
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
def _on_step(self) -> bool:
t = self.num_timesteps
n = self.schedule[0]["max_n_sheep"]
d = self.schedule[0].get("difficulty", 1.0)
for entry in self.schedule:
if t >= entry["step"]:
n = entry["max_n_sheep"]
d = entry.get("difficulty", 1.0)
if n != self._last_n:
self._call("set_max_n_sheep", n)
self._last_n = n
if d != self._last_d:
self._call("set_difficulty", d)
self._last_d = d
if self.verbose:
print(f"[curriculum] t={t} → max_n_sheep={n} difficulty={d}")
return True
entropy_losses, pg_losses, value_losses, kl_losses = [], [], [], []
clip_fractions = []
continue_training = True
for epoch in range(self.n_epochs):
approx_kl_divs = []
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, th.distributions.Categorical.__bases__):
actions = rollout_data.actions.long().flatten()
values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations, actions)
values = values.flatten()
advantages = rollout_data.advantages
if self.normalize_advantage and len(advantages) > 1:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
values_pred = values
else:
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf)
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
if entropy is None:
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
# --- KL-to-reference term ----------------------------
# Both policies are diagonal Gaussian (ActorCriticPolicy).
# KL(π ‖ π_ref) per-action-dim; sum over the action axis
# to get total KL per sample, then mean over batch.
# Computed on the rollout's observations so the penalty
# reflects what the agent actually saw.
if self.ref_policy is None:
raise RuntimeError("KLPPO.train called without ref_policy")
with th.no_grad():
ref_dist = self.ref_policy.get_distribution(rollout_data.observations)
pi_dist = self.policy.get_distribution(rollout_data.observations)
kl_div = th.distributions.kl.kl_divergence(
pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean()
kl_losses.append(kl_div.item())
# ----------------------------------------------------
loss = (policy_loss
+ self.ent_coef * entropy_loss
+ self.vf_coef * value_loss
+ self.kl_coef * kl_div)
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
self.policy.optimizer.zero_grad()
loss.backward()
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
self._n_updates += 1
if not continue_training:
break
explained_var = self._explained_variance()
self.logger.record("train/entropy_loss", float(np.mean(entropy_losses)))
self.logger.record("train/policy_gradient_loss", float(np.mean(pg_losses)))
self.logger.record("train/value_loss", float(np.mean(value_losses)))
self.logger.record("train/kl_to_reference", float(np.mean(kl_losses)))
self.logger.record("train/approx_kl", float(np.mean(approx_kl_divs)))
self.logger.record("train/clip_fraction", float(np.mean(clip_fractions)))
self.logger.record("train/explained_variance", float(explained_var))
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def _explained_variance(self) -> float:
# SB3 doesn't expose this as a method; replicate the computation.
y_pred = self.rollout_buffer.values.flatten()
y_true = self.rollout_buffer.returns.flatten()
var_y = np.var(y_true)
return float("nan") if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# --------------------------------------------------------------------------
# --------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------
# --------------------------------------------------------------------
def main():
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=os.path.join(_HERE, "configs", "ppo_default.yaml"))
parser.add_argument("--out-dir", default=os.path.join(_HERE, "runs", "latest"))
parser.add_argument("--n-envs", type=int, default=None,
help="Override config n_envs.")
parser.add_argument("--total-timesteps", type=int, default=None,
help="Override config total_timesteps.")
parser.add_argument("--bc", default="training/runs/bc_v3",
help="Directory containing the BC initialisation (policy.zip).")
parser.add_argument("--out", default="training/runs/rl_v1",
help="Where to save the fine-tuned policy.")
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
parser.add_argument("--n-envs", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=5e-5,
help="Low LR keeps PPO close to the BC mean.")
parser.add_argument("--kl-coef", type=float, default=0.05,
help="KL-to-reference penalty coefficient.")
parser.add_argument("--log-std", type=float, default=-1.5,
help="Initial (and frozen) log_std. σ ≈ exp(-1.5) ≈ 0.22.")
parser.add_argument("--freeze-log-std", action="store_true", default=True,
help="Keep log_std fixed; only the policy mean updates.")
parser.add_argument("--n-steps", type=int, default=2048,
help="Steps per rollout per env.")
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--n-epochs", type=int, default=10)
parser.add_argument("--gamma", type=float, default=0.995)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--clip-range", type=float, default=0.1,
help="Tight clip range — keep updates conservative.")
parser.add_argument("--ent-coef", type=float, default=0.0)
parser.add_argument("--target-kl", type=float, default=0.02,
help="SB3's per-batch KL early stop; safety belt.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--resume", type=str, default=None,
help="Path to a SB3 zip to resume from.")
# SB3 recommends CPU for MlpPolicy — GPU helps CNN policies, not MLPs
# of this size. Override with --device cuda if you really want it.
parser.add_argument("--device", default="cpu")
parser.add_argument("--no-vecnorm", action="store_true",
help="Disable VecNormalize wrapper. Required when "
"resuming from a BC-pretrained policy that "
"wasn't trained under it.")
parser.add_argument("--no-curriculum", action="store_true",
help="Skip curriculum callback (resumed policy is "
"already competent across the distribution).")
parser.add_argument("--imitate-weight", type=float, default=None,
help="Override env W_IMITATE. Set to 0 to disable "
"Strömbom imitation reward.")
parser.add_argument("--difficulty", type=float, default=None,
help="Override env difficulty (0=easy, 1=hard). "
"Used in BC fine-tune to skip easy curriculum.")
parser.add_argument("--log-std", type=float, default=None,
help="Override the policy's log_std after load. "
"BC trained with std≈1.6 (log_std=0.5) which "
"is too noisy for fine-tune. Use -1.5 (std≈0.22) "
"to keep PPO close to the BC mean while still "
"exploring locally.")
parser.add_argument("--learning-rate", type=float, default=None,
help="Override config learning rate. For BC "
"fine-tune, 5e-5 is much safer than the 3e-4 "
"default.")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
bc_zip = Path(args.bc) / "policy.zip"
if not bc_zip.exists():
raise SystemExit(
f"BC checkpoint not found at {bc_zip}. Train bc_v3 first with "
f"`python -m training.bc_pretrain`."
)
n_envs = args.n_envs or cfg["n_envs"]
total_timesteps = args.total_timesteps or cfg["total_timesteps"]
out = Path(args.out_dir)
out = Path(args.out)
out.mkdir(parents=True, exist_ok=True)
(out / "checkpoints").mkdir(exist_ok=True)
(out / "best").mkdir(exist_ok=True)
(out / "evals").mkdir(exist_ok=True)
print(f"[train] out={out} n_envs={n_envs} total={total_timesteps} device={args.device}")
# --- Inspect BC obs dim → infer frame_stack ---
ref_only = PPO.load(str(bc_zip), device=args.device)
obs_dim = int(ref_only.observation_space.shape[0])
if obs_dim % OBS_DIM != 0:
raise SystemExit(f"BC obs dim {obs_dim} is not a multiple of {OBS_DIM}.")
frame_stack = obs_dim // OBS_DIM
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
# --- Train env (vectorised, optionally normalised) ---
env_fns = [_make_env(i, seed=args.seed) for i in range(n_envs)]
venv = SubprocVecEnv(env_fns) if n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, seed=args.seed + 999)])
if not args.no_vecnorm:
venv = VecNormalize(venv, norm_obs=True, norm_reward=False, clip_obs=10.0)
eval_venv = VecNormalize(eval_venv, norm_obs=True, norm_reward=False,
clip_obs=10.0, training=False)
eval_venv.obs_rms = venv.obs_rms
else:
print("[train] VecNormalize disabled (resumed policy was trained without it).")
# --- Vectorised envs (match BC obs space) ---
env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)]
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
# Apply env-level overrides (used by BC fine-tune to disable Strömbom
# imitation and start at full deployment difficulty).
def _env_call(method, value):
for v in (venv, eval_venv):
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
if args.imitate_weight is not None:
_env_call("set_imitate_weight", args.imitate_weight)
print(f"[train] W_IMITATE overridden to {args.imitate_weight}")
if args.difficulty is not None:
_env_call("set_difficulty", args.difficulty)
print(f"[train] difficulty pinned to {args.difficulty}")
# --- Model ---
policy_kwargs = dict(
net_arch=dict(pi=cfg["net_arch_pi"], vf=cfg["net_arch_vf"]),
log_std_init=cfg.get("log_std_init", 0.0),
# --- Trainable policy: load BC weights, then bolt onto PPO ---
# Trick: instantiate a PPO with the right env (so the policy
# network is constructed at the correct obs/action shape), then
# copy BC weights into it.
model = KLPPO(
"MlpPolicy", venv,
ref_policy=None, # filled in below
kl_coef=args.kl_coef,
learning_rate=args.learning_rate,
n_steps=args.n_steps,
batch_size=args.batch_size,
n_epochs=args.n_epochs,
gamma=args.gamma,
gae_lambda=args.gae_lambda,
clip_range=args.clip_range,
ent_coef=args.ent_coef,
target_kl=args.target_kl,
policy_kwargs=dict(
net_arch=dict(pi=[512, 512], vf=[512, 512]),
log_std_init=args.log_std,
),
verbose=1,
seed=args.seed,
device=args.device,
tensorboard_log=str(out / "tb"),
)
if args.resume:
print(f"[train] resuming from {args.resume}")
custom_objects = {}
if args.learning_rate is not None:
custom_objects["learning_rate"] = args.learning_rate
model = PPO.load(args.resume, env=venv, device=args.device,
tensorboard_log=str(out / "tb"),
custom_objects=custom_objects or None)
if args.log_std is not None:
import torch as _th
with _th.no_grad():
model.policy.log_std.fill_(args.log_std)
print(f"[train] log_std overridden to {args.log_std} "
f"(std≈{2.71828 ** args.log_std:.2f})")
if args.learning_rate is not None:
print(f"[train] learning_rate overridden to {args.learning_rate}")
else:
model = PPO(
cfg["policy"], venv,
learning_rate=cfg["learning_rate"],
n_steps=cfg["n_steps"],
batch_size=cfg["batch_size"],
n_epochs=cfg["n_epochs"],
gamma=cfg["gamma"],
gae_lambda=cfg["gae_lambda"],
clip_range=cfg["clip_range"],
ent_coef=cfg["ent_coef"],
vf_coef=cfg["vf_coef"],
max_grad_norm=cfg["max_grad_norm"],
target_kl=cfg.get("target_kl"),
policy_kwargs=policy_kwargs,
tensorboard_log=str(out / "tb"),
seed=args.seed,
device=args.device,
verbose=1,
)
# --- Load BC weights into both `model.policy` and `ref_policy` ---
bc_state = ref_only.policy.state_dict()
# Strict=False because the value head may not have been trained in
# BC — that's fine, PPO will train it from scratch.
missing, unexpected = model.policy.load_state_dict(bc_state, strict=False)
print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}")
# Build a separate reference policy with identical architecture and
# the BC weights, frozen.
ref_policy = type(model.policy)(
observation_space=model.observation_space,
action_space=model.action_space,
lr_schedule=lambda _: 0.0,
net_arch=dict(pi=[512, 512], vf=[512, 512]),
log_std_init=args.log_std,
).to(args.device)
ref_policy.load_state_dict(bc_state, strict=False)
model.ref_policy = ref_policy
model.ref_policy.set_training_mode(False)
for p in model.ref_policy.parameters():
p.requires_grad = False
# Align both policies' log_std. BC was trained with log_std≈0.5
# (σ≈1.65), which would make the KL term huge from a std mismatch
# rather than the mean drift we actually care about. Force both to
# the same small value so KL measures only how far the policy mean
# has drifted from the BC mean.
with th.no_grad():
model.policy.log_std.fill_(args.log_std)
model.ref_policy.log_std.fill_(args.log_std)
if args.freeze_log_std:
model.policy.log_std.requires_grad = False
print(f"[rl] log_std frozen at {args.log_std} (σ{np.exp(args.log_std):.3f})")
# --- Callbacks ---
ckpt_cb = CheckpointCallback(
save_freq=max(1, cfg["checkpoint_freq"] // n_envs),
save_path=str(out / "checkpoints"), name_prefix="ppo",
save_vecnormalize=True,
save_freq=max(1, 50_000 // args.n_envs),
save_path=str(out / "checkpoints"),
name_prefix="ppo",
)
eval_cb = EvalCallback(
eval_venv,
best_model_save_path=str(out / "best"),
log_path=str(out / "evals"),
eval_freq=max(1, cfg["eval_freq"] // n_envs),
n_eval_episodes=cfg["n_eval_episodes"],
eval_freq=max(1, 20_000 // args.n_envs),
n_eval_episodes=5,
deterministic=True,
)
callbacks = [ckpt_cb, eval_cb]
if not args.no_curriculum and "curriculum" in cfg and cfg["curriculum"]:
callbacks.append(CurriculumCallback(
cfg["curriculum"], [venv, eval_venv], verbose=1,
))
elif args.no_curriculum:
print("[train] curriculum disabled — env knobs left at their current values.")
# --- Train ---
model.learn(total_timesteps=total_timesteps, callback=callbacks,
progress_bar=True)
print(f"[rl] training: total_timesteps={args.total_timesteps} "
f"n_envs={args.n_envs} lr={args.learning_rate} kl_coef={args.kl_coef}")
model.learn(total_timesteps=args.total_timesteps,
callback=[ckpt_cb, eval_cb], progress_bar=True)
# --- Save final model + VecNormalize stats ---
model.save(out / "final.zip")
venv.save(str(out / "vecnormalize.pkl"))
# The EvalCallback already wrote best_model.zip into out/best/ — drop the
# VecNormalize stats next to it for the controller to pick up.
venv.save(str(out / "best" / "vecnormalize.pkl"))
print(f"[train] done. saved to {out}")
# --- Save final checkpoint in the SB3 zip the controller expects ---
model.save(out / "policy.zip")
print(f"[rl] saved fine-tuned policy → {out/'policy.zip'}")
if __name__ == "__main__":