Checkpoint 4
This commit is contained in:
Executable
+166
@@ -0,0 +1,166 @@
|
||||
#!/bin/bash
|
||||
# tools/auto_dagger.sh — automated DAgger collection across many headless
|
||||
# Webots runs.
|
||||
#
|
||||
# For each (flock_size, run_index) combination, generates a world with N
|
||||
# active sheep at randomised positions, launches Webots in fast/headless
|
||||
# mode, lets the controller log (lidar_obs, teacher_action) pairs for up
|
||||
# to RUN_SEC seconds, kills the run, and moves on. The dog controller's
|
||||
# 500-step periodic flush means each run produces a complete .npz even
|
||||
# when killed by timeout.
|
||||
#
|
||||
# Usage:
|
||||
# tools/auto_dagger.sh [RUNS_PER_FLOCK] [SECONDS_PER_RUN]
|
||||
# RUNS_PER_FLOCK : how many randomised runs per flock size (default 3)
|
||||
# SECONDS_PER_RUN: wall-clock cap per Webots run (default 60)
|
||||
#
|
||||
# Env-var overrides:
|
||||
# HERDING_POLICY_DIR : policy the controller loads (only used when
|
||||
# HERDING_DAGGER_DRIVER=student). Default bc_v3.
|
||||
# HERDING_DAGGER_DRIVER : "teacher" (default) or "student".
|
||||
# HEADLESS=1 : force --no-rendering (default on).
|
||||
# FLOCKS="1 3 5 8 10" : space-separated flock sizes to iterate over.
|
||||
#
|
||||
# Output:
|
||||
# training/dagger/dagger_<ts>.npz — one per Webots run.
|
||||
#
|
||||
# After collection, run:
|
||||
# python -m tools.dagger_merge_train --out training/runs/bc_dagger
|
||||
|
||||
set -e
|
||||
|
||||
RUNS_PER_FLOCK=${1:-3}
|
||||
RUN_SEC=${2:-60}
|
||||
FLOCKS=${FLOCKS:-"1 3 5 8 10"}
|
||||
HEADLESS=${HEADLESS:-1}
|
||||
|
||||
ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
|
||||
SRC="$ROOT/worlds/field.wbt"
|
||||
DST="$ROOT/worlds/field_test.wbt"
|
||||
POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_v3}"
|
||||
DRIVER="${HERDING_DAGGER_DRIVER:-teacher}"
|
||||
DONE_FILE="$ROOT/training/dagger/.DONE"
|
||||
WEBOTS_PID=""
|
||||
|
||||
cleanup() {
|
||||
echo "Caught interrupt — killing Webots (pid=$WEBOTS_PID) and exiting."
|
||||
[[ -n "$WEBOTS_PID" ]] && kill "$WEBOTS_PID" 2>/dev/null
|
||||
wait "$WEBOTS_PID" 2>/dev/null || true
|
||||
exit 1
|
||||
}
|
||||
trap cleanup INT TERM
|
||||
|
||||
webots_args=(--mode=fast --batch --minimize)
|
||||
if [[ "$HEADLESS" == "1" ]]; then
|
||||
webots_args+=(--no-rendering)
|
||||
fi
|
||||
|
||||
echo "Auto-dagger collection"
|
||||
echo " flock sizes : $FLOCKS"
|
||||
echo " runs per size : $RUNS_PER_FLOCK"
|
||||
echo " seconds per run : $RUN_SEC"
|
||||
echo " policy dir : $POLICY_DIR (used only when driver=student)"
|
||||
echo " driver : $DRIVER"
|
||||
echo " webots flags : ${webots_args[*]}"
|
||||
echo
|
||||
|
||||
# Runtime config — re-written before each run anyway, but written once
|
||||
# here so a manual webots launch at the same time would also pick it up.
|
||||
cat > "$ROOT/herding_runtime.cfg" <<EOF
|
||||
HERDING_MODE=dagger
|
||||
HERDING_POLICY_DIR=$POLICY_DIR
|
||||
HERDING_DAGGER_DRIVER=$DRIVER
|
||||
EOF
|
||||
|
||||
# Count files before, so we can summarise what was added.
|
||||
mkdir -p "$ROOT/training/dagger"
|
||||
before_count=$(ls -1 "$ROOT/training/dagger"/dagger_*.npz 2>/dev/null | wc -l || echo 0)
|
||||
|
||||
run_idx=0
|
||||
total_runs=0
|
||||
for f in $FLOCKS; do total_runs=$((total_runs + RUNS_PER_FLOCK)); done
|
||||
|
||||
for flock in $FLOCKS; do
|
||||
for run in $(seq 1 "$RUNS_PER_FLOCK"); do
|
||||
run_idx=$((run_idx + 1))
|
||||
seed=$((1000 * flock + run))
|
||||
echo "=== [$run_idx/$total_runs] flock=$flock run=$run seed=$seed ==="
|
||||
|
||||
# Generate randomised world.
|
||||
cp "$SRC" "$DST"
|
||||
for i in $(seq $((flock + 1)) 10); do
|
||||
sed -i "s|^Sheep .* \"sheep${i}\".*|# &|" "$DST"
|
||||
done
|
||||
# Inline Python: jitter sheep1..flock translations.
|
||||
python3 - "$DST" "$flock" "$seed" <<'PYEOF'
|
||||
import re, random, sys
|
||||
path, n_str, seed = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
n = int(n_str); random.seed(int(seed))
|
||||
with open(path) as f:
|
||||
txt = f.read()
|
||||
def rand_pos():
|
||||
while True:
|
||||
x = random.uniform(-12.0, 12.0)
|
||||
y = random.uniform(-10.0, 12.0) # avoid the gate strip
|
||||
if x * x + y * y > 9.0: # at least 3 m from dog spawn
|
||||
return x, y
|
||||
for i in range(1, n + 1):
|
||||
x, y = rand_pos()
|
||||
pat = re.compile(
|
||||
r'Sheep \{ translation\s+\S+\s+\S+\s+(\S+)\s+name "sheep' + str(i) + r'"'
|
||||
)
|
||||
txt = pat.sub(rf'Sheep {{ translation {x:.2f} {y:.2f} \g<1> name "sheep{i}"', txt, count=1)
|
||||
with open(path, "w") as f:
|
||||
f.write(txt)
|
||||
PYEOF
|
||||
|
||||
# Run Webots in the background; poll for the .DONE sentinel or
|
||||
# the wall-clock timeout, whichever comes first.
|
||||
rm -f "$DONE_FILE"
|
||||
webots "${webots_args[@]}" "$DST" \
|
||||
> /tmp/webots_dagger_run.log 2>&1 &
|
||||
WEBOTS_PID=$!
|
||||
|
||||
# Give the controller 10 s to start before polling the sentinel,
|
||||
# otherwise a sheep that spawns already penned triggers an instant
|
||||
# false-positive kill.
|
||||
elapsed=0
|
||||
grace=10
|
||||
while kill -0 "$WEBOTS_PID" 2>/dev/null; do
|
||||
if (( elapsed >= grace )) && [[ -f "$DONE_FILE" ]]; then
|
||||
echo " sentinel .DONE detected — killing Webots early"
|
||||
kill "$WEBOTS_PID" 2>/dev/null
|
||||
wait "$WEBOTS_PID" 2>/dev/null || true
|
||||
break
|
||||
fi
|
||||
if (( elapsed >= RUN_SEC )); then
|
||||
echo " timeout ($RUN_SEC s) — killing Webots"
|
||||
kill "$WEBOTS_PID" 2>/dev/null
|
||||
wait "$WEBOTS_PID" 2>/dev/null || true
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
elapsed=$((elapsed + 2))
|
||||
done
|
||||
WEBOTS_PID=""
|
||||
|
||||
# Quick sanity from the log: did the controller actually run?
|
||||
if grep -q "running in mode=dagger" /tmp/webots_dagger_run.log; then
|
||||
new_pairs=$(tail -50 /tmp/webots_dagger_run.log | grep -oE 'logged=[0-9]+' | tail -1)
|
||||
echo " controller ran ($new_pairs)"
|
||||
else
|
||||
echo " WARNING: controller may not have started (see /tmp/webots_dagger_run.log)"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
after_count=$(ls -1 "$ROOT/training/dagger"/dagger_*.npz 2>/dev/null | wc -l || echo 0)
|
||||
new_files=$((after_count - before_count))
|
||||
|
||||
echo
|
||||
echo "Done."
|
||||
echo " new dagger files : $new_files"
|
||||
echo " total in dir : $after_count"
|
||||
echo
|
||||
echo "Next:"
|
||||
echo " python -m tools.dagger_merge_train --out training/runs/bc_dagger"
|
||||
+37
-9
@@ -26,12 +26,16 @@ if _PROJECT_ROOT not in sys.path:
|
||||
|
||||
import numpy as np
|
||||
|
||||
from herding.active_scan import ActiveScanTeacher
|
||||
from herding.geometry import PEN_ENTRY
|
||||
from herding.sequential import compute_action as sequential_action
|
||||
from herding.strombom import compute_action as strombom_action
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
# Base analytic teachers (no scanning). The default at demo-collection
|
||||
# time wraps these in ActiveScanTeacher, which under LiDAR makes the
|
||||
# teacher operate on the same partial obs as the student.
|
||||
TEACHERS = {
|
||||
"sequential": sequential_action,
|
||||
"strombom": strombom_action,
|
||||
@@ -39,19 +43,34 @@ TEACHERS = {
|
||||
|
||||
|
||||
def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int,
|
||||
teacher_fn):
|
||||
teacher_fn, frame_stack: int = 1, privileged: bool = False):
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||
difficulty=1.0, seed=seed)
|
||||
difficulty=1.0, seed=seed, frame_stack=frame_stack)
|
||||
obs, _ = env.reset(seed=seed)
|
||||
obs_list, action_list = [], []
|
||||
# Active-scan wrapper: scan first, then run the base teacher on the
|
||||
# tracker dict. Reset state per episode so the opening scan kicks in.
|
||||
scan_teacher = ActiveScanTeacher(teacher_fn)
|
||||
for step in range(max_steps):
|
||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||
for i in range(env.n_sheep) if not env.sheep_penned[i]}
|
||||
if not positions:
|
||||
break
|
||||
vx, vy, _mode = teacher_fn(
|
||||
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
|
||||
)
|
||||
if privileged:
|
||||
# Asymmetric "learning by cheating": teacher reads GT, student
|
||||
# gets LiDAR obs. Kept available for ablation; default off.
|
||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||
for i in range(env.n_sheep) if not env.sheep_penned[i]}
|
||||
if not positions:
|
||||
break
|
||||
vx, vy, _mode = teacher_fn(
|
||||
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
|
||||
)
|
||||
else:
|
||||
# Matched-perception teacher: it sees what the student sees
|
||||
# (the tracker dict), with active scanning to fill the
|
||||
# tracker before driving.
|
||||
positions = env.perceived_positions()
|
||||
vx, vy, _mode = scan_teacher(
|
||||
(env.dog_x, env.dog_y), env.dog_heading,
|
||||
positions, PEN_ENTRY,
|
||||
)
|
||||
action = np.array([vx, vy], dtype=np.float32)
|
||||
if step % subsample == 0:
|
||||
obs_list.append(obs.copy())
|
||||
@@ -81,6 +100,14 @@ def main():
|
||||
parser.add_argument("--teacher", default="sequential",
|
||||
choices=list(TEACHERS.keys()),
|
||||
help="Which analytic teacher to demonstrate.")
|
||||
parser.add_argument("--frame-stack", type=int, default=1,
|
||||
help="K — concatenate the last K env obs into a "
|
||||
"single (32·K)-D vector. Lets a memoryless "
|
||||
"MLP recover temporal info under partial "
|
||||
"LiDAR observability.")
|
||||
parser.add_argument("--privileged", action="store_true",
|
||||
help="Teacher reads ground truth (asymmetric BC). "
|
||||
"Default: matched-perception with active scan.")
|
||||
args = parser.parse_args()
|
||||
teacher_fn = TEACHERS[args.teacher]
|
||||
print(f"[demos] teacher: {args.teacher}")
|
||||
@@ -97,6 +124,7 @@ def main():
|
||||
for seed in range(args.seeds_per_n):
|
||||
obs, actions, success, total_steps = collect_one(
|
||||
n, seed, args.max_steps, args.subsample, teacher_fn,
|
||||
frame_stack=args.frame_stack, privileged=args.privileged,
|
||||
)
|
||||
n_total += 1
|
||||
if success:
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Merge Webots DAgger demos with sim demos and retrain the BC policy.
|
||||
|
||||
The dog controller in ``HERDING_MODE=dagger`` writes per-run files to
|
||||
``training/dagger/dagger_<ts>.npz`` containing ``(obs, actions)`` pairs
|
||||
where:
|
||||
|
||||
* ``obs`` is the **stacked LiDAR observation** as built by the live
|
||||
Webots tracker — exactly the input distribution the deployed
|
||||
controller sees.
|
||||
* ``actions`` is the **active-scan-teacher action computed from
|
||||
ground-truth sheep positions** (read off the sheep emitter).
|
||||
|
||||
Combined with the existing sim demos (``training/demos_v3.npz`` by
|
||||
default), this gives the BC student a training set that includes the
|
||||
real Webots false-positive distribution — closing the sim-to-real
|
||||
perception gap that the all-sim pipeline couldn't bridge.
|
||||
|
||||
Usage::
|
||||
|
||||
# Iteration 1 — merge all dagger files with sim demos, retrain
|
||||
python -m tools.dagger_merge_train \\
|
||||
--sim training/demos_v3.npz \\
|
||||
--out training/runs/bc_dagger1
|
||||
|
||||
# Iteration 2 — drop the sim baseline, train only on Webots data
|
||||
python -m tools.dagger_merge_train --no-sim --out training/runs/bc_dagger2
|
||||
|
||||
The new policy is saved as ``<out>/policy.zip`` and is auto-loaded by
|
||||
the controller's resolution priority on the next Webots run.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sim", default="training/demos_v3.npz",
|
||||
help="Sim demo file to mix with the Webots data. "
|
||||
"Pass --no-sim to train only on dagger data.")
|
||||
parser.add_argument("--no-sim", action="store_true",
|
||||
help="Skip the sim demos entirely.")
|
||||
parser.add_argument("--dagger-glob", default="training/dagger/dagger_*.npz",
|
||||
help="Glob for Webots-collected dagger files.")
|
||||
parser.add_argument("--merged-out", default="training/demos_dagger.npz",
|
||||
help="Where to write the merged demo file.")
|
||||
parser.add_argument("--out", default="training/runs/bc_dagger",
|
||||
help="Where to write the BC policy.")
|
||||
parser.add_argument("--epochs", type=int, default=60)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--net-arch", default="512,512")
|
||||
parser.add_argument("--cos-weight", type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Gather Webots files ---
|
||||
dagger_paths = sorted(glob.glob(args.dagger_glob))
|
||||
if not dagger_paths:
|
||||
raise SystemExit(f"No dagger files found at {args.dagger_glob} — "
|
||||
"run Webots in HERDING_MODE=dagger first.")
|
||||
|
||||
chunks_obs: list[np.ndarray] = []
|
||||
chunks_act: list[np.ndarray] = []
|
||||
total_dagger = 0
|
||||
for p in dagger_paths:
|
||||
data = np.load(p)
|
||||
obs = data["obs"].astype(np.float32)
|
||||
act = data["actions"].astype(np.float32)
|
||||
chunks_obs.append(obs)
|
||||
chunks_act.append(act)
|
||||
total_dagger += len(obs)
|
||||
print(f" + {p}: {obs.shape[0]} pairs (obs dim {obs.shape[1]})")
|
||||
print(f"[merge] total dagger pairs: {total_dagger}")
|
||||
|
||||
obs_dim = chunks_obs[0].shape[1]
|
||||
if any(c.shape[1] != obs_dim for c in chunks_obs):
|
||||
raise SystemExit(
|
||||
"Dagger files have inconsistent obs dims — they were collected "
|
||||
"with different frame_stack settings. Either rerun with a "
|
||||
"consistent setting or filter the glob."
|
||||
)
|
||||
|
||||
# --- Optionally include sim demos ---
|
||||
if not args.no_sim:
|
||||
sim = np.load(args.sim)
|
||||
sim_obs = sim["obs"].astype(np.float32)
|
||||
sim_act = sim["actions"].astype(np.float32)
|
||||
if sim_obs.shape[1] != obs_dim:
|
||||
raise SystemExit(
|
||||
f"Sim demos have obs dim {sim_obs.shape[1]} but dagger demos "
|
||||
f"have {obs_dim}. Recollect sim demos at the same frame_stack."
|
||||
)
|
||||
chunks_obs.append(sim_obs)
|
||||
chunks_act.append(sim_act)
|
||||
print(f"[merge] + sim demos: {sim_obs.shape[0]} pairs from {args.sim}")
|
||||
|
||||
obs_all = np.concatenate(chunks_obs, axis=0)
|
||||
act_all = np.concatenate(chunks_act, axis=0)
|
||||
# Empty meta — bc_pretrain doesn't actually use it but the file format
|
||||
# has it.
|
||||
meta = np.zeros((0, 5), dtype=np.int32)
|
||||
|
||||
Path(args.merged_out).parent.mkdir(parents=True, exist_ok=True)
|
||||
np.savez(args.merged_out, obs=obs_all, actions=act_all, meta=meta)
|
||||
print(f"[merge] wrote {len(obs_all)} pairs → {args.merged_out}")
|
||||
print(f"[merge] obs shape {obs_all.shape}, action shape {act_all.shape}")
|
||||
|
||||
# --- Run BC training ---
|
||||
cmd = [
|
||||
sys.executable, "-m", "training.bc_pretrain",
|
||||
"--demos", args.merged_out,
|
||||
"--out", args.out,
|
||||
"--epochs", str(args.epochs),
|
||||
"--batch-size", str(args.batch_size),
|
||||
"--net-arch", args.net_arch,
|
||||
"--cos-weight", str(args.cos_weight),
|
||||
]
|
||||
print(f"\n[merge] launching: {' '.join(cmd)}")
|
||||
subprocess.run(cmd, check=True, cwd=_PROJECT_ROOT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+14
-9
@@ -7,29 +7,33 @@
|
||||
# Usage:
|
||||
# tools/run_webots.sh [N] [MODE]
|
||||
# N : number of active sheep (1..10), default 10
|
||||
# MODE : "rl" | "strombom" | "sequential", default "rl"
|
||||
# MODE : "bc" | "rl" | "strombom" | "sequential" | "dagger", default "bc"
|
||||
#
|
||||
# Examples:
|
||||
# tools/run_webots.sh 10 rl # BC-trained RL policy, 10 sheep
|
||||
# tools/run_webots.sh 10 bc # BC-trained policy, 10 sheep
|
||||
# tools/run_webots.sh 10 rl # KL-PPO fine-tune of bc, 10 sheep
|
||||
# tools/run_webots.sh 5 sequential # the analytic teacher, 5 sheep
|
||||
# tools/run_webots.sh 3 strombom # canonical baseline, 3 sheep
|
||||
#
|
||||
# Notes:
|
||||
# * The RL mode loads training/runs/bc_solo/policy.zip by default.
|
||||
# Override via HERDING_POLICY_DIR=/path/to/run env var.
|
||||
# * The RL mode loads the latest BC policy by default — priority
|
||||
# bc_dagger_v2 → bc_dagger → bc_c2v3 (the controller resolves it).
|
||||
# (LiDAR-perception, frame-stack K=4). Override via
|
||||
# HERDING_POLICY_DIR=/path/to/run env var.
|
||||
# * Conda env "tir" must be active (provides stable-baselines3 + torch).
|
||||
|
||||
set -e
|
||||
N=${1:-10}
|
||||
MODE=${2:-rl}
|
||||
MODE=${2:-bc}
|
||||
|
||||
if (( N < 1 || N > 10 )); then
|
||||
echo "N must be 1..10, got $N" >&2; exit 1
|
||||
fi
|
||||
case "$MODE" in
|
||||
rl|strombom|sequential) ;;
|
||||
*) echo "MODE must be rl|strombom|sequential, got '$MODE'" >&2; exit 1 ;;
|
||||
bc|rl|strombom|sequential|dagger) ;;
|
||||
*) echo "MODE must be bc|rl|strombom|sequential|dagger, got '$MODE'" >&2; exit 1 ;;
|
||||
esac
|
||||
DAGGER_DRIVER=${HERDING_DAGGER_DRIVER:-teacher}
|
||||
|
||||
ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
|
||||
SRC="$ROOT/worlds/field.wbt"
|
||||
@@ -46,15 +50,16 @@ echo "------------------------------------------------------------"
|
||||
echo "World : $DST"
|
||||
echo "Mode : $MODE"
|
||||
echo "Sheep : $active active"
|
||||
echo "Policy dir : ${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_solo}"
|
||||
echo "Policy dir : ${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_v3}"
|
||||
echo "------------------------------------------------------------"
|
||||
|
||||
# Webots strips HERDING_* env vars from controller subprocesses in some
|
||||
# setups, so we also write a runtime config file the controller reads.
|
||||
RESOLVED_POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_solo}"
|
||||
RESOLVED_POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_v3}"
|
||||
cat > "$ROOT/herding_runtime.cfg" <<EOF
|
||||
HERDING_MODE=$MODE
|
||||
HERDING_POLICY_DIR=$RESOLVED_POLICY_DIR
|
||||
HERDING_DAGGER_DRIVER=$DAGGER_DRIVER
|
||||
EOF
|
||||
|
||||
export HERDING_MODE="$MODE"
|
||||
|
||||
Reference in New Issue
Block a user