07d1ece3d4
After a deep investigation into the n=5 mecanum sim-to-real gap, all
attempted fixes (consensus tightening, wall_reject tightening, static-
phantom drop, deploy-time track merge, in-tracker track merge,
fp_rate-augmented retrain, max_range cap, 140° mecanum retrain) failed
to reliably pen n=5 in Webots without regressing n=10. The phantom
problem at 360° + small flock is genuinely hard and out of scope for
the deadline; documented in docs/status.md.
Result preserved from the previous mecanum work:
* 16/16 differential cells pen N/N.
* 4/8 mecanum cells (all n=10) pen 10/10 via Supervisor kinematic
injection (commit 27c0f65).
* n=5 mecanum is the known gap.
Small changes that survived the iteration:
* tests/test_config.py — strafe_efficiency=1.0 is now valid (kinematic
injection means the gym preset and Webots controller share the
formula, so textbook values produce gym-identical body motion).
* tools/run_webots.sh — refreshed the LiDAR-variant comment.
* training/rl/train.py — comment polish.
450 lines
20 KiB
Python
450 lines
20 KiB
Python
"""KL-regularised PPO fine-tune of a behaviour-cloned policy.
|
||
|
||
The trainable policy is initialised from ``runs/bc/policy.zip``. A
|
||
frozen copy of the same weights becomes the reference; each PPO loss
|
||
gets an extra ``β · KL(π ‖ π_ref)`` term so the policy can only move
|
||
within a trust region around BC. ``log_std`` is fixed small to keep
|
||
exploration tight.
|
||
|
||
Output: ``runs/rl/policy.zip`` — same SB3 format as the BC checkpoint,
|
||
loadable by ``HERDING_MODE=rl`` in the dog controller.
|
||
|
||
Usage::
|
||
|
||
python -m training.rl.train \\
|
||
--bc training/runs/bc \\
|
||
--out training/runs/rl \\
|
||
--total-timesteps 2000000
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
from pathlib import Path
|
||
|
||
# Configure field geometry before other herding imports read it at module level.
|
||
from herding.world.geometry import configure_from_args as _configure_from_args
|
||
_configure_from_args()
|
||
|
||
import numpy as np
|
||
import torch as th
|
||
import torch.nn.functional as F
|
||
from stable_baselines3 import PPO
|
||
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
|
||
from stable_baselines3.common.monitor import Monitor
|
||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||
|
||
from herding.perception.obs import OBS_DIM
|
||
from training.herding_env import HerdingEnv
|
||
|
||
|
||
# --------------------------------------------------------------------
|
||
# Env factory
|
||
# --------------------------------------------------------------------
|
||
|
||
def _make_env(rank: int, seed: int, frame_stack: int,
|
||
drive_mode: str = "differential",
|
||
difficulty: float = 1.0,
|
||
max_n_sheep: int = 10,
|
||
herding_cfg=None):
|
||
def _thunk():
|
||
env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack,
|
||
drive_mode=drive_mode, difficulty=difficulty,
|
||
max_n_sheep=max_n_sheep, herding_cfg=herding_cfg)
|
||
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
|
||
return env
|
||
return _thunk
|
||
|
||
|
||
# --------------------------------------------------------------------
|
||
# KL-regularised PPO
|
||
# --------------------------------------------------------------------
|
||
|
||
class KLPPO(PPO):
|
||
"""PPO with an extra KL-to-reference penalty in the policy loss.
|
||
|
||
Overrides only ``train()``; rollout buffer, clipped surrogate, value
|
||
loss and entropy bonus are unchanged from stock SB3 PPO.
|
||
"""
|
||
|
||
def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.ref_policy = ref_policy
|
||
if self.ref_policy is not None:
|
||
self.ref_policy.set_training_mode(False)
|
||
for p in self.ref_policy.parameters():
|
||
p.requires_grad = False
|
||
self.kl_coef = kl_coef
|
||
|
||
def train(self) -> None:
|
||
# Stock SB3 PPO.train() structure with the KL-to-reference term
|
||
# added inside the inner minibatch loop.
|
||
self.policy.set_training_mode(True)
|
||
self._update_learning_rate(self.policy.optimizer)
|
||
clip_range = self.clip_range(self._current_progress_remaining)
|
||
if self.clip_range_vf is not None:
|
||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||
|
||
entropy_losses, pg_losses, value_losses, kl_losses = [], [], [], []
|
||
clip_fractions = []
|
||
continue_training = True
|
||
|
||
for epoch in range(self.n_epochs):
|
||
approx_kl_divs = []
|
||
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||
actions = rollout_data.actions
|
||
if isinstance(self.action_space, th.distributions.Categorical.__bases__):
|
||
actions = rollout_data.actions.long().flatten()
|
||
|
||
values, log_prob, entropy = self.policy.evaluate_actions(
|
||
rollout_data.observations, actions)
|
||
values = values.flatten()
|
||
advantages = rollout_data.advantages
|
||
if self.normalize_advantage and len(advantages) > 1:
|
||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||
|
||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||
policy_loss_1 = advantages * ratio
|
||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
||
pg_losses.append(policy_loss.item())
|
||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
||
clip_fractions.append(clip_fraction)
|
||
|
||
if self.clip_range_vf is None:
|
||
values_pred = values
|
||
else:
|
||
values_pred = rollout_data.old_values + th.clamp(
|
||
values - rollout_data.old_values, -clip_range_vf, clip_range_vf)
|
||
value_loss = F.mse_loss(rollout_data.returns, values_pred)
|
||
value_losses.append(value_loss.item())
|
||
|
||
if entropy is None:
|
||
entropy_loss = -th.mean(-log_prob)
|
||
else:
|
||
entropy_loss = -th.mean(entropy)
|
||
entropy_losses.append(entropy_loss.item())
|
||
|
||
# KL-to-reference: closed-form KL between two diagonal
|
||
# Gaussians, summed over the action axis, mean over batch.
|
||
if self.ref_policy is None:
|
||
raise RuntimeError("KLPPO.train called without ref_policy")
|
||
with th.no_grad():
|
||
ref_dist = self.ref_policy.get_distribution(rollout_data.observations)
|
||
pi_dist = self.policy.get_distribution(rollout_data.observations)
|
||
kl_div = th.distributions.kl.kl_divergence(
|
||
pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean()
|
||
kl_losses.append(kl_div.item())
|
||
|
||
loss = (policy_loss
|
||
+ self.ent_coef * entropy_loss
|
||
+ self.vf_coef * value_loss
|
||
+ self.kl_coef * kl_div)
|
||
|
||
with th.no_grad():
|
||
log_ratio = log_prob - rollout_data.old_log_prob
|
||
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||
approx_kl_divs.append(approx_kl_div)
|
||
|
||
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||
continue_training = False
|
||
if self.verbose >= 1:
|
||
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
||
break
|
||
|
||
self.policy.optimizer.zero_grad()
|
||
loss.backward()
|
||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||
self.policy.optimizer.step()
|
||
|
||
self._n_updates += 1
|
||
if not continue_training:
|
||
break
|
||
|
||
explained_var = self._explained_variance()
|
||
self.logger.record("train/entropy_loss", float(np.mean(entropy_losses)))
|
||
self.logger.record("train/policy_gradient_loss", float(np.mean(pg_losses)))
|
||
self.logger.record("train/value_loss", float(np.mean(value_losses)))
|
||
self.logger.record("train/kl_to_reference", float(np.mean(kl_losses)))
|
||
self.logger.record("train/approx_kl", float(np.mean(approx_kl_divs)))
|
||
self.logger.record("train/clip_fraction", float(np.mean(clip_fractions)))
|
||
self.logger.record("train/explained_variance", float(explained_var))
|
||
if hasattr(self.policy, "log_std"):
|
||
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||
|
||
def _explained_variance(self) -> float:
|
||
y_pred = self.rollout_buffer.values.flatten()
|
||
y_true = self.rollout_buffer.returns.flatten()
|
||
var_y = np.var(y_true)
|
||
return float("nan") if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||
|
||
|
||
# --------------------------------------------------------------------
|
||
# Main
|
||
# --------------------------------------------------------------------
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--bc", default="training/runs/bc",
|
||
help="Directory containing the BC initialisation.")
|
||
parser.add_argument("--out", default="training/runs/rl",
|
||
help="Where to save the fine-tuned policy.")
|
||
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
|
||
parser.add_argument("--n-envs", type=int, default=8)
|
||
parser.add_argument("--learning-rate", type=float, default=5e-5)
|
||
parser.add_argument("--kl-coef", type=float, default=0.05,
|
||
help="Coefficient of the KL-to-reference penalty.")
|
||
parser.add_argument("--log-std", type=float, default=-1.5,
|
||
help="Initial (and frozen) log_std for exploration.")
|
||
parser.add_argument("--freeze-log-std", action="store_true", default=True)
|
||
parser.add_argument("--n-steps", type=int, default=2048)
|
||
parser.add_argument("--batch-size", type=int, default=256)
|
||
parser.add_argument("--n-epochs", type=int, default=10)
|
||
parser.add_argument("--gamma", type=float, default=0.995)
|
||
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||
parser.add_argument("--clip-range", type=float, default=0.1)
|
||
parser.add_argument("--ent-coef", type=float, default=0.0)
|
||
parser.add_argument("--target-kl", type=float, default=0.02,
|
||
help="SB3 per-batch KL early-stop guard.")
|
||
parser.add_argument("--seed", type=int, default=0)
|
||
parser.add_argument("--device", default="cpu")
|
||
parser.add_argument("--drive-mode", default=None,
|
||
choices=["differential", "mecanum"],
|
||
help="Drive mode. If not set, inferred from "
|
||
"BC action dimension (2→differential, 3→mecanum).")
|
||
parser.add_argument("--imitate-weight", type=float, default=None,
|
||
help="Override env.W_IMITATE (e.g. 0.0 to drop "
|
||
"Strömbom imitation during fine-tune).")
|
||
parser.add_argument("--time-weight", type=float, default=None,
|
||
help="Override env.W_TIME (e.g. -0.1 for a "
|
||
"per-step time penalty).")
|
||
parser.add_argument("--difficulty", type=float, default=1.0,
|
||
help="HerdingEnv difficulty for PPO rollouts. "
|
||
"Must match eval (1.0) to avoid train/eval "
|
||
"distribution mismatch.")
|
||
parser.add_argument("--max-n-sheep", type=int, default=10,
|
||
help="Upper bound on flock size sampled each reset.")
|
||
parser.add_argument("--world", default=None,
|
||
choices=["field", "field_round"],
|
||
help="World shape. If not set, uses HERDING_WORLD "
|
||
"env var or defaults to 'field'.")
|
||
# Domain randomisation
|
||
parser.add_argument("--fp-rate", type=float, default=0.0,
|
||
help="Mean false-positive detections per step (Poisson λ).")
|
||
parser.add_argument("--action-smooth", type=float, default=0.0,
|
||
help="EMA on dog actions (0=none, 0.55=Webots match).")
|
||
parser.add_argument("--wheel-slip-std", type=float, default=0.0,
|
||
help="Gaussian wheel-speed noise std (rad/s).")
|
||
args = parser.parse_args()
|
||
# --world was already honoured in the early pre-parse above; here we
|
||
# just sanity-check that the final argparse view agrees.
|
||
if args.world is not None:
|
||
from herding.world.geometry import FIELD_SHAPE as _CURRENT_SHAPE
|
||
if args.world != _CURRENT_SHAPE:
|
||
print(f"[rl] WARNING: --world={args.world} but geometry is "
|
||
f"'{_CURRENT_SHAPE}'. File a bug.")
|
||
|
||
bc_zip = Path(args.bc) / "policy.zip"
|
||
if not bc_zip.exists():
|
||
raise SystemExit(
|
||
f"BC checkpoint not found at {bc_zip}. Train bc first with "
|
||
f"`python -m training.bc.pretrain`."
|
||
)
|
||
|
||
out = Path(args.out)
|
||
out.mkdir(parents=True, exist_ok=True)
|
||
(out / "checkpoints").mkdir(exist_ok=True)
|
||
(out / "best").mkdir(exist_ok=True)
|
||
|
||
# Infer frame_stack from the BC checkpoint's obs space.
|
||
ref_only = PPO.load(str(bc_zip), device=args.device)
|
||
obs_dim = int(ref_only.observation_space.shape[0])
|
||
if obs_dim % OBS_DIM != 0:
|
||
raise SystemExit(f"BC obs dim {obs_dim} is not a multiple of {OBS_DIM}.")
|
||
frame_stack = obs_dim // OBS_DIM
|
||
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
|
||
|
||
# Infer drive mode from BC action dim if not explicitly set.
|
||
bc_action_dim = int(ref_only.action_space.shape[0])
|
||
if args.drive_mode is not None:
|
||
drive_mode = args.drive_mode
|
||
elif bc_action_dim == 3:
|
||
drive_mode = "mecanum"
|
||
else:
|
||
drive_mode = "differential"
|
||
print(f"[rl] drive_mode={drive_mode} (BC action_dim={bc_action_dim})")
|
||
|
||
from herding.config import (
|
||
HerdingConfig, HERDING_MEC_WEBOTS_360, DomainRandomConfig, RobotConfig,
|
||
)
|
||
herding_cfg = None
|
||
# Mecanum trains under HERDING_MEC_WEBOTS_360 (360° LiDAR +
|
||
# kinematic-matched strafe scaling + small compass-noise DR).
|
||
is_mecanum = (drive_mode == "mecanum")
|
||
if is_mecanum or args.fp_rate > 0.0 or args.action_smooth > 0.0 or args.wheel_slip_std > 0.0:
|
||
if is_mecanum:
|
||
base = HERDING_MEC_WEBOTS_360
|
||
strafe_eff = base.robot.strafe_efficiency
|
||
strafe_bleed = base.robot.strafe_to_forward_bleed
|
||
compass_std = 0.1 # heading robustness DR
|
||
else:
|
||
base = None
|
||
strafe_eff = 1.0
|
||
strafe_bleed = 0.0
|
||
compass_std = 0.0
|
||
if is_mecanum:
|
||
herding_cfg = base.replace(
|
||
domain_random=DomainRandomConfig(
|
||
fp_rate=args.fp_rate,
|
||
wheel_slip_std=args.wheel_slip_std,
|
||
compass_noise_std=compass_std,
|
||
),
|
||
robot=RobotConfig(
|
||
action_smooth=args.action_smooth,
|
||
strafe_efficiency=strafe_eff,
|
||
strafe_to_forward_bleed=strafe_bleed,
|
||
),
|
||
)
|
||
else:
|
||
herding_cfg = HerdingConfig(
|
||
domain_random=DomainRandomConfig(
|
||
fp_rate=args.fp_rate,
|
||
wheel_slip_std=args.wheel_slip_std,
|
||
),
|
||
robot=RobotConfig(
|
||
action_smooth=args.action_smooth,
|
||
strafe_efficiency=strafe_eff,
|
||
strafe_to_forward_bleed=strafe_bleed,
|
||
),
|
||
)
|
||
print(f"[rl] domain-random: fp_rate={args.fp_rate} "
|
||
f"action_smooth={args.action_smooth} "
|
||
f"wheel_slip_std={args.wheel_slip_std} "
|
||
f"strafe_eff={strafe_eff:.2f} strafe_bleed={strafe_bleed:.2f} "
|
||
f"compass_noise={compass_std}")
|
||
|
||
env_fns = [_make_env(i, args.seed, frame_stack, drive_mode,
|
||
difficulty=args.difficulty,
|
||
max_n_sheep=args.max_n_sheep,
|
||
herding_cfg=herding_cfg)
|
||
for i in range(args.n_envs)]
|
||
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
||
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack,
|
||
drive_mode,
|
||
difficulty=args.difficulty,
|
||
max_n_sheep=args.max_n_sheep,
|
||
herding_cfg=herding_cfg)])
|
||
print(f"[rl] difficulty={args.difficulty} max_n_sheep={args.max_n_sheep}")
|
||
|
||
# Reward-shaping overrides (broadcast to every env instance).
|
||
def _broadcast(method: str, value):
|
||
for v in (venv, eval_venv):
|
||
try:
|
||
v.env_method(method, value)
|
||
except AttributeError:
|
||
v.venv.env_method(method, value)
|
||
if args.imitate_weight is not None:
|
||
_broadcast("set_imitate_weight", args.imitate_weight)
|
||
print(f"[rl] W_IMITATE overridden to {args.imitate_weight}")
|
||
if args.time_weight is not None:
|
||
_broadcast("set_time_weight", args.time_weight)
|
||
print(f"[rl] W_TIME overridden to {args.time_weight}")
|
||
|
||
# Build a fresh KLPPO at the right obs/action shape, then copy BC
|
||
# weights into both the trainable policy and the frozen reference.
|
||
model = KLPPO(
|
||
"MlpPolicy", venv,
|
||
ref_policy=None, # filled in below
|
||
kl_coef=args.kl_coef,
|
||
learning_rate=args.learning_rate,
|
||
n_steps=args.n_steps,
|
||
batch_size=args.batch_size,
|
||
n_epochs=args.n_epochs,
|
||
gamma=args.gamma,
|
||
gae_lambda=args.gae_lambda,
|
||
clip_range=args.clip_range,
|
||
ent_coef=args.ent_coef,
|
||
target_kl=args.target_kl,
|
||
policy_kwargs=dict(
|
||
net_arch=dict(pi=[512, 512], vf=[512, 512]),
|
||
log_std_init=args.log_std,
|
||
),
|
||
verbose=1,
|
||
seed=args.seed,
|
||
device=args.device,
|
||
tensorboard_log=str(out / "tb"),
|
||
)
|
||
|
||
# strict=False — the BC value head wasn't trained; PPO trains it.
|
||
bc_state = ref_only.policy.state_dict()
|
||
missing, unexpected = model.policy.load_state_dict(bc_state, strict=False)
|
||
print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}")
|
||
|
||
ref_policy = type(model.policy)(
|
||
observation_space=model.observation_space,
|
||
action_space=model.action_space,
|
||
lr_schedule=lambda _: 0.0,
|
||
net_arch=dict(pi=[512, 512], vf=[512, 512]),
|
||
log_std_init=args.log_std,
|
||
).to(args.device)
|
||
ref_policy.load_state_dict(bc_state, strict=False)
|
||
model.ref_policy = ref_policy
|
||
model.ref_policy.set_training_mode(False)
|
||
for p in model.ref_policy.parameters():
|
||
p.requires_grad = False
|
||
|
||
# Force both policies to the same log_std so the KL term measures
|
||
# mean drift only, not a std mismatch carried over from BC.
|
||
with th.no_grad():
|
||
model.policy.log_std.fill_(args.log_std)
|
||
model.ref_policy.log_std.fill_(args.log_std)
|
||
if args.freeze_log_std:
|
||
model.policy.log_std.requires_grad = False
|
||
print(f"[rl] log_std frozen at {args.log_std} (σ ≈ {np.exp(args.log_std):.3f})")
|
||
|
||
ckpt_cb = CheckpointCallback(
|
||
save_freq=max(1, 50_000 // args.n_envs),
|
||
save_path=str(out / "checkpoints"),
|
||
name_prefix="ppo",
|
||
)
|
||
# EvalCallback writes <save_path>/best_model.zip on every new best
|
||
# eval reward. We send it straight to ``out/`` and rename to
|
||
# ``policy.zip`` after training so the deployed file lives at the
|
||
# canonical path.
|
||
eval_cb = EvalCallback(
|
||
eval_venv,
|
||
best_model_save_path=str(out),
|
||
log_path=str(out / "evals"),
|
||
eval_freq=max(1, 20_000 // args.n_envs),
|
||
n_eval_episodes=5,
|
||
deterministic=True,
|
||
)
|
||
|
||
print(f"[rl] training: total_timesteps={args.total_timesteps} "
|
||
f"n_envs={args.n_envs} lr={args.learning_rate} kl_coef={args.kl_coef}")
|
||
model.learn(total_timesteps=args.total_timesteps,
|
||
callback=[ckpt_cb, eval_cb], progress_bar=True)
|
||
|
||
# Save the end-of-training state for debugging convergence behaviour.
|
||
model.save(out / "final.zip")
|
||
|
||
# Promote the EvalCallback's best-by-eval-reward snapshot to the
|
||
# canonical ``policy.zip`` (what the controller loads). Fall back
|
||
# to the final state if eval never recorded a "best".
|
||
import shutil
|
||
best_zip = out / "best_model.zip"
|
||
policy_zip = out / "policy.zip"
|
||
if best_zip.exists():
|
||
if policy_zip.exists():
|
||
policy_zip.unlink()
|
||
best_zip.rename(policy_zip)
|
||
print(f"[rl] best snapshot → {policy_zip} (final state kept at {out/'final.zip'})")
|
||
else:
|
||
shutil.copy(out / "final.zip", policy_zip)
|
||
print(f"[rl] no best snapshot recorded; using final → {policy_zip}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|