Files
TIR_PROJ/training/rl/train.py
T
Johnny Fernandes a01a5c9cef Checkpoint 7
2026-05-11 12:21:51 +01:00

343 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""KL-regularised PPO fine-tune of a behaviour-cloned policy.
The trainable policy is initialised from ``runs/bc/policy.zip``. A
frozen copy of the same weights becomes the reference; each PPO loss
gets an extra ``β · KL(π ‖ π_ref)`` term so the policy can only move
within a trust region around BC. ``log_std`` is fixed small to keep
exploration tight.
Output: ``runs/rl/policy.zip`` — same SB3 format as the BC checkpoint,
loadable by ``HERDING_MODE=rl`` in the dog controller.
Usage::
python -m training.rl.train \\
--bc training/runs/bc \\
--out training/runs/rl \\
--total-timesteps 2000000
"""
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
import torch as th
import torch.nn.functional as F
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from herding.perception.obs import OBS_DIM
from training.herding_env import HerdingEnv
# --------------------------------------------------------------------
# Env factory
# --------------------------------------------------------------------
def _make_env(rank: int, seed: int, frame_stack: int):
def _thunk():
env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack)
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
return env
return _thunk
# --------------------------------------------------------------------
# KL-regularised PPO
# --------------------------------------------------------------------
class KLPPO(PPO):
"""PPO with an extra KL-to-reference penalty in the policy loss.
Overrides only ``train()``; rollout buffer, clipped surrogate, value
loss and entropy bonus are unchanged from stock SB3 PPO.
"""
def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs):
super().__init__(*args, **kwargs)
self.ref_policy = ref_policy
if self.ref_policy is not None:
self.ref_policy.set_training_mode(False)
for p in self.ref_policy.parameters():
p.requires_grad = False
self.kl_coef = kl_coef
def train(self) -> None:
# Stock SB3 PPO.train() structure with the KL-to-reference term
# added inside the inner minibatch loop.
self.policy.set_training_mode(True)
self._update_learning_rate(self.policy.optimizer)
clip_range = self.clip_range(self._current_progress_remaining)
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
entropy_losses, pg_losses, value_losses, kl_losses = [], [], [], []
clip_fractions = []
continue_training = True
for epoch in range(self.n_epochs):
approx_kl_divs = []
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, th.distributions.Categorical.__bases__):
actions = rollout_data.actions.long().flatten()
values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations, actions)
values = values.flatten()
advantages = rollout_data.advantages
if self.normalize_advantage and len(advantages) > 1:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
values_pred = values
else:
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf)
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
if entropy is None:
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
# KL-to-reference: closed-form KL between two diagonal
# Gaussians, summed over the action axis, mean over batch.
if self.ref_policy is None:
raise RuntimeError("KLPPO.train called without ref_policy")
with th.no_grad():
ref_dist = self.ref_policy.get_distribution(rollout_data.observations)
pi_dist = self.policy.get_distribution(rollout_data.observations)
kl_div = th.distributions.kl.kl_divergence(
pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean()
kl_losses.append(kl_div.item())
loss = (policy_loss
+ self.ent_coef * entropy_loss
+ self.vf_coef * value_loss
+ self.kl_coef * kl_div)
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
self.policy.optimizer.zero_grad()
loss.backward()
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
self._n_updates += 1
if not continue_training:
break
explained_var = self._explained_variance()
self.logger.record("train/entropy_loss", float(np.mean(entropy_losses)))
self.logger.record("train/policy_gradient_loss", float(np.mean(pg_losses)))
self.logger.record("train/value_loss", float(np.mean(value_losses)))
self.logger.record("train/kl_to_reference", float(np.mean(kl_losses)))
self.logger.record("train/approx_kl", float(np.mean(approx_kl_divs)))
self.logger.record("train/clip_fraction", float(np.mean(clip_fractions)))
self.logger.record("train/explained_variance", float(explained_var))
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def _explained_variance(self) -> float:
y_pred = self.rollout_buffer.values.flatten()
y_true = self.rollout_buffer.returns.flatten()
var_y = np.var(y_true)
return float("nan") if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# --------------------------------------------------------------------
# Main
# --------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--bc", default="training/runs/bc",
help="Directory containing the BC initialisation.")
parser.add_argument("--out", default="training/runs/rl",
help="Where to save the fine-tuned policy.")
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
parser.add_argument("--n-envs", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=5e-5)
parser.add_argument("--kl-coef", type=float, default=0.05,
help="Coefficient of the KL-to-reference penalty.")
parser.add_argument("--log-std", type=float, default=-1.5,
help="Initial (and frozen) log_std for exploration.")
parser.add_argument("--freeze-log-std", action="store_true", default=True)
parser.add_argument("--n-steps", type=int, default=2048)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--n-epochs", type=int, default=10)
parser.add_argument("--gamma", type=float, default=0.995)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--clip-range", type=float, default=0.1)
parser.add_argument("--ent-coef", type=float, default=0.0)
parser.add_argument("--target-kl", type=float, default=0.02,
help="SB3 per-batch KL early-stop guard.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", default="cpu")
parser.add_argument("--imitate-weight", type=float, default=None,
help="Override env.W_IMITATE (e.g. 0.0 to drop "
"Strömbom imitation during fine-tune).")
parser.add_argument("--time-weight", type=float, default=None,
help="Override env.W_TIME (e.g. -0.1 for a "
"per-step time penalty).")
args = parser.parse_args()
bc_zip = Path(args.bc) / "policy.zip"
if not bc_zip.exists():
raise SystemExit(
f"BC checkpoint not found at {bc_zip}. Train bc first with "
f"`python -m training.bc.pretrain`."
)
out = Path(args.out)
out.mkdir(parents=True, exist_ok=True)
(out / "checkpoints").mkdir(exist_ok=True)
(out / "best").mkdir(exist_ok=True)
# Infer frame_stack from the BC checkpoint's obs space.
ref_only = PPO.load(str(bc_zip), device=args.device)
obs_dim = int(ref_only.observation_space.shape[0])
if obs_dim % OBS_DIM != 0:
raise SystemExit(f"BC obs dim {obs_dim} is not a multiple of {OBS_DIM}.")
frame_stack = obs_dim // OBS_DIM
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)]
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
# Reward-shaping overrides (broadcast to every env instance).
def _broadcast(method: str, value):
for v in (venv, eval_venv):
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
if args.imitate_weight is not None:
_broadcast("set_imitate_weight", args.imitate_weight)
print(f"[rl] W_IMITATE overridden to {args.imitate_weight}")
if args.time_weight is not None:
_broadcast("set_time_weight", args.time_weight)
print(f"[rl] W_TIME overridden to {args.time_weight}")
# Build a fresh KLPPO at the right obs/action shape, then copy BC
# weights into both the trainable policy and the frozen reference.
model = KLPPO(
"MlpPolicy", venv,
ref_policy=None, # filled in below
kl_coef=args.kl_coef,
learning_rate=args.learning_rate,
n_steps=args.n_steps,
batch_size=args.batch_size,
n_epochs=args.n_epochs,
gamma=args.gamma,
gae_lambda=args.gae_lambda,
clip_range=args.clip_range,
ent_coef=args.ent_coef,
target_kl=args.target_kl,
policy_kwargs=dict(
net_arch=dict(pi=[512, 512], vf=[512, 512]),
log_std_init=args.log_std,
),
verbose=1,
seed=args.seed,
device=args.device,
tensorboard_log=str(out / "tb"),
)
# strict=False — the BC value head wasn't trained; PPO trains it.
bc_state = ref_only.policy.state_dict()
missing, unexpected = model.policy.load_state_dict(bc_state, strict=False)
print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}")
ref_policy = type(model.policy)(
observation_space=model.observation_space,
action_space=model.action_space,
lr_schedule=lambda _: 0.0,
net_arch=dict(pi=[512, 512], vf=[512, 512]),
log_std_init=args.log_std,
).to(args.device)
ref_policy.load_state_dict(bc_state, strict=False)
model.ref_policy = ref_policy
model.ref_policy.set_training_mode(False)
for p in model.ref_policy.parameters():
p.requires_grad = False
# Force both policies to the same log_std so the KL term measures
# mean drift only, not a std mismatch carried over from BC.
with th.no_grad():
model.policy.log_std.fill_(args.log_std)
model.ref_policy.log_std.fill_(args.log_std)
if args.freeze_log_std:
model.policy.log_std.requires_grad = False
print(f"[rl] log_std frozen at {args.log_std} (σ{np.exp(args.log_std):.3f})")
ckpt_cb = CheckpointCallback(
save_freq=max(1, 50_000 // args.n_envs),
save_path=str(out / "checkpoints"),
name_prefix="ppo",
)
# EvalCallback writes <save_path>/best_model.zip on every new best
# eval reward. We send it straight to ``out/`` and rename to
# ``policy.zip`` after training so the deployed file lives at the
# canonical path.
eval_cb = EvalCallback(
eval_venv,
best_model_save_path=str(out),
log_path=str(out / "evals"),
eval_freq=max(1, 20_000 // args.n_envs),
n_eval_episodes=5,
deterministic=True,
)
print(f"[rl] training: total_timesteps={args.total_timesteps} "
f"n_envs={args.n_envs} lr={args.learning_rate} kl_coef={args.kl_coef}")
model.learn(total_timesteps=args.total_timesteps,
callback=[ckpt_cb, eval_cb], progress_bar=True)
# Save the end-of-training state for debugging convergence behaviour.
model.save(out / "final.zip")
# Promote the EvalCallback's best-by-eval-reward snapshot to the
# canonical ``policy.zip`` (what the controller loads). Fall back
# to the final state if eval never recorded a "best".
import shutil
best_zip = out / "best_model.zip"
policy_zip = out / "policy.zip"
if best_zip.exists():
if policy_zip.exists():
policy_zip.unlink()
best_zip.rename(policy_zip)
print(f"[rl] best snapshot → {policy_zip} (final state kept at {out/'final.zip'})")
else:
shutil.copy(out / "final.zip", policy_zip)
print(f"[rl] no best snapshot recorded; using final → {policy_zip}")
if __name__ == "__main__":
main()