116 lines
3.6 KiB
Markdown
116 lines
3.6 KiB
Markdown
# Training pipeline
|
||
|
||
Behavior cloning of analytic herding teachers into a neural-network
|
||
policy that runs under LiDAR perception in Webots.
|
||
|
||
```
|
||
sim demos (active-scan teacher on tracker output, K=4 frame stack)
|
||
│
|
||
▼
|
||
bc_pretrain.py ──► runs/bc (BC baseline)
|
||
│
|
||
▼ KL-regularised PPO fine-tune (training/train_ppo.py)
|
||
│
|
||
runs/rl (deployed `rl` mode)
|
||
|
||
# optional branch — kept for reference, not deployed:
|
||
runs/bc_dagger (Webots-grounded DAgger refinement, useful if a
|
||
modified world breaks sim-to-real transfer)
|
||
```
|
||
|
||
## Files
|
||
|
||
```
|
||
herding_env.py — Gymnasium env (LiDAR raycast + tracker by default)
|
||
bc_pretrain.py — MSE + cosine BC of (obs, action) demos into MlpPolicy
|
||
eval.py — analytic teachers + BC policies, full n=1..10 grid
|
||
parity_test.py — shape / determinism / baseline smoke test
|
||
runs/ — checkpoints (most are .gitignored; the deployed
|
||
ones are whitelisted in the top-level .gitignore)
|
||
```
|
||
|
||
## Setup
|
||
|
||
```
|
||
pip install -r requirements.txt
|
||
```
|
||
|
||
CPU is the default and recommended device — SB3 PPO with an MLP policy
|
||
of this size runs faster on CPU than GPU because the bottleneck is
|
||
rollout collection, not gradient compute.
|
||
|
||
## The BC pipeline
|
||
|
||
```
|
||
# 1. Sim demos with the active-scan + Strömbom teacher under LiDAR
|
||
# perception. K=4 frame stack so the MLP has temporal context.
|
||
python -m tools.collect_demos --teacher strombom \
|
||
--out demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
||
|
||
# 2. Behavior-clone.
|
||
python -m training.bc_pretrain --demos demos.npz \
|
||
--out runs/bc --epochs 60 --net-arch 512,512
|
||
|
||
# 3. Evaluate.
|
||
python -m training.eval --policy runs/bc \
|
||
--max-flock 10 --max-steps 8000 --n-seeds 5
|
||
```
|
||
|
||
`bc_pretrain.py` saves the **best-val_cos** snapshot, not the final
|
||
epoch — multi-modal teachers make training noisy and the last epoch is
|
||
often worse than an earlier one.
|
||
|
||
## DAgger from Webots
|
||
|
||
Sim-only BC plateaus because the env's 2D raycast can't reproduce all
|
||
the false-positive clusters Webots generates from real geometry. The
|
||
fix is to collect (obs, teacher_action) pairs from inside Webots:
|
||
|
||
```
|
||
# Headless DAgger collection: 5 flock sizes × 3 runs each.
|
||
tools/auto_dagger.sh 3 60
|
||
|
||
# Merge with the sim baseline + retrain.
|
||
python -m tools.dagger_merge_train --out runs/bc_dagger
|
||
```
|
||
|
||
Iterate by re-running collection with the new student in the driver's
|
||
seat:
|
||
|
||
```
|
||
HERDING_POLICY_DIR=$PWD/training/runs/bc_dagger \
|
||
HERDING_DAGGER_DRIVER=student \
|
||
tools/auto_dagger.sh 3 60
|
||
python -m tools.dagger_merge_train --out runs/bc_dagger
|
||
```
|
||
|
||
## Available analytic teachers
|
||
|
||
| Name | What it does | Notes |
|
||
|---|---|---|
|
||
| `strombom` | Canonical Strömbom — collect when flock is scattered, drive CoM otherwise | Default; works well for n=1–10 under tight cohesion |
|
||
| `sequential` | Pick the sheep closest to the pen and drive only it | Alternative; needs loose-cohesion regime |
|
||
|
||
Both are wrapped at demo-collection time in
|
||
`herding/active_scan.py:ActiveScanTeacher`, which adds an opening
|
||
in-place rotation, walk-to-centre when the LiDAR sees nothing, and
|
||
near-sheep speed modulation (the same modulation `herding/control.py`
|
||
applies to every dog mode at inference).
|
||
|
||
## Evaluating analytic teachers directly
|
||
|
||
```
|
||
python -m training.eval --policy strombom --max-flock 10 --max-steps 8000 --n-seeds 5
|
||
python -m training.eval --policy sequential --max-flock 10 --max-steps 8000 --n-seeds 5
|
||
```
|
||
|
||
## Webots inference
|
||
|
||
```
|
||
tools/run_webots.sh 10 rl
|
||
```
|
||
|
||
The dog controller loads `runs/bc` for `bc` mode and `runs/rl` for
|
||
`rl` mode. Override with `HERDING_POLICY_DIR=…` for a specific
|
||
checkpoint.
|