Files
TIR_PROJ/tools/dagger_merge_train.py
T
2026-05-11 10:35:39 +01:00

136 lines
5.2 KiB
Python

"""Merge Webots DAgger demos with sim demos and retrain the BC policy.
The dog controller in ``HERDING_MODE=dagger`` writes per-run files to
``training/dagger/dagger_<ts>.npz`` containing ``(obs, actions)`` pairs
where:
* ``obs`` is the **stacked LiDAR observation** as built by the live
Webots tracker — exactly the input distribution the deployed
controller sees.
* ``actions`` is the **active-scan-teacher action computed from
ground-truth sheep positions** (read off the sheep emitter).
Combined with the existing sim demos (``training/demos.npz`` by
default), this gives the BC student a training set that includes the
real Webots false-positive distribution — closing the sim-to-real
perception gap that the all-sim pipeline couldn't bridge.
Usage::
# Iteration 1 — merge all dagger files with sim demos, retrain
python -m tools.dagger_merge_train \\
--sim training/demos.npz \\
--out training/runs/bc_dagger1
# Iteration 2 — drop the sim baseline, train only on Webots data
python -m tools.dagger_merge_train --no-sim --out training/runs/bc_dagger2
The new policy is saved as ``<out>/policy.zip`` and is auto-loaded by
the controller's resolution priority on the next Webots run.
"""
from __future__ import annotations
import argparse
import glob
import os
import subprocess
import sys
from pathlib import Path
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--sim", default="training/demos.npz",
help="Sim demo file to mix with the Webots data. "
"Pass --no-sim to train only on dagger data.")
parser.add_argument("--no-sim", action="store_true",
help="Skip the sim demos entirely.")
parser.add_argument("--dagger-glob", default="training/dagger/dagger_*.npz",
help="Glob for Webots-collected dagger files.")
parser.add_argument("--merged-out", default="training/demos_dagger.npz",
help="Where to write the merged demo file.")
parser.add_argument("--out", default="training/runs/bc_dagger",
help="Where to write the BC policy.")
parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--net-arch", default="512,512")
parser.add_argument("--cos-weight", type=float, default=1.0)
args = parser.parse_args()
# --- Gather Webots files ---
dagger_paths = sorted(glob.glob(args.dagger_glob))
if not dagger_paths:
raise SystemExit(f"No dagger files found at {args.dagger_glob} — "
"run Webots in HERDING_MODE=dagger first.")
chunks_obs: list[np.ndarray] = []
chunks_act: list[np.ndarray] = []
total_dagger = 0
for p in dagger_paths:
data = np.load(p)
obs = data["obs"].astype(np.float32)
act = data["actions"].astype(np.float32)
chunks_obs.append(obs)
chunks_act.append(act)
total_dagger += len(obs)
print(f" + {p}: {obs.shape[0]} pairs (obs dim {obs.shape[1]})")
print(f"[merge] total dagger pairs: {total_dagger}")
obs_dim = chunks_obs[0].shape[1]
if any(c.shape[1] != obs_dim for c in chunks_obs):
raise SystemExit(
"Dagger files have inconsistent obs dims — they were collected "
"with different frame_stack settings. Either rerun with a "
"consistent setting or filter the glob."
)
# --- Optionally include sim demos ---
if not args.no_sim:
sim = np.load(args.sim)
sim_obs = sim["obs"].astype(np.float32)
sim_act = sim["actions"].astype(np.float32)
if sim_obs.shape[1] != obs_dim:
raise SystemExit(
f"Sim demos have obs dim {sim_obs.shape[1]} but dagger demos "
f"have {obs_dim}. Recollect sim demos at the same frame_stack."
)
chunks_obs.append(sim_obs)
chunks_act.append(sim_act)
print(f"[merge] + sim demos: {sim_obs.shape[0]} pairs from {args.sim}")
obs_all = np.concatenate(chunks_obs, axis=0)
act_all = np.concatenate(chunks_act, axis=0)
# Empty meta — bc_pretrain doesn't actually use it but the file format
# has it.
meta = np.zeros((0, 5), dtype=np.int32)
Path(args.merged_out).parent.mkdir(parents=True, exist_ok=True)
np.savez(args.merged_out, obs=obs_all, actions=act_all, meta=meta)
print(f"[merge] wrote {len(obs_all)} pairs → {args.merged_out}")
print(f"[merge] obs shape {obs_all.shape}, action shape {act_all.shape}")
# --- Run BC training ---
cmd = [
sys.executable, "-m", "training.bc_pretrain",
"--demos", args.merged_out,
"--out", args.out,
"--epochs", str(args.epochs),
"--batch-size", str(args.batch_size),
"--net-arch", args.net_arch,
"--cos-weight", str(args.cos_weight),
]
print(f"\n[merge] launching: {' '.join(cmd)}")
subprocess.run(cmd, check=True, cwd=_PROJECT_ROOT)
if __name__ == "__main__":
main()