Checkpoint 7

This commit is contained in:
Johnny Fernandes
2026-05-11 12:21:51 +01:00
parent fce0e0c786
commit a01a5c9cef
34 changed files with 1266 additions and 1038 deletions
View File
+342
View File
@@ -0,0 +1,342 @@
"""KL-regularised PPO fine-tune of a behaviour-cloned policy.
The trainable policy is initialised from ``runs/bc/policy.zip``. A
frozen copy of the same weights becomes the reference; each PPO loss
gets an extra ``β · KL(π ‖ π_ref)`` term so the policy can only move
within a trust region around BC. ``log_std`` is fixed small to keep
exploration tight.
Output: ``runs/rl/policy.zip`` — same SB3 format as the BC checkpoint,
loadable by ``HERDING_MODE=rl`` in the dog controller.
Usage::
python -m training.rl.train \\
--bc training/runs/bc \\
--out training/runs/rl \\
--total-timesteps 2000000
"""
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
import torch as th
import torch.nn.functional as F
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from herding.perception.obs import OBS_DIM
from training.herding_env import HerdingEnv
# --------------------------------------------------------------------
# Env factory
# --------------------------------------------------------------------
def _make_env(rank: int, seed: int, frame_stack: int):
def _thunk():
env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack)
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
return env
return _thunk
# --------------------------------------------------------------------
# KL-regularised PPO
# --------------------------------------------------------------------
class KLPPO(PPO):
"""PPO with an extra KL-to-reference penalty in the policy loss.
Overrides only ``train()``; rollout buffer, clipped surrogate, value
loss and entropy bonus are unchanged from stock SB3 PPO.
"""
def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs):
super().__init__(*args, **kwargs)
self.ref_policy = ref_policy
if self.ref_policy is not None:
self.ref_policy.set_training_mode(False)
for p in self.ref_policy.parameters():
p.requires_grad = False
self.kl_coef = kl_coef
def train(self) -> None:
# Stock SB3 PPO.train() structure with the KL-to-reference term
# added inside the inner minibatch loop.
self.policy.set_training_mode(True)
self._update_learning_rate(self.policy.optimizer)
clip_range = self.clip_range(self._current_progress_remaining)
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
entropy_losses, pg_losses, value_losses, kl_losses = [], [], [], []
clip_fractions = []
continue_training = True
for epoch in range(self.n_epochs):
approx_kl_divs = []
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, th.distributions.Categorical.__bases__):
actions = rollout_data.actions.long().flatten()
values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations, actions)
values = values.flatten()
advantages = rollout_data.advantages
if self.normalize_advantage and len(advantages) > 1:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
values_pred = values
else:
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf)
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
if entropy is None:
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
# KL-to-reference: closed-form KL between two diagonal
# Gaussians, summed over the action axis, mean over batch.
if self.ref_policy is None:
raise RuntimeError("KLPPO.train called without ref_policy")
with th.no_grad():
ref_dist = self.ref_policy.get_distribution(rollout_data.observations)
pi_dist = self.policy.get_distribution(rollout_data.observations)
kl_div = th.distributions.kl.kl_divergence(
pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean()
kl_losses.append(kl_div.item())
loss = (policy_loss
+ self.ent_coef * entropy_loss
+ self.vf_coef * value_loss
+ self.kl_coef * kl_div)
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
self.policy.optimizer.zero_grad()
loss.backward()
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
self._n_updates += 1
if not continue_training:
break
explained_var = self._explained_variance()
self.logger.record("train/entropy_loss", float(np.mean(entropy_losses)))
self.logger.record("train/policy_gradient_loss", float(np.mean(pg_losses)))
self.logger.record("train/value_loss", float(np.mean(value_losses)))
self.logger.record("train/kl_to_reference", float(np.mean(kl_losses)))
self.logger.record("train/approx_kl", float(np.mean(approx_kl_divs)))
self.logger.record("train/clip_fraction", float(np.mean(clip_fractions)))
self.logger.record("train/explained_variance", float(explained_var))
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def _explained_variance(self) -> float:
y_pred = self.rollout_buffer.values.flatten()
y_true = self.rollout_buffer.returns.flatten()
var_y = np.var(y_true)
return float("nan") if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# --------------------------------------------------------------------
# Main
# --------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--bc", default="training/runs/bc",
help="Directory containing the BC initialisation.")
parser.add_argument("--out", default="training/runs/rl",
help="Where to save the fine-tuned policy.")
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
parser.add_argument("--n-envs", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=5e-5)
parser.add_argument("--kl-coef", type=float, default=0.05,
help="Coefficient of the KL-to-reference penalty.")
parser.add_argument("--log-std", type=float, default=-1.5,
help="Initial (and frozen) log_std for exploration.")
parser.add_argument("--freeze-log-std", action="store_true", default=True)
parser.add_argument("--n-steps", type=int, default=2048)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--n-epochs", type=int, default=10)
parser.add_argument("--gamma", type=float, default=0.995)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--clip-range", type=float, default=0.1)
parser.add_argument("--ent-coef", type=float, default=0.0)
parser.add_argument("--target-kl", type=float, default=0.02,
help="SB3 per-batch KL early-stop guard.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", default="cpu")
parser.add_argument("--imitate-weight", type=float, default=None,
help="Override env.W_IMITATE (e.g. 0.0 to drop "
"Strömbom imitation during fine-tune).")
parser.add_argument("--time-weight", type=float, default=None,
help="Override env.W_TIME (e.g. -0.1 for a "
"per-step time penalty).")
args = parser.parse_args()
bc_zip = Path(args.bc) / "policy.zip"
if not bc_zip.exists():
raise SystemExit(
f"BC checkpoint not found at {bc_zip}. Train bc first with "
f"`python -m training.bc.pretrain`."
)
out = Path(args.out)
out.mkdir(parents=True, exist_ok=True)
(out / "checkpoints").mkdir(exist_ok=True)
(out / "best").mkdir(exist_ok=True)
# Infer frame_stack from the BC checkpoint's obs space.
ref_only = PPO.load(str(bc_zip), device=args.device)
obs_dim = int(ref_only.observation_space.shape[0])
if obs_dim % OBS_DIM != 0:
raise SystemExit(f"BC obs dim {obs_dim} is not a multiple of {OBS_DIM}.")
frame_stack = obs_dim // OBS_DIM
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)]
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
# Reward-shaping overrides (broadcast to every env instance).
def _broadcast(method: str, value):
for v in (venv, eval_venv):
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
if args.imitate_weight is not None:
_broadcast("set_imitate_weight", args.imitate_weight)
print(f"[rl] W_IMITATE overridden to {args.imitate_weight}")
if args.time_weight is not None:
_broadcast("set_time_weight", args.time_weight)
print(f"[rl] W_TIME overridden to {args.time_weight}")
# Build a fresh KLPPO at the right obs/action shape, then copy BC
# weights into both the trainable policy and the frozen reference.
model = KLPPO(
"MlpPolicy", venv,
ref_policy=None, # filled in below
kl_coef=args.kl_coef,
learning_rate=args.learning_rate,
n_steps=args.n_steps,
batch_size=args.batch_size,
n_epochs=args.n_epochs,
gamma=args.gamma,
gae_lambda=args.gae_lambda,
clip_range=args.clip_range,
ent_coef=args.ent_coef,
target_kl=args.target_kl,
policy_kwargs=dict(
net_arch=dict(pi=[512, 512], vf=[512, 512]),
log_std_init=args.log_std,
),
verbose=1,
seed=args.seed,
device=args.device,
tensorboard_log=str(out / "tb"),
)
# strict=False — the BC value head wasn't trained; PPO trains it.
bc_state = ref_only.policy.state_dict()
missing, unexpected = model.policy.load_state_dict(bc_state, strict=False)
print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}")
ref_policy = type(model.policy)(
observation_space=model.observation_space,
action_space=model.action_space,
lr_schedule=lambda _: 0.0,
net_arch=dict(pi=[512, 512], vf=[512, 512]),
log_std_init=args.log_std,
).to(args.device)
ref_policy.load_state_dict(bc_state, strict=False)
model.ref_policy = ref_policy
model.ref_policy.set_training_mode(False)
for p in model.ref_policy.parameters():
p.requires_grad = False
# Force both policies to the same log_std so the KL term measures
# mean drift only, not a std mismatch carried over from BC.
with th.no_grad():
model.policy.log_std.fill_(args.log_std)
model.ref_policy.log_std.fill_(args.log_std)
if args.freeze_log_std:
model.policy.log_std.requires_grad = False
print(f"[rl] log_std frozen at {args.log_std} (σ{np.exp(args.log_std):.3f})")
ckpt_cb = CheckpointCallback(
save_freq=max(1, 50_000 // args.n_envs),
save_path=str(out / "checkpoints"),
name_prefix="ppo",
)
# EvalCallback writes <save_path>/best_model.zip on every new best
# eval reward. We send it straight to ``out/`` and rename to
# ``policy.zip`` after training so the deployed file lives at the
# canonical path.
eval_cb = EvalCallback(
eval_venv,
best_model_save_path=str(out),
log_path=str(out / "evals"),
eval_freq=max(1, 20_000 // args.n_envs),
n_eval_episodes=5,
deterministic=True,
)
print(f"[rl] training: total_timesteps={args.total_timesteps} "
f"n_envs={args.n_envs} lr={args.learning_rate} kl_coef={args.kl_coef}")
model.learn(total_timesteps=args.total_timesteps,
callback=[ckpt_cb, eval_cb], progress_bar=True)
# Save the end-of-training state for debugging convergence behaviour.
model.save(out / "final.zip")
# Promote the EvalCallback's best-by-eval-reward snapshot to the
# canonical ``policy.zip`` (what the controller loads). Fall back
# to the final state if eval never recorded a "best".
import shutil
best_zip = out / "best_model.zip"
policy_zip = out / "policy.zip"
if best_zip.exists():
if policy_zip.exists():
policy_zip.unlink()
best_zip.rename(policy_zip)
print(f"[rl] best snapshot → {policy_zip} (final state kept at {out/'final.zip'})")
else:
shutil.copy(out / "final.zip", policy_zip)
print(f"[rl] no best snapshot recorded; using final → {policy_zip}")
if __name__ == "__main__":
main()