"""Collect (obs, action) demonstrations from the sequential teacher. Runs the sequential algorithm across a grid of (n_sheep, seed) combos at full difficulty, logs the (observation, action) pair every Nth step, and saves successful trajectories to a numpy ``.npz`` for behavior cloning. Failed trajectories are dropped by default — we only want to teach the policy from good examples. Usage:: python -m tools.collect_demos --out training/demos.npz """ from __future__ import annotations import argparse import os import sys import time from pathlib import Path _HERE = os.path.dirname(os.path.abspath(__file__)) _PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..")) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) import numpy as np from herding.geometry import PEN_ENTRY from herding.sequential import compute_action from training.herding_env import HerdingEnv def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int): env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps, difficulty=1.0, seed=seed) obs, _ = env.reset(seed=seed) obs_list, action_list = [], [] for step in range(max_steps): positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i])) for i in range(env.n_sheep) if not env.sheep_penned[i]} if not positions: break vx, vy, _mode = compute_action( (env.dog_x, env.dog_y), positions, PEN_ENTRY, ) action = np.array([vx, vy], dtype=np.float32) if step % subsample == 0: obs_list.append(obs.copy()) action_list.append(action.copy()) obs, _r, term, trunc, _info = env.step(action) if term or trunc: break success = bool(env.sheep_penned.all()) return ( np.asarray(obs_list, dtype=np.float32), np.asarray(action_list, dtype=np.float32), success, env.steps, ) def main(): parser = argparse.ArgumentParser() parser.add_argument("--out", default="training/demos.npz") parser.add_argument("--n-sheep-list", default="1,2,3,5,8,10") parser.add_argument("--seeds-per-n", type=int, default=15) parser.add_argument("--max-steps", type=int, default=30000) parser.add_argument("--subsample", type=int, default=5, help="Keep every Nth (obs, action) pair.") parser.add_argument("--keep-failures", action="store_true", help="Include partial-success trajectories. Default off.") args = parser.parse_args() n_sheep_list = [int(x) for x in args.n_sheep_list.split(",")] print(f"[demos] grid: n_sheep={n_sheep_list}, seeds={args.seeds_per_n}, " f"max_steps={args.max_steps}, subsample={args.subsample}") all_obs, all_actions, all_meta = [], [], [] t_start = time.time() n_success = 0; n_total = 0 for n in n_sheep_list: for seed in range(args.seeds_per_n): obs, actions, success, total_steps = collect_one( n, seed, args.max_steps, args.subsample, ) n_total += 1 if success: n_success += 1 keep = success or args.keep_failures if keep and len(obs) > 0: all_obs.append(obs) all_actions.append(actions) all_meta.append((n, seed, len(obs), int(success), total_steps)) tag = "✓" if success else "✗" print(f" [{tag}] n={n:>2d} seed={seed:>2d} steps={total_steps:>6d} " f"logged={len(obs):>5d}") if not all_obs: raise RuntimeError("No trajectories kept — try --keep-failures.") obs = np.concatenate(all_obs, axis=0) actions = np.concatenate(all_actions, axis=0) meta = np.array(all_meta, dtype=np.int32) Path(args.out).parent.mkdir(parents=True, exist_ok=True) np.savez(args.out, obs=obs, actions=actions, meta=meta) elapsed = time.time() - t_start print(f"\n=== {n_success}/{n_total} trajectories successful ({100*n_success/n_total:.0f}%) ===") print(f"=== {len(obs)} transitions saved to {args.out} ===") print(f"=== obs={obs.shape}, actions={actions.shape}, elapsed={elapsed:.0f}s ===") if __name__ == "__main__": main()