"""Collect (obs, action) demonstrations from an analytic teacher. Runs the chosen teacher across a grid of ``(n_sheep, seed)`` combos at full difficulty, logs every Nth ``(obs, action)`` pair, and saves successful trajectories to ``.npz`` for behaviour cloning. The teacher is wrapped in :class:`ActiveScanTeacher` by default so it operates on the same partial-obs view the student will have at deployment. Usage:: python -m training.bc.collect --teacher strombom \\ --out training/bc/demos.npz --frame-stack 4 """ from __future__ import annotations import argparse import os import time from pathlib import Path import numpy as np # Configure field geometry before other herding imports read it at module level. from herding.world.geometry import configure_from_args as _configure_from_args _configure_from_args() from herding.control.active_scan import ActiveScanTeacher from herding.world.geometry import PEN_ENTRY, FIELD_SHAPE from herding.control.sequential import compute_action as sequential_action from herding.control.strombom import compute_action as strombom_action from herding.control.universal import compute_action as universal_action from training.herding_env import HerdingEnv TEACHERS = { "sequential": sequential_action, "strombom": strombom_action, "universal": universal_action, } def _call_teacher(fn, dog_xy, dog_heading, sheep_positions, pen_target, drive_mode="differential"): """Call any teacher function and return (vx, vy, omega, mode). Normalizes across 3-tuple teachers (vx, vy, mode) and 4-tuple universal teacher (vx, vy, omega, mode). ActiveScanTeacher (when invoked with drive_mode="mecanum") propagates the base teacher's omega — see test_active_scan_preserves_mecanum_omega. """ # The universal teacher and ActiveScanTeacher accept the extended # (dog_xy, heading, sheep, pen, drive_mode) signature. Older # teachers accept (dog_xy, sheep, pen). Detect by trying the # extended call first. try: result = fn(dog_xy, dog_heading, sheep_positions, pen_target, drive_mode) except TypeError: try: result = fn(dog_xy, dog_heading, sheep_positions, pen_target) except TypeError: result = fn(dog_xy, sheep_positions, pen_target) if len(result) == 4: return result # (vx, vy, omega, mode) vx, vy, mode = result return vx, vy, 0.0, mode def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int, teacher_fn, frame_stack: int = 1, privileged: bool = False, drive_mode: str = "differential", herding_cfg=None, actor_policy=None): """Collect (obs, teacher_action) pairs from one episode. ``actor_policy`` (DAgger mode): a callable ``policy(obs) -> action`` that drives the env. The teacher still labels each visited state. If ``None`` (default), the teacher drives. """ env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps, difficulty=1.0, seed=seed, frame_stack=frame_stack, drive_mode=drive_mode, herding_cfg=herding_cfg) obs, _ = env.reset(seed=seed) obs_list, action_list = [], [] scan_teacher = ActiveScanTeacher(teacher_fn) for step in range(max_steps): if privileged: positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i])) for i in range(env.n_sheep) if not env.sheep_penned[i]} if not positions: break vx, vy, omega, _mode = _call_teacher( teacher_fn, (env.dog_x, env.dog_y), env.dog_heading, positions, PEN_ENTRY, drive_mode, ) else: positions = env.perceived_positions() result = _call_teacher( scan_teacher, (env.dog_x, env.dog_y), env.dog_heading, positions, PEN_ENTRY, drive_mode, ) vx, vy, omega, _mode = result if drive_mode == "mecanum": teacher_action = np.array([vx, vy, omega], dtype=np.float32) else: teacher_action = np.array([vx, vy], dtype=np.float32) if step % subsample == 0: obs_list.append(obs.copy()) action_list.append(teacher_action.copy()) # In DAgger mode the policy drives; otherwise the teacher does. step_action = (actor_policy(obs) if actor_policy is not None else teacher_action) obs, _r, term, trunc, _info = env.step(step_action) if term or trunc: break success = bool(env.sheep_penned.all()) return ( np.asarray(obs_list, dtype=np.float32), np.asarray(action_list, dtype=np.float32), success, env.steps, ) def main(): parser = argparse.ArgumentParser() parser.add_argument("--out", default="training/bc/demos.npz") parser.add_argument("--n-sheep-list", default="1,2,3,5,8,10") parser.add_argument("--seeds-per-n", type=int, default=15) parser.add_argument("--max-steps", type=int, default=30000) parser.add_argument("--subsample", type=int, default=5, help="Keep every Nth (obs, action) pair.") parser.add_argument("--keep-failures", action="store_true", help="Include partial-success trajectories. Default off.") parser.add_argument("--teacher", default="universal", choices=list(TEACHERS.keys()), help="Which analytic teacher to demonstrate.") parser.add_argument("--frame-stack", type=int, default=1, help="Concatenate the last K obs into a " "(32·K)-D vector for the policy.") parser.add_argument("--privileged", action="store_true", help="Teacher reads ground truth instead of " "tracker output (asymmetric BC).") parser.add_argument("--drive-mode", default="differential", choices=["differential", "mecanum"], help="Drive mode for the dog robot.") parser.add_argument("--world", default=None, choices=["field", "field_round"], help="World shape. If not set, uses HERDING_WORLD " "env var or defaults to 'field'. Must be set " "before geometry is imported.") # Domain randomisation — applied to the gym env during collection so # the teacher demonstrates under the same noise the policy will face. parser.add_argument("--fp-rate", type=float, default=0.0, help="Mean false-positive detections injected per " "step (Poisson λ). 0 = clean sim (default).") parser.add_argument("--action-smooth", type=float, default=0.0, help="EMA coefficient on dog actions (0 = none). " "Set to 0.55 to match the Webots controller.") parser.add_argument("--wheel-slip-std", type=float, default=0.0, help="Gaussian noise (rad/s) on wheel speeds for " "mecanum dynamics domain randomisation.") parser.add_argument("--dagger-policy", default=None, help="Path to a BC/PPO policy directory. When set, " "the policy drives the env (DAgger) while the " "teacher labels every visited state.") parser.add_argument("--use-webots-preset", action="store_true", help="Use HERDING_WEBOTS preset (140° FOV + tight " "tracker). Match this to deployment for DAgger.") args = parser.parse_args() # Validate --world matches geometry (already configured by the # early pre-parse above, or by HERDING_WORLD env var). if args.world is not None and args.world != FIELD_SHAPE: print(f"[demos] WARNING: --world={args.world} but geometry is " f"'{FIELD_SHAPE}'. This should not happen — file a bug.") from herding.config import ( HerdingConfig, HERDING_WEBOTS, HERDING_MEC_WEBOTS, DomainRandomConfig, RobotConfig, ) if args.use_webots_preset: # Pick the drive-matched Webots preset — for mecanum we use the # variant that simulates the physical-roller proto's strafe # efficiency and forward bleed so the policy trains under the # same imperfect kinematics it sees at deployment. base = HERDING_MEC_WEBOTS if args.drive_mode == "mecanum" else HERDING_WEBOTS herding_cfg = base.replace( domain_random=DomainRandomConfig( fp_rate=args.fp_rate, wheel_slip_std=args.wheel_slip_std, ), robot=RobotConfig( action_smooth=args.action_smooth, strafe_efficiency=base.robot.strafe_efficiency, strafe_to_forward_bleed=base.robot.strafe_to_forward_bleed, ), ) preset_name = "HERDING_MEC_WEBOTS" if args.drive_mode == "mecanum" else "HERDING_WEBOTS" print(f"[demos] {preset_name} preset + DR: fp_rate={args.fp_rate} " f"action_smooth={args.action_smooth} wheel_slip_std={args.wheel_slip_std} " f"strafe_eff={herding_cfg.robot.strafe_efficiency:.2f}") else: herding_cfg = None if args.fp_rate > 0.0 or args.action_smooth > 0.0 or args.wheel_slip_std > 0.0: herding_cfg = HerdingConfig( domain_random=DomainRandomConfig( fp_rate=args.fp_rate, wheel_slip_std=args.wheel_slip_std, ), robot=RobotConfig(action_smooth=args.action_smooth), ) print(f"[demos] domain-random: fp_rate={args.fp_rate} " f"action_smooth={args.action_smooth} " f"wheel_slip_std={args.wheel_slip_std}") actor_policy = None if args.dagger_policy is not None: # DAgger: failures are the most valuable data (off-policy states # where the student needs teacher correction). Always keep them. args.keep_failures = True from stable_baselines3 import PPO from pathlib import Path as _P run = _P(args.dagger_policy) for name in ("policy.zip", "final.zip"): if (run / name).exists(): zip_path = run / name break else: raise FileNotFoundError( f"No policy found in {run} (tried policy.zip, final.zip)") _model = PPO.load(str(zip_path), device="auto") print(f"[demos] DAgger mode: actor = {zip_path}") def actor_policy(obs): obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1) a, _ = _model.predict(obs_b, deterministic=True) return a[0] teacher_fn = TEACHERS[args.teacher] print(f"[demos] teacher: {args.teacher} world: {FIELD_SHAPE}") n_sheep_list = [int(x) for x in args.n_sheep_list.split(",")] print(f"[demos] grid: n_sheep={n_sheep_list}, seeds={args.seeds_per_n}, " f"max_steps={args.max_steps}, subsample={args.subsample}") all_obs, all_actions, all_meta = [], [], [] t_start = time.time() n_success = 0; n_total = 0 for n in n_sheep_list: for seed in range(args.seeds_per_n): obs, actions, success, total_steps = collect_one( n, seed, args.max_steps, args.subsample, teacher_fn, frame_stack=args.frame_stack, privileged=args.privileged, drive_mode=args.drive_mode, herding_cfg=herding_cfg, actor_policy=actor_policy, ) n_total += 1 if success: n_success += 1 keep = success or args.keep_failures if keep and len(obs) > 0: all_obs.append(obs) all_actions.append(actions) all_meta.append((n, seed, len(obs), int(success), total_steps)) tag = "✓" if success else "✗" print(f" [{tag}] n={n:>2d} seed={seed:>2d} steps={total_steps:>6d} " f"logged={len(obs):>5d}") if not all_obs: raise RuntimeError("No trajectories kept — try --keep-failures.") obs = np.concatenate(all_obs, axis=0) actions = np.concatenate(all_actions, axis=0) meta = np.array(all_meta, dtype=np.int32) Path(args.out).parent.mkdir(parents=True, exist_ok=True) np.savez(args.out, obs=obs, actions=actions, meta=meta) elapsed = time.time() - t_start print(f"\n=== {n_success}/{n_total} trajectories successful ({100*n_success/n_total:.0f}%) ===") print(f"=== {len(obs)} transitions saved to {args.out} ===") print(f"=== obs={obs.shape}, actions={actions.shape}, elapsed={elapsed:.0f}s ===") if __name__ == "__main__": main()