Checkpoint 5 - incomplete

This commit is contained in:
Johnny Fernandes
2026-05-11 10:35:39 +01:00
parent 6688325d89
commit b457155538
13 changed files with 174 additions and 74 deletions
+3
View File
@@ -21,6 +21,9 @@ training/runs/*/evals/
training/runs/*/best/ training/runs/*/best/
!training/runs/.gitkeep !training/runs/.gitkeep
!training/runs/bc_v3/policy.zip !training/runs/bc_v3/policy.zip
!training/runs/rl_v1/policy.zip
!training/runs/rl_v2/policy.zip
!training/runs/rl_v2/best/best_model.zip
# Webots launcher scratch # Webots launcher scratch
worlds/field_test.wbt worlds/field_test.wbt
+53 -19
View File
@@ -27,6 +27,15 @@ control step:
positions for sheep currently outside the FOV positions for sheep currently outside the FOV
(`herding/sheep_tracker.py`). (`herding/sheep_tracker.py`).
**LiDAR validation** (intermediate-goal item v from `docs/project.md`):
run the dog controller in `HERDING_MODE=diag` mode to capture 80
real Webots scans plus the ground-truth sheep positions in
`training/dagger/diag_<ts>.npz`. Comparing detections against GT in
that file showed clustered centroids match GT positions within 0.15 m
after the +SHEEP_RADIUS surface-to-centre correction — i.e. the
LiDAR pipeline produces correct sheep-position estimates from the
real Webots scan, validating the sensor for the herding task.
The tracker outputs a `{name: (x, y)}` dict shaped exactly like the The tracker outputs a `{name: (x, y)}` dict shaped exactly like the
prior receiver-based one, so Strömbom, Sequential, and the BC obs prior receiver-based one, so Strömbom, Sequential, and the BC obs
builder all run unchanged on top of it. The 2D Gymnasium env builder all run unchanged on top of it. The 2D Gymnasium env
@@ -48,22 +57,22 @@ python -m training.parity_test
# 3. Reproduce the BC policy (~10 min on CPU: ~5 min demos + ~3 min BC) # 3. Reproduce the BC policy (~10 min on CPU: ~5 min demos + ~3 min BC)
python -m tools.collect_demos --teacher strombom \ python -m tools.collect_demos --teacher strombom \
--out training/demos_v3.npz --seeds-per-n 15 --subsample 3 --frame-stack 4 --out training/demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
python -m training.bc_pretrain --demos training/demos_v3.npz \ python -m training.bc_pretrain --demos training/demos.npz \
--out training/runs/bc_v3 --epochs 60 --net-arch 512,512 --out training/runs/bc --epochs 60 --net-arch 512,512
# 4. Optional: DAgger from inside Webots if sim-trained doesn't transfer # 4. Optional: DAgger from inside Webots if sim-trained doesn't transfer
tools/auto_dagger.sh 3 60 tools/auto_dagger.sh 3 60
python -m tools.dagger_merge_train --out training/runs/bc_dagger python -m tools.dagger_merge_train --out training/runs/bc_dagger
# 5. Evaluate (env) # 5. Evaluate (env)
python -m training.eval --policy training/runs/bc_v3 \ python -m training.eval --policy training/runs/bc \
--max-flock 10 --max-steps 8000 --n-seeds 5 --max-flock 10 --max-steps 8000 --n-seeds 5
# 6. Optional RL fine-tune of the BC policy (~40 min on CPU, 1 M steps) # 6. Optional RL fine-tune of the BC policy (~40 min on CPU, 1 M steps)
python -m training.train_ppo \ python -m training.train_ppo \
--bc training/runs/bc_v3 \ --bc training/runs/bc \
--out training/runs/rl_v1 \ --out training/runs/rl \
--total-timesteps 1000000 --total-timesteps 1000000
# 7. Run in Webots # 7. Run in Webots
@@ -127,23 +136,48 @@ scattering the flock. Direction (intent) is preserved.
All modes also share the same EMA action smoother in All modes also share the same EMA action smoother in
`controllers/shepherd_dog/shepherd_dog.py:ACTION_SMOOTH = 0.55`. `controllers/shepherd_dog/shepherd_dog.py:ACTION_SMOOTH = 0.55`.
## Webots results (steps to all-penned, fast mode) ## Results — env eval, 10 seeds × n=1..10
Single seed per cell using `worlds/field.wbt` defaults. All modes hit `max_steps=15000`, full-field spawn distribution. Success rate per
100 % pen rate; numbers shown are time-to-all-penned in simulation flock size, then mean steps over successful seeds.
steps (16 ms each).
| n | Strömbom | `bc` | `rl` (KL-PPO of `bc`) | ### Success rate (%)
| n | Strömbom | `bc` | `rl` |
|---:|---:|---:|---:| |---:|---:|---:|---:|
| 3 | 5 800 | 9 800 | **4 800** | | 1 | 30 | 80 | **90** |
| 5 | 10 200 | 9 200 | 9 800 | | 2 | 90 | 50 | **90** |
| 8 | 14 000 | 17 600 | **15 400** | | 3 | 60 | 90 | **90** |
| 10 | 18 600 | 19 600 | **12 000** | | 4 | 40 | 80 | **90** |
| 5 | 60 | 70 | **100** |
| 6 | 30 | 80 | 80 |
| 7 | 70 | 80 | **100** |
| 8 | 30 | 100 | **100** |
| 9 | 40 | 90 | **100** |
| 10 | 50 | 100 | **100** |
The RL fine-tune is **39 % faster than `bc` on n=10** and **51 % faster ### Mean penned per episode (out of n)
on n=3**, confirming the KL-anchored PPO actually finds reward-driven
improvements over the BC imitation baseline rather than just collapsing | n | Strömbom | `bc` | `rl` |
back to it. |---:|---:|---:|---:|
| 1 | 0.30 | 0.80 | **0.90** |
| 5 | 3.90 | 4.10 | **5.00** |
| 8 | 4.20 | 8.00 | **8.00** |
| 10 | 7.40 | 10.00 | **10.00** |
### Takeaways
- **BC clearly beats Strömbom** under realistic LiDAR conditions (full
field, partial observability). Strömbom struggles on small flocks
where a single sheep can spawn beyond the LiDAR's 12 m range; BC
learned active perception from the demos.
- **RL refines BC** without regressing on any cell. Ties or beats BC
at every flock size; biggest gains at n=1 and n=4 where BC's
imitation of Strömbom's drive heuristic was sub-optimal.
- **Aggressive reward shaping doesn't help** — a more aggressive
variant (β=0.02, W_TIME=-0.1, W_IMITATE=0, 3 M steps) trained as
an ablation was strictly worse than the conservative tune shipped
here (β=0.05, W_IMITATE=0.5, 1 M steps).
## License ## License
+22 -15
View File
@@ -8,11 +8,11 @@ env vars on some setups):
sequential → single-target "pin and push" — drives the sheep sequential → single-target "pin and push" — drives the sheep
closest to the pen. closest to the pen.
bc → behaviour-cloned MLP, trained on Strömbom demos via bc → behaviour-cloned MLP, trained on Strömbom demos via
sim. Default policy directory: training/runs/bc_v3. sim. Default policy directory: training/runs/bc.
rl → KL-regularised PPO fine-tune of the BC policy. Same rl → KL-regularised PPO fine-tune of the BC policy. Same
obs/action space as bc; refines time-to-pen via obs/action space as bc; refines time-to-pen via
environment reward while staying anchored to bc. environment reward while staying anchored to bc.
Default policy directory: training/runs/rl_v1. Default policy directory: training/runs/rl.
dagger → DAgger data collection. Reads sheep ground-truth dagger → DAgger data collection. Reads sheep ground-truth
via the receiver, computes the active-scan teacher's via the receiver, computes the active-scan teacher's
recommended action at every step, drives with either recommended action at every step, drives with either
@@ -122,9 +122,9 @@ def _resolve_policy_dir(mode: str) -> str:
1. HERDING_POLICY_DIR env var or runtime-cfg entry, if it points 1. HERDING_POLICY_DIR env var or runtime-cfg entry, if it points
to a real directory. to a real directory.
2. Mode-specific default: 2. Mode-specific default:
bc → training/runs/bc_v3 (Strömbom-imitated MLP) bc → training/runs/bc (Strömbom-imitated MLP)
rl → training/runs/rl_v1 (KL-PPO fine-tune of bc_v3) rl → training/runs/rl (KL-PPO fine-tune of bc)
3. Fall back to bc_v3. 3. Fall back to bc.
All checkpoints are frame-stacked K = 4; ``policy_loader`` reads All checkpoints are frame-stacked K = 4; ``policy_loader`` reads
the stacking factor from the policy's observation space. the stacking factor from the policy's observation space.
""" """
@@ -133,9 +133,9 @@ def _resolve_policy_dir(mode: str) -> str:
if env_dir and os.path.isdir(env_dir): if env_dir and os.path.isdir(env_dir):
return env_dir return env_dir
mode_default = { mode_default = {
"bc": os.path.join(_PROJECT_ROOT, "training", "runs", "bc_v3"), "bc": os.path.join(_PROJECT_ROOT, "training", "runs", "bc"),
"rl": os.path.join(_PROJECT_ROOT, "training", "runs", "rl_v1"), "rl": os.path.join(_PROJECT_ROOT, "training", "runs", "rl"),
"dagger": os.path.join(_PROJECT_ROOT, "training", "runs", "bc_v3"), "dagger": os.path.join(_PROJECT_ROOT, "training", "runs", "bc"),
} }
primary = mode_default.get(mode, mode_default["bc"]) primary = mode_default.get(mode, mode_default["bc"])
if os.path.isdir(primary): if os.path.isdir(primary):
@@ -150,9 +150,9 @@ def _resolve_policy_dir(mode: str) -> str:
_VALID_MODES = ("bc", "rl", "strombom", "sequential", "dagger", "diag") _VALID_MODES = ("bc", "rl", "strombom", "sequential", "dagger", "diag")
# Back-compat: an old config saying HERDING_MODE=rl meant "the BC policy". # Back-compat: an old config saying HERDING_MODE=rl meant "the BC policy".
# We now use `rl` strictly for the KL-PPO fine-tune. If the rl_v1 # We now use `rl` strictly for the KL-PPO fine-tune. If the rl
# directory isn't present, _resolve_policy_dir below silently falls # directory isn't present, _resolve_policy_dir below silently falls
# back to bc_v3, preserving the old behaviour. # back to bc, preserving the old behaviour.
if MODE not in _VALID_MODES: if MODE not in _VALID_MODES:
print(f"[dog] unknown HERDING_MODE={MODE!r}; defaulting to strombom.") print(f"[dog] unknown HERDING_MODE={MODE!r}; defaulting to strombom.")
MODE = "strombom" MODE = "strombom"
@@ -477,15 +477,22 @@ while robot.step(timestep) != -1:
left_ear.setPosition(ear_pos) left_ear.setPosition(ear_pos)
right_ear.setPosition(-ear_pos) right_ear.setPosition(-ear_pos)
# --- DAgger: early-stop when all GT sheep are penned --- # --- Early-stop when all GT sheep are penned (all modes) ---
if MODE == "dagger" and _gt_sheep: # The dog isn't a Supervisor so it can't call simulationQuit() —
# instead we write a sentinel file the launcher polls for and uses
# to kill the Webots process. Bounded by `_gt_sheep` so we don't
# fire during the first few steps while the receiver fills.
if _gt_sheep and not os.path.exists(_DAGGER_DONE_FILE):
gt_active_count = sum(1 for x, y in _gt_sheep.values() gt_active_count = sum(1 for x, y in _gt_sheep.values()
if not is_penned_position(x, y)) if not is_penned_position(x, y))
if gt_active_count == 0 and not os.path.exists(_DAGGER_DONE_FILE): if gt_active_count == 0:
if MODE == "dagger":
_dump_dagger_log() _dump_dagger_log()
os.makedirs(os.path.dirname(_DAGGER_DONE_FILE), exist_ok=True)
open(_DAGGER_DONE_FILE, "w").close() open(_DAGGER_DONE_FILE, "w").close()
print(f"[dog dagger] all {len(_gt_sheep)} sheep penned " print(f"[dog] all {len(_gt_sheep)} sheep penned at step "
f"wrote {_DAGGER_DONE_FILE}, exiting early") f"{step_count}wrote {_DAGGER_DONE_FILE}, "
f"launcher will close Webots")
if MODE == "dagger" and step_count % DAGGER_FLUSH_STEPS == 0 and DAGGER_LOG_OBS: if MODE == "dagger" and step_count % DAGGER_FLUSH_STEPS == 0 and DAGGER_LOG_OBS:
_dump_dagger_log() _dump_dagger_log()
+2 -2
View File
@@ -16,7 +16,7 @@
# #
# Env-var overrides: # Env-var overrides:
# HERDING_POLICY_DIR : policy the controller loads (only used when # HERDING_POLICY_DIR : policy the controller loads (only used when
# HERDING_DAGGER_DRIVER=student). Default bc_v3. # HERDING_DAGGER_DRIVER=student). Default bc.
# HERDING_DAGGER_DRIVER : "teacher" (default) or "student". # HERDING_DAGGER_DRIVER : "teacher" (default) or "student".
# HEADLESS=1 : force --no-rendering (default on). # HEADLESS=1 : force --no-rendering (default on).
# FLOCKS="1 3 5 8 10" : space-separated flock sizes to iterate over. # FLOCKS="1 3 5 8 10" : space-separated flock sizes to iterate over.
@@ -37,7 +37,7 @@ HEADLESS=${HEADLESS:-1}
ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )" ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
SRC="$ROOT/worlds/field.wbt" SRC="$ROOT/worlds/field.wbt"
DST="$ROOT/worlds/field_test.wbt" DST="$ROOT/worlds/field_test.wbt"
POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_v3}" POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc}"
DRIVER="${HERDING_DAGGER_DRIVER:-teacher}" DRIVER="${HERDING_DAGGER_DRIVER:-teacher}"
DONE_FILE="$ROOT/training/dagger/.DONE" DONE_FILE="$ROOT/training/dagger/.DONE"
WEBOTS_PID="" WEBOTS_PID=""
+3 -3
View File
@@ -10,7 +10,7 @@ where:
* ``actions`` is the **active-scan-teacher action computed from * ``actions`` is the **active-scan-teacher action computed from
ground-truth sheep positions** (read off the sheep emitter). ground-truth sheep positions** (read off the sheep emitter).
Combined with the existing sim demos (``training/demos_v3.npz`` by Combined with the existing sim demos (``training/demos.npz`` by
default), this gives the BC student a training set that includes the default), this gives the BC student a training set that includes the
real Webots false-positive distribution — closing the sim-to-real real Webots false-positive distribution — closing the sim-to-real
perception gap that the all-sim pipeline couldn't bridge. perception gap that the all-sim pipeline couldn't bridge.
@@ -19,7 +19,7 @@ Usage::
# Iteration 1 — merge all dagger files with sim demos, retrain # Iteration 1 — merge all dagger files with sim demos, retrain
python -m tools.dagger_merge_train \\ python -m tools.dagger_merge_train \\
--sim training/demos_v3.npz \\ --sim training/demos.npz \\
--out training/runs/bc_dagger1 --out training/runs/bc_dagger1
# Iteration 2 — drop the sim baseline, train only on Webots data # Iteration 2 — drop the sim baseline, train only on Webots data
@@ -48,7 +48,7 @@ import numpy as np
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--sim", default="training/demos_v3.npz", parser.add_argument("--sim", default="training/demos.npz",
help="Sim demo file to mix with the Webots data. " help="Sim demo file to mix with the Webots data. "
"Pass --no-sim to train only on dagger data.") "Pass --no-sim to train only on dagger data.")
parser.add_argument("--no-sim", action="store_true", parser.add_argument("--no-sim", action="store_true",
+32 -4
View File
@@ -17,7 +17,7 @@
# #
# Notes: # Notes:
# * The RL mode loads the latest BC policy by default — priority # * The RL mode loads the latest BC policy by default — priority
# bc_dagger_v2 → bc_dagger → bc_c2v3 (the controller resolves it). # the BC policy (bc/policy.zip) (the controller resolves it).
# (LiDAR-perception, frame-stack K=4). Override via # (LiDAR-perception, frame-stack K=4). Override via
# HERDING_POLICY_DIR=/path/to/run env var. # HERDING_POLICY_DIR=/path/to/run env var.
# * Conda env "tir" must be active (provides stable-baselines3 + torch). # * Conda env "tir" must be active (provides stable-baselines3 + torch).
@@ -50,12 +50,12 @@ echo "------------------------------------------------------------"
echo "World : $DST" echo "World : $DST"
echo "Mode : $MODE" echo "Mode : $MODE"
echo "Sheep : $active active" echo "Sheep : $active active"
echo "Policy dir : ${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_v3}" echo "Policy dir : ${HERDING_POLICY_DIR:-$ROOT/training/runs/bc}"
echo "------------------------------------------------------------" echo "------------------------------------------------------------"
# Webots strips HERDING_* env vars from controller subprocesses in some # Webots strips HERDING_* env vars from controller subprocesses in some
# setups, so we also write a runtime config file the controller reads. # setups, so we also write a runtime config file the controller reads.
RESOLVED_POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc_v3}" RESOLVED_POLICY_DIR="${HERDING_POLICY_DIR:-$ROOT/training/runs/bc}"
cat > "$ROOT/herding_runtime.cfg" <<EOF cat > "$ROOT/herding_runtime.cfg" <<EOF
HERDING_MODE=$MODE HERDING_MODE=$MODE
HERDING_POLICY_DIR=$RESOLVED_POLICY_DIR HERDING_POLICY_DIR=$RESOLVED_POLICY_DIR
@@ -65,4 +65,32 @@ EOF
export HERDING_MODE="$MODE" export HERDING_MODE="$MODE"
export HERDING_POLICY_DIR="$RESOLVED_POLICY_DIR" export HERDING_POLICY_DIR="$RESOLVED_POLICY_DIR"
exec webots "$DST" # The controller writes this sentinel when all GT sheep are penned. We
# poll for it and kill Webots so the run finishes cleanly instead of
# idling for minutes after the task is done.
DONE_FILE="$ROOT/training/dagger/.DONE"
mkdir -p "$(dirname "$DONE_FILE")"
rm -f "$DONE_FILE"
webots "$DST" &
WEBOTS_PID=$!
cleanup() {
kill "$WEBOTS_PID" 2>/dev/null || true
wait "$WEBOTS_PID" 2>/dev/null || true
exit 0
}
trap cleanup INT TERM
# Poll for the sentinel; bail when Webots exits on its own or when the
# user closes the window.
while kill -0 "$WEBOTS_PID" 2>/dev/null; do
if [[ -f "$DONE_FILE" ]]; then
echo "[run_webots] all sheep penned — closing Webots"
sleep 1 # let the controller print its line
kill "$WEBOTS_PID" 2>/dev/null || true
break
fi
sleep 1
done
wait "$WEBOTS_PID" 2>/dev/null || true
+15 -12
View File
@@ -7,12 +7,15 @@ policy that runs under LiDAR perception in Webots.
sim demos (active-scan teacher on tracker output, K=4 frame stack) sim demos (active-scan teacher on tracker output, K=4 frame stack)
bc_pretrain.py ──► runs/bc_v3 (deployed policy — beats Strömbom on n≥8) bc_pretrain.py ──► runs/bc (BC baseline)
(optional: tools/auto_dagger.sh + tools/dagger_merge_train.py KL-regularised PPO fine-tune (training/train_ppo.py)
│ if sim-trained doesn't transfer cleanly to Webots)
runs/bc_dagger runs/rl (deployed `rl` mode)
# optional branch — kept for reference, not deployed:
runs/bc_dagger (Webots-grounded DAgger refinement, useful if a
modified world breaks sim-to-real transfer)
``` ```
## Files ## Files
@@ -42,14 +45,14 @@ rollout collection, not gradient compute.
# 1. Sim demos with the active-scan + Strömbom teacher under LiDAR # 1. Sim demos with the active-scan + Strömbom teacher under LiDAR
# perception. K=4 frame stack so the MLP has temporal context. # perception. K=4 frame stack so the MLP has temporal context.
python -m tools.collect_demos --teacher strombom \ python -m tools.collect_demos --teacher strombom \
--out demos_v3.npz --seeds-per-n 15 --subsample 3 --frame-stack 4 --out demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
# 2. Behavior-clone. # 2. Behavior-clone.
python -m training.bc_pretrain --demos demos_v3.npz \ python -m training.bc_pretrain --demos demos.npz \
--out runs/bc_v3 --epochs 60 --net-arch 512,512 --out runs/bc --epochs 60 --net-arch 512,512
# 3. Evaluate. # 3. Evaluate.
python -m training.eval --policy runs/bc_v3 \ python -m training.eval --policy runs/bc \
--max-flock 10 --max-steps 8000 --n-seeds 5 --max-flock 10 --max-steps 8000 --n-seeds 5
``` ```
@@ -78,7 +81,7 @@ seat:
HERDING_POLICY_DIR=$PWD/training/runs/bc_dagger \ HERDING_POLICY_DIR=$PWD/training/runs/bc_dagger \
HERDING_DAGGER_DRIVER=student \ HERDING_DAGGER_DRIVER=student \
tools/auto_dagger.sh 3 60 tools/auto_dagger.sh 3 60
python -m tools.dagger_merge_train --out runs/bc_dagger_v2 python -m tools.dagger_merge_train --out runs/bc_dagger
``` ```
## Available analytic teachers ## Available analytic teachers
@@ -107,6 +110,6 @@ python -m training.eval --policy sequential --max-flock 10 --max-steps 8000 --n
tools/run_webots.sh 10 rl tools/run_webots.sh 10 rl
``` ```
The dog controller loads the highest-priority policy that exists The dog controller loads `runs/bc` for `bc` mode and `runs/rl` for
(`bc_dagger_v2``bc_dagger``bc_v3`). Override with `rl` mode. Override with `HERDING_POLICY_DIR=…` for a specific
`HERDING_POLICY_DIR=…` if you want a specific checkpoint. checkpoint.
+2 -2
View File
@@ -15,7 +15,7 @@ Usage::
python -m training.bc_pretrain \\ python -m training.bc_pretrain \\
--demos training/demos.npz \\ --demos training/demos.npz \\
--out training/runs/bc_flock --out training/runs/bc
""" """
from __future__ import annotations from __future__ import annotations
@@ -83,7 +83,7 @@ def policy_forward_mean(policy, obs_batch):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--demos", default="training/demos.npz") parser.add_argument("--demos", default="training/demos.npz")
parser.add_argument("--out", default="training/runs/bc_solo") parser.add_argument("--out", default="training/runs/bc")
parser.add_argument("--epochs", type=int, default=60) parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--lr", type=float, default=1e-3)
View File
+9
View File
@@ -204,6 +204,12 @@ class HerdingEnv(gym.Env):
already mimics a stronger teacher (sequential).""" already mimics a stronger teacher (sequential)."""
self.W_IMITATE = float(value) self.W_IMITATE = float(value)
def set_time_weight(self, value: float) -> None:
"""Override W_TIME (instance-level). Default 0.0; a small
negative value (e.g. -0.1) adds a per-step penalty that
explicitly rewards fast time-to-pen during PPO fine-tune."""
self.W_TIME = float(value)
# ---- gym API ---- # ---- gym API ----
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
super().reset(seed=seed) super().reset(seed=seed)
@@ -431,6 +437,9 @@ class HerdingEnv(gym.Env):
d_progress = max(-5.0, min(5.0, self.prev_d_pen - d_pen)) d_progress = max(-5.0, min(5.0, self.prev_d_pen - d_pen))
r = self.W_PEN_DELTA * delta_pen + self.W_PROGRESS * d_progress r = self.W_PEN_DELTA * delta_pen + self.W_PROGRESS * d_progress
# Per-step time penalty (0 by default). When negative, encourages
# the policy to finish quickly — used during PPO fine-tune.
r += self.W_TIME
if action is not None and self.W_IMITATE > 0.0: if action is not None and self.W_IMITATE > 0.0:
positions = self._perceived_positions() positions = self._perceived_positions()
Binary file not shown.
+32 -7
View File
@@ -10,7 +10,7 @@ per-step reward signal does the rest.
Pipeline Pipeline
-------- --------
1. Load ``bc_v3`` weights into both the trainable policy and a frozen 1. Load ``bc`` weights into both the trainable policy and a frozen
reference ``ref_policy``. reference ``ref_policy``.
2. Initialise the policy's log_std to a small fixed value (≈ 1.5) 2. Initialise the policy's log_std to a small fixed value (≈ 1.5)
and disable its gradient — exploration noise stays small so PPO and disable its gradient — exploration noise stays small so PPO
@@ -19,14 +19,14 @@ Pipeline
each minibatch. each minibatch.
4. Train for ~13 M timesteps with a low LR (5e-5). 4. Train for ~13 M timesteps with a low LR (5e-5).
Output: ``runs/rl_v1/policy.zip`` — same SB3 format as bc_v3, loadable Output: ``runs/rl/policy.zip`` — same SB3 format as bc, loadable
by the dog controller's ``HERDING_MODE=rl`` path. by the dog controller's ``HERDING_MODE=rl`` path.
Usage:: Usage::
python -m training.train_ppo \\ python -m training.train_ppo \\
--bc training/runs/bc_v3 \\ --bc training/runs/bc \\
--out training/runs/rl_v1 \\ --out training/runs/rl \\
--total-timesteps 2000000 --total-timesteps 2000000
""" """
@@ -205,9 +205,9 @@ class KLPPO(PPO):
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--bc", default="training/runs/bc_v3", parser.add_argument("--bc", default="training/runs/bc",
help="Directory containing the BC initialisation (policy.zip).") help="Directory containing the BC initialisation (policy.zip).")
parser.add_argument("--out", default="training/runs/rl_v1", parser.add_argument("--out", default="training/runs/rl",
help="Where to save the fine-tuned policy.") help="Where to save the fine-tuned policy.")
parser.add_argument("--total-timesteps", type=int, default=2_000_000) parser.add_argument("--total-timesteps", type=int, default=2_000_000)
parser.add_argument("--n-envs", type=int, default=8) parser.add_argument("--n-envs", type=int, default=8)
@@ -232,12 +232,23 @@ def main() -> None:
help="SB3's per-batch KL early stop; safety belt.") help="SB3's per-batch KL early stop; safety belt.")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", default="cpu") parser.add_argument("--device", default="cpu")
parser.add_argument("--imitate-weight", type=float, default=None,
help="Override env.W_IMITATE for this training "
"run. Set to 0.0 to drop the Strömbom "
"cosine-imitation reward — useful during "
"PPO refinement where you want reward, "
"not teacher imitation, to drive updates.")
parser.add_argument("--time-weight", type=float, default=None,
help="Override env.W_TIME. Default env value is "
"0.0; setting e.g. -0.1 adds a small per-"
"step penalty that explicitly rewards "
"fast time-to-pen.")
args = parser.parse_args() args = parser.parse_args()
bc_zip = Path(args.bc) / "policy.zip" bc_zip = Path(args.bc) / "policy.zip"
if not bc_zip.exists(): if not bc_zip.exists():
raise SystemExit( raise SystemExit(
f"BC checkpoint not found at {bc_zip}. Train bc_v3 first with " f"BC checkpoint not found at {bc_zip}. Train bc first with "
f"`python -m training.bc_pretrain`." f"`python -m training.bc_pretrain`."
) )
@@ -259,6 +270,20 @@ def main() -> None:
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns) venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)]) eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
# --- Apply reward-shaping overrides to every env instance ---
def _broadcast(method: str, value):
for v in (venv, eval_venv):
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
if args.imitate_weight is not None:
_broadcast("set_imitate_weight", args.imitate_weight)
print(f"[rl] W_IMITATE overridden to {args.imitate_weight}")
if args.time_weight is not None:
_broadcast("set_time_weight", args.time_weight)
print(f"[rl] W_TIME overridden to {args.time_weight}")
# --- Trainable policy: load BC weights, then bolt onto PPO --- # --- Trainable policy: load BC weights, then bolt onto PPO ---
# Trick: instantiate a PPO with the right env (so the policy # Trick: instantiate a PPO with the right env (so the policy
# network is constructed at the correct obs/action shape), then # network is constructed at the correct obs/action shape), then
-9
View File
@@ -1,9 +0,0 @@
Webots Project File version R2025a
perspectives: 000000ff00000000fd00000002000000010000011c00000405fc0200000001fb0000001400540065007800740045006400690074006f00720100000000000004050000003f00ffffff00000003000007c500000092fc0100000001fb0000001a0043006f006e0073006f006c00650041006c006c0041006c006c0100000000000007c50000006900ffffff000006a70000040500000001000000020000000100000008fc00000000
simulationViewPerspectives: 000000ff000000010000000200000100000003a80100000002010000000100
sceneTreePerspectives: 000000ff00000001000000030000001f0000018b000000fa0100000002010000000200
maximizedDockId: -1
centralWidgetVisible: 1
orthographicViewHeight: 1
textFiles: -1
consoles: Console:All:All