375 lines
16 KiB
Python
375 lines
16 KiB
Python
"""KL-regularised PPO fine-tune of a behaviour-cloned policy.
|
||
|
||
The PPO-from-scratch and unregularised PPO-fine-tune-of-BC versions
|
||
we tried earlier failed for the standard reasons (sparse pen reward,
|
||
long horizons, exploration noise destroying BC weights). The fix is
|
||
to anchor the policy to its BC initialisation with a KL penalty in
|
||
the loss — the policy is free to refine the BC mean within a
|
||
trust-region-like ball around the reference, and the dense-enough
|
||
per-step reward signal does the rest.
|
||
|
||
Pipeline
|
||
--------
|
||
1. Load ``bc`` weights into both the trainable policy and a frozen
|
||
reference ``ref_policy``.
|
||
2. Initialise the policy's log_std to a small fixed value (≈ −1.5)
|
||
and disable its gradient — exploration noise stays small so PPO
|
||
updates don't blow up the BC mean before reward can stabilise.
|
||
3. Override ``PPO.train()`` to add ``β · KL(π ‖ π_ref)`` to the loss
|
||
each minibatch.
|
||
4. Train for ~1–3 M timesteps with a low LR (5e-5).
|
||
|
||
Output: ``runs/rl/policy.zip`` — same SB3 format as bc, loadable
|
||
by the dog controller's ``HERDING_MODE=rl`` path.
|
||
|
||
Usage::
|
||
|
||
python -m training.train_ppo \\
|
||
--bc training/runs/bc \\
|
||
--out training/runs/rl \\
|
||
--total-timesteps 2000000
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||
if _PROJECT_ROOT not in sys.path:
|
||
sys.path.insert(0, _PROJECT_ROOT)
|
||
|
||
import numpy as np
|
||
import torch as th
|
||
import torch.nn.functional as F
|
||
from stable_baselines3 import PPO
|
||
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
|
||
from stable_baselines3.common.monitor import Monitor
|
||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||
|
||
from herding.obs import OBS_DIM
|
||
from training.herding_env import HerdingEnv
|
||
|
||
|
||
# --------------------------------------------------------------------
|
||
# Env factory
|
||
# --------------------------------------------------------------------
|
||
|
||
def _make_env(rank: int, seed: int, frame_stack: int):
|
||
def _thunk():
|
||
env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack)
|
||
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
|
||
return env
|
||
return _thunk
|
||
|
||
|
||
# --------------------------------------------------------------------
|
||
# KL-regularised PPO
|
||
# --------------------------------------------------------------------
|
||
|
||
class KLPPO(PPO):
|
||
"""PPO with an extra KL-to-reference penalty in the policy loss.
|
||
|
||
Subclasses SB3's PPO and overrides ``train()`` only to add a single
|
||
line for the KL term — everything else (rollout buffer, clipped
|
||
surrogate, value loss, entropy bonus) is unchanged.
|
||
"""
|
||
|
||
def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
# ref_policy is set after construction (caller can build it
|
||
# from the BC checkpoint once `self.policy` exists).
|
||
self.ref_policy = ref_policy
|
||
if self.ref_policy is not None:
|
||
self.ref_policy.set_training_mode(False)
|
||
for p in self.ref_policy.parameters():
|
||
p.requires_grad = False
|
||
self.kl_coef = kl_coef
|
||
|
||
def train(self) -> None:
|
||
# Copied from stable_baselines3.ppo.PPO.train (v2.x), with the
|
||
# KL-to-reference term added. Keeping the structure intact so
|
||
# behavioural parity with stock PPO is obvious.
|
||
self.policy.set_training_mode(True)
|
||
self._update_learning_rate(self.policy.optimizer)
|
||
clip_range = self.clip_range(self._current_progress_remaining)
|
||
if self.clip_range_vf is not None:
|
||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||
|
||
entropy_losses, pg_losses, value_losses, kl_losses = [], [], [], []
|
||
clip_fractions = []
|
||
continue_training = True
|
||
|
||
for epoch in range(self.n_epochs):
|
||
approx_kl_divs = []
|
||
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||
actions = rollout_data.actions
|
||
if isinstance(self.action_space, th.distributions.Categorical.__bases__):
|
||
actions = rollout_data.actions.long().flatten()
|
||
|
||
values, log_prob, entropy = self.policy.evaluate_actions(
|
||
rollout_data.observations, actions)
|
||
values = values.flatten()
|
||
advantages = rollout_data.advantages
|
||
if self.normalize_advantage and len(advantages) > 1:
|
||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||
|
||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||
policy_loss_1 = advantages * ratio
|
||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
||
pg_losses.append(policy_loss.item())
|
||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
||
clip_fractions.append(clip_fraction)
|
||
|
||
if self.clip_range_vf is None:
|
||
values_pred = values
|
||
else:
|
||
values_pred = rollout_data.old_values + th.clamp(
|
||
values - rollout_data.old_values, -clip_range_vf, clip_range_vf)
|
||
value_loss = F.mse_loss(rollout_data.returns, values_pred)
|
||
value_losses.append(value_loss.item())
|
||
|
||
if entropy is None:
|
||
entropy_loss = -th.mean(-log_prob)
|
||
else:
|
||
entropy_loss = -th.mean(entropy)
|
||
entropy_losses.append(entropy_loss.item())
|
||
|
||
# --- KL-to-reference term ----------------------------
|
||
# Both policies are diagonal Gaussian (ActorCriticPolicy).
|
||
# KL(π ‖ π_ref) per-action-dim; sum over the action axis
|
||
# to get total KL per sample, then mean over batch.
|
||
# Computed on the rollout's observations so the penalty
|
||
# reflects what the agent actually saw.
|
||
if self.ref_policy is None:
|
||
raise RuntimeError("KLPPO.train called without ref_policy")
|
||
with th.no_grad():
|
||
ref_dist = self.ref_policy.get_distribution(rollout_data.observations)
|
||
pi_dist = self.policy.get_distribution(rollout_data.observations)
|
||
kl_div = th.distributions.kl.kl_divergence(
|
||
pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean()
|
||
kl_losses.append(kl_div.item())
|
||
# ----------------------------------------------------
|
||
|
||
loss = (policy_loss
|
||
+ self.ent_coef * entropy_loss
|
||
+ self.vf_coef * value_loss
|
||
+ self.kl_coef * kl_div)
|
||
|
||
with th.no_grad():
|
||
log_ratio = log_prob - rollout_data.old_log_prob
|
||
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||
approx_kl_divs.append(approx_kl_div)
|
||
|
||
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||
continue_training = False
|
||
if self.verbose >= 1:
|
||
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
||
break
|
||
|
||
self.policy.optimizer.zero_grad()
|
||
loss.backward()
|
||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||
self.policy.optimizer.step()
|
||
|
||
self._n_updates += 1
|
||
if not continue_training:
|
||
break
|
||
|
||
explained_var = self._explained_variance()
|
||
self.logger.record("train/entropy_loss", float(np.mean(entropy_losses)))
|
||
self.logger.record("train/policy_gradient_loss", float(np.mean(pg_losses)))
|
||
self.logger.record("train/value_loss", float(np.mean(value_losses)))
|
||
self.logger.record("train/kl_to_reference", float(np.mean(kl_losses)))
|
||
self.logger.record("train/approx_kl", float(np.mean(approx_kl_divs)))
|
||
self.logger.record("train/clip_fraction", float(np.mean(clip_fractions)))
|
||
self.logger.record("train/explained_variance", float(explained_var))
|
||
if hasattr(self.policy, "log_std"):
|
||
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||
|
||
def _explained_variance(self) -> float:
|
||
# SB3 doesn't expose this as a method; replicate the computation.
|
||
y_pred = self.rollout_buffer.values.flatten()
|
||
y_true = self.rollout_buffer.returns.flatten()
|
||
var_y = np.var(y_true)
|
||
return float("nan") if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||
|
||
|
||
# --------------------------------------------------------------------
|
||
# Main
|
||
# --------------------------------------------------------------------
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--bc", default="training/runs/bc",
|
||
help="Directory containing the BC initialisation (policy.zip).")
|
||
parser.add_argument("--out", default="training/runs/rl",
|
||
help="Where to save the fine-tuned policy.")
|
||
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
|
||
parser.add_argument("--n-envs", type=int, default=8)
|
||
parser.add_argument("--learning-rate", type=float, default=5e-5,
|
||
help="Low LR keeps PPO close to the BC mean.")
|
||
parser.add_argument("--kl-coef", type=float, default=0.05,
|
||
help="KL-to-reference penalty coefficient.")
|
||
parser.add_argument("--log-std", type=float, default=-1.5,
|
||
help="Initial (and frozen) log_std. σ ≈ exp(-1.5) ≈ 0.22.")
|
||
parser.add_argument("--freeze-log-std", action="store_true", default=True,
|
||
help="Keep log_std fixed; only the policy mean updates.")
|
||
parser.add_argument("--n-steps", type=int, default=2048,
|
||
help="Steps per rollout per env.")
|
||
parser.add_argument("--batch-size", type=int, default=256)
|
||
parser.add_argument("--n-epochs", type=int, default=10)
|
||
parser.add_argument("--gamma", type=float, default=0.995)
|
||
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||
parser.add_argument("--clip-range", type=float, default=0.1,
|
||
help="Tight clip range — keep updates conservative.")
|
||
parser.add_argument("--ent-coef", type=float, default=0.0)
|
||
parser.add_argument("--target-kl", type=float, default=0.02,
|
||
help="SB3's per-batch KL early stop; safety belt.")
|
||
parser.add_argument("--seed", type=int, default=0)
|
||
parser.add_argument("--device", default="cpu")
|
||
parser.add_argument("--imitate-weight", type=float, default=None,
|
||
help="Override env.W_IMITATE for this training "
|
||
"run. Set to 0.0 to drop the Strömbom "
|
||
"cosine-imitation reward — useful during "
|
||
"PPO refinement where you want reward, "
|
||
"not teacher imitation, to drive updates.")
|
||
parser.add_argument("--time-weight", type=float, default=None,
|
||
help="Override env.W_TIME. Default env value is "
|
||
"0.0; setting e.g. -0.1 adds a small per-"
|
||
"step penalty that explicitly rewards "
|
||
"fast time-to-pen.")
|
||
args = parser.parse_args()
|
||
|
||
bc_zip = Path(args.bc) / "policy.zip"
|
||
if not bc_zip.exists():
|
||
raise SystemExit(
|
||
f"BC checkpoint not found at {bc_zip}. Train bc first with "
|
||
f"`python -m training.bc_pretrain`."
|
||
)
|
||
|
||
out = Path(args.out)
|
||
out.mkdir(parents=True, exist_ok=True)
|
||
(out / "checkpoints").mkdir(exist_ok=True)
|
||
(out / "best").mkdir(exist_ok=True)
|
||
|
||
# --- Inspect BC obs dim → infer frame_stack ---
|
||
ref_only = PPO.load(str(bc_zip), device=args.device)
|
||
obs_dim = int(ref_only.observation_space.shape[0])
|
||
if obs_dim % OBS_DIM != 0:
|
||
raise SystemExit(f"BC obs dim {obs_dim} is not a multiple of {OBS_DIM}.")
|
||
frame_stack = obs_dim // OBS_DIM
|
||
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
|
||
|
||
# --- Vectorised envs (match BC obs space) ---
|
||
env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)]
|
||
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
||
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
|
||
|
||
# --- Apply reward-shaping overrides to every env instance ---
|
||
def _broadcast(method: str, value):
|
||
for v in (venv, eval_venv):
|
||
try:
|
||
v.env_method(method, value)
|
||
except AttributeError:
|
||
v.venv.env_method(method, value)
|
||
if args.imitate_weight is not None:
|
||
_broadcast("set_imitate_weight", args.imitate_weight)
|
||
print(f"[rl] W_IMITATE overridden to {args.imitate_weight}")
|
||
if args.time_weight is not None:
|
||
_broadcast("set_time_weight", args.time_weight)
|
||
print(f"[rl] W_TIME overridden to {args.time_weight}")
|
||
|
||
# --- Trainable policy: load BC weights, then bolt onto PPO ---
|
||
# Trick: instantiate a PPO with the right env (so the policy
|
||
# network is constructed at the correct obs/action shape), then
|
||
# copy BC weights into it.
|
||
model = KLPPO(
|
||
"MlpPolicy", venv,
|
||
ref_policy=None, # filled in below
|
||
kl_coef=args.kl_coef,
|
||
learning_rate=args.learning_rate,
|
||
n_steps=args.n_steps,
|
||
batch_size=args.batch_size,
|
||
n_epochs=args.n_epochs,
|
||
gamma=args.gamma,
|
||
gae_lambda=args.gae_lambda,
|
||
clip_range=args.clip_range,
|
||
ent_coef=args.ent_coef,
|
||
target_kl=args.target_kl,
|
||
policy_kwargs=dict(
|
||
net_arch=dict(pi=[512, 512], vf=[512, 512]),
|
||
log_std_init=args.log_std,
|
||
),
|
||
verbose=1,
|
||
seed=args.seed,
|
||
device=args.device,
|
||
tensorboard_log=str(out / "tb"),
|
||
)
|
||
|
||
# --- Load BC weights into both `model.policy` and `ref_policy` ---
|
||
bc_state = ref_only.policy.state_dict()
|
||
# Strict=False because the value head may not have been trained in
|
||
# BC — that's fine, PPO will train it from scratch.
|
||
missing, unexpected = model.policy.load_state_dict(bc_state, strict=False)
|
||
print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}")
|
||
|
||
# Build a separate reference policy with identical architecture and
|
||
# the BC weights, frozen.
|
||
ref_policy = type(model.policy)(
|
||
observation_space=model.observation_space,
|
||
action_space=model.action_space,
|
||
lr_schedule=lambda _: 0.0,
|
||
net_arch=dict(pi=[512, 512], vf=[512, 512]),
|
||
log_std_init=args.log_std,
|
||
).to(args.device)
|
||
ref_policy.load_state_dict(bc_state, strict=False)
|
||
model.ref_policy = ref_policy
|
||
model.ref_policy.set_training_mode(False)
|
||
for p in model.ref_policy.parameters():
|
||
p.requires_grad = False
|
||
|
||
# Align both policies' log_std. BC was trained with log_std≈0.5
|
||
# (σ≈1.65), which would make the KL term huge from a std mismatch
|
||
# rather than the mean drift we actually care about. Force both to
|
||
# the same small value so KL measures only how far the policy mean
|
||
# has drifted from the BC mean.
|
||
with th.no_grad():
|
||
model.policy.log_std.fill_(args.log_std)
|
||
model.ref_policy.log_std.fill_(args.log_std)
|
||
if args.freeze_log_std:
|
||
model.policy.log_std.requires_grad = False
|
||
print(f"[rl] log_std frozen at {args.log_std} (σ ≈ {np.exp(args.log_std):.3f})")
|
||
|
||
# --- Callbacks ---
|
||
ckpt_cb = CheckpointCallback(
|
||
save_freq=max(1, 50_000 // args.n_envs),
|
||
save_path=str(out / "checkpoints"),
|
||
name_prefix="ppo",
|
||
)
|
||
eval_cb = EvalCallback(
|
||
eval_venv,
|
||
best_model_save_path=str(out / "best"),
|
||
log_path=str(out / "evals"),
|
||
eval_freq=max(1, 20_000 // args.n_envs),
|
||
n_eval_episodes=5,
|
||
deterministic=True,
|
||
)
|
||
|
||
print(f"[rl] training: total_timesteps={args.total_timesteps} "
|
||
f"n_envs={args.n_envs} lr={args.learning_rate} kl_coef={args.kl_coef}")
|
||
model.learn(total_timesteps=args.total_timesteps,
|
||
callback=[ckpt_cb, eval_cb], progress_bar=True)
|
||
|
||
# --- Save final checkpoint in the SB3 zip the controller expects ---
|
||
model.save(out / "policy.zip")
|
||
print(f"[rl] saved fine-tuned policy → {out/'policy.zip'}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|