"""Behaviour cloning of an analytic teacher into an SB3 MlpPolicy. Trains the mean-action head against ``(obs, action)`` demos from ``training.bc.collect`` using ``MSE + (1 − cos_sim)`` — the cosine term prevents collapse toward zero against unit-vector targets. The best-by-val_cos snapshot is restored at the end of training because multi-modal teachers make the last epoch unreliable. Output zip is loadable by ``PPO.load(...)`` and consumed by ``HERDING_MODE=bc`` in the dog controller. Usage:: python -m training.bc.pretrain \\ --demos training/bc/demos.npz \\ --out training/runs/bc """ from __future__ import annotations import argparse import time from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from stable_baselines3 import PPO from stable_baselines3.common.vec_env import DummyVecEnv from training.herding_env import HerdingEnv def build_model(net_arch_pi, net_arch_vf, log_std_init: float, frame_stack: int = 1, drive_mode: str = "differential"): """Build a fresh SB3 PPO solely as a vehicle for the policy weights. PPO's training-loop plumbing isn't used during BC. ``frame_stack`` must match the demo file so the env's obs space agrees with the recorded obs shape. """ env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack, drive_mode=drive_mode)]) model = PPO( "MlpPolicy", env, policy_kwargs=dict( net_arch=dict(pi=net_arch_pi, vf=net_arch_vf), log_std_init=log_std_init, ), verbose=0, ) return model, env def policy_forward_mean(policy, obs_batch): """Return the deterministic mean action for an obs batch. SB3's ActorCriticPolicy routes ``forward`` through a Distribution wrapper; we replicate the underlying chain ``extract_features → mlp_extractor → action_net``. """ features = policy.extract_features(obs_batch) pi_features = features[0] if isinstance(features, tuple) else features latent_pi, _ = policy.mlp_extractor(pi_features) return policy.action_net(latent_pi) def main(): parser = argparse.ArgumentParser() parser.add_argument("--demos", default="training/bc/demos.npz") parser.add_argument("--out", default="training/runs/bc") parser.add_argument("--epochs", type=int, default=60) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--val-split", type=float, default=0.1) parser.add_argument("--net-arch", default="256,256", help="Comma-separated hidden layer widths.") parser.add_argument("--log-std-init", type=float, default=0.5) parser.add_argument("--cos-weight", type=float, default=1.0, help="Weight of the (1 - cosine_similarity) loss " "term; balances against MSE.") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--device", default="cpu") parser.add_argument("--drive-mode", default=None, choices=["differential", "mecanum"], help="Drive mode. If not set, inferred from " "demo action dimension (2→differential, 3→mecanum).") args = parser.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) # --- Load demos --- print(f"[bc] loading demos from {args.demos}") data = np.load(args.demos) obs = data["obs"].astype(np.float32) actions = data["actions"].astype(np.float32) meta = data["meta"] print(f"[bc] obs={obs.shape} actions={actions.shape} trajectories={len(meta)}") if obs.size == 0: raise RuntimeError("Empty demo file.") a_norms = np.linalg.norm(actions, axis=1) print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} " f"min={a_norms.min():.3f} max={a_norms.max():.3f}") # --- Train/val split --- n = len(obs) perm = np.random.permutation(n) n_val = int(n * args.val_split) val_idx, train_idx = perm[:n_val], perm[n_val:] print(f"[bc] train={len(train_idx)} val={len(val_idx)}") obs_t = torch.from_numpy(obs) act_t = torch.from_numpy(actions) train_loader = DataLoader( TensorDataset(obs_t[train_idx], act_t[train_idx]), batch_size=args.batch_size, shuffle=True, ) val_loader = DataLoader( TensorDataset(obs_t[val_idx], act_t[val_idx]), batch_size=args.batch_size, shuffle=False, ) net_arch_pi = [int(x) for x in args.net_arch.split(",")] net_arch_vf = net_arch_pi[:] # Frame stack is inferred from the demo obs dim. obs_dim = obs.shape[1] from herding.perception.obs import OBS_DIM as _SINGLE if obs_dim % _SINGLE != 0: raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}") frame_stack = obs_dim // _SINGLE if frame_stack > 1: print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}") # Infer drive mode from action dimension if not explicitly set. action_dim = actions.shape[1] if args.drive_mode is not None: drive_mode = args.drive_mode elif action_dim == 3: drive_mode = "mecanum" else: drive_mode = "differential" print(f"[bc] drive_mode={drive_mode} (action_dim={action_dim})") model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init, frame_stack=frame_stack, drive_mode=drive_mode) policy = model.policy.to(args.device) optimizer = optim.Adam(policy.parameters(), lr=args.lr) # --- Train --- print(f"[bc] training: epochs={args.epochs} batch={args.batch_size} " f"lr={args.lr} device={args.device}") t_start = time.time() best_val = float("inf") best_cos = -1.0 best_state = None # restored at the end so noisy last epochs don't win def combined_loss(pred, target): mse = nn.functional.mse_loss(pred, target) p_norm = pred.norm(dim=1).clamp_min(1e-6) t_norm = target.norm(dim=1).clamp_min(1e-6) cos_sim = (pred * target).sum(dim=1) / (p_norm * t_norm) cos_loss = (1.0 - cos_sim).mean() return mse + args.cos_weight * cos_loss, mse.item(), cos_sim.mean().item() for epoch in range(args.epochs): policy.train() train_loss_total, train_mse_total, train_cos_total, train_count = 0.0, 0.0, 0.0, 0 for ob_batch, act_batch in train_loader: ob_batch = ob_batch.to(args.device) act_batch = act_batch.to(args.device) optimizer.zero_grad() mean_action = policy_forward_mean(policy, ob_batch) loss, mse_val, cos_val = combined_loss(mean_action, act_batch) loss.backward() optimizer.step() bs = ob_batch.size(0) train_loss_total += loss.item() * bs train_mse_total += mse_val * bs train_cos_total += cos_val * bs train_count += bs train_mse = train_mse_total / max(1, train_count) train_cos = train_cos_total / max(1, train_count) policy.eval() val_total, val_count = 0.0, 0 cos_sim_total = 0.0 with torch.no_grad(): for ob_batch, act_batch in val_loader: ob_batch = ob_batch.to(args.device) act_batch = act_batch.to(args.device) mean_action = policy_forward_mean(policy, ob_batch) bs = ob_batch.size(0) val_total += nn.functional.mse_loss( mean_action, act_batch, reduction="sum", ).item() m_norm = mean_action.norm(dim=1).clamp_min(1e-6) a_norm = act_batch.norm(dim=1).clamp_min(1e-6) cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm) cos_sim_total += cos.sum().item() val_count += bs val_mse = val_total / max(1, val_count) / actions.shape[1] cos_sim = cos_sim_total / max(1, val_count) print(f" epoch {epoch+1:>2d}/{args.epochs} " f"train_mse={train_mse:.4f} train_cos={train_cos:+.3f} " f"val_mse={val_mse:.4f} val_cos={cos_sim:+.3f}") if val_mse < best_val: best_val = val_mse if cos_sim > best_cos: best_cos = cos_sim best_state = {k: v.detach().cpu().clone() for k, v in policy.state_dict().items()} if best_state is not None: policy.load_state_dict(best_state) print(f"[bc] restored best-val_cos snapshot (cos={best_cos:.3f})") elapsed = time.time() - t_start print(f"[bc] done in {elapsed:.0f}s best_val_mse={best_val:.4f}") # --- Save --- out_dir = Path(args.out) out_dir.mkdir(parents=True, exist_ok=True) model.save(out_dir / "policy.zip") print(f"[bc] saved policy to {out_dir / 'policy.zip'}") print(f"\n[bc] verify with: " f"python -m training.eval --policy {out_dir}") if __name__ == "__main__": main()