"""Merge Webots DAgger demos with sim demos and retrain the BC policy. The dog controller in ``HERDING_MODE=dagger`` writes per-run files to ``training/dagger/dagger_.npz`` containing ``(obs, actions)`` pairs where: * ``obs`` is the **stacked LiDAR observation** as built by the live Webots tracker — exactly the input distribution the deployed controller sees. * ``actions`` is the **active-scan-teacher action computed from ground-truth sheep positions** (read off the sheep emitter). Combined with the existing sim demos (``training/demos.npz`` by default), this gives the BC student a training set that includes the real Webots false-positive distribution — closing the sim-to-real perception gap that the all-sim pipeline couldn't bridge. Usage:: # Iteration 1 — merge all dagger files with sim demos, retrain python -m tools.dagger_merge_train \\ --sim training/demos.npz \\ --out training/runs/bc_dagger1 # Iteration 2 — drop the sim baseline, train only on Webots data python -m tools.dagger_merge_train --no-sim --out training/runs/bc_dagger2 The new policy is saved as ``/policy.zip`` and is auto-loaded by the controller's resolution priority on the next Webots run. """ from __future__ import annotations import argparse import glob import os import subprocess import sys from pathlib import Path _HERE = os.path.dirname(os.path.abspath(__file__)) _PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..")) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) import numpy as np def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--sim", default="training/demos.npz", help="Sim demo file to mix with the Webots data. " "Pass --no-sim to train only on dagger data.") parser.add_argument("--no-sim", action="store_true", help="Skip the sim demos entirely.") parser.add_argument("--dagger-glob", default="training/dagger/dagger_*.npz", help="Glob for Webots-collected dagger files.") parser.add_argument("--merged-out", default="training/demos_dagger.npz", help="Where to write the merged demo file.") parser.add_argument("--out", default="training/runs/bc_dagger", help="Where to write the BC policy.") parser.add_argument("--epochs", type=int, default=60) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--net-arch", default="512,512") parser.add_argument("--cos-weight", type=float, default=1.0) args = parser.parse_args() # --- Gather Webots files --- dagger_paths = sorted(glob.glob(args.dagger_glob)) if not dagger_paths: raise SystemExit(f"No dagger files found at {args.dagger_glob} — " "run Webots in HERDING_MODE=dagger first.") chunks_obs: list[np.ndarray] = [] chunks_act: list[np.ndarray] = [] total_dagger = 0 for p in dagger_paths: data = np.load(p) obs = data["obs"].astype(np.float32) act = data["actions"].astype(np.float32) chunks_obs.append(obs) chunks_act.append(act) total_dagger += len(obs) print(f" + {p}: {obs.shape[0]} pairs (obs dim {obs.shape[1]})") print(f"[merge] total dagger pairs: {total_dagger}") obs_dim = chunks_obs[0].shape[1] if any(c.shape[1] != obs_dim for c in chunks_obs): raise SystemExit( "Dagger files have inconsistent obs dims — they were collected " "with different frame_stack settings. Either rerun with a " "consistent setting or filter the glob." ) # --- Optionally include sim demos --- if not args.no_sim: sim = np.load(args.sim) sim_obs = sim["obs"].astype(np.float32) sim_act = sim["actions"].astype(np.float32) if sim_obs.shape[1] != obs_dim: raise SystemExit( f"Sim demos have obs dim {sim_obs.shape[1]} but dagger demos " f"have {obs_dim}. Recollect sim demos at the same frame_stack." ) chunks_obs.append(sim_obs) chunks_act.append(sim_act) print(f"[merge] + sim demos: {sim_obs.shape[0]} pairs from {args.sim}") obs_all = np.concatenate(chunks_obs, axis=0) act_all = np.concatenate(chunks_act, axis=0) # Empty meta — bc_pretrain doesn't actually use it but the file format # has it. meta = np.zeros((0, 5), dtype=np.int32) Path(args.merged_out).parent.mkdir(parents=True, exist_ok=True) np.savez(args.merged_out, obs=obs_all, actions=act_all, meta=meta) print(f"[merge] wrote {len(obs_all)} pairs → {args.merged_out}") print(f"[merge] obs shape {obs_all.shape}, action shape {act_all.shape}") # --- Run BC training --- cmd = [ sys.executable, "-m", "training.bc_pretrain", "--demos", args.merged_out, "--out", args.out, "--epochs", str(args.epochs), "--batch-size", str(args.batch_size), "--net-arch", args.net_arch, "--cos-weight", str(args.cos_weight), ] print(f"\n[merge] launching: {' '.join(cmd)}") subprocess.run(cmd, check=True, cwd=_PROJECT_ROOT) if __name__ == "__main__": main()