"""Behavior cloning of an analytic teacher into an SB3-compatible policy. Trains the policy network (mean-action head) of an SB3 ``MlpPolicy`` to mimic the (obs, action) demonstrations produced by ``tools.collect_demos``. The saved zip is loadable via ``PPO.load(...)`` and is what the Webots dog controller uses in ``HERDING_MODE=rl``. Loss: MSE + (1 - cosine similarity). The cosine term is what stops the policy mean from collapsing toward zero against unit-vector targets. Best-by-val_cos checkpoint is restored at the end of training so noisy multi-modal teachers (e.g. Strömbom) don't lose progress when the last epoch lands on a bad gradient step. Usage:: python -m training.bc_pretrain \\ --demos training/demos.npz \\ --out training/runs/bc_flock """ from __future__ import annotations import argparse import os import sys import time from pathlib import Path _HERE = os.path.dirname(os.path.abspath(__file__)) _PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..")) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from stable_baselines3 import PPO from stable_baselines3.common.vec_env import DummyVecEnv from training.herding_env import HerdingEnv def build_model(net_arch_pi, net_arch_vf, log_std_init: float, frame_stack: int = 1): """Build a fresh SB3 PPO solely as a vehicle for the policy weights. PPO's training-loop plumbing isn't used during BC. ``frame_stack`` must match the demo file so the env's obs space agrees with the recorded obs shape. """ env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack)]) model = PPO( "MlpPolicy", env, policy_kwargs=dict( net_arch=dict(pi=net_arch_pi, vf=net_arch_vf), log_std_init=log_std_init, ), verbose=0, ) return model, env def policy_forward_mean(policy, obs_batch): """Return the policy's deterministic mean action for a batch. SB3's ActorCriticPolicy doesn't expose this directly — it goes through a Distribution wrapper. We replicate the forward path: extract_features → mlp_extractor → action_net. """ features = policy.extract_features(obs_batch) if isinstance(features, tuple): # SB3 ≥ 2.0 sometimes returns (pi_features, vf_features) pi_features = features[0] else: pi_features = features latent_pi, _latent_vf = policy.mlp_extractor(pi_features) return policy.action_net(latent_pi) def main(): parser = argparse.ArgumentParser() parser.add_argument("--demos", default="training/demos.npz") parser.add_argument("--out", default="training/runs/bc_solo") parser.add_argument("--epochs", type=int, default=60) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--val-split", type=float, default=0.1) parser.add_argument("--net-arch", default="256,256", help="Comma-separated hidden layer widths.") parser.add_argument("--log-std-init", type=float, default=0.5) parser.add_argument("--cos-weight", type=float, default=1.0, help="Weight on (1 - cosine similarity) loss term. " "MSE alone shrinks policy output toward zero " "(zero-magnitude action minimises mean squared " "error against ±1 targets); cos loss keeps " "the action pointed correctly even at small " "magnitudes.") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--device", default="cpu") args = parser.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) # --- Load demos --- print(f"[bc] loading demos from {args.demos}") data = np.load(args.demos) obs = data["obs"].astype(np.float32) actions = data["actions"].astype(np.float32) meta = data["meta"] print(f"[bc] obs={obs.shape} actions={actions.shape} trajectories={len(meta)}") if obs.size == 0: raise RuntimeError("Empty demo file.") # Action sanity check — sequential outputs unit vectors. a_norms = np.linalg.norm(actions, axis=1) print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} " f"min={a_norms.min():.3f} max={a_norms.max():.3f}") # --- Train/val split --- n = len(obs) perm = np.random.permutation(n) n_val = int(n * args.val_split) val_idx, train_idx = perm[:n_val], perm[n_val:] print(f"[bc] train={len(train_idx)} val={len(val_idx)}") obs_t = torch.from_numpy(obs) act_t = torch.from_numpy(actions) train_loader = DataLoader( TensorDataset(obs_t[train_idx], act_t[train_idx]), batch_size=args.batch_size, shuffle=True, ) val_loader = DataLoader( TensorDataset(obs_t[val_idx], act_t[val_idx]), batch_size=args.batch_size, shuffle=False, ) # --- Build model --- net_arch_pi = [int(x) for x in args.net_arch.split(",")] net_arch_vf = net_arch_pi[:] # Auto-detect frame stacking from the demo file so a stacked-obs # demo trains a stacked-obs policy without an extra CLI flag. obs_dim = obs.shape[1] from herding.obs import OBS_DIM as _SINGLE if obs_dim % _SINGLE != 0: raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}") frame_stack = obs_dim // _SINGLE if frame_stack > 1: print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}") model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init, frame_stack=frame_stack) policy = model.policy.to(args.device) optimizer = optim.Adam(policy.parameters(), lr=args.lr) # --- Train --- print(f"[bc] training: epochs={args.epochs} batch={args.batch_size} " f"lr={args.lr} device={args.device}") t_start = time.time() best_val = float("inf") best_cos = -1.0 # Snapshot the best-by-val_cos policy weights and restore at the end — # training is noisy on multi-modal teachers (e.g. Strömbom collect/drive), # so the last epoch is often worse than an earlier one. best_state = None def combined_loss(pred, target): mse = nn.functional.mse_loss(pred, target) p_norm = pred.norm(dim=1).clamp_min(1e-6) t_norm = target.norm(dim=1).clamp_min(1e-6) cos_sim = (pred * target).sum(dim=1) / (p_norm * t_norm) cos_loss = (1.0 - cos_sim).mean() return mse + args.cos_weight * cos_loss, mse.item(), cos_sim.mean().item() for epoch in range(args.epochs): policy.train() train_loss_total, train_mse_total, train_cos_total, train_count = 0.0, 0.0, 0.0, 0 for ob_batch, act_batch in train_loader: ob_batch = ob_batch.to(args.device) act_batch = act_batch.to(args.device) optimizer.zero_grad() mean_action = policy_forward_mean(policy, ob_batch) loss, mse_val, cos_val = combined_loss(mean_action, act_batch) loss.backward() optimizer.step() bs = ob_batch.size(0) train_loss_total += loss.item() * bs train_mse_total += mse_val * bs train_cos_total += cos_val * bs train_count += bs train_mse = train_mse_total / max(1, train_count) train_cos = train_cos_total / max(1, train_count) policy.eval() val_total, val_count = 0.0, 0 cos_sim_total = 0.0 with torch.no_grad(): for ob_batch, act_batch in val_loader: ob_batch = ob_batch.to(args.device) act_batch = act_batch.to(args.device) mean_action = policy_forward_mean(policy, ob_batch) bs = ob_batch.size(0) val_total += nn.functional.mse_loss( mean_action, act_batch, reduction="sum", ).item() # Cosine similarity in action space — useful sanity for # "is the policy pointing the same way as the teacher?". m_norm = mean_action.norm(dim=1).clamp_min(1e-6) a_norm = act_batch.norm(dim=1).clamp_min(1e-6) cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm) cos_sim_total += cos.sum().item() val_count += bs val_mse = val_total / max(1, val_count) / actions.shape[1] cos_sim = cos_sim_total / max(1, val_count) print(f" epoch {epoch+1:>2d}/{args.epochs} " f"train_mse={train_mse:.4f} train_cos={train_cos:+.3f} " f"val_mse={val_mse:.4f} val_cos={cos_sim:+.3f}") if val_mse < best_val: best_val = val_mse if cos_sim > best_cos: best_cos = cos_sim best_state = {k: v.detach().cpu().clone() for k, v in policy.state_dict().items()} if best_state is not None: policy.load_state_dict(best_state) print(f"[bc] restored best-val_cos snapshot (cos={best_cos:.3f})") elapsed = time.time() - t_start print(f"[bc] done in {elapsed:.0f}s best_val_mse={best_val:.4f}") # --- Save --- out_dir = Path(args.out) out_dir.mkdir(parents=True, exist_ok=True) model.save(out_dir / "policy.zip") print(f"[bc] saved policy to {out_dir / 'policy.zip'}") print(f"\n[bc] verify with: " f"python -m training.eval --policy {out_dir}") if __name__ == "__main__": main()