Files
TIR_PROJ/training/bc/collect.py
T
Johnny Fernandes a01a5c9cef Checkpoint 7
2026-05-11 12:21:51 +01:00

145 lines
5.7 KiB
Python

"""Collect (obs, action) demonstrations from an analytic teacher.
Runs the chosen teacher across a grid of ``(n_sheep, seed)`` combos at
full difficulty, logs every Nth ``(obs, action)`` pair, and saves
successful trajectories to ``.npz`` for behaviour cloning. The teacher
is wrapped in :class:`ActiveScanTeacher` by default so it operates on
the same partial-obs view the student will have at deployment.
Usage::
python -m training.bc.collect --teacher strombom \\
--out training/bc/demos.npz --frame-stack 4
"""
from __future__ import annotations
import argparse
import time
from pathlib import Path
import numpy as np
from herding.control.active_scan import ActiveScanTeacher
from herding.world.geometry import PEN_ENTRY
from herding.control.sequential import compute_action as sequential_action
from herding.control.strombom import compute_action as strombom_action
from training.herding_env import HerdingEnv
TEACHERS = {
"sequential": sequential_action,
"strombom": strombom_action,
}
def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int,
teacher_fn, frame_stack: int = 1, privileged: bool = False):
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
difficulty=1.0, seed=seed, frame_stack=frame_stack)
obs, _ = env.reset(seed=seed)
obs_list, action_list = [], []
# Wrap the base teacher so it opens with a rotation and walks to
# centre when the tracker briefly empties — matches the student.
scan_teacher = ActiveScanTeacher(teacher_fn)
for step in range(max_steps):
if privileged:
# Asymmetric variant: teacher reads ground truth while the
# student keeps the LiDAR obs. Default off.
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
for i in range(env.n_sheep) if not env.sheep_penned[i]}
if not positions:
break
vx, vy, _mode = teacher_fn(
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
)
else:
positions = env.perceived_positions()
vx, vy, _mode = scan_teacher(
(env.dog_x, env.dog_y), env.dog_heading,
positions, PEN_ENTRY,
)
action = np.array([vx, vy], dtype=np.float32)
if step % subsample == 0:
obs_list.append(obs.copy())
action_list.append(action.copy())
obs, _r, term, trunc, _info = env.step(action)
if term or trunc:
break
success = bool(env.sheep_penned.all())
return (
np.asarray(obs_list, dtype=np.float32),
np.asarray(action_list, dtype=np.float32),
success,
env.steps,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--out", default="training/bc/demos.npz")
parser.add_argument("--n-sheep-list", default="1,2,3,5,8,10")
parser.add_argument("--seeds-per-n", type=int, default=15)
parser.add_argument("--max-steps", type=int, default=30000)
parser.add_argument("--subsample", type=int, default=5,
help="Keep every Nth (obs, action) pair.")
parser.add_argument("--keep-failures", action="store_true",
help="Include partial-success trajectories. Default off.")
parser.add_argument("--teacher", default="sequential",
choices=list(TEACHERS.keys()),
help="Which analytic teacher to demonstrate.")
parser.add_argument("--frame-stack", type=int, default=1,
help="Concatenate the last K obs into a "
"(32·K)-D vector for the policy.")
parser.add_argument("--privileged", action="store_true",
help="Teacher reads ground truth instead of "
"tracker output (asymmetric BC).")
args = parser.parse_args()
teacher_fn = TEACHERS[args.teacher]
print(f"[demos] teacher: {args.teacher}")
n_sheep_list = [int(x) for x in args.n_sheep_list.split(",")]
print(f"[demos] grid: n_sheep={n_sheep_list}, seeds={args.seeds_per_n}, "
f"max_steps={args.max_steps}, subsample={args.subsample}")
all_obs, all_actions, all_meta = [], [], []
t_start = time.time()
n_success = 0; n_total = 0
for n in n_sheep_list:
for seed in range(args.seeds_per_n):
obs, actions, success, total_steps = collect_one(
n, seed, args.max_steps, args.subsample, teacher_fn,
frame_stack=args.frame_stack, privileged=args.privileged,
)
n_total += 1
if success:
n_success += 1
keep = success or args.keep_failures
if keep and len(obs) > 0:
all_obs.append(obs)
all_actions.append(actions)
all_meta.append((n, seed, len(obs), int(success), total_steps))
tag = "✓" if success else "✗"
print(f" [{tag}] n={n:>2d} seed={seed:>2d} steps={total_steps:>6d} "
f"logged={len(obs):>5d}")
if not all_obs:
raise RuntimeError("No trajectories kept — try --keep-failures.")
obs = np.concatenate(all_obs, axis=0)
actions = np.concatenate(all_actions, axis=0)
meta = np.array(all_meta, dtype=np.int32)
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
np.savez(args.out, obs=obs, actions=actions, meta=meta)
elapsed = time.time() - t_start
print(f"\n=== {n_success}/{n_total} trajectories successful ({100*n_success/n_total:.0f}%) ===")
print(f"=== {len(obs)} transitions saved to {args.out} ===")
print(f"=== obs={obs.shape}, actions={actions.shape}, elapsed={elapsed:.0f}s ===")
if __name__ == "__main__":
main()