Checkpoint 2

This commit is contained in:
Johnny Fernandes
2026-05-07 22:00:10 +01:00
parent 90aa3bbcb4
commit 1bb9415414
37 changed files with 3068 additions and 2912 deletions
+117
View File
@@ -0,0 +1,117 @@
"""Collect (obs, action) demonstrations from the sequential teacher.
Runs the sequential algorithm across a grid of (n_sheep, seed) combos
at full difficulty, logs the (observation, action) pair every Nth step,
and saves successful trajectories to a numpy ``.npz`` for behavior
cloning. Failed trajectories are dropped by default — we only want to
teach the policy from good examples.
Usage::
python -m tools.collect_demos --out training/demos.npz
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
from herding.geometry import PEN_ENTRY
from herding.sequential import compute_action
from training.herding_env import HerdingEnv
def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int):
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
difficulty=1.0, seed=seed)
obs, _ = env.reset(seed=seed)
obs_list, action_list = [], []
for step in range(max_steps):
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
for i in range(env.n_sheep) if not env.sheep_penned[i]}
if not positions:
break
vx, vy, _mode = compute_action(
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
)
action = np.array([vx, vy], dtype=np.float32)
if step % subsample == 0:
obs_list.append(obs.copy())
action_list.append(action.copy())
obs, _r, term, trunc, _info = env.step(action)
if term or trunc:
break
success = bool(env.sheep_penned.all())
return (
np.asarray(obs_list, dtype=np.float32),
np.asarray(action_list, dtype=np.float32),
success,
env.steps,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--out", default="training/demos.npz")
parser.add_argument("--n-sheep-list", default="1,2,3,5,8,10")
parser.add_argument("--seeds-per-n", type=int, default=15)
parser.add_argument("--max-steps", type=int, default=30000)
parser.add_argument("--subsample", type=int, default=5,
help="Keep every Nth (obs, action) pair.")
parser.add_argument("--keep-failures", action="store_true",
help="Include partial-success trajectories. Default off.")
args = parser.parse_args()
n_sheep_list = [int(x) for x in args.n_sheep_list.split(",")]
print(f"[demos] grid: n_sheep={n_sheep_list}, seeds={args.seeds_per_n}, "
f"max_steps={args.max_steps}, subsample={args.subsample}")
all_obs, all_actions, all_meta = [], [], []
t_start = time.time()
n_success = 0; n_total = 0
for n in n_sheep_list:
for seed in range(args.seeds_per_n):
obs, actions, success, total_steps = collect_one(
n, seed, args.max_steps, args.subsample,
)
n_total += 1
if success:
n_success += 1
keep = success or args.keep_failures
if keep and len(obs) > 0:
all_obs.append(obs)
all_actions.append(actions)
all_meta.append((n, seed, len(obs), int(success), total_steps))
tag = "" if success else ""
print(f" [{tag}] n={n:>2d} seed={seed:>2d} steps={total_steps:>6d} "
f"logged={len(obs):>5d}")
if not all_obs:
raise RuntimeError("No trajectories kept — try --keep-failures.")
obs = np.concatenate(all_obs, axis=0)
actions = np.concatenate(all_actions, axis=0)
meta = np.array(all_meta, dtype=np.int32)
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
np.savez(args.out, obs=obs, actions=actions, meta=meta)
elapsed = time.time() - t_start
print(f"\n=== {n_success}/{n_total} trajectories successful ({100*n_success/n_total:.0f}%) ===")
print(f"=== {len(obs)} transitions saved to {args.out} ===")
print(f"=== obs={obs.shape}, actions={actions.shape}, elapsed={elapsed:.0f}s ===")
if __name__ == "__main__":
main()
+63
View File
@@ -0,0 +1,63 @@
#!/bin/bash
# Launch Webots with N sheep enabled and the chosen controller mode.
# Generates a temporary world file in worlds/field_test.wbt with sheep
# beyond N commented out, sets the env vars the dog controller reads,
# then execs Webots on it.
#
# Usage:
# tools/run_webots.sh [N] [MODE]
# N : number of active sheep (1..10), default 10
# MODE : "rl" | "strombom" | "sequential", default "rl"
#
# Examples:
# tools/run_webots.sh 10 rl # BC-trained RL policy, 10 sheep
# tools/run_webots.sh 5 sequential # the analytic teacher, 5 sheep
# tools/run_webots.sh 3 strombom # canonical baseline, 3 sheep
#
# Notes:
# * The RL mode loads training/runs/bc_pretrained/policy.zip by default.
# Override via HERDING_POLICY_DIR=/path/to/run env var.
# * Conda env "tir" must be active (provides stable-baselines3 + torch).
set -e
N=${1:-10}
MODE=${2:-rl}
if (( N < 1 || N > 10 )); then
echo "N must be 1..10, got $N" >&2; exit 1
fi
case "$MODE" in
rl|strombom|sequential) ;;
*) echo "MODE must be rl|strombom|sequential, got '$MODE'" >&2; exit 1 ;;
esac
ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
SRC="$ROOT/worlds/field.wbt"
DST="$ROOT/worlds/field_test.wbt"
cp "$SRC" "$DST"
# Comment out sheep N+1..10 by prefixing the matching Sheep { ... } line.
for i in $(seq $((N+1)) 10); do
sed -i "s|^Sheep .* \"sheep${i}\".*|# &|" "$DST"
done
active=$(grep -c '^Sheep' "$DST")
echo "------------------------------------------------------------"
echo "World : $DST"
echo "Mode : $MODE"
echo "Sheep : $active active"
echo "Policy dir : ${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_pretrained}"
echo "------------------------------------------------------------"
# Webots strips HERDING_* env vars from controller subprocesses in some
# setups, so we also write a runtime config file the controller reads.
RESOLVED_POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_pretrained}"
cat > "$ROOT/herding_runtime.cfg" <<EOF
HERDING_MODE=$MODE
HERDING_POLICY_DIR=$RESOLVED_POLICY_DIR
EOF
export HERDING_MODE="$MODE"
export HERDING_POLICY_DIR="$RESOLVED_POLICY_DIR"
exec webots "$DST"