Checkpoint 4
This commit is contained in:
@@ -0,0 +1,135 @@
|
||||
"""Merge Webots DAgger demos with sim demos and retrain the BC policy.
|
||||
|
||||
The dog controller in ``HERDING_MODE=dagger`` writes per-run files to
|
||||
``training/dagger/dagger_<ts>.npz`` containing ``(obs, actions)`` pairs
|
||||
where:
|
||||
|
||||
* ``obs`` is the **stacked LiDAR observation** as built by the live
|
||||
Webots tracker — exactly the input distribution the deployed
|
||||
controller sees.
|
||||
* ``actions`` is the **active-scan-teacher action computed from
|
||||
ground-truth sheep positions** (read off the sheep emitter).
|
||||
|
||||
Combined with the existing sim demos (``training/demos_v3.npz`` by
|
||||
default), this gives the BC student a training set that includes the
|
||||
real Webots false-positive distribution — closing the sim-to-real
|
||||
perception gap that the all-sim pipeline couldn't bridge.
|
||||
|
||||
Usage::
|
||||
|
||||
# Iteration 1 — merge all dagger files with sim demos, retrain
|
||||
python -m tools.dagger_merge_train \\
|
||||
--sim training/demos_v3.npz \\
|
||||
--out training/runs/bc_dagger1
|
||||
|
||||
# Iteration 2 — drop the sim baseline, train only on Webots data
|
||||
python -m tools.dagger_merge_train --no-sim --out training/runs/bc_dagger2
|
||||
|
||||
The new policy is saved as ``<out>/policy.zip`` and is auto-loaded by
|
||||
the controller's resolution priority on the next Webots run.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sim", default="training/demos_v3.npz",
|
||||
help="Sim demo file to mix with the Webots data. "
|
||||
"Pass --no-sim to train only on dagger data.")
|
||||
parser.add_argument("--no-sim", action="store_true",
|
||||
help="Skip the sim demos entirely.")
|
||||
parser.add_argument("--dagger-glob", default="training/dagger/dagger_*.npz",
|
||||
help="Glob for Webots-collected dagger files.")
|
||||
parser.add_argument("--merged-out", default="training/demos_dagger.npz",
|
||||
help="Where to write the merged demo file.")
|
||||
parser.add_argument("--out", default="training/runs/bc_dagger",
|
||||
help="Where to write the BC policy.")
|
||||
parser.add_argument("--epochs", type=int, default=60)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--net-arch", default="512,512")
|
||||
parser.add_argument("--cos-weight", type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Gather Webots files ---
|
||||
dagger_paths = sorted(glob.glob(args.dagger_glob))
|
||||
if not dagger_paths:
|
||||
raise SystemExit(f"No dagger files found at {args.dagger_glob} — "
|
||||
"run Webots in HERDING_MODE=dagger first.")
|
||||
|
||||
chunks_obs: list[np.ndarray] = []
|
||||
chunks_act: list[np.ndarray] = []
|
||||
total_dagger = 0
|
||||
for p in dagger_paths:
|
||||
data = np.load(p)
|
||||
obs = data["obs"].astype(np.float32)
|
||||
act = data["actions"].astype(np.float32)
|
||||
chunks_obs.append(obs)
|
||||
chunks_act.append(act)
|
||||
total_dagger += len(obs)
|
||||
print(f" + {p}: {obs.shape[0]} pairs (obs dim {obs.shape[1]})")
|
||||
print(f"[merge] total dagger pairs: {total_dagger}")
|
||||
|
||||
obs_dim = chunks_obs[0].shape[1]
|
||||
if any(c.shape[1] != obs_dim for c in chunks_obs):
|
||||
raise SystemExit(
|
||||
"Dagger files have inconsistent obs dims — they were collected "
|
||||
"with different frame_stack settings. Either rerun with a "
|
||||
"consistent setting or filter the glob."
|
||||
)
|
||||
|
||||
# --- Optionally include sim demos ---
|
||||
if not args.no_sim:
|
||||
sim = np.load(args.sim)
|
||||
sim_obs = sim["obs"].astype(np.float32)
|
||||
sim_act = sim["actions"].astype(np.float32)
|
||||
if sim_obs.shape[1] != obs_dim:
|
||||
raise SystemExit(
|
||||
f"Sim demos have obs dim {sim_obs.shape[1]} but dagger demos "
|
||||
f"have {obs_dim}. Recollect sim demos at the same frame_stack."
|
||||
)
|
||||
chunks_obs.append(sim_obs)
|
||||
chunks_act.append(sim_act)
|
||||
print(f"[merge] + sim demos: {sim_obs.shape[0]} pairs from {args.sim}")
|
||||
|
||||
obs_all = np.concatenate(chunks_obs, axis=0)
|
||||
act_all = np.concatenate(chunks_act, axis=0)
|
||||
# Empty meta — bc_pretrain doesn't actually use it but the file format
|
||||
# has it.
|
||||
meta = np.zeros((0, 5), dtype=np.int32)
|
||||
|
||||
Path(args.merged_out).parent.mkdir(parents=True, exist_ok=True)
|
||||
np.savez(args.merged_out, obs=obs_all, actions=act_all, meta=meta)
|
||||
print(f"[merge] wrote {len(obs_all)} pairs → {args.merged_out}")
|
||||
print(f"[merge] obs shape {obs_all.shape}, action shape {act_all.shape}")
|
||||
|
||||
# --- Run BC training ---
|
||||
cmd = [
|
||||
sys.executable, "-m", "training.bc_pretrain",
|
||||
"--demos", args.merged_out,
|
||||
"--out", args.out,
|
||||
"--epochs", str(args.epochs),
|
||||
"--batch-size", str(args.batch_size),
|
||||
"--net-arch", args.net_arch,
|
||||
"--cos-weight", str(args.cos_weight),
|
||||
]
|
||||
print(f"\n[merge] launching: {' '.join(cmd)}")
|
||||
subprocess.run(cmd, check=True, cwd=_PROJECT_ROOT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user