Checkpoint 7
This commit is contained in:
@@ -1,158 +0,0 @@
|
||||
"""Collect (obs, action) demonstrations from the sequential teacher.
|
||||
|
||||
Runs the sequential algorithm across a grid of (n_sheep, seed) combos
|
||||
at full difficulty, logs the (observation, action) pair every Nth step,
|
||||
and saves successful trajectories to a numpy ``.npz`` for behavior
|
||||
cloning. Failed trajectories are dropped by default — we only want to
|
||||
teach the policy from good examples.
|
||||
|
||||
Usage::
|
||||
|
||||
python -m tools.collect_demos --out training/demos.npz
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from herding.control.active_scan import ActiveScanTeacher
|
||||
from herding.world.geometry import PEN_ENTRY
|
||||
from herding.control.sequential import compute_action as sequential_action
|
||||
from herding.control.strombom import compute_action as strombom_action
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
# Base analytic teachers (no scanning). The default at demo-collection
|
||||
# time wraps these in ActiveScanTeacher, which under LiDAR makes the
|
||||
# teacher operate on the same partial obs as the student.
|
||||
TEACHERS = {
|
||||
"sequential": sequential_action,
|
||||
"strombom": strombom_action,
|
||||
}
|
||||
|
||||
|
||||
def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int,
|
||||
teacher_fn, frame_stack: int = 1, privileged: bool = False):
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||
difficulty=1.0, seed=seed, frame_stack=frame_stack)
|
||||
obs, _ = env.reset(seed=seed)
|
||||
obs_list, action_list = [], []
|
||||
# Active-scan wrapper: scan first, then run the base teacher on the
|
||||
# tracker dict. Reset state per episode so the opening scan kicks in.
|
||||
scan_teacher = ActiveScanTeacher(teacher_fn)
|
||||
for step in range(max_steps):
|
||||
if privileged:
|
||||
# Asymmetric "learning by cheating": teacher reads GT, student
|
||||
# gets LiDAR obs. Kept available for ablation; default off.
|
||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||
for i in range(env.n_sheep) if not env.sheep_penned[i]}
|
||||
if not positions:
|
||||
break
|
||||
vx, vy, _mode = teacher_fn(
|
||||
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
|
||||
)
|
||||
else:
|
||||
# Matched-perception teacher: it sees what the student sees
|
||||
# (the tracker dict), with active scanning to fill the
|
||||
# tracker before driving.
|
||||
positions = env.perceived_positions()
|
||||
vx, vy, _mode = scan_teacher(
|
||||
(env.dog_x, env.dog_y), env.dog_heading,
|
||||
positions, PEN_ENTRY,
|
||||
)
|
||||
action = np.array([vx, vy], dtype=np.float32)
|
||||
if step % subsample == 0:
|
||||
obs_list.append(obs.copy())
|
||||
action_list.append(action.copy())
|
||||
obs, _r, term, trunc, _info = env.step(action)
|
||||
if term or trunc:
|
||||
break
|
||||
success = bool(env.sheep_penned.all())
|
||||
return (
|
||||
np.asarray(obs_list, dtype=np.float32),
|
||||
np.asarray(action_list, dtype=np.float32),
|
||||
success,
|
||||
env.steps,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--out", default="training/demos.npz")
|
||||
parser.add_argument("--n-sheep-list", default="1,2,3,5,8,10")
|
||||
parser.add_argument("--seeds-per-n", type=int, default=15)
|
||||
parser.add_argument("--max-steps", type=int, default=30000)
|
||||
parser.add_argument("--subsample", type=int, default=5,
|
||||
help="Keep every Nth (obs, action) pair.")
|
||||
parser.add_argument("--keep-failures", action="store_true",
|
||||
help="Include partial-success trajectories. Default off.")
|
||||
parser.add_argument("--teacher", default="sequential",
|
||||
choices=list(TEACHERS.keys()),
|
||||
help="Which analytic teacher to demonstrate.")
|
||||
parser.add_argument("--frame-stack", type=int, default=1,
|
||||
help="K — concatenate the last K env obs into a "
|
||||
"single (32·K)-D vector. Lets a memoryless "
|
||||
"MLP recover temporal info under partial "
|
||||
"LiDAR observability.")
|
||||
parser.add_argument("--privileged", action="store_true",
|
||||
help="Teacher reads ground truth (asymmetric BC). "
|
||||
"Default: matched-perception with active scan.")
|
||||
args = parser.parse_args()
|
||||
teacher_fn = TEACHERS[args.teacher]
|
||||
print(f"[demos] teacher: {args.teacher}")
|
||||
|
||||
n_sheep_list = [int(x) for x in args.n_sheep_list.split(",")]
|
||||
print(f"[demos] grid: n_sheep={n_sheep_list}, seeds={args.seeds_per_n}, "
|
||||
f"max_steps={args.max_steps}, subsample={args.subsample}")
|
||||
|
||||
all_obs, all_actions, all_meta = [], [], []
|
||||
t_start = time.time()
|
||||
n_success = 0; n_total = 0
|
||||
|
||||
for n in n_sheep_list:
|
||||
for seed in range(args.seeds_per_n):
|
||||
obs, actions, success, total_steps = collect_one(
|
||||
n, seed, args.max_steps, args.subsample, teacher_fn,
|
||||
frame_stack=args.frame_stack, privileged=args.privileged,
|
||||
)
|
||||
n_total += 1
|
||||
if success:
|
||||
n_success += 1
|
||||
keep = success or args.keep_failures
|
||||
if keep and len(obs) > 0:
|
||||
all_obs.append(obs)
|
||||
all_actions.append(actions)
|
||||
all_meta.append((n, seed, len(obs), int(success), total_steps))
|
||||
tag = "✓" if success else "✗"
|
||||
print(f" [{tag}] n={n:>2d} seed={seed:>2d} steps={total_steps:>6d} "
|
||||
f"logged={len(obs):>5d}")
|
||||
|
||||
if not all_obs:
|
||||
raise RuntimeError("No trajectories kept — try --keep-failures.")
|
||||
|
||||
obs = np.concatenate(all_obs, axis=0)
|
||||
actions = np.concatenate(all_actions, axis=0)
|
||||
meta = np.array(all_meta, dtype=np.int32)
|
||||
|
||||
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
||||
np.savez(args.out, obs=obs, actions=actions, meta=meta)
|
||||
|
||||
elapsed = time.time() - t_start
|
||||
print(f"\n=== {n_success}/{n_total} trajectories successful ({100*n_success/n_total:.0f}%) ===")
|
||||
print(f"=== {len(obs)} transitions saved to {args.out} ===")
|
||||
print(f"=== obs={obs.shape}, actions={actions.shape}, elapsed={elapsed:.0f}s ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user