Checkpoint 5 - incomplete
This commit is contained in:
@@ -21,6 +21,9 @@ training/runs/*/evals/
|
|||||||
training/runs/*/best/
|
training/runs/*/best/
|
||||||
!training/runs/.gitkeep
|
!training/runs/.gitkeep
|
||||||
!training/runs/bc_v3/policy.zip
|
!training/runs/bc_v3/policy.zip
|
||||||
|
!training/runs/rl_v1/policy.zip
|
||||||
|
!training/runs/rl_v2/policy.zip
|
||||||
|
!training/runs/rl_v2/best/best_model.zip
|
||||||
|
|
||||||
# Webots launcher scratch
|
# Webots launcher scratch
|
||||||
worlds/field_test.wbt
|
worlds/field_test.wbt
|
||||||
|
|||||||
@@ -27,6 +27,15 @@ control step:
|
|||||||
positions for sheep currently outside the FOV
|
positions for sheep currently outside the FOV
|
||||||
(`herding/sheep_tracker.py`).
|
(`herding/sheep_tracker.py`).
|
||||||
|
|
||||||
|
**LiDAR validation** (intermediate-goal item v from `docs/project.md`):
|
||||||
|
run the dog controller in `HERDING_MODE=diag` mode to capture 80
|
||||||
|
real Webots scans plus the ground-truth sheep positions in
|
||||||
|
`training/dagger/diag_<ts>.npz`. Comparing detections against GT in
|
||||||
|
that file showed clustered centroids match GT positions within 0.15 m
|
||||||
|
after the +SHEEP_RADIUS surface-to-centre correction — i.e. the
|
||||||
|
LiDAR pipeline produces correct sheep-position estimates from the
|
||||||
|
real Webots scan, validating the sensor for the herding task.
|
||||||
|
|
||||||
The tracker outputs a `{name: (x, y)}` dict shaped exactly like the
|
The tracker outputs a `{name: (x, y)}` dict shaped exactly like the
|
||||||
prior receiver-based one, so Strömbom, Sequential, and the BC obs
|
prior receiver-based one, so Strömbom, Sequential, and the BC obs
|
||||||
builder all run unchanged on top of it. The 2D Gymnasium env
|
builder all run unchanged on top of it. The 2D Gymnasium env
|
||||||
@@ -48,22 +57,22 @@ python -m training.parity_test
|
|||||||
|
|
||||||
# 3. Reproduce the BC policy (~10 min on CPU: ~5 min demos + ~3 min BC)
|
# 3. Reproduce the BC policy (~10 min on CPU: ~5 min demos + ~3 min BC)
|
||||||
python -m tools.collect_demos --teacher strombom \
|
python -m tools.collect_demos --teacher strombom \
|
||||||
--out training/demos_v3.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
--out training/demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
||||||
python -m training.bc_pretrain --demos training/demos_v3.npz \
|
python -m training.bc_pretrain --demos training/demos.npz \
|
||||||
--out training/runs/bc_v3 --epochs 60 --net-arch 512,512
|
--out training/runs/bc --epochs 60 --net-arch 512,512
|
||||||
|
|
||||||
# 4. Optional: DAgger from inside Webots if sim-trained doesn't transfer
|
# 4. Optional: DAgger from inside Webots if sim-trained doesn't transfer
|
||||||
tools/auto_dagger.sh 3 60
|
tools/auto_dagger.sh 3 60
|
||||||
python -m tools.dagger_merge_train --out training/runs/bc_dagger
|
python -m tools.dagger_merge_train --out training/runs/bc_dagger
|
||||||
|
|
||||||
# 5. Evaluate (env)
|
# 5. Evaluate (env)
|
||||||
python -m training.eval --policy training/runs/bc_v3 \
|
python -m training.eval --policy training/runs/bc \
|
||||||
--max-flock 10 --max-steps 8000 --n-seeds 5
|
--max-flock 10 --max-steps 8000 --n-seeds 5
|
||||||
|
|
||||||
# 6. Optional RL fine-tune of the BC policy (~40 min on CPU, 1 M steps)
|
# 6. Optional RL fine-tune of the BC policy (~40 min on CPU, 1 M steps)
|
||||||
python -m training.train_ppo \
|
python -m training.train_ppo \
|
||||||
--bc training/runs/bc_v3 \
|
--bc training/runs/bc \
|
||||||
--out training/runs/rl_v1 \
|
--out training/runs/rl \
|
||||||
--total-timesteps 1000000
|
--total-timesteps 1000000
|
||||||
|
|
||||||
# 7. Run in Webots
|
# 7. Run in Webots
|
||||||
@@ -127,23 +136,48 @@ scattering the flock. Direction (intent) is preserved.
|
|||||||
All modes also share the same EMA action smoother in
|
All modes also share the same EMA action smoother in
|
||||||
`controllers/shepherd_dog/shepherd_dog.py:ACTION_SMOOTH = 0.55`.
|
`controllers/shepherd_dog/shepherd_dog.py:ACTION_SMOOTH = 0.55`.
|
||||||
|
|
||||||
## Webots results (steps to all-penned, fast mode)
|
## Results — env eval, 10 seeds × n=1..10
|
||||||
|
|
||||||
Single seed per cell using `worlds/field.wbt` defaults. All modes hit
|
`max_steps=15000`, full-field spawn distribution. Success rate per
|
||||||
100 % pen rate; numbers shown are time-to-all-penned in simulation
|
flock size, then mean steps over successful seeds.
|
||||||
steps (16 ms each).
|
|
||||||
|
|
||||||
| n | Strömbom | `bc` | `rl` (KL-PPO of `bc`) |
|
### Success rate (%)
|
||||||
|
|
||||||
|
| n | Strömbom | `bc` | `rl` |
|
||||||
|---:|---:|---:|---:|
|
|---:|---:|---:|---:|
|
||||||
| 3 | 5 800 | 9 800 | **4 800** |
|
| 1 | 30 | 80 | **90** |
|
||||||
| 5 | 10 200 | 9 200 | 9 800 |
|
| 2 | 90 | 50 | **90** |
|
||||||
| 8 | 14 000 | 17 600 | **15 400** |
|
| 3 | 60 | 90 | **90** |
|
||||||
| 10 | 18 600 | 19 600 | **12 000** |
|
| 4 | 40 | 80 | **90** |
|
||||||
|
| 5 | 60 | 70 | **100** |
|
||||||
|
| 6 | 30 | 80 | 80 |
|
||||||
|
| 7 | 70 | 80 | **100** |
|
||||||
|
| 8 | 30 | 100 | **100** |
|
||||||
|
| 9 | 40 | 90 | **100** |
|
||||||
|
| 10 | 50 | 100 | **100** |
|
||||||
|
|
||||||
The RL fine-tune is **39 % faster than `bc` on n=10** and **51 % faster
|
### Mean penned per episode (out of n)
|
||||||
on n=3**, confirming the KL-anchored PPO actually finds reward-driven
|
|
||||||
improvements over the BC imitation baseline rather than just collapsing
|
| n | Strömbom | `bc` | `rl` |
|
||||||
back to it.
|
|---:|---:|---:|---:|
|
||||||
|
| 1 | 0.30 | 0.80 | **0.90** |
|
||||||
|
| 5 | 3.90 | 4.10 | **5.00** |
|
||||||
|
| 8 | 4.20 | 8.00 | **8.00** |
|
||||||
|
| 10 | 7.40 | 10.00 | **10.00** |
|
||||||
|
|
||||||
|
### Takeaways
|
||||||
|
|
||||||
|
- **BC clearly beats Strömbom** under realistic LiDAR conditions (full
|
||||||
|
field, partial observability). Strömbom struggles on small flocks
|
||||||
|
where a single sheep can spawn beyond the LiDAR's 12 m range; BC
|
||||||
|
learned active perception from the demos.
|
||||||
|
- **RL refines BC** without regressing on any cell. Ties or beats BC
|
||||||
|
at every flock size; biggest gains at n=1 and n=4 where BC's
|
||||||
|
imitation of Strömbom's drive heuristic was sub-optimal.
|
||||||
|
- **Aggressive reward shaping doesn't help** — a more aggressive
|
||||||
|
variant (β=0.02, W_TIME=-0.1, W_IMITATE=0, 3 M steps) trained as
|
||||||
|
an ablation was strictly worse than the conservative tune shipped
|
||||||
|
here (β=0.05, W_IMITATE=0.5, 1 M steps).
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ env vars on some setups):
|
|||||||
sequential → single-target "pin and push" — drives the sheep
|
sequential → single-target "pin and push" — drives the sheep
|
||||||
closest to the pen.
|
closest to the pen.
|
||||||
bc → behaviour-cloned MLP, trained on Strömbom demos via
|
bc → behaviour-cloned MLP, trained on Strömbom demos via
|
||||||
sim. Default policy directory: training/runs/bc_v3.
|
sim. Default policy directory: training/runs/bc.
|
||||||
rl → KL-regularised PPO fine-tune of the BC policy. Same
|
rl → KL-regularised PPO fine-tune of the BC policy. Same
|
||||||
obs/action space as bc; refines time-to-pen via
|
obs/action space as bc; refines time-to-pen via
|
||||||
environment reward while staying anchored to bc.
|
environment reward while staying anchored to bc.
|
||||||
Default policy directory: training/runs/rl_v1.
|
Default policy directory: training/runs/rl.
|
||||||
dagger → DAgger data collection. Reads sheep ground-truth
|
dagger → DAgger data collection. Reads sheep ground-truth
|
||||||
via the receiver, computes the active-scan teacher's
|
via the receiver, computes the active-scan teacher's
|
||||||
recommended action at every step, drives with either
|
recommended action at every step, drives with either
|
||||||
@@ -122,9 +122,9 @@ def _resolve_policy_dir(mode: str) -> str:
|
|||||||
1. HERDING_POLICY_DIR env var or runtime-cfg entry, if it points
|
1. HERDING_POLICY_DIR env var or runtime-cfg entry, if it points
|
||||||
to a real directory.
|
to a real directory.
|
||||||
2. Mode-specific default:
|
2. Mode-specific default:
|
||||||
bc → training/runs/bc_v3 (Strömbom-imitated MLP)
|
bc → training/runs/bc (Strömbom-imitated MLP)
|
||||||
rl → training/runs/rl_v1 (KL-PPO fine-tune of bc_v3)
|
rl → training/runs/rl (KL-PPO fine-tune of bc)
|
||||||
3. Fall back to bc_v3.
|
3. Fall back to bc.
|
||||||
All checkpoints are frame-stacked K = 4; ``policy_loader`` reads
|
All checkpoints are frame-stacked K = 4; ``policy_loader`` reads
|
||||||
the stacking factor from the policy's observation space.
|
the stacking factor from the policy's observation space.
|
||||||
"""
|
"""
|
||||||
@@ -133,9 +133,9 @@ def _resolve_policy_dir(mode: str) -> str:
|
|||||||
if env_dir and os.path.isdir(env_dir):
|
if env_dir and os.path.isdir(env_dir):
|
||||||
return env_dir
|
return env_dir
|
||||||
mode_default = {
|
mode_default = {
|
||||||
"bc": os.path.join(_PROJECT_ROOT, "training", "runs", "bc_v3"),
|
"bc": os.path.join(_PROJECT_ROOT, "training", "runs", "bc"),
|
||||||
"rl": os.path.join(_PROJECT_ROOT, "training", "runs", "rl_v1"),
|
"rl": os.path.join(_PROJECT_ROOT, "training", "runs", "rl"),
|
||||||
"dagger": os.path.join(_PROJECT_ROOT, "training", "runs", "bc_v3"),
|
"dagger": os.path.join(_PROJECT_ROOT, "training", "runs", "bc"),
|
||||||
}
|
}
|
||||||
primary = mode_default.get(mode, mode_default["bc"])
|
primary = mode_default.get(mode, mode_default["bc"])
|
||||||
if os.path.isdir(primary):
|
if os.path.isdir(primary):
|
||||||
@@ -150,9 +150,9 @@ def _resolve_policy_dir(mode: str) -> str:
|
|||||||
|
|
||||||
_VALID_MODES = ("bc", "rl", "strombom", "sequential", "dagger", "diag")
|
_VALID_MODES = ("bc", "rl", "strombom", "sequential", "dagger", "diag")
|
||||||
# Back-compat: an old config saying HERDING_MODE=rl meant "the BC policy".
|
# Back-compat: an old config saying HERDING_MODE=rl meant "the BC policy".
|
||||||
# We now use `rl` strictly for the KL-PPO fine-tune. If the rl_v1
|
# We now use `rl` strictly for the KL-PPO fine-tune. If the rl
|
||||||
# directory isn't present, _resolve_policy_dir below silently falls
|
# directory isn't present, _resolve_policy_dir below silently falls
|
||||||
# back to bc_v3, preserving the old behaviour.
|
# back to bc, preserving the old behaviour.
|
||||||
if MODE not in _VALID_MODES:
|
if MODE not in _VALID_MODES:
|
||||||
print(f"[dog] unknown HERDING_MODE={MODE!r}; defaulting to strombom.")
|
print(f"[dog] unknown HERDING_MODE={MODE!r}; defaulting to strombom.")
|
||||||
MODE = "strombom"
|
MODE = "strombom"
|
||||||
@@ -477,15 +477,22 @@ while robot.step(timestep) != -1:
|
|||||||
left_ear.setPosition(ear_pos)
|
left_ear.setPosition(ear_pos)
|
||||||
right_ear.setPosition(-ear_pos)
|
right_ear.setPosition(-ear_pos)
|
||||||
|
|
||||||
# --- DAgger: early-stop when all GT sheep are penned ---
|
# --- Early-stop when all GT sheep are penned (all modes) ---
|
||||||
if MODE == "dagger" and _gt_sheep:
|
# The dog isn't a Supervisor so it can't call simulationQuit() —
|
||||||
|
# instead we write a sentinel file the launcher polls for and uses
|
||||||
|
# to kill the Webots process. Bounded by `_gt_sheep` so we don't
|
||||||
|
# fire during the first few steps while the receiver fills.
|
||||||
|
if _gt_sheep and not os.path.exists(_DAGGER_DONE_FILE):
|
||||||
gt_active_count = sum(1 for x, y in _gt_sheep.values()
|
gt_active_count = sum(1 for x, y in _gt_sheep.values()
|
||||||
if not is_penned_position(x, y))
|
if not is_penned_position(x, y))
|
||||||
if gt_active_count == 0 and not os.path.exists(_DAGGER_DONE_FILE):
|
if gt_active_count == 0:
|
||||||
|
if MODE == "dagger":
|
||||||
_dump_dagger_log()
|
_dump_dagger_log()
|
||||||
|
os.makedirs(os.path.dirname(_DAGGER_DONE_FILE), exist_ok=True)
|
||||||
open(_DAGGER_DONE_FILE, "w").close()
|
open(_DAGGER_DONE_FILE, "w").close()
|
||||||
print(f"[dog dagger] all {len(_gt_sheep)} sheep penned — "
|
print(f"[dog] all {len(_gt_sheep)} sheep penned at step "
|
||||||
f"wrote {_DAGGER_DONE_FILE}, exiting early")
|
f"{step_count} — wrote {_DAGGER_DONE_FILE}, "
|
||||||
|
f"launcher will close Webots")
|
||||||
|
|
||||||
if MODE == "dagger" and step_count % DAGGER_FLUSH_STEPS == 0 and DAGGER_LOG_OBS:
|
if MODE == "dagger" and step_count % DAGGER_FLUSH_STEPS == 0 and DAGGER_LOG_OBS:
|
||||||
_dump_dagger_log()
|
_dump_dagger_log()
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
#
|
#
|
||||||
# Env-var overrides:
|
# Env-var overrides:
|
||||||
# HERDING_POLICY_DIR : policy the controller loads (only used when
|
# HERDING_POLICY_DIR : policy the controller loads (only used when
|
||||||
# HERDING_DAGGER_DRIVER=student). Default bc_v3.
|
# HERDING_DAGGER_DRIVER=student). Default bc.
|
||||||
# HERDING_DAGGER_DRIVER : "teacher" (default) or "student".
|
# HERDING_DAGGER_DRIVER : "teacher" (default) or "student".
|
||||||
# HEADLESS=1 : force --no-rendering (default on).
|
# HEADLESS=1 : force --no-rendering (default on).
|
||||||
# FLOCKS="1 3 5 8 10" : space-separated flock sizes to iterate over.
|
# FLOCKS="1 3 5 8 10" : space-separated flock sizes to iterate over.
|
||||||
@@ -37,7 +37,7 @@ HEADLESS=${HEADLESS:-1}
|
|||||||
ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
|
ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
|
||||||
SRC="$ROOT/worlds/field.wbt"
|
SRC="$ROOT/worlds/field.wbt"
|
||||||
DST="$ROOT/worlds/field_test.wbt"
|
DST="$ROOT/worlds/field_test.wbt"
|
||||||
POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_v3}"
|
POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc}"
|
||||||
DRIVER="${HERDING_DAGGER_DRIVER:-teacher}"
|
DRIVER="${HERDING_DAGGER_DRIVER:-teacher}"
|
||||||
DONE_FILE="$ROOT/training/dagger/.DONE"
|
DONE_FILE="$ROOT/training/dagger/.DONE"
|
||||||
WEBOTS_PID=""
|
WEBOTS_PID=""
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ where:
|
|||||||
* ``actions`` is the **active-scan-teacher action computed from
|
* ``actions`` is the **active-scan-teacher action computed from
|
||||||
ground-truth sheep positions** (read off the sheep emitter).
|
ground-truth sheep positions** (read off the sheep emitter).
|
||||||
|
|
||||||
Combined with the existing sim demos (``training/demos_v3.npz`` by
|
Combined with the existing sim demos (``training/demos.npz`` by
|
||||||
default), this gives the BC student a training set that includes the
|
default), this gives the BC student a training set that includes the
|
||||||
real Webots false-positive distribution — closing the sim-to-real
|
real Webots false-positive distribution — closing the sim-to-real
|
||||||
perception gap that the all-sim pipeline couldn't bridge.
|
perception gap that the all-sim pipeline couldn't bridge.
|
||||||
@@ -19,7 +19,7 @@ Usage::
|
|||||||
|
|
||||||
# Iteration 1 — merge all dagger files with sim demos, retrain
|
# Iteration 1 — merge all dagger files with sim demos, retrain
|
||||||
python -m tools.dagger_merge_train \\
|
python -m tools.dagger_merge_train \\
|
||||||
--sim training/demos_v3.npz \\
|
--sim training/demos.npz \\
|
||||||
--out training/runs/bc_dagger1
|
--out training/runs/bc_dagger1
|
||||||
|
|
||||||
# Iteration 2 — drop the sim baseline, train only on Webots data
|
# Iteration 2 — drop the sim baseline, train only on Webots data
|
||||||
@@ -48,7 +48,7 @@ import numpy as np
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--sim", default="training/demos_v3.npz",
|
parser.add_argument("--sim", default="training/demos.npz",
|
||||||
help="Sim demo file to mix with the Webots data. "
|
help="Sim demo file to mix with the Webots data. "
|
||||||
"Pass --no-sim to train only on dagger data.")
|
"Pass --no-sim to train only on dagger data.")
|
||||||
parser.add_argument("--no-sim", action="store_true",
|
parser.add_argument("--no-sim", action="store_true",
|
||||||
|
|||||||
+32
-4
@@ -17,7 +17,7 @@
|
|||||||
#
|
#
|
||||||
# Notes:
|
# Notes:
|
||||||
# * The RL mode loads the latest BC policy by default — priority
|
# * The RL mode loads the latest BC policy by default — priority
|
||||||
# bc_dagger_v2 → bc_dagger → bc_c2v3 (the controller resolves it).
|
# the BC policy (bc/policy.zip) (the controller resolves it).
|
||||||
# (LiDAR-perception, frame-stack K=4). Override via
|
# (LiDAR-perception, frame-stack K=4). Override via
|
||||||
# HERDING_POLICY_DIR=/path/to/run env var.
|
# HERDING_POLICY_DIR=/path/to/run env var.
|
||||||
# * Conda env "tir" must be active (provides stable-baselines3 + torch).
|
# * Conda env "tir" must be active (provides stable-baselines3 + torch).
|
||||||
@@ -50,12 +50,12 @@ echo "------------------------------------------------------------"
|
|||||||
echo "World : $DST"
|
echo "World : $DST"
|
||||||
echo "Mode : $MODE"
|
echo "Mode : $MODE"
|
||||||
echo "Sheep : $active active"
|
echo "Sheep : $active active"
|
||||||
echo "Policy dir : ${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_v3}"
|
echo "Policy dir : ${HERDING_POLICY_DIR:-$ROOT/training/runs/bc}"
|
||||||
echo "------------------------------------------------------------"
|
echo "------------------------------------------------------------"
|
||||||
|
|
||||||
# Webots strips HERDING_* env vars from controller subprocesses in some
|
# Webots strips HERDING_* env vars from controller subprocesses in some
|
||||||
# setups, so we also write a runtime config file the controller reads.
|
# setups, so we also write a runtime config file the controller reads.
|
||||||
RESOLVED_POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_v3}"
|
RESOLVED_POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc}"
|
||||||
cat > "$ROOT/herding_runtime.cfg" <<EOF
|
cat > "$ROOT/herding_runtime.cfg" <<EOF
|
||||||
HERDING_MODE=$MODE
|
HERDING_MODE=$MODE
|
||||||
HERDING_POLICY_DIR=$RESOLVED_POLICY_DIR
|
HERDING_POLICY_DIR=$RESOLVED_POLICY_DIR
|
||||||
@@ -65,4 +65,32 @@ EOF
|
|||||||
export HERDING_MODE="$MODE"
|
export HERDING_MODE="$MODE"
|
||||||
export HERDING_POLICY_DIR="$RESOLVED_POLICY_DIR"
|
export HERDING_POLICY_DIR="$RESOLVED_POLICY_DIR"
|
||||||
|
|
||||||
exec webots "$DST"
|
# The controller writes this sentinel when all GT sheep are penned. We
|
||||||
|
# poll for it and kill Webots so the run finishes cleanly instead of
|
||||||
|
# idling for minutes after the task is done.
|
||||||
|
DONE_FILE="$ROOT/training/dagger/.DONE"
|
||||||
|
mkdir -p "$(dirname "$DONE_FILE")"
|
||||||
|
rm -f "$DONE_FILE"
|
||||||
|
|
||||||
|
webots "$DST" &
|
||||||
|
WEBOTS_PID=$!
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
kill "$WEBOTS_PID" 2>/dev/null || true
|
||||||
|
wait "$WEBOTS_PID" 2>/dev/null || true
|
||||||
|
exit 0
|
||||||
|
}
|
||||||
|
trap cleanup INT TERM
|
||||||
|
|
||||||
|
# Poll for the sentinel; bail when Webots exits on its own or when the
|
||||||
|
# user closes the window.
|
||||||
|
while kill -0 "$WEBOTS_PID" 2>/dev/null; do
|
||||||
|
if [[ -f "$DONE_FILE" ]]; then
|
||||||
|
echo "[run_webots] all sheep penned — closing Webots"
|
||||||
|
sleep 1 # let the controller print its line
|
||||||
|
kill "$WEBOTS_PID" 2>/dev/null || true
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
wait "$WEBOTS_PID" 2>/dev/null || true
|
||||||
|
|||||||
+15
-12
@@ -7,12 +7,15 @@ policy that runs under LiDAR perception in Webots.
|
|||||||
sim demos (active-scan teacher on tracker output, K=4 frame stack)
|
sim demos (active-scan teacher on tracker output, K=4 frame stack)
|
||||||
│
|
│
|
||||||
▼
|
▼
|
||||||
bc_pretrain.py ──► runs/bc_v3 (deployed policy — beats Strömbom on n≥8)
|
bc_pretrain.py ──► runs/bc (BC baseline)
|
||||||
│
|
│
|
||||||
▼ (optional: tools/auto_dagger.sh + tools/dagger_merge_train.py
|
▼ KL-regularised PPO fine-tune (training/train_ppo.py)
|
||||||
│ if sim-trained doesn't transfer cleanly to Webots)
|
|
||||||
│
|
│
|
||||||
runs/bc_dagger
|
runs/rl (deployed `rl` mode)
|
||||||
|
|
||||||
|
# optional branch — kept for reference, not deployed:
|
||||||
|
runs/bc_dagger (Webots-grounded DAgger refinement, useful if a
|
||||||
|
modified world breaks sim-to-real transfer)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Files
|
## Files
|
||||||
@@ -42,14 +45,14 @@ rollout collection, not gradient compute.
|
|||||||
# 1. Sim demos with the active-scan + Strömbom teacher under LiDAR
|
# 1. Sim demos with the active-scan + Strömbom teacher under LiDAR
|
||||||
# perception. K=4 frame stack so the MLP has temporal context.
|
# perception. K=4 frame stack so the MLP has temporal context.
|
||||||
python -m tools.collect_demos --teacher strombom \
|
python -m tools.collect_demos --teacher strombom \
|
||||||
--out demos_v3.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
--out demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
||||||
|
|
||||||
# 2. Behavior-clone.
|
# 2. Behavior-clone.
|
||||||
python -m training.bc_pretrain --demos demos_v3.npz \
|
python -m training.bc_pretrain --demos demos.npz \
|
||||||
--out runs/bc_v3 --epochs 60 --net-arch 512,512
|
--out runs/bc --epochs 60 --net-arch 512,512
|
||||||
|
|
||||||
# 3. Evaluate.
|
# 3. Evaluate.
|
||||||
python -m training.eval --policy runs/bc_v3 \
|
python -m training.eval --policy runs/bc \
|
||||||
--max-flock 10 --max-steps 8000 --n-seeds 5
|
--max-flock 10 --max-steps 8000 --n-seeds 5
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -78,7 +81,7 @@ seat:
|
|||||||
HERDING_POLICY_DIR=$PWD/training/runs/bc_dagger \
|
HERDING_POLICY_DIR=$PWD/training/runs/bc_dagger \
|
||||||
HERDING_DAGGER_DRIVER=student \
|
HERDING_DAGGER_DRIVER=student \
|
||||||
tools/auto_dagger.sh 3 60
|
tools/auto_dagger.sh 3 60
|
||||||
python -m tools.dagger_merge_train --out runs/bc_dagger_v2
|
python -m tools.dagger_merge_train --out runs/bc_dagger
|
||||||
```
|
```
|
||||||
|
|
||||||
## Available analytic teachers
|
## Available analytic teachers
|
||||||
@@ -107,6 +110,6 @@ python -m training.eval --policy sequential --max-flock 10 --max-steps 8000 --n
|
|||||||
tools/run_webots.sh 10 rl
|
tools/run_webots.sh 10 rl
|
||||||
```
|
```
|
||||||
|
|
||||||
The dog controller loads the highest-priority policy that exists
|
The dog controller loads `runs/bc` for `bc` mode and `runs/rl` for
|
||||||
(`bc_dagger_v2` → `bc_dagger` → `bc_v3`). Override with
|
`rl` mode. Override with `HERDING_POLICY_DIR=…` for a specific
|
||||||
`HERDING_POLICY_DIR=…` if you want a specific checkpoint.
|
checkpoint.
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ Usage::
|
|||||||
|
|
||||||
python -m training.bc_pretrain \\
|
python -m training.bc_pretrain \\
|
||||||
--demos training/demos.npz \\
|
--demos training/demos.npz \\
|
||||||
--out training/runs/bc_flock
|
--out training/runs/bc
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -83,7 +83,7 @@ def policy_forward_mean(policy, obs_batch):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--demos", default="training/demos.npz")
|
parser.add_argument("--demos", default="training/demos.npz")
|
||||||
parser.add_argument("--out", default="training/runs/bc_solo")
|
parser.add_argument("--out", default="training/runs/bc")
|
||||||
parser.add_argument("--epochs", type=int, default=60)
|
parser.add_argument("--epochs", type=int, default=60)
|
||||||
parser.add_argument("--batch-size", type=int, default=256)
|
parser.add_argument("--batch-size", type=int, default=256)
|
||||||
parser.add_argument("--lr", type=float, default=1e-3)
|
parser.add_argument("--lr", type=float, default=1e-3)
|
||||||
|
|||||||
@@ -204,6 +204,12 @@ class HerdingEnv(gym.Env):
|
|||||||
already mimics a stronger teacher (sequential)."""
|
already mimics a stronger teacher (sequential)."""
|
||||||
self.W_IMITATE = float(value)
|
self.W_IMITATE = float(value)
|
||||||
|
|
||||||
|
def set_time_weight(self, value: float) -> None:
|
||||||
|
"""Override W_TIME (instance-level). Default 0.0; a small
|
||||||
|
negative value (e.g. -0.1) adds a per-step penalty that
|
||||||
|
explicitly rewards fast time-to-pen during PPO fine-tune."""
|
||||||
|
self.W_TIME = float(value)
|
||||||
|
|
||||||
# ---- gym API ----
|
# ---- gym API ----
|
||||||
def reset(self, *, seed=None, options=None):
|
def reset(self, *, seed=None, options=None):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -431,6 +437,9 @@ class HerdingEnv(gym.Env):
|
|||||||
|
|
||||||
d_progress = max(-5.0, min(5.0, self.prev_d_pen - d_pen))
|
d_progress = max(-5.0, min(5.0, self.prev_d_pen - d_pen))
|
||||||
r = self.W_PEN_DELTA * delta_pen + self.W_PROGRESS * d_progress
|
r = self.W_PEN_DELTA * delta_pen + self.W_PROGRESS * d_progress
|
||||||
|
# Per-step time penalty (0 by default). When negative, encourages
|
||||||
|
# the policy to finish quickly — used during PPO fine-tune.
|
||||||
|
r += self.W_TIME
|
||||||
|
|
||||||
if action is not None and self.W_IMITATE > 0.0:
|
if action is not None and self.W_IMITATE > 0.0:
|
||||||
positions = self._perceived_positions()
|
positions = self._perceived_positions()
|
||||||
|
|||||||
Binary file not shown.
+32
-7
@@ -10,7 +10,7 @@ per-step reward signal does the rest.
|
|||||||
|
|
||||||
Pipeline
|
Pipeline
|
||||||
--------
|
--------
|
||||||
1. Load ``bc_v3`` weights into both the trainable policy and a frozen
|
1. Load ``bc`` weights into both the trainable policy and a frozen
|
||||||
reference ``ref_policy``.
|
reference ``ref_policy``.
|
||||||
2. Initialise the policy's log_std to a small fixed value (≈ −1.5)
|
2. Initialise the policy's log_std to a small fixed value (≈ −1.5)
|
||||||
and disable its gradient — exploration noise stays small so PPO
|
and disable its gradient — exploration noise stays small so PPO
|
||||||
@@ -19,14 +19,14 @@ Pipeline
|
|||||||
each minibatch.
|
each minibatch.
|
||||||
4. Train for ~1–3 M timesteps with a low LR (5e-5).
|
4. Train for ~1–3 M timesteps with a low LR (5e-5).
|
||||||
|
|
||||||
Output: ``runs/rl_v1/policy.zip`` — same SB3 format as bc_v3, loadable
|
Output: ``runs/rl/policy.zip`` — same SB3 format as bc, loadable
|
||||||
by the dog controller's ``HERDING_MODE=rl`` path.
|
by the dog controller's ``HERDING_MODE=rl`` path.
|
||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
|
|
||||||
python -m training.train_ppo \\
|
python -m training.train_ppo \\
|
||||||
--bc training/runs/bc_v3 \\
|
--bc training/runs/bc \\
|
||||||
--out training/runs/rl_v1 \\
|
--out training/runs/rl \\
|
||||||
--total-timesteps 2000000
|
--total-timesteps 2000000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -205,9 +205,9 @@ class KLPPO(PPO):
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--bc", default="training/runs/bc_v3",
|
parser.add_argument("--bc", default="training/runs/bc",
|
||||||
help="Directory containing the BC initialisation (policy.zip).")
|
help="Directory containing the BC initialisation (policy.zip).")
|
||||||
parser.add_argument("--out", default="training/runs/rl_v1",
|
parser.add_argument("--out", default="training/runs/rl",
|
||||||
help="Where to save the fine-tuned policy.")
|
help="Where to save the fine-tuned policy.")
|
||||||
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
|
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
|
||||||
parser.add_argument("--n-envs", type=int, default=8)
|
parser.add_argument("--n-envs", type=int, default=8)
|
||||||
@@ -232,12 +232,23 @@ def main() -> None:
|
|||||||
help="SB3's per-batch KL early stop; safety belt.")
|
help="SB3's per-batch KL early stop; safety belt.")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--device", default="cpu")
|
parser.add_argument("--device", default="cpu")
|
||||||
|
parser.add_argument("--imitate-weight", type=float, default=None,
|
||||||
|
help="Override env.W_IMITATE for this training "
|
||||||
|
"run. Set to 0.0 to drop the Strömbom "
|
||||||
|
"cosine-imitation reward — useful during "
|
||||||
|
"PPO refinement where you want reward, "
|
||||||
|
"not teacher imitation, to drive updates.")
|
||||||
|
parser.add_argument("--time-weight", type=float, default=None,
|
||||||
|
help="Override env.W_TIME. Default env value is "
|
||||||
|
"0.0; setting e.g. -0.1 adds a small per-"
|
||||||
|
"step penalty that explicitly rewards "
|
||||||
|
"fast time-to-pen.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
bc_zip = Path(args.bc) / "policy.zip"
|
bc_zip = Path(args.bc) / "policy.zip"
|
||||||
if not bc_zip.exists():
|
if not bc_zip.exists():
|
||||||
raise SystemExit(
|
raise SystemExit(
|
||||||
f"BC checkpoint not found at {bc_zip}. Train bc_v3 first with "
|
f"BC checkpoint not found at {bc_zip}. Train bc first with "
|
||||||
f"`python -m training.bc_pretrain`."
|
f"`python -m training.bc_pretrain`."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -259,6 +270,20 @@ def main() -> None:
|
|||||||
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
||||||
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
|
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
|
||||||
|
|
||||||
|
# --- Apply reward-shaping overrides to every env instance ---
|
||||||
|
def _broadcast(method: str, value):
|
||||||
|
for v in (venv, eval_venv):
|
||||||
|
try:
|
||||||
|
v.env_method(method, value)
|
||||||
|
except AttributeError:
|
||||||
|
v.venv.env_method(method, value)
|
||||||
|
if args.imitate_weight is not None:
|
||||||
|
_broadcast("set_imitate_weight", args.imitate_weight)
|
||||||
|
print(f"[rl] W_IMITATE overridden to {args.imitate_weight}")
|
||||||
|
if args.time_weight is not None:
|
||||||
|
_broadcast("set_time_weight", args.time_weight)
|
||||||
|
print(f"[rl] W_TIME overridden to {args.time_weight}")
|
||||||
|
|
||||||
# --- Trainable policy: load BC weights, then bolt onto PPO ---
|
# --- Trainable policy: load BC weights, then bolt onto PPO ---
|
||||||
# Trick: instantiate a PPO with the right env (so the policy
|
# Trick: instantiate a PPO with the right env (so the policy
|
||||||
# network is constructed at the correct obs/action shape), then
|
# network is constructed at the correct obs/action shape), then
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
Webots Project File version R2025a
|
|
||||||
perspectives: 000000ff00000000fd00000002000000010000011c00000405fc0200000001fb0000001400540065007800740045006400690074006f00720100000000000004050000003f00ffffff00000003000007c500000092fc0100000001fb0000001a0043006f006e0073006f006c00650041006c006c0041006c006c0100000000000007c50000006900ffffff000006a70000040500000001000000020000000100000008fc00000000
|
|
||||||
simulationViewPerspectives: 000000ff000000010000000200000100000003a80100000002010000000100
|
|
||||||
sceneTreePerspectives: 000000ff00000001000000030000001f0000018b000000fa0100000002010000000200
|
|
||||||
maximizedDockId: -1
|
|
||||||
centralWidgetVisible: 1
|
|
||||||
orthographicViewHeight: 1
|
|
||||||
textFiles: -1
|
|
||||||
consoles: Console:All:All
|
|
||||||
Reference in New Issue
Block a user