Checkpoint 7
This commit is contained in:
@@ -0,0 +1,103 @@
|
|||||||
|
# Training pipeline for the shepherd-dog herding project.
|
||||||
|
# Stages chain via output files in training/.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# make # full pipeline: bc_demos -> bc -> rl -> eval
|
||||||
|
# make bc_demos # generate sim demos
|
||||||
|
# make bc # behaviour clone (rebuilds bc_demos if missing)
|
||||||
|
# make rl # KL-PPO fine-tune (rebuilds bc if missing)
|
||||||
|
# make eval # 10-seed env eval of rl
|
||||||
|
# make test # pytest suite
|
||||||
|
# make webots N=10 MODE=rl # launch Webots in the chosen mode
|
||||||
|
# make clean # delete bc_demos and run artefacts
|
||||||
|
# make help # print the target table
|
||||||
|
#
|
||||||
|
# Override any hyperparameter on the command line, for example:
|
||||||
|
# make rl PPO_STEPS=2000000 KL=0.02
|
||||||
|
# make eval EVAL_SEEDS=20
|
||||||
|
|
||||||
|
|
||||||
|
PY := python
|
||||||
|
|
||||||
|
BC_DEMOS := training/bc/demos.npz
|
||||||
|
BC_DIR := training/runs/bc
|
||||||
|
RL_DIR := training/runs/rl
|
||||||
|
BC_POLICY := $(BC_DIR)/policy.zip
|
||||||
|
RL_POLICY := $(RL_DIR)/policy.zip
|
||||||
|
|
||||||
|
# --- Demo collection ---
|
||||||
|
TEACHER ?= strombom
|
||||||
|
SEEDS_PER_N ?= 15
|
||||||
|
SUBSAMPLE ?= 3
|
||||||
|
FRAME_STACK ?= 4
|
||||||
|
|
||||||
|
# --- Behaviour cloning ---
|
||||||
|
BC_EPOCHS ?= 60
|
||||||
|
BC_NET_ARCH ?= 512,512
|
||||||
|
|
||||||
|
# --- KL-PPO fine-tune ---
|
||||||
|
PPO_STEPS ?= 1000000
|
||||||
|
KL ?= 0.05
|
||||||
|
|
||||||
|
# --- Evaluation ---
|
||||||
|
EVAL_SEEDS ?= 10
|
||||||
|
EVAL_MAX_STEPS ?= 15000
|
||||||
|
|
||||||
|
# --- Webots launcher ---
|
||||||
|
N ?= 10
|
||||||
|
MODE ?= rl
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: all bc_demos bc rl eval test webots clean help
|
||||||
|
|
||||||
|
all: eval
|
||||||
|
|
||||||
|
bc_demos: $(BC_DEMOS)
|
||||||
|
$(BC_DEMOS):
|
||||||
|
$(PY) -m training.bc.collect \
|
||||||
|
--teacher $(TEACHER) --out $(BC_DEMOS) \
|
||||||
|
--seeds-per-n $(SEEDS_PER_N) --subsample $(SUBSAMPLE) \
|
||||||
|
--frame-stack $(FRAME_STACK)
|
||||||
|
|
||||||
|
bc: $(BC_POLICY)
|
||||||
|
$(BC_POLICY): $(BC_DEMOS)
|
||||||
|
$(PY) -m training.bc.pretrain \
|
||||||
|
--demos $(BC_DEMOS) --out $(BC_DIR) \
|
||||||
|
--epochs $(BC_EPOCHS) --net-arch $(BC_NET_ARCH)
|
||||||
|
|
||||||
|
rl: $(RL_POLICY)
|
||||||
|
$(RL_POLICY): $(BC_POLICY)
|
||||||
|
$(PY) -m training.rl.train \
|
||||||
|
--bc $(BC_DIR) --out $(RL_DIR) \
|
||||||
|
--total-timesteps $(PPO_STEPS) --kl-coef $(KL)
|
||||||
|
|
||||||
|
eval: $(RL_POLICY)
|
||||||
|
$(PY) -m training.eval --policy $(RL_DIR) \
|
||||||
|
--max-flock 10 --max-steps $(EVAL_MAX_STEPS) --n-seeds $(EVAL_SEEDS)
|
||||||
|
|
||||||
|
test:
|
||||||
|
$(PY) -m pytest tests/
|
||||||
|
|
||||||
|
webots:
|
||||||
|
tools/run_webots.sh $(N) $(MODE)
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf $(BC_DEMOS) $(BC_DIR) $(RL_DIR)
|
||||||
|
|
||||||
|
help:
|
||||||
|
@echo "Targets:"
|
||||||
|
@echo " make full pipeline (bc_demos -> bc -> rl -> eval)"
|
||||||
|
@echo " make bc_demos sim demos via the '$(TEACHER)' teacher"
|
||||||
|
@echo " make bc train BC (rebuilds bc_demos if missing)"
|
||||||
|
@echo " make rl KL-PPO fine-tune (rebuilds bc if missing)"
|
||||||
|
@echo " make eval $(EVAL_SEEDS)-seed env eval of rl"
|
||||||
|
@echo " make test pytest suite"
|
||||||
|
@echo " make webots [N=$(N)] [MODE=$(MODE)]"
|
||||||
|
@echo " launch Webots in the chosen mode"
|
||||||
|
@echo " make clean delete bc_demos and run artefacts"
|
||||||
|
@echo ""
|
||||||
|
@echo "Hyperparameter overrides (showing defaults):"
|
||||||
|
@echo " TEACHER=$(TEACHER) SEEDS_PER_N=$(SEEDS_PER_N) SUBSAMPLE=$(SUBSAMPLE) FRAME_STACK=$(FRAME_STACK)"
|
||||||
|
@echo " BC_EPOCHS=$(BC_EPOCHS) BC_NET_ARCH=$(BC_NET_ARCH)"
|
||||||
|
@echo " PPO_STEPS=$(PPO_STEPS) KL=$(KL)"
|
||||||
|
@echo " EVAL_SEEDS=$(EVAL_SEEDS) EVAL_MAX_STEPS=$(EVAL_MAX_STEPS)"
|
||||||
@@ -22,10 +22,10 @@ control step:
|
|||||||
|
|
||||||
1. Read `lidar.getRangeImage()`,
|
1. Read `lidar.getRangeImage()`,
|
||||||
2. Cluster returns into world-frame `(x, y)` estimates
|
2. Cluster returns into world-frame `(x, y)` estimates
|
||||||
(`herding/lidar_perception.py`),
|
(`herding/perception/lidar_perception.py`),
|
||||||
3. Fold them into a multi-target tracker that maintains last-seen
|
3. Fold them into a multi-target tracker that maintains last-seen
|
||||||
positions for sheep currently outside the FOV
|
positions for sheep currently outside the FOV
|
||||||
(`herding/sheep_tracker.py`).
|
(`herding/perception/sheep_tracker.py`).
|
||||||
|
|
||||||
**LiDAR validation** (intermediate-goal item v from `docs/project.md`):
|
**LiDAR validation** (intermediate-goal item v from `docs/project.md`):
|
||||||
during development a diagnostic-dump controller captured 80 real
|
during development a diagnostic-dump controller captured 80 real
|
||||||
@@ -39,7 +39,7 @@ task.
|
|||||||
The tracker outputs a `{name: (x, y)}` dict shaped exactly like the
|
The tracker outputs a `{name: (x, y)}` dict shaped exactly like the
|
||||||
prior receiver-based one, so Strömbom, Sequential, and the BC obs
|
prior receiver-based one, so Strömbom, Sequential, and the BC obs
|
||||||
builder all run unchanged on top of it. The 2D Gymnasium env
|
builder all run unchanged on top of it. The 2D Gymnasium env
|
||||||
(`herding/lidar_sim.py`) raycasts sheep discs at training time, so
|
(`herding/perception/lidar_sim.py`) raycasts sheep discs at training time, so
|
||||||
demos collected in the env match the perception the deployed
|
demos collected in the env match the perception the deployed
|
||||||
controller sees in Webots.
|
controller sees in Webots.
|
||||||
|
|
||||||
@@ -52,36 +52,32 @@ Privileged ground-truth perception is available for ablation —
|
|||||||
# 1. Set up the Python env (any venv with PyTorch + SB3)
|
# 1. Set up the Python env (any venv with PyTorch + SB3)
|
||||||
pip install -r training/requirements.txt
|
pip install -r training/requirements.txt
|
||||||
|
|
||||||
# 2. Smoke test
|
# 2. Smoke test (70 pytest cases, < 1 s)
|
||||||
python -m tests.parity_test
|
make test
|
||||||
|
|
||||||
# 3. Reproduce the BC policy (~10 min on CPU: ~5 min demos + ~3 min BC)
|
# 3. Reproduce the full pipeline (~30–60 min CPU)
|
||||||
python -m tools.collect_demos --teacher strombom \
|
make # demos -> bc -> rl -> eval
|
||||||
--out training/demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
|
||||||
python -m training.bc_pretrain --demos training/demos.npz \
|
|
||||||
--out training/runs/bc --epochs 60 --net-arch 512,512
|
|
||||||
|
|
||||||
# 4. KL-PPO fine-tune of the BC policy (~30 min on CPU, 1 M steps)
|
# Individual stages (each rebuilds upstream artefacts if missing):
|
||||||
python -m training.train_ppo \
|
make bc_demos # sim demos
|
||||||
--bc training/runs/bc \
|
make bc # behaviour clone
|
||||||
--out training/runs/rl \
|
make rl # KL-PPO fine-tune
|
||||||
--total-timesteps 1000000
|
make eval # 10-seed env eval of rl
|
||||||
|
|
||||||
# 5. Evaluate (env)
|
# 4. Run in Webots
|
||||||
python -m training.eval --policy training/runs/rl \
|
make webots N=10 MODE=bc # behaviour-cloned MLP
|
||||||
--max-flock 10 --max-steps 15000 --n-seeds 10
|
make webots N=10 MODE=rl # KL-PPO fine-tune
|
||||||
|
make webots N=10 MODE=strombom # analytic baseline
|
||||||
# 6. Run in Webots
|
# (or invoke directly: tools/run_webots.sh 10 rl)
|
||||||
tools/run_webots.sh 10 bc # behaviour-cloned MLP
|
|
||||||
tools/run_webots.sh 10 rl # KL-PPO fine-tune
|
|
||||||
tools/run_webots.sh 10 strombom # analytic baseline
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`make help` lists every target and the overridable hyperparameters
|
||||||
|
(e.g. `make rl PPO_STEPS=2000000 KL=0.02`).
|
||||||
|
|
||||||
## Layout
|
## Layout
|
||||||
|
|
||||||
```
|
```
|
||||||
herding/ — perception / control / world primitives
|
herding/ — perception / control / world primitives
|
||||||
obs.py — 32-D order-invariant observation builder
|
|
||||||
world/ — environment-side physics & geometry
|
world/ — environment-side physics & geometry
|
||||||
geometry.py field/pen constants, robot specs
|
geometry.py field/pen constants, robot specs
|
||||||
diffdrive.py differential-drive kinematics
|
diffdrive.py differential-drive kinematics
|
||||||
@@ -90,6 +86,7 @@ herding/ — perception / control / world primitives
|
|||||||
lidar_sim.py fast 2D raycast for the env
|
lidar_sim.py fast 2D raycast for the env
|
||||||
lidar_perception.py scan → world-frame cluster centroids + filters
|
lidar_perception.py scan → world-frame cluster centroids + filters
|
||||||
sheep_tracker.py multi-target NN tracker with FOV memory
|
sheep_tracker.py multi-target NN tracker with FOV memory
|
||||||
|
obs.py 32-D order-invariant observation builder
|
||||||
control/ — every dog mode's action source
|
control/ — every dog mode's action source
|
||||||
strombom.py canonical CoM collect/drive heuristic
|
strombom.py canonical CoM collect/drive heuristic
|
||||||
sequential.py single-target "pin-and-push" alternative
|
sequential.py single-target "pin-and-push" alternative
|
||||||
@@ -105,19 +102,28 @@ controllers/
|
|||||||
|
|
||||||
training/
|
training/
|
||||||
herding_env.py — Gymnasium env (LiDAR + tracker by default)
|
herding_env.py — Gymnasium env (LiDAR + tracker by default)
|
||||||
bc_pretrain.py — supervised BC of (obs, action) demos into MLP
|
bc/collect.py — sim demos via the active-scan teacher
|
||||||
train_ppo.py — KL-regularised PPO fine-tune of BC
|
bc/pretrain.py — supervised BC of (obs, action) demos into MLP
|
||||||
|
rl/train.py — KL-regularised PPO fine-tune of BC
|
||||||
eval.py — analytic + learned policy comparison harness
|
eval.py — analytic + learned policy comparison harness
|
||||||
|
bc/demos.npz — collected demonstrations (gitignored)
|
||||||
runs/ — checkpoints (whitelisted in .gitignore)
|
runs/ — checkpoints (whitelisted in .gitignore)
|
||||||
requirements.txt
|
requirements.txt
|
||||||
|
|
||||||
tests/
|
tests/
|
||||||
parity_test.py — shape / determinism / baseline smoke test
|
conftest.py — pytest setup (adds project root to sys.path)
|
||||||
|
test_geometry.py — geometric predicates + constants
|
||||||
|
test_diffdrive.py — kinematics and (vx, vy) → wheel-speed map
|
||||||
|
test_obs.py — observation builder (shape, normalisation, order)
|
||||||
|
test_control.py — speed modulation + analytic teachers + active scan
|
||||||
|
test_perception.py — LiDAR sim + clustering + tracker
|
||||||
|
test_env.py — Gymnasium contract + determinism + reward
|
||||||
|
|
||||||
tools/
|
tools/
|
||||||
collect_demos.py — sim demos via the active-scan teacher
|
|
||||||
run_webots.sh — launch Webots with N sheep + chosen mode
|
run_webots.sh — launch Webots with N sheep + chosen mode
|
||||||
|
|
||||||
|
Makefile — pipeline orchestrator (make / make rl / make test / …)
|
||||||
|
|
||||||
worlds/
|
worlds/
|
||||||
field.wbt — main world (3 m gate, external pen)
|
field.wbt — main world (3 m gate, external pen)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Backwards-compat shim — flocking logic now lives in ``herding.flocking_sim``.
|
"""Backwards-compat shim — flocking logic now lives in ``herding.world.flocking_sim``.
|
||||||
|
|
||||||
Kept so any external reference still resolves.
|
Kept so any external reference still resolves.
|
||||||
"""
|
"""
|
||||||
|
|||||||
+19
-36
@@ -1,14 +1,13 @@
|
|||||||
"""Sheep flocking controller (Webots).
|
"""Sheep flocking controller (Webots).
|
||||||
|
|
||||||
Each sheep broadcasts its GPS position every 3 steps on channel 1 and
|
Each sheep emits its GPS position every 3 steps and listens for the
|
||||||
listens for the dog and peer sheep positions. The behavioural step is
|
dog's position and peer-sheep positions. The behavioural step is
|
||||||
delegated to ``herding.flocking_sim.compute_heading_speed`` so the
|
delegated to :func:`herding.world.flocking_sim.compute_heading_speed`
|
||||||
training environment and Webots run identical sheep dynamics.
|
so the env and Webots use identical sheep dynamics.
|
||||||
|
|
||||||
Pen behaviour: a sheep latches to ``penned`` the first time it crosses
|
A sheep latches penned the first time it crosses the gate plane south;
|
||||||
the south-wall gate plane into the gate corridor. Once latched it turns
|
the wool turns pink (via the exposed ``woolColor`` PROTO field) and
|
||||||
pink (via the exposed ``woolColor`` PROTO field) and the force stack
|
the dynamics switch to in-pen containment.
|
||||||
switches to in-pen containment.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@@ -32,10 +31,7 @@ from herding.world.geometry import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# --- Devices ---
|
||||||
# Device setup
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
robot = Supervisor()
|
robot = Supervisor()
|
||||||
timestep = int(robot.getBasicTimeStep())
|
timestep = int(robot.getBasicTimeStep())
|
||||||
name = robot.getName()
|
name = robot.getName()
|
||||||
@@ -55,14 +51,10 @@ receiver = robot.getDevice("receiver"); receiver.enable(timestep)
|
|||||||
emitter = robot.getDevice("emitter")
|
emitter = robot.getDevice("emitter")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# --- Helpers ---
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def bearing():
|
def bearing():
|
||||||
# Compass returns north direction in sensor frame; for this Z-up world
|
"""World-frame heading (0 = east, π/2 = north)."""
|
||||||
# with north = +Y, atan2(n[0], n[1]) gives the standard math angle
|
|
||||||
# (0 = east, π/2 = north) matching atan2(fy, fx) used for headings.
|
|
||||||
n = compass.getValues()
|
n = compass.getValues()
|
||||||
return math.atan2(n[0], n[1])
|
return math.atan2(n[0], n[1])
|
||||||
|
|
||||||
@@ -76,45 +68,36 @@ def drive(heading, speed_motor):
|
|||||||
|
|
||||||
|
|
||||||
def paint_pink():
|
def paint_pink():
|
||||||
# woolColor is declared as a PROTO field with IS binding to the DEF WOOL
|
"""Switch the sheep's wool to pink via the exposed PROTO field."""
|
||||||
# PBRAppearance baseColor; setting it propagates to every USE WOOL shape.
|
|
||||||
self_node.getField("woolColor").setSFColor([1.0, 0.55, 0.72])
|
self_node.getField("woolColor").setSFColor([1.0, 0.55, 0.72])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# --- State ---
|
||||||
# State
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
wander_angle = random.uniform(-math.pi, math.pi)
|
wander_angle = random.uniform(-math.pi, math.pi)
|
||||||
step_count = 0
|
step_count = 0
|
||||||
dog_x, dog_y = None, None
|
dog_x, dog_y = None, None
|
||||||
peers = {} # name → (x, y), one entry per neighbour, cleared every 30 steps
|
peers = {} # name → (x, y); periodically pruned
|
||||||
penned = False
|
penned = False
|
||||||
|
|
||||||
# Stuck detection: differential-drive sheep can pin against a wall and need
|
# Safety net for differential-drive sheep pinned against a wall.
|
||||||
# a forced reverse-and-rotate to escape. If displacement < STUCK_DIST for
|
|
||||||
# STUCK_STEPS consecutive steps, drive toward field centre.
|
|
||||||
_prev_x, _prev_y = None, None
|
_prev_x, _prev_y = None, None
|
||||||
_stuck_count = 0
|
_stuck_count = 0
|
||||||
STUCK_STEPS = 20
|
STUCK_STEPS = 20
|
||||||
STUCK_DIST = 0.05
|
STUCK_DIST = 0.05
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# --- Main loop ---
|
||||||
# Main loop
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
while robot.step(timestep) != -1:
|
while robot.step(timestep) != -1:
|
||||||
step_count += 1
|
step_count += 1
|
||||||
pos = gps.getValues()
|
pos = gps.getValues()
|
||||||
x, y = pos[0], pos[1]
|
x, y = pos[0], pos[1]
|
||||||
|
|
||||||
# Pen entry: one-way latch. Penned sheep get pink wool and switch behaviour.
|
|
||||||
if not penned and is_penned_position(x, y):
|
if not penned and is_penned_position(x, y):
|
||||||
penned = True
|
penned = True
|
||||||
paint_pink()
|
paint_pink()
|
||||||
|
|
||||||
# Refresh peer table — clear before receiving so fresh data is never lost.
|
# Stale peers get dropped periodically so a peer that's gone silent
|
||||||
|
# doesn't permanently distort the local CoM.
|
||||||
if step_count % 30 == 0:
|
if step_count % 30 == 0:
|
||||||
peers.clear()
|
peers.clear()
|
||||||
while receiver.getQueueLength() > 0:
|
while receiver.getQueueLength() > 0:
|
||||||
@@ -132,12 +115,12 @@ while robot.step(timestep) != -1:
|
|||||||
wander_angle=wander_angle,
|
wander_angle=wander_angle,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stuck detection — safety net for differential-drive wall pinning.
|
# Stuck-against-wall recovery: drive toward the field centre.
|
||||||
if _prev_x is not None:
|
if _prev_x is not None:
|
||||||
moved = math.hypot(x - _prev_x, y - _prev_y)
|
moved = math.hypot(x - _prev_x, y - _prev_y)
|
||||||
_stuck_count = _stuck_count + 1 if moved < STUCK_DIST else 0
|
_stuck_count = _stuck_count + 1 if moved < STUCK_DIST else 0
|
||||||
if _stuck_count >= STUCK_STEPS:
|
if _stuck_count >= STUCK_STEPS:
|
||||||
heading = math.atan2(-y, -x) # always points away from the boundary
|
heading = math.atan2(-y, -x)
|
||||||
speed = MAX_SPEED
|
speed = MAX_SPEED
|
||||||
_stuck_count = 0
|
_stuck_count = 0
|
||||||
_prev_x, _prev_y = x, y
|
_prev_x, _prev_y = x, y
|
||||||
|
|||||||
@@ -1,18 +1,13 @@
|
|||||||
"""Lazy loader for the SB3 PPO policy used by the dog controller.
|
"""Lazy SB3 policy loader for the dog controller.
|
||||||
|
|
||||||
Importing stable-baselines3 inside the Webots Python interpreter is only
|
SB3 is imported only when a learned policy is actually requested,
|
||||||
needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This
|
so the analytic modes can run on installs without stable-baselines3
|
||||||
loader keeps SB3 out of the import path until you actually ask for the RL
|
or torch.
|
||||||
policy, so users without SB3 installed can still run the Strömbom
|
|
||||||
baseline.
|
|
||||||
|
|
||||||
The policy + VecNormalize statistics are saved together by
|
The handle auto-detects frame stacking from the policy's expected
|
||||||
``training/train_ppo.py``:
|
observation dimension: if it's a multiple of the single-frame
|
||||||
|
``OBS_DIM``, an internal buffer of the last K frames is maintained
|
||||||
runs/<name>/best/best_model.zip # SB3 PPO checkpoint
|
and concatenated on each ``predict`` call.
|
||||||
runs/<name>/best/vecnormalize.pkl # observation-normaliser stats
|
|
||||||
|
|
||||||
Pass either the directory or the explicit zip path.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -20,20 +15,12 @@ from pathlib import Path
|
|||||||
|
|
||||||
|
|
||||||
class PolicyHandle:
|
class PolicyHandle:
|
||||||
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
|
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``."""
|
||||||
``predict(obs)`` without thinking about either.
|
|
||||||
|
|
||||||
Frame stacking is auto-detected from the policy's expected obs dim:
|
|
||||||
if it's a multiple of the single-frame ``OBS_DIM``, the handle keeps
|
|
||||||
a deque of the last K frames and concatenates them on each predict.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model, vecnorm):
|
def __init__(self, model, vecnorm):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.vecnorm = vecnorm
|
self.vecnorm = vecnorm
|
||||||
# Lazy import to avoid forcing herding/* into the import path
|
from herding.perception.obs import OBS_DIM
|
||||||
# when SB3 isn't being used.
|
|
||||||
from herding.obs import OBS_DIM
|
|
||||||
policy_dim = int(model.observation_space.shape[0])
|
policy_dim = int(model.observation_space.shape[0])
|
||||||
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
||||||
self.frame_stack = policy_dim // OBS_DIM
|
self.frame_stack = policy_dim // OBS_DIM
|
||||||
@@ -46,7 +33,7 @@ class PolicyHandle:
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
single = np.asarray(obs, dtype=np.float32).reshape(-1)
|
single = np.asarray(obs, dtype=np.float32).reshape(-1)
|
||||||
if single.shape[0] != self._single_dim:
|
if single.shape[0] != self._single_dim:
|
||||||
# Caller already passed a stacked obs — use as-is.
|
# Caller passed an already-stacked obs.
|
||||||
stacked = single
|
stacked = single
|
||||||
elif self.frame_stack > 1:
|
elif self.frame_stack > 1:
|
||||||
if not self._buffer:
|
if not self._buffer:
|
||||||
@@ -67,18 +54,19 @@ class PolicyHandle:
|
|||||||
|
|
||||||
|
|
||||||
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
||||||
"""Load a PPO model (and optional VecNormalize) from disk.
|
"""Load a policy zip (+ optional VecNormalize pickle) from disk.
|
||||||
|
|
||||||
``model_path`` may be the .zip checkpoint or a directory containing
|
``model_path`` may be a ``.zip`` file or a directory; in the
|
||||||
``best_model.zip`` (and optionally ``vecnormalize.pkl``).
|
latter case ``policy.zip`` is preferred, with ``final.zip`` as
|
||||||
|
a fallback for partially-completed RL runs.
|
||||||
"""
|
"""
|
||||||
p = Path(model_path)
|
p = Path(model_path)
|
||||||
if p.is_dir():
|
if p.is_dir():
|
||||||
zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"]
|
zip_candidates = [p / "policy.zip", p / "final.zip"]
|
||||||
zip_path = next((z for z in zip_candidates if z.exists()), None)
|
zip_path = next((z for z in zip_candidates if z.exists()), None)
|
||||||
if zip_path is None:
|
if zip_path is None:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)"
|
f"No policy zip in {p} (looked for policy.zip, final.zip)"
|
||||||
)
|
)
|
||||||
if vecnorm_path is None:
|
if vecnorm_path is None:
|
||||||
vn = p / "vecnormalize.pkl"
|
vn = p / "vecnormalize.pkl"
|
||||||
@@ -87,15 +75,13 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
|||||||
else:
|
else:
|
||||||
zip_path = p
|
zip_path = p
|
||||||
|
|
||||||
# Imports deferred so the Strömbom path doesn't require SB3.
|
# Deferred imports so the analytic path doesn't require SB3.
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from stable_baselines3.common.vec_env import VecNormalize
|
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
|
||||||
|
|
||||||
model = PPO.load(str(zip_path), device="auto")
|
model = PPO.load(str(zip_path), device="auto")
|
||||||
vecnorm = None
|
vecnorm = None
|
||||||
if vecnorm_path and os.path.exists(vecnorm_path):
|
if vecnorm_path and os.path.exists(vecnorm_path):
|
||||||
# VecNormalize.load needs a venv to attach to; we only need its stats
|
|
||||||
# at inference, so we reconstruct the wrapper manually.
|
|
||||||
import pickle
|
import pickle
|
||||||
with open(vecnorm_path, "rb") as f:
|
with open(vecnorm_path, "rb") as f:
|
||||||
vecnorm = pickle.load(f)
|
vecnorm = pickle.load(f)
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ from herding.control.active_scan import ActiveScanTeacher
|
|||||||
from herding.control.modulation import modulate_speed_near_sheep
|
from herding.control.modulation import modulate_speed_near_sheep
|
||||||
from herding.control.sequential import compute_action as sequential_action
|
from herding.control.sequential import compute_action as sequential_action
|
||||||
from herding.control.strombom import compute_action as strombom_action
|
from herding.control.strombom import compute_action as strombom_action
|
||||||
from herding.obs import build_obs
|
from herding.perception.obs import build_obs
|
||||||
from herding.perception.lidar_perception import detections_from_scan
|
from herding.perception.lidar_perception import detections_from_scan
|
||||||
from herding.perception.sheep_tracker import SheepTracker
|
from herding.perception.sheep_tracker import SheepTracker
|
||||||
from herding.world.diffdrive import velocity_to_wheels
|
from herding.world.diffdrive import velocity_to_wheels
|
||||||
|
|||||||
@@ -1,23 +1,19 @@
|
|||||||
"""Active-perception wrapper for the analytic shepherding teachers.
|
"""Active-perception wrapper for the analytic shepherd teachers.
|
||||||
|
|
||||||
Under LiDAR (partial observability), the tracker starts empty — the
|
Under partial-observability LiDAR perception the tracker starts empty
|
||||||
dog hasn't seen any sheep yet. A naive Strömbom call returns
|
— a naive analytic teacher returns ``(0, 0, "idle")`` and the dog
|
||||||
``(0, 0, "idle")`` and the dog stops. The student then learns "do
|
stops. This wrapper interleaves the underlying teacher with two
|
||||||
nothing when the tracker is empty," which is a fatal local optimum.
|
exploration behaviours:
|
||||||
|
|
||||||
This wrapper replaces the idle case with a **scan action**: a unit
|
* opening in-place rotation for the first ``INITIAL_SCAN_STEPS``,
|
||||||
vector 90° CCW from the dog's current forward direction. Passed
|
guaranteeing the LiDAR sweeps a full circle before driving;
|
||||||
through ``velocity_to_wheels`` it produces a fast in-place rotation
|
* walk-to-centre when the tracker has been empty for at least
|
||||||
(``cos(err)`` clamp drives forward speed to ~0 because the target is
|
``EMPTY_DEBOUNCE_STEPS`` consecutive frames (corners can sit
|
||||||
orthogonal to the heading). The dog spins for the first
|
beyond the 12 m LiDAR range).
|
||||||
``initial_scan_steps`` steps of every episode regardless of tracker
|
|
||||||
state, and re-enters scan whenever the tracker goes empty mid-episode.
|
|
||||||
|
|
||||||
Once enough sheep are tracked, control hands over to the underlying
|
When the tracker has detections the base teacher's action is used,
|
||||||
analytic teacher (Strömbom or Sequential), which now operates on a
|
post-processed by ``modulate_speed_near_sheep`` so the dog doesn't
|
||||||
populated tracker dict. Both teacher and student see the same
|
charge the flock.
|
||||||
LiDAR-perceived view — there's no information asymmetry, so the
|
|
||||||
student can in principle achieve the teacher's full performance.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -27,26 +23,17 @@ import math
|
|||||||
from herding.control.modulation import modulate_speed_near_sheep
|
from herding.control.modulation import modulate_speed_near_sheep
|
||||||
|
|
||||||
|
|
||||||
INITIAL_SCAN_STEPS = 80 # ≈1.3 s at dt=16 ms — full rotation at the +π turn target.
|
INITIAL_SCAN_STEPS = 80 # ≈1.3 s — covers one full rotation
|
||||||
EXPLORE_SPEED = 0.7 # m/s-ish unit (action norm) used when walking blind
|
EXPLORE_SPEED = 0.7 # action norm while walking blind
|
||||||
|
EMPTY_DEBOUNCE_STEPS = 8 # consecutive empty frames before exploring
|
||||||
# Debounce on tracker emptiness — a single empty frame between
|
|
||||||
# detections is not enough reason to abandon the drive and start
|
|
||||||
# scanning. Require this many consecutive empty frames first.
|
|
||||||
EMPTY_DEBOUNCE_STEPS = 8
|
|
||||||
|
|
||||||
|
|
||||||
class ActiveScanTeacher:
|
class ActiveScanTeacher:
|
||||||
"""Stateful wrapper. Construct one per episode; call ``reset()``
|
"""Stateful wrapper. Construct one per episode (or call ``reset``).
|
||||||
between episodes if reusing the instance.
|
|
||||||
|
|
||||||
Call signature::
|
Call signature::
|
||||||
|
|
||||||
vx, vy, mode = teacher(dog_xy, dog_heading, sheep_positions, pen_target)
|
vx, vy, mode = teacher(dog_xy, dog_heading, sheep_positions, pen_target)
|
||||||
|
|
||||||
Note the extra ``dog_heading`` arg — required to compute the
|
|
||||||
rotation direction. The base teachers (Strömbom, Sequential)
|
|
||||||
don't use heading; we strip it before passing them through.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, base_action_fn, initial_scan_steps: int = INITIAL_SCAN_STEPS):
|
def __init__(self, base_action_fn, initial_scan_steps: int = INITIAL_SCAN_STEPS):
|
||||||
@@ -61,27 +48,17 @@ class ActiveScanTeacher:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _scan_action(dog_heading: float) -> tuple[float, float]:
|
def _scan_action(dog_heading: float) -> tuple[float, float]:
|
||||||
# Target = current_heading + π. velocity_to_wheels gets err=π,
|
# Target opposite to current heading; velocity_to_wheels'
|
||||||
# so turn = k_turn·π = 4π ≈ 12.6 rad/s wheel angular vel and
|
# cos(err) clamp drives forward speed to ~0 → in-place rotation.
|
||||||
# cos(err) clamps the forward speed to ~0. Maximum in-place
|
|
||||||
# rotation under this controller; one full rotation in ~60 steps.
|
|
||||||
target = dog_heading + math.pi
|
target = dog_heading + math.pi
|
||||||
return math.cos(target), math.sin(target)
|
return math.cos(target), math.sin(target)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _explore_action(dog_xy) -> tuple[float, float]:
|
def _explore_action(dog_xy) -> tuple[float, float]:
|
||||||
"""Walk back toward the field centre when nothing is in view.
|
"""Walk toward (0, 0) while the LiDAR keeps sweeping."""
|
||||||
|
|
||||||
At difficulty=1 sheep can spawn up to ~18 m from origin while
|
|
||||||
the LiDAR has a 12 m range, so an in-place scan from a corner
|
|
||||||
can return zero hits. Walking toward (0, 0) shrinks the
|
|
||||||
max-distance-to-any-sheep and the scanner cone sweeps along
|
|
||||||
the path, eventually picking sheep up.
|
|
||||||
"""
|
|
||||||
dx, dy = -dog_xy[0], -dog_xy[1]
|
dx, dy = -dog_xy[0], -dog_xy[1]
|
||||||
d = math.hypot(dx, dy)
|
d = math.hypot(dx, dy)
|
||||||
if d < 0.5:
|
if d < 0.5:
|
||||||
# At the centre — fall through to a scan instead.
|
|
||||||
return 0.0, 0.0
|
return 0.0, 0.0
|
||||||
return EXPLORE_SPEED * dx / d, EXPLORE_SPEED * dy / d
|
return EXPLORE_SPEED * dx / d, EXPLORE_SPEED * dy / d
|
||||||
|
|
||||||
@@ -89,22 +66,18 @@ class ActiveScanTeacher:
|
|||||||
self.step += 1
|
self.step += 1
|
||||||
n_visible = len(sheep_positions)
|
n_visible = len(sheep_positions)
|
||||||
|
|
||||||
# Track empty-streak for the explore debounce.
|
|
||||||
if n_visible == 0:
|
if n_visible == 0:
|
||||||
self.empty_streak += 1
|
self.empty_streak += 1
|
||||||
else:
|
else:
|
||||||
self.empty_streak = 0
|
self.empty_streak = 0
|
||||||
|
|
||||||
# Phase 1: opening rotation, regardless of tracker state.
|
# Phase 1: opening rotation.
|
||||||
if self.step <= self.initial_scan:
|
if self.step <= self.initial_scan:
|
||||||
vx, vy = self._scan_action(dog_heading)
|
vx, vy = self._scan_action(dog_heading)
|
||||||
self.last_action = (vx, vy)
|
self.last_action = (vx, vy)
|
||||||
return vx, vy, "scan_initial"
|
return vx, vy, "scan_initial"
|
||||||
|
|
||||||
# Phase 2: tracker has been empty for a while — walk back to the
|
# Phase 2: walk-to-centre after a sustained empty tracker.
|
||||||
# centre while the LiDAR keeps sweeping. The debounce prevents
|
|
||||||
# this from firing every time the tracker briefly blinks to zero
|
|
||||||
# (which causes the "dog starts going away from sheep" symptom).
|
|
||||||
if self.empty_streak >= EMPTY_DEBOUNCE_STEPS:
|
if self.empty_streak >= EMPTY_DEBOUNCE_STEPS:
|
||||||
ex, ey = self._explore_action(dog_xy)
|
ex, ey = self._explore_action(dog_xy)
|
||||||
if ex == 0.0 and ey == 0.0:
|
if ex == 0.0 and ey == 0.0:
|
||||||
@@ -116,16 +89,13 @@ class ActiveScanTeacher:
|
|||||||
self.last_action = (vx, vy)
|
self.last_action = (vx, vy)
|
||||||
return vx, vy, mode
|
return vx, vy, mode
|
||||||
|
|
||||||
# Phase 2b: tracker just blinked empty for <DEBOUNCE frames —
|
# Phase 2b: brief tracker blink — hold the previous action.
|
||||||
# hold the previous action so the dog doesn't lurch.
|
|
||||||
if n_visible == 0:
|
if n_visible == 0:
|
||||||
vx, vy = self.last_action
|
vx, vy = self.last_action
|
||||||
return vx, vy, "hold"
|
return vx, vy, "hold"
|
||||||
|
|
||||||
# Phase 3: hand to the underlying analytic teacher, then apply
|
# Phase 3: hand off to the underlying analytic teacher, then
|
||||||
# the shared near-sheep speed modulation (centralised in
|
# apply the shared near-sheep speed modulation.
|
||||||
# herding.control so the BC student, Strömbom, Sequential and
|
|
||||||
# the DAgger teacher all behave identically near sheep).
|
|
||||||
vx, vy, mode = self.base(dog_xy, sheep_positions, pen_target)
|
vx, vy, mode = self.base(dog_xy, sheep_positions, pen_target)
|
||||||
vx, vy = modulate_speed_near_sheep(vx, vy, dog_xy, sheep_positions)
|
vx, vy = modulate_speed_near_sheep(vx, vy, dog_xy, sheep_positions)
|
||||||
self.last_action = (vx, vy)
|
self.last_action = (vx, vy)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Shared low-level control helpers used by every dog mode.
|
"""Shared action post-processing.
|
||||||
|
|
||||||
Centralised here so the BC student, Strömbom, Sequential, and the DAgger
|
Every dog mode routes its action through ``modulate_speed_near_sheep``
|
||||||
teacher all apply identical post-processing to their action outputs.
|
so the magnitude is reduced near sheep — direction (intent) is
|
||||||
The downstream wheel-velocity layer (``herding.diffdrive``) is unchanged.
|
preserved.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -10,12 +10,8 @@ from __future__ import annotations
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
# Speed-modulation: scale action magnitude down when close to the
|
SLOW_NEAR_SHEEP = 2.5 # m — distance below which action norm is scaled down
|
||||||
# nearest sheep. Stops the dog from charging in at full speed and
|
MIN_SPEED = 0.30 # action norm at zero distance
|
||||||
# scattering the flock. Action norm linearly ramps from MIN_SPEED at
|
|
||||||
# distance 0 to 1.0 at SLOW_NEAR_SHEEP.
|
|
||||||
SLOW_NEAR_SHEEP = 2.5
|
|
||||||
MIN_SPEED = 0.30
|
|
||||||
|
|
||||||
|
|
||||||
def modulate_speed_near_sheep(
|
def modulate_speed_near_sheep(
|
||||||
@@ -25,16 +21,9 @@ def modulate_speed_near_sheep(
|
|||||||
slow_dist: float = SLOW_NEAR_SHEEP,
|
slow_dist: float = SLOW_NEAR_SHEEP,
|
||||||
min_scale: float = MIN_SPEED,
|
min_scale: float = MIN_SPEED,
|
||||||
) -> tuple[float, float]:
|
) -> tuple[float, float]:
|
||||||
"""Scale (vx, vy) magnitude down when close to the nearest sheep.
|
"""Linearly ramp action magnitude from ``min_scale`` at distance 0
|
||||||
|
to 1.0 at ``slow_dist``. ``sheep_positions`` may be a
|
||||||
``sheep_positions`` accepts either a ``{name: (x, y)}`` dict
|
``{name: (x, y)}`` dict or an iterable of ``(x, y)`` tuples.
|
||||||
(matching what the trackers emit) or an iterable of ``(x, y)``
|
|
||||||
tuples. Empty input → action returned unchanged.
|
|
||||||
|
|
||||||
The intent direction is preserved; only magnitude is reduced. With
|
|
||||||
``slow_dist=2.5`` and ``min_scale=0.3``, an action that started at
|
|
||||||
norm 1 is multiplied by 0.3 right next to a sheep, by 0.65 at 1 m
|
|
||||||
away, and by 1.0 once the nearest sheep is ≥ 2.5 m off.
|
|
||||||
"""
|
"""
|
||||||
if not sheep_positions:
|
if not sheep_positions:
|
||||||
return vx, vy
|
return vx, vy
|
||||||
|
|||||||
@@ -1,25 +1,9 @@
|
|||||||
"""Sequential single-target shepherd dog algorithm.
|
"""Sequential "pin-and-push" shepherd-dog controller.
|
||||||
|
|
||||||
Strömbom drives the flock's centre of mass; with N sheep and a narrow
|
Single-target alternative to Strömbom: each step, target the sheep
|
||||||
3 m gate, this fails because the flock is wider than the gate and CoM
|
closest to the pen, park behind it, drive it through; once it latches
|
||||||
driving abandons stragglers. Real sheepdogs solve this differently:
|
penned the next-closest sheep becomes the target. Naturally queues
|
||||||
they pick *one* sheep at a time, drive it through, return for the next.
|
the flock through a narrow gate.
|
||||||
|
|
||||||
This module implements that "pin-and-push" approach.
|
|
||||||
|
|
||||||
Algorithm (one step):
|
|
||||||
1. Active sheep = those still in the field (not yet penned).
|
|
||||||
2. Target = the active sheep currently closest to the pen entry.
|
|
||||||
3. Drive position = ``target + Δ · unit(target − pen_entry)`` —
|
|
||||||
directly behind the target relative to the goal.
|
|
||||||
4. Output unit vector pointing the dog at the drive position.
|
|
||||||
|
|
||||||
Once the target crosses the gate it latches as penned and is removed
|
|
||||||
from the active set; the next-closest unpenned sheep becomes the
|
|
||||||
target. The algorithm naturally "queues" sheep through the gate.
|
|
||||||
|
|
||||||
Empirically (with our flocking dynamics) this scales linearly with
|
|
||||||
flock size and works up to at least n=10 within a 15 000-step budget.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@@ -43,25 +27,17 @@ def _is_active(x, y) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def compute_action(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
def compute_action(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
||||||
"""Return ``(vx, vy, mode)`` where mode encodes the current target.
|
"""Return ``(vx, vy, mode)`` — same call signature as Strömbom."""
|
||||||
|
|
||||||
Compatible with the Strömbom call signature so it can be drop-in
|
|
||||||
swapped in the dog controller and the env's imitation reward.
|
|
||||||
"""
|
|
||||||
active = [(name, x, y) for name, (x, y) in sheep_positions.items()
|
active = [(name, x, y) for name, (x, y) in sheep_positions.items()
|
||||||
if _is_active(x, y)]
|
if _is_active(x, y)]
|
||||||
if not active:
|
if not active:
|
||||||
return 0.0, 0.0, "idle"
|
return 0.0, 0.0, "idle"
|
||||||
|
|
||||||
# Pick target = sheep closest to pen entry. Stable choice: as one
|
|
||||||
# sheep approaches and crosses the gate it stays the target until
|
|
||||||
# latched; then the next-closest takes over.
|
|
||||||
name, sx, sy = min(
|
name, sx, sy = min(
|
||||||
active,
|
active,
|
||||||
key=lambda s: math.hypot(s[1] - pen_target[0], s[2] - pen_target[1]),
|
key=lambda s: math.hypot(s[1] - pen_target[0], s[2] - pen_target[1]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Drive position behind the target along the (target → pen) line.
|
|
||||||
ux, uy = _unit(sx - pen_target[0], sy - pen_target[1])
|
ux, uy = _unit(sx - pen_target[0], sy - pen_target[1])
|
||||||
tx = sx + DELTA_DRIVE * ux
|
tx = sx + DELTA_DRIVE * ux
|
||||||
ty = sy + DELTA_DRIVE * uy
|
ty = sy + DELTA_DRIVE * uy
|
||||||
@@ -71,7 +47,7 @@ def compute_action(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
|||||||
|
|
||||||
|
|
||||||
def compute_action_debug(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
def compute_action_debug(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
||||||
"""Debug variant returning ``(vx, vy, mode, debug_dict)``."""
|
"""``compute_action`` plus a debug dict (target, drive point)."""
|
||||||
active = [(name, x, y) for name, (x, y) in sheep_positions.items()
|
active = [(name, x, y) for name, (x, y) in sheep_positions.items()
|
||||||
if _is_active(x, y)]
|
if _is_active(x, y)]
|
||||||
if not active:
|
if not active:
|
||||||
|
|||||||
+14
-33
@@ -1,30 +1,20 @@
|
|||||||
"""Strömbom collect/drive heuristic for the shepherd dog.
|
"""Strömbom (2014) collect/drive heuristic for the shepherd dog.
|
||||||
|
|
||||||
Adapted from the original ``controllers/shepherd_dog/strombom.py`` and
|
When the flock is scattered (max radius > F_FACTOR · √n) the dog moves
|
||||||
updated for the external pen layout. Used as a baseline controller and
|
to a point behind the furthest sheep and pushes it back toward the
|
||||||
as the fallback when the RL policy isn't available.
|
flock CoM. Otherwise it drives, parking behind the CoM relative to
|
||||||
|
the pen target. Returns a unit-vector intent ``(vx, vy, mode)``.
|
||||||
|
|
||||||
Reference: Strömbom et al. 2014, "Solving the shepherding problem".
|
Reference: Strömbom et al. 2014, "Solving the shepherding problem."
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from herding.world.geometry import PEN_ENTRY, GATE_Y, in_pen
|
from herding.world.geometry import PEN_ENTRY, GATE_Y, in_pen
|
||||||
|
|
||||||
# Algorithm parameters. DELTA_DRIVE / DELTA_COLLECT were tightened from
|
F_FACTOR = 4.0 # collect/drive threshold scaled by √n
|
||||||
# the original (4.0 / 2.5) because the new external pen sits ~26 m from
|
DELTA_COLLECT = 1.5 # drive-position offset behind the furthest sheep
|
||||||
# typical sheep spawn locations — at the old 4 m standoff, the flee force
|
DELTA_DRIVE = 2.0 # drive-position offset behind the flock CoM
|
||||||
# (quadratic ramp, 3.7 at 4 m vs ~10 at 2 m) couldn't move sheep through
|
|
||||||
# the path inside the 3000-step episode budget.
|
|
||||||
#
|
|
||||||
# F_FACTOR was 2.0 in the original Strömbom paper; raised to 4.0 here so
|
|
||||||
# the dog stays in *drive* mode much longer. With our tighter cohesion
|
|
||||||
# (flocking_sim.py), partially-collected flocks consolidate naturally
|
|
||||||
# during a drive, and we don't waste 80% of the time budget on a slow
|
|
||||||
# "collect" pre-phase.
|
|
||||||
F_FACTOR = 4.0
|
|
||||||
DELTA_COLLECT = 1.5
|
|
||||||
DELTA_DRIVE = 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def _unit(x, y):
|
def _unit(x, y):
|
||||||
@@ -35,18 +25,12 @@ def _unit(x, y):
|
|||||||
|
|
||||||
|
|
||||||
def _is_active(x, y) -> bool:
|
def _is_active(x, y) -> bool:
|
||||||
"""A sheep is "active" if it's still in the field — not in or below
|
"""A sheep still in the field counts; one south of the gate doesn't."""
|
||||||
the gate plane (we treat anything south of the gate as committed to
|
|
||||||
the pen and stop trying to herd it)."""
|
|
||||||
return (not in_pen(x, y)) and y > GATE_Y
|
return (not in_pen(x, y)) and y > GATE_Y
|
||||||
|
|
||||||
|
|
||||||
def compute_action(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
def compute_action(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
||||||
"""Return ``(vx, vy, mode)`` — mode in {idle, collect, drive}.
|
"""Return ``(vx, vy, mode)`` — mode in {idle, collect, drive}."""
|
||||||
|
|
||||||
``sheep_positions`` is a ``{name: (x, y)}`` mapping (matches the
|
|
||||||
Webots controller's representation).
|
|
||||||
"""
|
|
||||||
active = [(x, y) for (x, y) in sheep_positions.values() if _is_active(x, y)]
|
active = [(x, y) for (x, y) in sheep_positions.values() if _is_active(x, y)]
|
||||||
if not active:
|
if not active:
|
||||||
return 0.0, 0.0, "idle"
|
return 0.0, 0.0, "idle"
|
||||||
@@ -58,14 +42,14 @@ def compute_action(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
|||||||
radius = max(dists)
|
radius = max(dists)
|
||||||
|
|
||||||
if radius > F_FACTOR * math.sqrt(n):
|
if radius > F_FACTOR * math.sqrt(n):
|
||||||
# Collect: aim at a point behind the furthest sheep, opposite the CoM.
|
# Collect: aim behind the furthest sheep, opposite the CoM.
|
||||||
idx = max(range(n), key=lambda i: dists[i])
|
idx = max(range(n), key=lambda i: dists[i])
|
||||||
sx, sy = active[idx]
|
sx, sy = active[idx]
|
||||||
ux, uy = _unit(sx - com_x, sy - com_y)
|
ux, uy = _unit(sx - com_x, sy - com_y)
|
||||||
tx, ty = sx + DELTA_COLLECT * ux, sy + DELTA_COLLECT * uy
|
tx, ty = sx + DELTA_COLLECT * ux, sy + DELTA_COLLECT * uy
|
||||||
mode = "collect"
|
mode = "collect"
|
||||||
else:
|
else:
|
||||||
# Drive: aim at a point behind the flock CoM relative to the goal.
|
# Drive: aim behind the CoM, opposite the pen.
|
||||||
ux, uy = _unit(com_x - pen_target[0], com_y - pen_target[1])
|
ux, uy = _unit(com_x - pen_target[0], com_y - pen_target[1])
|
||||||
tx, ty = com_x + DELTA_DRIVE * ux, com_y + DELTA_DRIVE * uy
|
tx, ty = com_x + DELTA_DRIVE * ux, com_y + DELTA_DRIVE * uy
|
||||||
mode = "drive"
|
mode = "drive"
|
||||||
@@ -75,10 +59,7 @@ def compute_action(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
|||||||
|
|
||||||
|
|
||||||
def compute_action_debug(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
def compute_action_debug(dog_xy, sheep_positions, pen_target=PEN_ENTRY):
|
||||||
"""Variant of compute_action that also returns a small debug dict.
|
"""``compute_action`` plus a small debug dict (CoM, target, radius)."""
|
||||||
|
|
||||||
Kept for parity with the legacy controller's CSV logger.
|
|
||||||
"""
|
|
||||||
active = [(x, y) for (x, y) in sheep_positions.values() if _is_active(x, y)]
|
active = [(x, y) for (x, y) in sheep_positions.values() if _is_active(x, y)]
|
||||||
if not active:
|
if not active:
|
||||||
return 0.0, 0.0, "idle", {
|
return 0.0, 0.0, "idle", {
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""Cluster a 2D LiDAR scan into world-frame sheep position estimates.
|
"""Cluster a 2D LiDAR scan into world-frame sheep position estimates.
|
||||||
|
|
||||||
Pipeline:
|
Pipeline:
|
||||||
ranges (N,) ─► hit mask ─► world-frame points
|
|
||||||
|
ranges (N,) → hit mask → world-frame points
|
||||||
│
|
│
|
||||||
▼
|
▼
|
||||||
adjacency clustering (gap > GAP_THRESHOLD
|
adjacency clustering (gap > GAP_THRESHOLD
|
||||||
@@ -9,18 +10,12 @@ Pipeline:
|
|||||||
angular order)
|
angular order)
|
||||||
│
|
│
|
||||||
▼
|
▼
|
||||||
centroid + span filter
|
centroid + span + region + structure filters
|
||||||
│
|
|
||||||
▼
|
|
||||||
field/pen-corridor filter
|
|
||||||
│
|
│
|
||||||
▼
|
▼
|
||||||
list of (x, y) detections
|
list of (x, y) detections
|
||||||
|
|
||||||
The clusterer is intentionally simple — for ≤10 sheep there is rarely
|
The downstream tracker handles association across frames.
|
||||||
any real ambiguity, and proper DBSCAN would only matter if rays from
|
|
||||||
two adjacent sheep merged. The downstream tracker handles association
|
|
||||||
across frames.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -35,23 +30,19 @@ from herding.perception.lidar_sim import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
GAP_THRESHOLD = 0.6 # m — adjacent ray-points farther apart start new cluster
|
GAP_THRESHOLD = 0.6 # m — adjacent ray-points farther apart start a new cluster
|
||||||
MAX_CLUSTER_SPAN = 1.5 # m — clusters wider than this are likely walls/structures
|
MAX_CLUSTER_SPAN = 1.5 # m — wider clusters are walls / structures
|
||||||
RANGE_HIT_EPS = 0.05 # m — hit if range < max_range - eps
|
RANGE_HIT_EPS = 0.05 # m — hit if range < max_range - eps
|
||||||
WALL_REJECT = 0.5 # m — drop detections this close to a known wall line
|
WALL_REJECT = 0.5 # m — drop detections this close to a known wall line
|
||||||
|
|
||||||
# Known sheep-sized static features. Detections within STATIC_REJECT
|
# Sheep-sized static features (gate posts, corner pillars). A cluster
|
||||||
# of any of these are discarded — these aren't sheep. Mid-pillars on
|
# centred within STATIC_REJECT of any of these is never a sheep.
|
||||||
# the field walls are NOT in this list because they're embedded in the
|
|
||||||
# wall (the wall's span filter handles them); listing them here would
|
|
||||||
# only reject real sheep that happened to be near the wall.
|
|
||||||
_STATIC_FEATURES = (
|
_STATIC_FEATURES = (
|
||||||
# Gate posts (sheep-sized boxes flanking the south-wall opening)
|
( 10.0, -15.0), ( 13.0, -15.0), # gate posts
|
||||||
( 10.0, -15.0), ( 13.0, -15.0),
|
( 15.0, 15.0), ( 15.0, -15.0),
|
||||||
# Field corner pillars
|
(-15.0, 15.0), (-15.0, -15.0), # field corners
|
||||||
( 15.0, 15.0), ( 15.0, -15.0), (-15.0, 15.0), (-15.0, -15.0),
|
|
||||||
)
|
)
|
||||||
STATIC_REJECT = 0.8 # m — detection within this of a static feature → drop
|
STATIC_REJECT = 0.8
|
||||||
|
|
||||||
|
|
||||||
def detections_from_scan(
|
def detections_from_scan(
|
||||||
@@ -71,6 +62,8 @@ def detections_from_scan(
|
|||||||
px = dog_x + ranges * np.cos(world_a)
|
px = dog_x + ranges * np.cos(world_a)
|
||||||
py = dog_y + ranges * np.sin(world_a)
|
py = dog_y + ranges * np.sin(world_a)
|
||||||
|
|
||||||
|
# Walk rays in angular order; a large jump between consecutive
|
||||||
|
# world-frame hit points closes the current cluster.
|
||||||
clusters: list[list[tuple[float, float]]] = []
|
clusters: list[list[tuple[float, float]]] = []
|
||||||
current: list[tuple[float, float]] = []
|
current: list[tuple[float, float]] = []
|
||||||
prev: tuple[float, float] | None = None
|
prev: tuple[float, float] | None = None
|
||||||
@@ -98,41 +91,30 @@ def detections_from_scan(
|
|||||||
span = math.hypot(max(xs) - min(xs), max(ys) - min(ys))
|
span = math.hypot(max(xs) - min(xs), max(ys) - min(ys))
|
||||||
if span > MAX_CLUSTER_SPAN:
|
if span > MAX_CLUSTER_SPAN:
|
||||||
continue
|
continue
|
||||||
# Surface-to-centre correction: rays hit the front of the sheep,
|
# Rays hit the front edge of the sheep; offset outward by
|
||||||
# so the cluster centroid is biased toward the dog by SHEEP_RADIUS.
|
# SHEEP_RADIUS along the dog→cluster direction to estimate the
|
||||||
# Push it outward along the dog→cluster direction.
|
# centre.
|
||||||
dx, dy = cx - dog_x, cy - dog_y
|
dx, dy = cx - dog_x, cy - dog_y
|
||||||
d = math.hypot(dx, dy)
|
d = math.hypot(dx, dy)
|
||||||
if d > 1e-3:
|
if d > 1e-3:
|
||||||
cx += SHEEP_RADIUS * dx / d
|
cx += SHEEP_RADIUS * dx / d
|
||||||
cy += SHEEP_RADIUS * dy / d
|
cy += SHEEP_RADIUS * dy / d
|
||||||
# Keep detections inside the field OR in the gate corridor /
|
# Region filter: in-field clusters, plus a narrow strip south of
|
||||||
# external pen — penned sheep are still worth tracking so the
|
# the gate so sheep mid-crossing get latched penned. Detections
|
||||||
# tracker can latch them as "penned" rather than spawn fresh
|
# deeper into the pen are dropped — pen posts and rails would
|
||||||
# tracks each scan.
|
# otherwise generate phantom penned tracks.
|
||||||
# Accept detections inside the field, plus a narrow strip
|
|
||||||
# immediately south of the gate to catch sheep mid-crossing
|
|
||||||
# (so they get marked penned via is_penned_position before the
|
|
||||||
# track goes stale). Detections deeper into the pen are
|
|
||||||
# dropped entirely — Webots's pen posts and rails would
|
|
||||||
# otherwise produce a torrent of phantom penned tracks that
|
|
||||||
# the tracker can't keep up with.
|
|
||||||
in_main = (FIELD_X[0] - 0.2 < cx < FIELD_X[1] + 0.2 and
|
in_main = (FIELD_X[0] - 0.2 < cx < FIELD_X[1] + 0.2 and
|
||||||
FIELD_Y[0] - 0.2 < cy < FIELD_Y[1] + 0.2)
|
FIELD_Y[0] - 0.2 < cy < FIELD_Y[1] + 0.2)
|
||||||
in_gate_strip = (PEN_X[0] - 0.2 < cx < PEN_X[1] + 0.2 and
|
in_gate_strip = (PEN_X[0] - 0.2 < cx < PEN_X[1] + 0.2 and
|
||||||
GATE_Y - 1.0 < cy < GATE_Y + 0.2)
|
GATE_Y - 1.0 < cy < GATE_Y + 0.2)
|
||||||
if not (in_main or in_gate_strip):
|
if not (in_main or in_gate_strip):
|
||||||
continue
|
continue
|
||||||
# Known-static-feature filter: gate posts and corner pillars
|
# Known sheep-sized static features.
|
||||||
# show up as sheep-sized clusters but are never sheep.
|
|
||||||
if any(math.hypot(cx - fx, cy - fy) < STATIC_REJECT
|
if any(math.hypot(cx - fx, cy - fy) < STATIC_REJECT
|
||||||
for fx, fy in _STATIC_FEATURES):
|
for fx, fy in _STATIC_FEATURES):
|
||||||
continue
|
continue
|
||||||
# Wall-proximity filter: at oblique scan angles, walls produce
|
# Wall-proximity filter — sheep can't get this close to a wall,
|
||||||
# multiple short clusters because adjacent ray returns are
|
# so detections right at the wall line are structure noise.
|
||||||
# spaced just above GAP_THRESHOLD. Sheep can't get within ~0.3 m
|
|
||||||
# of a wall (the env clips them to FIELD_INSIDE), so anything
|
|
||||||
# right at the wall line is structure noise.
|
|
||||||
near_field_wall = (
|
near_field_wall = (
|
||||||
cx > FIELD_X[1] - WALL_REJECT or cx < FIELD_X[0] + WALL_REJECT or
|
cx > FIELD_X[1] - WALL_REJECT or cx < FIELD_X[0] + WALL_REJECT or
|
||||||
cy > FIELD_Y[1] - WALL_REJECT or
|
cy > FIELD_Y[1] - WALL_REJECT or
|
||||||
|
|||||||
@@ -1,16 +1,12 @@
|
|||||||
"""Fast 2D LiDAR simulator for the Gymnasium env.
|
"""Fast 2D LiDAR simulator for the Gymnasium env.
|
||||||
|
|
||||||
Raycasts against:
|
Raycasts against sheep (discs) and static world geometry (axis-aligned
|
||||||
* **Sheep** — discs of radius ``SHEEP_RADIUS``.
|
walls + gate posts) so the env reproduces the false-positive cluster
|
||||||
* **Static world geometry** — axis-aligned wall segments and gate
|
distribution Webots produces from real 3D geometry.
|
||||||
posts taken from ``worlds/field.wbt``. Without these, demos
|
|
||||||
collected in-env would never include the false-positive clusters
|
|
||||||
Webots produces from the stone walls and gate-post boxes, and the
|
|
||||||
BC student trained on those demos collapses on deployment.
|
|
||||||
|
|
||||||
Returns a range array matching the Webots Lidar device on the dog
|
Returns a range array matching the Webots Lidar device:
|
||||||
(see ``protos/ShepherdDog.proto``: 180 rays, 140° FOV centred on
|
180 rays, 140° FOV centred on forward, 12 m max range, 5 mm noise.
|
||||||
forward, 12 m max range, 5 mm noise).
|
See ``protos/ShepherdDog.proto``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -26,19 +22,13 @@ LIDAR_FOV = 2.44 # rad ≈ 140°
|
|||||||
LIDAR_MAX_RANGE = 12.0
|
LIDAR_MAX_RANGE = 12.0
|
||||||
LIDAR_NOISE = 0.005 # m, gaussian std
|
LIDAR_NOISE = 0.005 # m, gaussian std
|
||||||
|
|
||||||
# Sheep modelled as a vertical cylinder; this is the horizontal-section
|
# Sheep cross-section in the LiDAR plane (horizontal cylinder approx).
|
||||||
# radius the LiDAR plane intersects. Tuned to the proto sheep (~0.45 m
|
|
||||||
# body length). The exact value is not load-bearing — the perception
|
|
||||||
# clusterer is range-tolerant.
|
|
||||||
SHEEP_RADIUS = 0.30
|
SHEEP_RADIUS = 0.30
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# --- Static world geometry — mirrors worlds/field.wbt ---
|
||||||
# Static world geometry — must match worlds/field.wbt
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Vertical walls: (x, y_min, y_max). Field east/west walls and the two
|
# Vertical walls: (x, y_min, y_max).
|
||||||
# pen side walls are visible through the open gate.
|
|
||||||
_VERTICAL_WALLS = (
|
_VERTICAL_WALLS = (
|
||||||
( 15.0, -15.0, 15.0), # field east
|
( 15.0, -15.0, 15.0), # field east
|
||||||
(-15.0, -15.0, 15.0), # field west
|
(-15.0, -15.0, 15.0), # field west
|
||||||
@@ -46,8 +36,7 @@ _VERTICAL_WALLS = (
|
|||||||
( 13.0, -22.0, -15.0), # pen east
|
( 13.0, -22.0, -15.0), # pen east
|
||||||
)
|
)
|
||||||
|
|
||||||
# Horizontal walls: (y, x_min, x_max). South wall is split by the 3 m
|
# Horizontal walls: (y, x_min, x_max). South wall has a 3 m gap at the gate.
|
||||||
# gate at x ∈ [10, 13]; the pen south wall closes the back of the pen.
|
|
||||||
_HORIZONTAL_WALLS = (
|
_HORIZONTAL_WALLS = (
|
||||||
( 15.0, -15.0, 15.0), # field north
|
( 15.0, -15.0, 15.0), # field north
|
||||||
(-15.0, -15.0, 10.0), # field south-west of gate
|
(-15.0, -15.0, 10.0), # field south-west of gate
|
||||||
@@ -55,31 +44,23 @@ _HORIZONTAL_WALLS = (
|
|||||||
(-22.0, 10.0, 13.0), # pen south
|
(-22.0, 10.0, 13.0), # pen south
|
||||||
)
|
)
|
||||||
|
|
||||||
# Gate posts and field corner pillars treated as vertical cylinders at
|
# Gate posts + field corner pillars, treated as discs at LiDAR height.
|
||||||
# LiDAR height. Radius 0.25 m comes from the 0.44 × 0.44 m boxes in the
|
|
||||||
# wbt — close enough to a circular cross-section for this purpose.
|
|
||||||
_POSTS_XY = np.array([
|
_POSTS_XY = np.array([
|
||||||
( 10.0, -15.0), # west gate post
|
( 10.0, -15.0), ( 13.0, -15.0),
|
||||||
( 13.0, -15.0), # east gate post
|
( 15.0, 15.0), ( 15.0, -15.0),
|
||||||
( 15.0, 15.0), # NE field corner
|
(-15.0, 15.0), (-15.0, -15.0),
|
||||||
( 15.0, -15.0), # SE field corner
|
|
||||||
(-15.0, 15.0), # NW field corner
|
|
||||||
(-15.0, -15.0), # SW field corner
|
|
||||||
], dtype=np.float64)
|
], dtype=np.float64)
|
||||||
POST_RADIUS = 0.25
|
POST_RADIUS = 0.25
|
||||||
|
|
||||||
|
|
||||||
def ray_angles(n: int = LIDAR_N_RAYS, fov: float = LIDAR_FOV) -> np.ndarray:
|
def ray_angles(n: int = LIDAR_N_RAYS, fov: float = LIDAR_FOV) -> np.ndarray:
|
||||||
"""Local-frame ray angles, sweeping from +fov/2 to -fov/2.
|
"""Local-frame ray angles, CCW from forward, sweeping +fov/2 → -fov/2.
|
||||||
|
|
||||||
Convention: angle is measured CCW from the dog's forward axis. Ray 0
|
Matches Webots' default Lidar sweep direction.
|
||||||
points to the dog's left, last ray to the right. Webots' default
|
|
||||||
Lidar sweep matches this.
|
|
||||||
"""
|
"""
|
||||||
return np.linspace(fov / 2.0, -fov / 2.0, n, dtype=np.float64)
|
return np.linspace(fov / 2.0, -fov / 2.0, n, dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
# Cached so we don't rebuild every step.
|
|
||||||
_ANGLES = ray_angles()
|
_ANGLES = ray_angles()
|
||||||
_COS = np.cos(_ANGLES)
|
_COS = np.cos(_ANGLES)
|
||||||
_SIN = np.sin(_ANGLES)
|
_SIN = np.sin(_ANGLES)
|
||||||
@@ -88,13 +69,7 @@ _SIN = np.sin(_ANGLES)
|
|||||||
def _raycast_static(
|
def _raycast_static(
|
||||||
ox: float, oy: float, cos_w: np.ndarray, sin_w: np.ndarray,
|
ox: float, oy: float, cos_w: np.ndarray, sin_w: np.ndarray,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Per-ray distance to nearest wall or post hit (∞ if none).
|
"""Per-ray distance to the nearest wall or post hit (∞ if none)."""
|
||||||
|
|
||||||
Walls are axis-aligned line segments; for each ray we compute t at
|
|
||||||
which it crosses the wall's constant-coord plane and check the
|
|
||||||
other coord lies in the segment. Posts are circles; same disc
|
|
||||||
intersection as for sheep.
|
|
||||||
"""
|
|
||||||
n_rays = cos_w.shape[0]
|
n_rays = cos_w.shape[0]
|
||||||
best = np.full(n_rays, np.inf, dtype=np.float64)
|
best = np.full(n_rays, np.inf, dtype=np.float64)
|
||||||
|
|
||||||
@@ -144,10 +119,7 @@ def simulate_scan(
|
|||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Return a (N,) float32 range array. No-hit entries equal ``max_range``.
|
"""Return a (N,) float32 range array. No-hit entries equal ``max_range``.
|
||||||
|
|
||||||
``sheep_xy`` is the list of (x, y) world positions of every sheep in
|
``sheep_xy`` is every sheep (penned or active) in the scene.
|
||||||
the scene (penned and active). Static world geometry (walls and
|
|
||||||
posts) is also raycast so demos contain the same false-positive
|
|
||||||
clusters Webots produces.
|
|
||||||
"""
|
"""
|
||||||
n_rays = _ANGLES.shape[0]
|
n_rays = _ANGLES.shape[0]
|
||||||
|
|
||||||
@@ -172,8 +144,7 @@ def simulate_scan(
|
|||||||
nearest = candidate.min(axis=0)
|
nearest = candidate.min(axis=0)
|
||||||
np.minimum(best, nearest, out=best)
|
np.minimum(best, nearest, out=best)
|
||||||
|
|
||||||
# Clip to LIDAR_MAX_RANGE; entries that never got a hit stay at inf
|
# Entries with no hit stay at inf → clipped to max_range, matching Webots.
|
||||||
# → clipped down to max_range like the real Webots device.
|
|
||||||
ranges = np.minimum(best, max_range).astype(np.float32)
|
ranges = np.minimum(best, max_range).astype(np.float32)
|
||||||
return _add_noise(ranges, noise, rng, max_range)
|
return _add_noise(ranges, noise, rng, max_range)
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +1,25 @@
|
|||||||
"""Observation builder for the shepherd dog policy.
|
"""Observation builder for the shepherd-dog policy.
|
||||||
|
|
||||||
Order-invariant 32-D feature vector — the policy generalises across
|
Order-invariant 32-D feature vector. Sheep never appear by index in
|
||||||
flock sizes 1..MAX_SHEEP because individual sheep coordinates never
|
the observation, only via summary statistics, a polar histogram, and
|
||||||
appear in the observation by index, only summary statistics, a polar
|
two "named" channels (closest-to-pen, rearmost-from-pen) — so the
|
||||||
histogram, and two "named" sheep (closest-to-pen and rearmost-from-pen).
|
policy generalises across flock sizes 1..MAX_SHEEP.
|
||||||
|
|
||||||
The two named sheep matter for the sequential-driving teacher: it
|
|
||||||
targets the closest-to-pen sheep specifically, so the policy needs
|
|
||||||
that channel to mimic the teacher.
|
|
||||||
|
|
||||||
Layout (all components normalised so values stay roughly in [-1, 1]):
|
Layout (all components normalised so values stay roughly in [-1, 1]):
|
||||||
|
|
||||||
idx field
|
idx field
|
||||||
----- ----------------------------------------------------------
|
----- ----------------------------------------------------------
|
||||||
0..3 dog pose: x/15, y/15, cos(heading), sin(heading)
|
0..3 dog pose: x/15, y/15, cos(h), sin(h)
|
||||||
4..5 active-sheep CoM x/15, y/15
|
4..5 active-sheep CoM x/15, y/15
|
||||||
6..8 flock dispersion: max-radius/15, std_x/15, std_y/15
|
6..8 flock dispersion: max_radius/15, std_x/15, std_y/15
|
||||||
9..11 vector dog→CoM: dx/30, dy/30, dist/30
|
9..11 dog → CoM: dx/30, dy/30, dist/30
|
||||||
12..14 vector dog→pen-entry: dx/30, dy/30, dist/30
|
12..14 dog → pen entry: dx/30, dy/30, dist/30
|
||||||
15..16 vector furthest-sheep→CoM: dx/15, dy/15
|
15..16 furthest sheep → CoM: dx/15, dy/15
|
||||||
17..18 min sheep-to-wall, min dog-to-wall (both /15)
|
17..18 min sheep-to-wall, min dog-to-wall (both /15)
|
||||||
19 active-sheep count / MAX_SHEEP
|
19 active sheep count / MAX_SHEEP
|
||||||
20..27 8-bin polar histogram of active sheep around the dog,
|
20..27 8-bin polar histogram of active sheep in the dog's body frame
|
||||||
rotation-aware (binned in dog-relative frame), normalised
|
28..29 dog → closest-to-pen sheep: dx/15, dy/15
|
||||||
so the bins sum to 1.
|
30..31 dog → rearmost (furthest-from-pen) sheep: dx/15, dy/15
|
||||||
28..29 vector dog→closest-to-pen sheep: dx/15, dy/15
|
|
||||||
30..31 vector dog→rearmost (furthest-from-pen) sheep: dx/15, dy/15
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@@ -68,7 +62,6 @@ def build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list,
|
|||||||
obs[14] = math.hypot(pdx0, pdy0) / 30.0
|
obs[14] = math.hypot(pdx0, pdy0) / 30.0
|
||||||
|
|
||||||
if n == 0:
|
if n == 0:
|
||||||
# All sheep penned — terminal observation.
|
|
||||||
obs[19] = 0.0
|
obs[19] = 0.0
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
@@ -110,7 +103,7 @@ def build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list,
|
|||||||
obs[18] = float(min_dog_wall) / 15.0
|
obs[18] = float(min_dog_wall) / 15.0
|
||||||
obs[19] = n / n_max
|
obs[19] = n / n_max
|
||||||
|
|
||||||
# 8-bin polar histogram in the dog's body frame.
|
# Polar histogram in the dog's body frame.
|
||||||
rel_dx = arr[:, 0] - dog_x
|
rel_dx = arr[:, 0] - dog_x
|
||||||
rel_dy = arr[:, 1] - dog_y
|
rel_dy = arr[:, 1] - dog_y
|
||||||
angles = np.arctan2(rel_dy, rel_dx) - dog_heading
|
angles = np.arctan2(rel_dy, rel_dx) - dog_heading
|
||||||
@@ -121,11 +114,9 @@ def build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list,
|
|||||||
hist /= max(1, n)
|
hist /= max(1, n)
|
||||||
obs[20:28] = hist
|
obs[20:28] = hist
|
||||||
|
|
||||||
# Closest-to-pen sheep (the sequential teacher's target) and rearmost
|
# Closest-to-pen and rearmost (furthest-from-pen) sheep. Without
|
||||||
# (furthest-from-pen, the natural "next target" once the closest is
|
# these named channels the obs cannot uniquely identify which sheep
|
||||||
# penned). Both expressed as offset from dog. These two channels make
|
# the teacher is steering toward, and BC fails to mimic it.
|
||||||
# BC tractable — without them the obs doesn't uniquely identify which
|
|
||||||
# sheep the teacher is steering toward.
|
|
||||||
pen_dists = np.hypot(arr[:, 0] - PEN_ENTRY[0], arr[:, 1] - PEN_ENTRY[1])
|
pen_dists = np.hypot(arr[:, 0] - PEN_ENTRY[0], arr[:, 1] - PEN_ENTRY[1])
|
||||||
closest_idx = int(np.argmin(pen_dists))
|
closest_idx = int(np.argmin(pen_dists))
|
||||||
rearmost_idx = int(np.argmax(pen_dists))
|
rearmost_idx = int(np.argmax(pen_dists))
|
||||||
@@ -1,25 +1,14 @@
|
|||||||
"""Multi-target tracker for LiDAR-detected sheep.
|
"""Multi-target tracker for LiDAR-detected sheep.
|
||||||
|
|
||||||
Greedy nearest-neighbour data association (with a distance gate) across
|
Greedy nearest-neighbour data association across frames, with a wider
|
||||||
frames, plus a memory of last-seen positions for tracks that fall out
|
re-acquisition gate for stale tracks (sheep flee during occlusion and
|
||||||
of the dog's FOV. Output is a ``{name: (x, y)}`` dict shaped exactly
|
reappear off-position), plus memory of last-seen positions for sheep
|
||||||
like the receiver-based ``sheep_positions`` used previously by the
|
out of FOV. Output is ``{name: (x, y)}`` — Strömbom / Sequential
|
||||||
Webots controller and by the env, so Strömbom and Sequential can
|
consume it directly.
|
||||||
consume it unchanged.
|
|
||||||
|
|
||||||
Penned-detection heuristic
|
A track is marked penned once its estimated position crosses the gate
|
||||||
--------------------------
|
plane south (``is_penned_position``). Penned tracks are excluded from
|
||||||
Two ways a track is marked penned:
|
``get_positions`` and kept indefinitely.
|
||||||
1. Its current estimated position is south of the gate plane and
|
|
||||||
within the gate column (the ``is_penned_position`` test the env
|
|
||||||
already uses on ground truth).
|
|
||||||
2. It hasn't been observed for ``STALE_STEPS`` and its last-seen
|
|
||||||
position was inside the gate-approach band — the dog's LiDAR can
|
|
||||||
only see ~2 m into the pen through the open gate, so a sheep
|
|
||||||
that disappeared near the entry has almost certainly entered.
|
|
||||||
|
|
||||||
Tracks marked penned are excluded from ``get_positions()`` (which is
|
|
||||||
what Strömbom consumes), matching the prior receiver-based behaviour.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -29,26 +18,22 @@ import math
|
|||||||
from herding.world.geometry import MAX_SHEEP, in_pen, is_penned_position
|
from herding.world.geometry import MAX_SHEEP, in_pen, is_penned_position
|
||||||
|
|
||||||
|
|
||||||
GATE_M = 2.5 # m — primary NN gate (recent tracks)
|
GATE_M = 2.5 # m — primary NN gate (recently observed tracks)
|
||||||
REACQUIRE_GATE_M = 4.5 # m — wider gate for re-acquiring stale tracks (sheep moved during occlusion)
|
REACQUIRE_GATE_M = 4.5 # m — wider gate for re-binding stale tracks
|
||||||
REACQUIRE_MIN_AGE = 20 # steps — only rebind via the wide gate if the track has been stale for this long
|
REACQUIRE_MIN_AGE = 20 # steps — track must be this stale to use the wider gate
|
||||||
PENNED_GATE_M = 4.0 # m — wide gate for matching against already-penned tracks; the pen is small (3×7 m) so duplicates are easy without it
|
PENNED_GATE_M = 4.0 # m — gate for matching detections to existing penned tracks
|
||||||
FORGET_STEPS = 200 # ~3.2 s — delete stale active tracks; tighter than 5 s to limit phantoms but long enough to bridge typical FOV gaps
|
FORGET_STEPS = 200 # ~3.2 s — delete stale active tracks (penned ones kept forever)
|
||||||
MAX_ACTIVE_TRACKS = MAX_SHEEP # hard cap to the worst-case real flock size
|
MAX_ACTIVE_TRACKS = MAX_SHEEP
|
||||||
# Penned tracks are never forgotten: sheep don't leave the pen, and
|
|
||||||
# losing the track makes the counter oscillate as the same sheep gets
|
|
||||||
# re-detected and counted multiple times.
|
|
||||||
|
|
||||||
|
|
||||||
class SheepTracker:
|
class SheepTracker:
|
||||||
"""Online tracker with NN association and a forgetful memory.
|
"""Online tracker with NN association and forgetful memory.
|
||||||
|
|
||||||
Each track stores ``(x, y, last_seen_step, penned)``.
|
Each track stores ``(x, y, last_seen_step, penned)``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, gate: float = GATE_M):
|
def __init__(self, gate: float = GATE_M):
|
||||||
self.gate = gate
|
self.gate = gate
|
||||||
# tid → (x, y, last_seen_step, penned)
|
|
||||||
self._tracks: dict[int, tuple[float, float, int, bool]] = {}
|
self._tracks: dict[int, tuple[float, float, int, bool]] = {}
|
||||||
self._next_id = 0
|
self._next_id = 0
|
||||||
self.step = 0
|
self.step = 0
|
||||||
@@ -58,9 +43,6 @@ class SheepTracker:
|
|||||||
self._next_id = 0
|
self._next_id = 0
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Update
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
def update(self, detections: list[tuple[float, float]]) -> dict[str, tuple[float, float]]:
|
def update(self, detections: list[tuple[float, float]]) -> dict[str, tuple[float, float]]:
|
||||||
"""Fold a new set of detections in and return active positions."""
|
"""Fold a new set of detections in and return active positions."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
@@ -68,9 +50,9 @@ class SheepTracker:
|
|||||||
det_used: set[int] = set()
|
det_used: set[int] = set()
|
||||||
updated_tids: set[int] = set()
|
updated_tids: set[int] = set()
|
||||||
|
|
||||||
# Pass 1: match against ACTIVE tracks first (oldest-seen-first so
|
# Pass 1 — match active tracks within the primary gate. Oldest-
|
||||||
# a re-emerging long-lost sheep grabs its old ID before a fresh
|
# seen tracks bind first so a re-emerging long-lost sheep keeps
|
||||||
# neighbour does).
|
# its old ID instead of being grabbed by a fresh neighbour.
|
||||||
active_tids = [tid for tid, t in self._tracks.items() if not t[3]]
|
active_tids = [tid for tid, t in self._tracks.items() if not t[3]]
|
||||||
active_tids.sort(key=lambda tid: self._tracks[tid][2])
|
active_tids.sort(key=lambda tid: self._tracks[tid][2])
|
||||||
for tid in active_tids:
|
for tid in active_tids:
|
||||||
@@ -89,12 +71,10 @@ class SheepTracker:
|
|||||||
det_used.add(best_j)
|
det_used.add(best_j)
|
||||||
updated_tids.add(tid)
|
updated_tids.add(tid)
|
||||||
|
|
||||||
# Pass 1b: re-acquisition with a wider gate for tracks that have
|
# Pass 1b — re-acquisition. Sheep flee at ~0.6 m/s, so over a
|
||||||
# been stale for ≥ REACQUIRE_MIN_AGE steps. Sheep flee at
|
# 1–2 s occlusion the same sheep may reappear outside the primary
|
||||||
# ~0.6 m/s; over a 1–2 s occlusion (dog rotating or driving)
|
# gate. Allow rebinding within a wider gate for stale-enough
|
||||||
# they move enough that a fresh detection lies outside the
|
# tracks; otherwise phantom tracks accumulate and corrupt CoM.
|
||||||
# primary GATE_M but is still clearly the same sheep. Without
|
|
||||||
# this, phantom tracks accumulate and corrupt the CoM.
|
|
||||||
for tid in active_tids:
|
for tid in active_tids:
|
||||||
if tid in updated_tids:
|
if tid in updated_tids:
|
||||||
continue
|
continue
|
||||||
@@ -115,10 +95,7 @@ class SheepTracker:
|
|||||||
det_used.add(best_j)
|
det_used.add(best_j)
|
||||||
updated_tids.add(tid)
|
updated_tids.add(tid)
|
||||||
|
|
||||||
# Pass 2: match remaining detections against PENNED tracks with
|
# Pass 2 — match remaining detections to penned tracks.
|
||||||
# a tighter gate. Without this, every frame near the gate spawns
|
|
||||||
# a fresh penned track for the same sheep, which under a long
|
|
||||||
# Webots run leads to thousands of phantom penned tracks.
|
|
||||||
penned_tids = [tid for tid, t in self._tracks.items() if t[3]]
|
penned_tids = [tid for tid, t in self._tracks.items() if t[3]]
|
||||||
for tid in penned_tids:
|
for tid in penned_tids:
|
||||||
tx, ty, _, _ = self._tracks[tid]
|
tx, ty, _, _ = self._tracks[tid]
|
||||||
@@ -135,9 +112,8 @@ class SheepTracker:
|
|||||||
self._tracks[tid] = (dx, dy, self.step, True)
|
self._tracks[tid] = (dx, dy, self.step, True)
|
||||||
det_used.add(best_j)
|
det_used.add(best_j)
|
||||||
|
|
||||||
# Unmatched detections → new tracks. A detection that is already
|
# Spawn new tracks for unmatched detections. Born "penned" if
|
||||||
# inside the pen is born "penned" so we don't accumulate active
|
# the detection already sits inside the pen geometry.
|
||||||
# tracks for sheep that arrived in the pen during occlusion.
|
|
||||||
for j, (dx, dy) in enumerate(detections):
|
for j, (dx, dy) in enumerate(detections):
|
||||||
if j in det_used:
|
if j in det_used:
|
||||||
continue
|
continue
|
||||||
@@ -145,44 +121,32 @@ class SheepTracker:
|
|||||||
self._tracks[self._next_id] = (dx, dy, self.step, penned)
|
self._tracks[self._next_id] = (dx, dy, self.step, penned)
|
||||||
self._next_id += 1
|
self._next_id += 1
|
||||||
|
|
||||||
# Promote active tracks to penned ONLY by geometric position
|
# Promote active tracks whose current estimate crosses the gate.
|
||||||
# (sheep is in the pen column south of the gate). The previous
|
|
||||||
# "stale + near gate" heuristic was firing on ordinary occlusion
|
|
||||||
# near the gate and creating phantom penned tracks.
|
|
||||||
for tid, (tx, ty, last, penned) in list(self._tracks.items()):
|
for tid, (tx, ty, last, penned) in list(self._tracks.items()):
|
||||||
if penned:
|
if penned:
|
||||||
continue
|
continue
|
||||||
if is_penned_position(tx, ty):
|
if is_penned_position(tx, ty):
|
||||||
self._tracks[tid] = (tx, ty, last, True)
|
self._tracks[tid] = (tx, ty, last, True)
|
||||||
|
|
||||||
# Forget stale ACTIVE tracks after FORGET_STEPS. Penned tracks
|
# Forget stale active tracks; penned tracks live forever.
|
||||||
# are kept indefinitely — sheep can't escape the pen, so once a
|
|
||||||
# track is marked penned, that sheep is permanently penned.
|
|
||||||
for tid, (tx, ty, last, penned) in list(self._tracks.items()):
|
for tid, (tx, ty, last, penned) in list(self._tracks.items()):
|
||||||
if penned:
|
if penned:
|
||||||
continue
|
continue
|
||||||
if (self.step - last) > FORGET_STEPS:
|
if (self.step - last) > FORGET_STEPS:
|
||||||
del self._tracks[tid]
|
del self._tracks[tid]
|
||||||
|
|
||||||
# Hard cap on the active set. If we somehow have more than
|
# Hard cap on the active set — drop the oldest-seen overflow.
|
||||||
# MAX_ACTIVE_TRACKS active tracks, drop the oldest-seen ones
|
|
||||||
# first — they are most likely false positives from world
|
|
||||||
# geometry (walls, gate posts) the env's raycaster doesn't
|
|
||||||
# model, and a bloated active set wrecks the downstream CoM.
|
|
||||||
active = [(tid, last) for tid, (_, _, last, p) in self._tracks.items()
|
active = [(tid, last) for tid, (_, _, last, p) in self._tracks.items()
|
||||||
if not p]
|
if not p]
|
||||||
if len(active) > MAX_ACTIVE_TRACKS:
|
if len(active) > MAX_ACTIVE_TRACKS:
|
||||||
active.sort(key=lambda kv: kv[1]) # oldest-seen first
|
active.sort(key=lambda kv: kv[1])
|
||||||
for tid, _ in active[: len(active) - MAX_ACTIVE_TRACKS]:
|
for tid, _ in active[: len(active) - MAX_ACTIVE_TRACKS]:
|
||||||
del self._tracks[tid]
|
del self._tracks[tid]
|
||||||
|
|
||||||
return self.get_positions()
|
return self.get_positions()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Outputs
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
def get_positions(self) -> dict[str, tuple[float, float]]:
|
def get_positions(self) -> dict[str, tuple[float, float]]:
|
||||||
"""Active (not-yet-penned) tracks. Same shape as receiver dict."""
|
"""Active (not-penned) tracks as a ``{name: (x, y)}`` dict."""
|
||||||
return {f"t{tid}": (x, y)
|
return {f"t{tid}": (x, y)
|
||||||
for tid, (x, y, _, penned) in self._tracks.items()
|
for tid, (x, y, _, penned) in self._tracks.items()
|
||||||
if not penned}
|
if not penned}
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
"""Differential-drive kinematics matching the Webots robot specs.
|
"""Differential-drive kinematics, shared by the env and Webots controllers.
|
||||||
|
|
||||||
The Webots controllers and the training env both use these helpers so the
|
First-order rigid-body model — no slip, wheel-accel limits, or contact
|
||||||
sim and the real (Webots) physics agree to first order. They do not model
|
forces. Webots' ODE physics handles those at inference; the env stays
|
||||||
slip, wheel acceleration limits, or contact forces — Webots does that for
|
close enough to first order that a policy trained here transfers.
|
||||||
us at inference time. The training env has to be close enough that a
|
|
||||||
policy trained against this kinematic model still works when handed off
|
|
||||||
to ODE physics.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@@ -34,10 +31,9 @@ def kinematics_step(x, y, h, w_left, w_right, wheel_radius, wheel_base, dt):
|
|||||||
|
|
||||||
def velocity_to_wheels(vx, vy, h, max_linear, wheel_radius, max_wheel_omega,
|
def velocity_to_wheels(vx, vy, h, max_linear, wheel_radius, max_wheel_omega,
|
||||||
k_turn=4.0):
|
k_turn=4.0):
|
||||||
"""Convert a desired (vx, vy) intent in [-1, 1]^2 to wheel speeds.
|
"""Convert a desired (vx, vy) intent in [-1, 1]² to wheel speeds.
|
||||||
|
|
||||||
Mirrors ``drive_action`` in controllers/shepherd_dog/shepherd_dog.py:
|
Forward speed scales by ``cos(err)`` (clamped to ±90°); a P
|
||||||
forward speed scales by ``cos(err)`` (clamped to ±90°), and a P
|
|
||||||
controller on heading error contributes the wheel-rate differential.
|
controller on heading error contributes the wheel-rate differential.
|
||||||
"""
|
"""
|
||||||
speed_ms = math.hypot(vx, vy) * max_linear
|
speed_ms = math.hypot(vx, vy) * max_linear
|
||||||
@@ -56,12 +52,7 @@ def velocity_to_wheels(vx, vy, h, max_linear, wheel_radius, max_wheel_omega,
|
|||||||
|
|
||||||
def heading_speed_to_wheels(heading, speed_motor, h, max_wheel_omega,
|
def heading_speed_to_wheels(heading, speed_motor, h, max_wheel_omega,
|
||||||
k_turn=4.0):
|
k_turn=4.0):
|
||||||
"""Sheep variant: speed already expressed in motor (wheel rad/s) units.
|
"""Sheep variant: speed in wheel rad/s, target as a heading angle."""
|
||||||
|
|
||||||
Matches the existing sheep controller (``controllers/sheep/sheep.py``)
|
|
||||||
where ``speed = max(WANDER_SPEED, min(FLEE_SPEED, mag * 3.0))`` and
|
|
||||||
these constants are wheel angular velocities, not linear m/s.
|
|
||||||
"""
|
|
||||||
err = math.atan2(math.sin(heading - h), math.cos(heading - h))
|
err = math.atan2(math.sin(heading - h), math.cos(heading - h))
|
||||||
fwd = max(0.0, math.cos(err)) * speed_motor
|
fwd = max(0.0, math.cos(err)) * speed_motor
|
||||||
turn = k_turn * err
|
turn = k_turn * err
|
||||||
|
|||||||
@@ -1,24 +1,19 @@
|
|||||||
"""Sheep flocking dynamics — Strömbom 2014 / Reynolds 1987 hybrid.
|
"""Sheep flocking dynamics — Strömbom 2014 / Reynolds 1987.
|
||||||
|
|
||||||
This is the per-sheep behavioural step used both by the Webots sheep
|
Per-sheep behavioural step used by both the Webots sheep controller
|
||||||
controller (scalar, one sheep at a time) and by the training environment
|
and the training environment. Each step a force stack is summed:
|
||||||
(loop over sheep).
|
|
||||||
|
|
||||||
Model
|
|
||||||
-----
|
|
||||||
The force stack each step (summed → heading + speed):
|
|
||||||
|
|
||||||
flee — quadratic ramp away from dog within FLEE_DIST
|
flee — quadratic ramp away from dog within FLEE_DIST
|
||||||
(Strömbom 2014 §2.1, term ρa)
|
(Strömbom 2014, term ρa)
|
||||||
cohesion — drift toward local centre of mass of peers within
|
cohesion — drift toward local centre of mass of peers within
|
||||||
COHESION_DIST (Strömbom 2014 §2.1, term c).
|
COHESION_DIST (Strömbom 2014, term c). Weight is
|
||||||
Weight is **higher when fleeing** — modelling the
|
higher while fleeing — fear-induced cohesion.
|
||||||
"safety in numbers" / predator-confusion effect
|
|
||||||
Strömbom 2014 describes as fear-induced cohesion.
|
|
||||||
separation — short-range inverse-distance repulsion from peers
|
separation — short-range inverse-distance repulsion from peers
|
||||||
(Strömbom 2014 §2.1, term α; Reynolds 1987)
|
(Strömbom 2014 term α; Reynolds 1987)
|
||||||
wander — small persistent drift for natural idle motion
|
wander — small persistent drift (Strömbom 2014 noise term ε)
|
||||||
(Strömbom 2014 §2.1, noise term ε)
|
|
||||||
|
Walls, the south-wall gate column, and in-pen containment are
|
||||||
|
environment-specific additions for the fenced Webots field.
|
||||||
|
|
||||||
References
|
References
|
||||||
----------
|
----------
|
||||||
@@ -26,26 +21,6 @@ References
|
|||||||
for herding autonomous, interacting agents." J R Soc Interface 11.
|
for herding autonomous, interacting agents." J R Soc Interface 11.
|
||||||
- Reynolds (1987). "Flocks, herds and schools: A distributed
|
- Reynolds (1987). "Flocks, herds and schools: A distributed
|
||||||
behavioural model." SIGGRAPH '87.
|
behavioural model." SIGGRAPH '87.
|
||||||
|
|
||||||
Environment-specific adaptations
|
|
||||||
--------------------------------
|
|
||||||
The original Strömbom model assumes an open field. Our scenario adds:
|
|
||||||
|
|
||||||
* Field walls — soft repulsion within ``WALL_MARGIN`` plus a hard
|
|
||||||
escape band when inside ``WALL_HARD_MARGIN``. Necessary because the
|
|
||||||
Webots field is fenced (30 m square enclosure).
|
|
||||||
* Gate column — the south wall has a 3 m gap at x ∈ [10, 13]; sheep
|
|
||||||
pass through it freely (no wall force inside the column).
|
|
||||||
* Penned containment — once a sheep crosses the gate plane south
|
|
||||||
(``geometry.is_penned_position``), the caller flags ``penned=True``
|
|
||||||
and we switch to in-pen wall-bounce + jitter. Sheep do not exit the
|
|
||||||
pen on their own. This is a hard sim constraint, not a behavioural
|
|
||||||
claim about real sheep.
|
|
||||||
|
|
||||||
Parameter tuning (cohesion weight 3× while fleeing) was chosen so the
|
|
||||||
flock survives passage through the 3 m gate without fragmenting — this
|
|
||||||
is a defensible engineering adaptation of Strömbom's qualitative
|
|
||||||
"fear-induced cohesion" to our gate width.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@@ -57,9 +32,7 @@ from herding.world.geometry import (
|
|||||||
GATE_X,
|
GATE_X,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Speed and force constants ---
|
# Speeds are in wheel rad/s (motor units); m/s = speed * SHEEP_WHEEL_RADIUS.
|
||||||
# All speeds here are in wheel rad/s (motor units), matching the existing
|
|
||||||
# sheep controller. Conversion to m/s = speed * SHEEP_WHEEL_RADIUS.
|
|
||||||
MAX_SPEED = 22.0
|
MAX_SPEED = 22.0
|
||||||
FLEE_SPEED = 20.0
|
FLEE_SPEED = 20.0
|
||||||
WANDER_SPEED = 3.0
|
WANDER_SPEED = 3.0
|
||||||
@@ -70,7 +43,7 @@ WALL_HARD_GAIN = 50.0
|
|||||||
|
|
||||||
FLEE_DIST = 7.0
|
FLEE_DIST = 7.0
|
||||||
SEPARATION_DIST = 2.5
|
SEPARATION_DIST = 2.5
|
||||||
COHESION_DIST = 12.0 # was 8.0 — wider engagement so far-flung sheep are pulled in
|
COHESION_DIST = 12.0
|
||||||
|
|
||||||
PEN_MARGIN = 0.8
|
PEN_MARGIN = 0.8
|
||||||
|
|
||||||
@@ -85,21 +58,17 @@ def _peers_iter(peers):
|
|||||||
def compute_heading_speed(x, y, penned, dog_xy, peers, wander_angle, rng=None):
|
def compute_heading_speed(x, y, penned, dog_xy, peers, wander_angle, rng=None):
|
||||||
"""Return ``(heading, speed, new_wander_angle)`` for one sheep step.
|
"""Return ``(heading, speed, new_wander_angle)`` for one sheep step.
|
||||||
|
|
||||||
``speed`` is in wheel rad/s (motor units), bounded by ``[WANDER_SPEED,
|
``speed`` is in wheel rad/s, bounded by ``[WANDER_SPEED, FLEE_SPEED]``.
|
||||||
FLEE_SPEED]``. ``heading`` is the world-frame target heading the sheep
|
``heading`` is the world-frame target heading (atan2 convention).
|
||||||
should aim for (atan2 convention).
|
``rng`` is an optional ``random.Random`` used for wander jitter; if
|
||||||
|
``None`` uses the module's global ``random``.
|
||||||
``rng`` is an optional ``random.Random``-compatible object used for
|
|
||||||
the wander-jitter. If ``None``, falls back to Python's global module
|
|
||||||
(matches Webots controller usage). Pass an env-owned RNG to make
|
|
||||||
rollouts deterministic given a seed.
|
|
||||||
"""
|
"""
|
||||||
fx, fy = 0.0, 0.0
|
fx, fy = 0.0, 0.0
|
||||||
peer_list = _peers_iter(peers)
|
peer_list = _peers_iter(peers)
|
||||||
rnd = rng if rng is not None else random
|
rnd = rng if rng is not None else random
|
||||||
|
|
||||||
if penned:
|
if penned:
|
||||||
# --- Pen containment: bounce off the four pen walls ---
|
# Pen containment: bounce off all four pen walls.
|
||||||
pm = PEN_MARGIN
|
pm = PEN_MARGIN
|
||||||
if x < PEN_X[0] + pm:
|
if x < PEN_X[0] + pm:
|
||||||
fx += ((PEN_X[0] + pm - x) / pm) * 15.0
|
fx += ((PEN_X[0] + pm - x) / pm) * 15.0
|
||||||
@@ -110,7 +79,7 @@ def compute_heading_speed(x, y, penned, dog_xy, peers, wander_angle, rng=None):
|
|||||||
if y > PEN_Y[1] - pm:
|
if y > PEN_Y[1] - pm:
|
||||||
fy -= ((y - (PEN_Y[1] - pm)) / pm) * 15.0
|
fy -= ((y - (PEN_Y[1] - pm)) / pm) * 15.0
|
||||||
|
|
||||||
# Mild peer separation — penned sheep crowd the corner otherwise.
|
# Mild peer separation so penned sheep don't crowd one corner.
|
||||||
for px, py in peer_list:
|
for px, py in peer_list:
|
||||||
dx, dy = px - x, py - y
|
dx, dy = px - x, py - y
|
||||||
d = math.hypot(dx, dy)
|
d = math.hypot(dx, dy)
|
||||||
@@ -125,7 +94,7 @@ def compute_heading_speed(x, y, penned, dog_xy, peers, wander_angle, rng=None):
|
|||||||
fy += math.sin(wander_angle) * 0.5
|
fy += math.sin(wander_angle) * 0.5
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# --- Free-roaming sheep in the field ---
|
# Free-roaming sheep in the field.
|
||||||
fleeing = False
|
fleeing = False
|
||||||
if dog_xy is not None:
|
if dog_xy is not None:
|
||||||
ddx = dog_xy[0] - x
|
ddx = dog_xy[0] - x
|
||||||
@@ -138,11 +107,9 @@ def compute_heading_speed(x, y, penned, dog_xy, peers, wander_angle, rng=None):
|
|||||||
fx -= (ddx / dist) * s
|
fx -= (ddx / dist) * s
|
||||||
fy -= (ddy / dist) * s
|
fy -= (ddy / dist) * s
|
||||||
|
|
||||||
# Cohesion — drift toward flock CoM (peers within COHESION_DIST).
|
# Cohesion: drift toward the local CoM of peers within
|
||||||
# Cohesion is *stronger* under flee than at rest (the
|
# COHESION_DIST. Stronger while fleeing — fear-induced
|
||||||
# predator-confusion / safety-in-numbers effect — sheep huddle when
|
# cohesion keeps the flock together through the gate.
|
||||||
# threatened). This is what makes shepherding work: the flock stays
|
|
||||||
# as one unit through the narrow gate instead of fragmenting.
|
|
||||||
cx, cy, cn = 0.0, 0.0, 0
|
cx, cy, cn = 0.0, 0.0, 0
|
||||||
for px, py in peer_list:
|
for px, py in peer_list:
|
||||||
d = math.hypot(px - x, py - y)
|
d = math.hypot(px - x, py - y)
|
||||||
@@ -151,12 +118,6 @@ def compute_heading_speed(x, y, penned, dog_xy, peers, wander_angle, rng=None):
|
|||||||
cy += py
|
cy += py
|
||||||
cn += 1
|
cn += 1
|
||||||
if cn > 0:
|
if cn > 0:
|
||||||
# Cohesion needs to dominate flee at close range so the flock
|
|
||||||
# stays glued together when squeezing through the narrow gate.
|
|
||||||
# Flee at 2 m has magnitude ~10; cohesion of w=3.0 with the
|
|
||||||
# peer-CoM 4 m away contributes ~12, so the flock prefers
|
|
||||||
# bunching to dispersing under pressure. This is what makes
|
|
||||||
# canonical Strömbom drive work in our 3 m gate.
|
|
||||||
w = 3.0 if fleeing else 1.0
|
w = 3.0 if fleeing else 1.0
|
||||||
fx += (cx / cn - x) * w
|
fx += (cx / cn - x) * w
|
||||||
fy += (cy / cn - y) * w
|
fy += (cy / cn - y) * w
|
||||||
@@ -170,8 +131,7 @@ def compute_heading_speed(x, y, penned, dog_xy, peers, wander_angle, rng=None):
|
|||||||
fx -= (ddx / d) * push * 2.5
|
fx -= (ddx / d) * push * 2.5
|
||||||
fy -= (ddy / d) * push * 2.5
|
fy -= (ddy / d) * push * 2.5
|
||||||
|
|
||||||
# Wall soft repulsion. The south wall is absent inside the gate
|
# Wall soft repulsion (south wall absent inside the gate column).
|
||||||
# column so sheep can be driven through it by the dog.
|
|
||||||
if x < FIELD_X[0] + WALL_MARGIN:
|
if x < FIELD_X[0] + WALL_MARGIN:
|
||||||
fx += ((FIELD_X[0] + WALL_MARGIN - x) / WALL_MARGIN) * 6.0
|
fx += ((FIELD_X[0] + WALL_MARGIN - x) / WALL_MARGIN) * 6.0
|
||||||
if x > FIELD_X[1] - WALL_MARGIN:
|
if x > FIELD_X[1] - WALL_MARGIN:
|
||||||
@@ -187,7 +147,7 @@ def compute_heading_speed(x, y, penned, dog_xy, peers, wander_angle, rng=None):
|
|||||||
fx += math.cos(wander_angle) * 0.5
|
fx += math.cos(wander_angle) * 0.5
|
||||||
fy += math.sin(wander_angle) * 0.5
|
fy += math.sin(wander_angle) * 0.5
|
||||||
|
|
||||||
# --- Hard escape band — overrides everything when very close to a wall ---
|
# Hard escape band — overrides everything else near a wall.
|
||||||
m, g = WALL_HARD_MARGIN, WALL_HARD_GAIN
|
m, g = WALL_HARD_MARGIN, WALL_HARD_GAIN
|
||||||
if x - FIELD_X[0] < m:
|
if x - FIELD_X[0] < m:
|
||||||
fx = max(fx, g * (1.0 - (x - FIELD_X[0]) / m))
|
fx = max(fx, g * (1.0 - (x - FIELD_X[0]) / m))
|
||||||
@@ -195,7 +155,6 @@ def compute_heading_speed(x, y, penned, dog_xy, peers, wander_angle, rng=None):
|
|||||||
fx = min(fx, -g * (1.0 - (FIELD_X[1] - x) / m))
|
fx = min(fx, -g * (1.0 - (FIELD_X[1] - x) / m))
|
||||||
if FIELD_Y[1] - y < m:
|
if FIELD_Y[1] - y < m:
|
||||||
fy = min(fy, -g * (1.0 - (FIELD_Y[1] - y) / m))
|
fy = min(fy, -g * (1.0 - (FIELD_Y[1] - y) / m))
|
||||||
# South wall hard escape only when not in the gate column and not penned.
|
|
||||||
if (not penned) and (y - FIELD_Y[0] < m) and not (GATE_X[0] <= x <= GATE_X[1]):
|
if (not penned) and (y - FIELD_Y[0] < m) and not (GATE_X[0] <= x <= GATE_X[1]):
|
||||||
fy = max(fy, g * (1.0 - (y - FIELD_Y[0]) / m))
|
fy = max(fy, g * (1.0 - (y - FIELD_Y[0]) / m))
|
||||||
|
|
||||||
|
|||||||
+14
-35
@@ -1,23 +1,15 @@
|
|||||||
"""World geometry and robot specs.
|
"""World geometry and robot specs.
|
||||||
|
|
||||||
All coordinates are in meters. (0, 0) is the centre of the field, +x is
|
Coordinates are metres; (0, 0) is the field centre, +x east, +y north.
|
||||||
east, +y is north. Z is up but unused here. These constants must match
|
These constants mirror ``worlds/field.wbt`` and the proto files — if
|
||||||
``worlds/field.wbt`` and the proto files; if the world changes, change
|
the world changes, this file is the single point of update.
|
||||||
this file and only this file.
|
|
||||||
|
|
||||||
Pen layout (post-refactor)
|
|
||||||
--------------------------
|
|
||||||
The pen is *external* to the field, accessed through a 3 m gate cut into
|
|
||||||
the south stone wall at y = -15. Sheep entering through the gate end up
|
|
||||||
in a fenced rectangle south of the field; the dog stays in the field
|
|
||||||
(soft-limited above DOG_SOUTH_LIMIT during training and inference).
|
|
||||||
|
|
||||||
field +y north
|
field +y north
|
||||||
+-----------+
|
+-----------+
|
||||||
| |
|
| |
|
||||||
| |
|
| |
|
||||||
| ...... |
|
| ...... |
|
||||||
+---||||----+ y = -15 (south wall, gate at x ∈ [10, 13])
|
+---||||----+ y = -15 (south wall, 3 m gate at x ∈ [10, 13])
|
||||||
||||
|
||||
|
||||||
|pen| y ∈ [-22, -15]
|
|pen| y ∈ [-22, -15]
|
||||||
+---+
|
+---+
|
||||||
@@ -25,46 +17,38 @@ in a fenced rectangle south of the field; the dog stays in the field
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
# --- Field (square, stone-walled) ---
|
# Field (square, stone-walled)
|
||||||
FIELD_X = (-15.0, 15.0)
|
FIELD_X = (-15.0, 15.0)
|
||||||
FIELD_Y = (-15.0, 15.0)
|
FIELD_Y = (-15.0, 15.0)
|
||||||
|
|
||||||
# Conservative inside bounds — sheep/dog should not graze the wall.
|
|
||||||
FIELD_INSIDE_MARGIN = 0.5
|
FIELD_INSIDE_MARGIN = 0.5
|
||||||
|
|
||||||
# --- Pen (external, south of the field) ---
|
# Pen (external, south of the field)
|
||||||
PEN_X = (10.0, 13.0)
|
PEN_X = (10.0, 13.0)
|
||||||
PEN_Y = (-22.0, -15.0)
|
PEN_Y = (-22.0, -15.0)
|
||||||
PEN_CENTER = (0.5 * (PEN_X[0] + PEN_X[1]), 0.5 * (PEN_Y[0] + PEN_Y[1]))
|
PEN_CENTER = (0.5 * (PEN_X[0] + PEN_X[1]), 0.5 * (PEN_Y[0] + PEN_Y[1]))
|
||||||
# The point the dog drives the flock toward: the gate centre on the field side.
|
|
||||||
PEN_ENTRY = (0.5 * (PEN_X[0] + PEN_X[1]), -15.0)
|
PEN_ENTRY = (0.5 * (PEN_X[0] + PEN_X[1]), -15.0)
|
||||||
|
|
||||||
# --- Gate (the hole in the south stone wall) ---
|
# Gate (hole in the south wall)
|
||||||
GATE_X = PEN_X
|
GATE_X = PEN_X
|
||||||
GATE_Y = -15.0
|
GATE_Y = -15.0
|
||||||
|
|
||||||
# --- Robot specs (must match proto files) ---
|
# Dog spec — protos/ShepherdDog.proto
|
||||||
# Dog (controllers/shepherd_dog/, protos/ShepherdDog.proto)
|
|
||||||
DOG_WHEEL_RADIUS = 0.038 # m
|
DOG_WHEEL_RADIUS = 0.038 # m
|
||||||
DOG_WHEEL_BASE = 0.28 # m, axle-to-axle
|
DOG_WHEEL_BASE = 0.28 # m, axle-to-axle
|
||||||
DOG_MAX_WHEEL_OMEGA = 70.0 # rad/s
|
DOG_MAX_WHEEL_OMEGA = 70.0 # rad/s
|
||||||
DOG_MAX_LINEAR = DOG_WHEEL_RADIUS * DOG_MAX_WHEEL_OMEGA # ~2.66 m/s
|
DOG_MAX_LINEAR = DOG_WHEEL_RADIUS * DOG_MAX_WHEEL_OMEGA # ≈ 2.66 m/s
|
||||||
|
|
||||||
# Sheep (controllers/sheep/, protos/Sheep.proto)
|
# Sheep spec — protos/Sheep.proto
|
||||||
SHEEP_WHEEL_RADIUS = 0.031 # m
|
SHEEP_WHEEL_RADIUS = 0.031 # m
|
||||||
SHEEP_WHEEL_BASE = 0.20 # m
|
SHEEP_WHEEL_BASE = 0.20 # m
|
||||||
SHEEP_MAX_WHEEL_OMEGA = 25.0 # rad/s
|
SHEEP_MAX_WHEEL_OMEGA = 25.0 # rad/s
|
||||||
SHEEP_MAX_LINEAR = SHEEP_WHEEL_RADIUS * SHEEP_MAX_WHEEL_OMEGA # ~0.78 m/s
|
SHEEP_MAX_LINEAR = SHEEP_WHEEL_RADIUS * SHEEP_MAX_WHEEL_OMEGA # ≈ 0.78 m/s
|
||||||
|
|
||||||
# --- Webots step ---
|
WEBOTS_DT = 0.016 # seconds (matches WorldInfo.basicTimeStep)
|
||||||
WEBOTS_DT = 0.016 # seconds, matches WorldInfo.basicTimeStep = 16 in field.wbt
|
|
||||||
|
|
||||||
# --- Dog "virtual south wall" (training keeps dog out of the pen) ---
|
# Virtual south wall — env and controller both keep the dog north of this.
|
||||||
# At inference the controller also clips to this so a slightly miscalibrated
|
|
||||||
# policy doesn't accidentally drive into the pen and trap the sheep.
|
|
||||||
DOG_SOUTH_LIMIT = -14.5
|
DOG_SOUTH_LIMIT = -14.5
|
||||||
|
|
||||||
# --- Maximum supported flock size ---
|
|
||||||
MAX_SHEEP = 10
|
MAX_SHEEP = 10
|
||||||
|
|
||||||
|
|
||||||
@@ -85,12 +69,7 @@ def in_gate_corridor(x: float, y: float, margin: float = 0.0) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def is_penned_position(x: float, y: float, latch_margin: float = 0.2) -> bool:
|
def is_penned_position(x: float, y: float, latch_margin: float = 0.2) -> bool:
|
||||||
"""A sheep latches to "penned" once it crosses the gate plane south.
|
"""True iff (x, y) is in the gate column and south of the gate line."""
|
||||||
|
|
||||||
True iff x is inside the gate column (with a small margin) AND
|
|
||||||
y has dipped below the gate line. Once latched, the sheep is held by
|
|
||||||
in-pen forces and will not exit on its own.
|
|
||||||
"""
|
|
||||||
return (PEN_X[0] - latch_margin <= x <= PEN_X[1] + latch_margin
|
return (PEN_X[0] - latch_margin <= x <= PEN_X[1] + latch_margin
|
||||||
and y <= GATE_Y)
|
and y <= GATE_Y)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
"""Pytest configuration — ensure the project root is on ``sys.path``."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_PROJECT_ROOT = os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
if _PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, _PROJECT_ROOT)
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
"""Parity smoke-test for the herding env.
|
|
||||||
|
|
||||||
Verifies (a) all imports resolve, (b) the env's reset/step contract is
|
|
||||||
correct, (c) deterministic seeds give deterministic trajectories, and
|
|
||||||
(d) the Strömbom baseline can drive the env without crashing.
|
|
||||||
|
|
||||||
Run::
|
|
||||||
|
|
||||||
python -m training.parity_test
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
|
||||||
if _PROJECT_ROOT not in sys.path:
|
|
||||||
sys.path.insert(0, _PROJECT_ROOT)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from herding.world.geometry import MAX_SHEEP, PEN_ENTRY
|
|
||||||
from herding.obs import OBS_DIM
|
|
||||||
from herding.control.strombom import compute_action
|
|
||||||
from training.herding_env import HerdingEnv
|
|
||||||
|
|
||||||
|
|
||||||
def test_obs_action_shapes():
|
|
||||||
env = HerdingEnv(n_sheep=3, seed=0)
|
|
||||||
obs, info = env.reset()
|
|
||||||
assert obs.shape == (OBS_DIM,), obs.shape
|
|
||||||
assert obs.dtype == np.float32
|
|
||||||
obs2, r, term, trunc, info = env.step(np.array([0.5, 0.0], dtype=np.float32))
|
|
||||||
assert obs2.shape == (OBS_DIM,)
|
|
||||||
assert isinstance(r, float)
|
|
||||||
assert isinstance(term, bool) and isinstance(trunc, bool)
|
|
||||||
print("[ok] shapes")
|
|
||||||
|
|
||||||
|
|
||||||
def test_reset_determinism():
|
|
||||||
"""Reset with the same seed should give the same initial observation.
|
|
||||||
|
|
||||||
We don't require step-determinism — PPO doesn't need it, and chasing
|
|
||||||
bit-exactness through the flocking jitter isn't worth the complexity.
|
|
||||||
"""
|
|
||||||
env_a = HerdingEnv(n_sheep=3, seed=42)
|
|
||||||
env_b = HerdingEnv(n_sheep=3, seed=42)
|
|
||||||
obs_a, _ = env_a.reset(seed=42)
|
|
||||||
obs_b, _ = env_b.reset(seed=42)
|
|
||||||
assert np.allclose(obs_a, obs_b), "Reset is non-deterministic for same seed"
|
|
||||||
print("[ok] reset determinism")
|
|
||||||
|
|
||||||
|
|
||||||
def test_curriculum_n_sheep_varies():
|
|
||||||
env = HerdingEnv(seed=0)
|
|
||||||
sizes = set()
|
|
||||||
for _ in range(40):
|
|
||||||
_, info = env.reset()
|
|
||||||
sizes.add(info["n_sheep"])
|
|
||||||
assert 1 in sizes
|
|
||||||
assert max(sizes) <= MAX_SHEEP
|
|
||||||
print(f"[ok] curriculum sampling — saw n_sheep in {sorted(sizes)}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_strombom_drives_env():
|
|
||||||
"""Quick functional check that the analytic baseline can play the env
|
|
||||||
without exploding. Not a success-rate test — just no errors / NaNs."""
|
|
||||||
env = HerdingEnv(n_sheep=2, max_steps=400, seed=1)
|
|
||||||
obs, _ = env.reset()
|
|
||||||
for t in range(400):
|
|
||||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
|
||||||
for i in range(env.n_sheep)
|
|
||||||
if not env.sheep_penned[i]}
|
|
||||||
if not positions:
|
|
||||||
break
|
|
||||||
vx, vy, _mode = compute_action((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
|
||||||
obs, r, term, trunc, info = env.step(np.array([vx, vy], dtype=np.float32))
|
|
||||||
assert np.isfinite(obs).all(), f"NaN/Inf in obs at step {t}"
|
|
||||||
assert np.isfinite(r), f"NaN reward at step {t}"
|
|
||||||
if term or trunc:
|
|
||||||
break
|
|
||||||
print(f"[ok] strombom rollout — final n_penned={int(env.sheep_penned.sum())}/{env.n_sheep} after {env.steps} steps")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
test_obs_action_shapes()
|
|
||||||
test_reset_determinism()
|
|
||||||
test_curriculum_n_sheep_varies()
|
|
||||||
test_strombom_drives_env()
|
|
||||||
print("\nAll parity checks passed.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -0,0 +1,164 @@
|
|||||||
|
"""Control primitives: speed modulation, Strömbom, Sequential, ActiveScan."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from herding.control.active_scan import (
|
||||||
|
EMPTY_DEBOUNCE_STEPS, INITIAL_SCAN_STEPS, ActiveScanTeacher,
|
||||||
|
)
|
||||||
|
from herding.control.modulation import (
|
||||||
|
MIN_SPEED, SLOW_NEAR_SHEEP, modulate_speed_near_sheep,
|
||||||
|
)
|
||||||
|
from herding.control.sequential import compute_action as sequential_action
|
||||||
|
from herding.control.strombom import (
|
||||||
|
DELTA_DRIVE, F_FACTOR, compute_action as strombom_action,
|
||||||
|
)
|
||||||
|
from herding.world.geometry import PEN_ENTRY
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Modulation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_modulation_empty_input_passthrough():
|
||||||
|
assert modulate_speed_near_sheep(1.0, 0.0, (0.0, 0.0), []) == (1.0, 0.0)
|
||||||
|
assert modulate_speed_near_sheep(1.0, 0.0, (0.0, 0.0), {}) == (1.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_modulation_far_sheep_passthrough():
|
||||||
|
vx, vy = modulate_speed_near_sheep(1.0, 0.0, (0.0, 0.0), [(100.0, 0.0)])
|
||||||
|
assert (vx, vy) == (1.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_modulation_close_sheep_min_speed():
|
||||||
|
vx, vy = modulate_speed_near_sheep(1.0, 0.0, (0.0, 0.0), [(0.0, 0.0)])
|
||||||
|
assert math.isclose(vx, MIN_SPEED)
|
||||||
|
assert vy == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_modulation_preserves_direction():
|
||||||
|
vx, vy = modulate_speed_near_sheep(0.6, 0.8, (0.0, 0.0), [(1.0, 0.0)])
|
||||||
|
ratio = math.hypot(vx, vy)
|
||||||
|
# Direction preserved.
|
||||||
|
assert math.isclose(vx / ratio, 0.6, abs_tol=1e-6)
|
||||||
|
assert math.isclose(vy / ratio, 0.8, abs_tol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_modulation_linear_ramp_midpoint():
|
||||||
|
vx, _ = modulate_speed_near_sheep(1.0, 0.0, (0.0, 0.0),
|
||||||
|
[(SLOW_NEAR_SHEEP / 2, 0.0)])
|
||||||
|
expected = MIN_SPEED + (1.0 - MIN_SPEED) * 0.5
|
||||||
|
assert math.isclose(vx, expected, abs_tol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_modulation_accepts_dict_input():
|
||||||
|
vx_list, _ = modulate_speed_near_sheep(1.0, 0.0, (0.0, 0.0),
|
||||||
|
[(1.0, 0.0)])
|
||||||
|
vx_dict, _ = modulate_speed_near_sheep(1.0, 0.0, (0.0, 0.0),
|
||||||
|
{"t0": (1.0, 0.0)})
|
||||||
|
assert math.isclose(vx_list, vx_dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Strömbom
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_strombom_empty_input_idle():
|
||||||
|
vx, vy, mode = strombom_action((0.0, 0.0), {}, PEN_ENTRY)
|
||||||
|
assert (vx, vy, mode) == (0.0, 0.0, "idle")
|
||||||
|
|
||||||
|
|
||||||
|
def test_strombom_tight_flock_drives():
|
||||||
|
# A tight 3-sheep cluster centred at (0, 8): radius < F_FACTOR·√3.
|
||||||
|
sheep = {"s0": (0.0, 8.0), "s1": (0.5, 8.5), "s2": (-0.5, 8.0)}
|
||||||
|
vx, vy, mode = strombom_action((0.0, 0.0), sheep, PEN_ENTRY)
|
||||||
|
assert mode == "drive"
|
||||||
|
assert math.isclose(math.hypot(vx, vy), 1.0, abs_tol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strombom_scattered_flock_collects():
|
||||||
|
# Sparse, max radius > F_FACTOR·√n.
|
||||||
|
sheep = {"s0": (10.0, 10.0), "s1": (-10.0, -10.0), "s2": (0.0, 0.0)}
|
||||||
|
_vx, _vy, mode = strombom_action((0.0, 0.0), sheep, PEN_ENTRY)
|
||||||
|
assert mode == "collect"
|
||||||
|
|
||||||
|
|
||||||
|
def test_strombom_ignores_already_penned_sheep():
|
||||||
|
"""Sheep south of the gate plane are excluded from the active set."""
|
||||||
|
sheep = {
|
||||||
|
"s_active": (5.0, 5.0),
|
||||||
|
"s_penned": (11.5, -20.0),
|
||||||
|
}
|
||||||
|
# With one active sheep, Strömbom drives (radius = 0 < threshold).
|
||||||
|
_vx, _vy, mode = strombom_action((0.0, 0.0), sheep, PEN_ENTRY)
|
||||||
|
assert mode == "drive"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sequential
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_sequential_empty_input_idle():
|
||||||
|
vx, vy, mode = sequential_action((0.0, 0.0), {}, PEN_ENTRY)
|
||||||
|
assert (vx, vy, mode) == (0.0, 0.0, "idle")
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_targets_closest_to_pen():
|
||||||
|
near = (10.0, -5.0) # closer to pen entry (11.5, -15)
|
||||||
|
far = (-10.0, 10.0)
|
||||||
|
sheep = {"near": near, "far": far}
|
||||||
|
_vx, _vy, mode = sequential_action((0.0, 0.0), sheep, PEN_ENTRY)
|
||||||
|
assert mode.startswith("drive:near")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ActiveScan wrapper
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_active_scan_initial_phase_rotates():
|
||||||
|
teacher = ActiveScanTeacher(strombom_action)
|
||||||
|
# First call → opening rotation regardless of input.
|
||||||
|
vx, vy, mode = teacher((0.0, 0.0), 0.0, {"s0": (5.0, 0.0)}, PEN_ENTRY)
|
||||||
|
assert mode == "scan_initial"
|
||||||
|
assert math.isclose(math.hypot(vx, vy), 1.0, abs_tol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_active_scan_hands_off_to_base_after_opener():
|
||||||
|
teacher = ActiveScanTeacher(strombom_action, initial_scan_steps=2)
|
||||||
|
# Burn through the opener.
|
||||||
|
for _ in range(2):
|
||||||
|
teacher((0.0, 0.0), 0.0, {"s0": (0.0, 8.0)}, PEN_ENTRY)
|
||||||
|
_vx, _vy, mode = teacher((0.0, 0.0), 0.0, {"s0": (0.0, 8.0)}, PEN_ENTRY)
|
||||||
|
# Either drive (Strömbom mode label) or collect; not scan_initial.
|
||||||
|
assert "scan" not in mode
|
||||||
|
|
||||||
|
|
||||||
|
def test_active_scan_holds_last_action_on_brief_empty():
|
||||||
|
teacher = ActiveScanTeacher(strombom_action, initial_scan_steps=1)
|
||||||
|
# Step once (opening), then once with a visible sheep — sets last_action.
|
||||||
|
teacher((0.0, 0.0), 0.0, {}, PEN_ENTRY)
|
||||||
|
teacher((0.0, 0.0), 0.0, {"s0": (0.0, 8.0)}, PEN_ENTRY)
|
||||||
|
last = teacher.last_action
|
||||||
|
# Now a single empty frame → hold.
|
||||||
|
vx, vy, mode = teacher((0.0, 0.0), 0.0, {}, PEN_ENTRY)
|
||||||
|
assert mode == "hold"
|
||||||
|
assert (vx, vy) == last
|
||||||
|
|
||||||
|
|
||||||
|
def test_active_scan_explores_after_sustained_empty():
|
||||||
|
teacher = ActiveScanTeacher(strombom_action, initial_scan_steps=1)
|
||||||
|
teacher((0.0, 0.0), 0.0, {}, PEN_ENTRY) # opener
|
||||||
|
for _ in range(EMPTY_DEBOUNCE_STEPS):
|
||||||
|
last_vx, last_vy, mode = teacher((5.0, 5.0), 0.0, {}, PEN_ENTRY)
|
||||||
|
assert mode in ("explore", "scan_at_centre")
|
||||||
|
|
||||||
|
|
||||||
|
def test_active_scan_reset_clears_state():
|
||||||
|
teacher = ActiveScanTeacher(strombom_action, initial_scan_steps=5)
|
||||||
|
for _ in range(3):
|
||||||
|
teacher((0.0, 0.0), 0.0, {}, PEN_ENTRY)
|
||||||
|
assert teacher.step == 3
|
||||||
|
teacher.reset()
|
||||||
|
assert teacher.step == 0
|
||||||
|
assert teacher.empty_streak == 0
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
"""Differential-drive kinematics and the (vx, vy) → wheel-speed map."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from herding.world.diffdrive import (
|
||||||
|
heading_speed_to_wheels, kinematics_step, velocity_to_wheels,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
WHEEL_R = 0.038
|
||||||
|
WHEEL_B = 0.28
|
||||||
|
MAX_OMEGA = 70.0
|
||||||
|
MAX_LINEAR = WHEEL_R * MAX_OMEGA
|
||||||
|
DT = 0.016
|
||||||
|
|
||||||
|
|
||||||
|
def test_kinematics_zero_input_is_identity():
|
||||||
|
x, y, h = kinematics_step(1.0, 2.0, 0.5, 0.0, 0.0, WHEEL_R, WHEEL_B, DT)
|
||||||
|
assert (x, y, h) == (1.0, 2.0, 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kinematics_forward_motion():
|
||||||
|
# Equal wheel speeds → pure translation along the heading.
|
||||||
|
x, y, h = kinematics_step(0.0, 0.0, 0.0, 10.0, 10.0, WHEEL_R, WHEEL_B, DT)
|
||||||
|
assert h == 0.0
|
||||||
|
assert math.isclose(x, 10.0 * WHEEL_R * DT)
|
||||||
|
assert y == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_kinematics_pure_rotation():
|
||||||
|
# Opposite wheel speeds → pure rotation, position unchanged.
|
||||||
|
x, y, h = kinematics_step(0.0, 0.0, 0.0, -5.0, 5.0, WHEEL_R, WHEEL_B, DT)
|
||||||
|
assert (x, y) == (0.0, 0.0)
|
||||||
|
assert h > 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_kinematics_heading_wrapped_to_pi():
|
||||||
|
_, _, h = kinematics_step(0.0, 0.0, math.pi - 0.01, 100.0, -100.0,
|
||||||
|
WHEEL_R, WHEEL_B, DT)
|
||||||
|
assert -math.pi <= h <= math.pi
|
||||||
|
|
||||||
|
|
||||||
|
def test_velocity_to_wheels_zero_velocity():
|
||||||
|
left, right = velocity_to_wheels(0.0, 0.0, 0.0,
|
||||||
|
MAX_LINEAR, WHEEL_R, MAX_OMEGA)
|
||||||
|
assert (left, right) == (0.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_velocity_to_wheels_aligned_forward():
|
||||||
|
# Target straight ahead → equal positive wheel speeds.
|
||||||
|
left, right = velocity_to_wheels(1.0, 0.0, 0.0,
|
||||||
|
MAX_LINEAR, WHEEL_R, MAX_OMEGA, k_turn=4.0)
|
||||||
|
assert math.isclose(left, right, abs_tol=1e-6)
|
||||||
|
assert left > 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_velocity_to_wheels_perpendicular_target_spins():
|
||||||
|
# Target 90° from heading → forward speed ≈ 0, wheels equal-and-opposite.
|
||||||
|
left, right = velocity_to_wheels(0.0, 1.0, 0.0,
|
||||||
|
MAX_LINEAR, WHEEL_R, MAX_OMEGA, k_turn=4.0)
|
||||||
|
assert left + right == pytest.approx(0.0, abs=1e-6)
|
||||||
|
assert right > 0.0 # turning CCW (left of heading is +y for h=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_velocity_to_wheels_clamped_to_max_omega():
|
||||||
|
# Far overshoot — both wheel commands clamped at ±MAX_OMEGA.
|
||||||
|
left, right = velocity_to_wheels(-1.0, 0.0, 0.0,
|
||||||
|
MAX_LINEAR, WHEEL_R, MAX_OMEGA, k_turn=20.0)
|
||||||
|
assert -MAX_OMEGA <= left <= MAX_OMEGA
|
||||||
|
assert -MAX_OMEGA <= right <= MAX_OMEGA
|
||||||
|
|
||||||
|
|
||||||
|
def test_heading_speed_to_wheels_aligned():
|
||||||
|
left, right = heading_speed_to_wheels(0.0, 10.0, 0.0, MAX_OMEGA)
|
||||||
|
assert math.isclose(left, right, abs_tol=1e-6)
|
||||||
|
assert left > 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_heading_speed_to_wheels_reverse_target_forwards_zero():
|
||||||
|
left, right = heading_speed_to_wheels(math.pi, 10.0, 0.0, MAX_OMEGA)
|
||||||
|
# cos(π) clamped at 0 → no forward; pure rotation.
|
||||||
|
assert left + right == pytest.approx(0.0, abs=1e-6)
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
"""Gymnasium env: contract, determinism, reward components."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from herding.world.geometry import MAX_SHEEP, PEN_ENTRY
|
||||||
|
from herding.perception.obs import OBS_DIM
|
||||||
|
from herding.control.strombom import compute_action as strombom_action
|
||||||
|
from training.herding_env import HerdingEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_obs_action_shapes_single_frame():
|
||||||
|
env = HerdingEnv(n_sheep=3, seed=0, use_lidar=False)
|
||||||
|
obs, info = env.reset()
|
||||||
|
assert obs.shape == (OBS_DIM,)
|
||||||
|
assert obs.dtype == np.float32
|
||||||
|
obs, reward, term, trunc, info = env.step(
|
||||||
|
np.array([0.5, 0.0], dtype=np.float32))
|
||||||
|
assert obs.shape == (OBS_DIM,)
|
||||||
|
assert isinstance(reward, float)
|
||||||
|
assert isinstance(term, bool) and isinstance(trunc, bool)
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_observation_space_matches_frame_stack():
|
||||||
|
env = HerdingEnv(n_sheep=2, seed=0, use_lidar=False, frame_stack=4)
|
||||||
|
obs, _ = env.reset()
|
||||||
|
assert obs.shape == (OBS_DIM * 4,)
|
||||||
|
assert env.observation_space.shape == (OBS_DIM * 4,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_reset_determinism_same_seed():
|
||||||
|
a = HerdingEnv(n_sheep=3, seed=42, use_lidar=False)
|
||||||
|
b = HerdingEnv(n_sheep=3, seed=42, use_lidar=False)
|
||||||
|
obs_a, _ = a.reset(seed=42)
|
||||||
|
obs_b, _ = b.reset(seed=42)
|
||||||
|
assert np.allclose(obs_a, obs_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_curriculum_samples_full_range():
|
||||||
|
env = HerdingEnv(seed=0, use_lidar=False)
|
||||||
|
sizes = set()
|
||||||
|
for _ in range(40):
|
||||||
|
_, info = env.reset()
|
||||||
|
sizes.add(info["n_sheep"])
|
||||||
|
assert 1 in sizes
|
||||||
|
assert max(sizes) <= MAX_SHEEP
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_step_returns_finite_values():
|
||||||
|
env = HerdingEnv(n_sheep=2, max_steps=200, seed=1, use_lidar=False)
|
||||||
|
obs, _ = env.reset()
|
||||||
|
for _ in range(200):
|
||||||
|
action = np.array([0.5, 0.5], dtype=np.float32)
|
||||||
|
obs, reward, term, trunc, _ = env.step(action)
|
||||||
|
assert np.isfinite(obs).all()
|
||||||
|
assert math.isfinite(reward)
|
||||||
|
if term or trunc:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_options_n_sheep_overrides_curriculum():
|
||||||
|
env = HerdingEnv(seed=0, use_lidar=False)
|
||||||
|
_, info = env.reset(options={"n_sheep": 7})
|
||||||
|
assert info["n_sheep"] == 7
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_perceived_positions_lidar_vs_privileged():
|
||||||
|
env_priv = HerdingEnv(n_sheep=3, seed=0, use_lidar=False)
|
||||||
|
env_priv.reset(seed=0)
|
||||||
|
pos_priv = env_priv.perceived_positions()
|
||||||
|
assert len(pos_priv) == 3
|
||||||
|
|
||||||
|
env_lidar = HerdingEnv(n_sheep=3, seed=0, use_lidar=True)
|
||||||
|
env_lidar.reset(seed=0)
|
||||||
|
pos_lidar = env_lidar.perceived_positions()
|
||||||
|
# LiDAR mode returns whatever the tracker has — may be fewer than 3
|
||||||
|
# if sheep are out of FOV / range, but never more.
|
||||||
|
assert len(pos_lidar) <= 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_set_time_weight_affects_reward():
|
||||||
|
env = HerdingEnv(n_sheep=1, seed=0, use_lidar=False)
|
||||||
|
env.reset(seed=0)
|
||||||
|
_, r_default, *_ = env.step(np.array([0.0, 0.0], dtype=np.float32))
|
||||||
|
env.set_time_weight(-1.0)
|
||||||
|
env.reset(seed=0)
|
||||||
|
_, r_penalised, *_ = env.step(np.array([0.0, 0.0], dtype=np.float32))
|
||||||
|
assert r_penalised < r_default
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_strombom_rollout_moves_dog():
|
||||||
|
env = HerdingEnv(n_sheep=2, max_steps=400, seed=1, use_lidar=False)
|
||||||
|
env.reset()
|
||||||
|
start = (env.dog_x, env.dog_y)
|
||||||
|
for _ in range(400):
|
||||||
|
positions = env.perceived_positions()
|
||||||
|
if not positions:
|
||||||
|
break
|
||||||
|
vx, vy, _ = strombom_action(
|
||||||
|
(env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
||||||
|
obs, _r, term, trunc, _ = env.step(
|
||||||
|
np.array([vx, vy], dtype=np.float32))
|
||||||
|
if term or trunc:
|
||||||
|
break
|
||||||
|
displacement = math.hypot(env.dog_x - start[0], env.dog_y - start[1])
|
||||||
|
assert displacement > 0.05
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
"""Geometric predicates and constants."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from herding.world.geometry import (
|
||||||
|
FIELD_X, FIELD_Y, GATE_X, GATE_Y, MAX_SHEEP, PEN_ENTRY, PEN_X, PEN_Y,
|
||||||
|
distance_to_pen_entry, in_field, in_gate_corridor, in_pen,
|
||||||
|
is_penned_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_field_dimensions():
|
||||||
|
assert FIELD_X == (-15.0, 15.0)
|
||||||
|
assert FIELD_Y == (-15.0, 15.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pen_geometry():
|
||||||
|
assert PEN_X == (10.0, 13.0)
|
||||||
|
assert PEN_Y == (-22.0, -15.0)
|
||||||
|
assert PEN_ENTRY == (11.5, -15.0)
|
||||||
|
assert GATE_X == PEN_X
|
||||||
|
assert GATE_Y == -15.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_in_pen_strict_interior():
|
||||||
|
assert in_pen(11.5, -18.0)
|
||||||
|
assert not in_pen(10.0, -18.0) # boundary excluded
|
||||||
|
assert not in_pen(11.5, -15.0) # gate plane excluded
|
||||||
|
assert not in_pen(0.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_in_field_with_margin():
|
||||||
|
assert in_field(0.0, 0.0)
|
||||||
|
assert in_field(14.0, 14.0)
|
||||||
|
assert not in_field(15.5, 0.0)
|
||||||
|
assert in_field(14.4, 0.0, margin=0.5)
|
||||||
|
assert not in_field(14.6, 0.0, margin=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_in_gate_corridor():
|
||||||
|
assert in_gate_corridor(11.5, -18.0)
|
||||||
|
assert in_gate_corridor(10.0, -15.0)
|
||||||
|
assert not in_gate_corridor(11.5, -10.0)
|
||||||
|
assert not in_gate_corridor(5.0, -18.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_penned_position_latches_below_gate():
|
||||||
|
# In the gate column and south of the gate plane → penned.
|
||||||
|
assert is_penned_position(11.5, -15.0)
|
||||||
|
assert is_penned_position(10.5, -18.0)
|
||||||
|
assert is_penned_position(12.5, -22.0)
|
||||||
|
# Above the gate plane → not yet.
|
||||||
|
assert not is_penned_position(11.5, -14.9)
|
||||||
|
# Outside the gate column → not penned even if south.
|
||||||
|
assert not is_penned_position(0.0, -16.0)
|
||||||
|
assert not is_penned_position(14.0, -16.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_penned_position_latch_margin():
|
||||||
|
# Slight tolerance on the gate column.
|
||||||
|
assert is_penned_position(9.9, -15.5)
|
||||||
|
assert is_penned_position(13.1, -15.5)
|
||||||
|
assert not is_penned_position(9.7, -15.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_distance_to_pen_entry():
|
||||||
|
assert distance_to_pen_entry(*PEN_ENTRY) == 0.0
|
||||||
|
assert math.isclose(distance_to_pen_entry(11.5, -10.0), 5.0)
|
||||||
|
assert math.isclose(distance_to_pen_entry(0.0, 0.0),
|
||||||
|
math.hypot(11.5, 15.0))
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_sheep_positive_int():
|
||||||
|
assert isinstance(MAX_SHEEP, int)
|
||||||
|
assert MAX_SHEEP >= 1
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
"""Observation builder — shape, normalisation, order invariance."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from herding.perception.obs import OBS_DIM, build_obs
|
||||||
|
|
||||||
|
|
||||||
|
def test_obs_shape_and_dtype():
|
||||||
|
obs = build_obs((0.0, 0.0), 0.0, [(5.0, 5.0)], [False])
|
||||||
|
assert obs.shape == (OBS_DIM,)
|
||||||
|
assert obs.dtype == np.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_obs_no_active_sheep_terminal():
|
||||||
|
# All sheep penned → flock-summary fields zero, count zero.
|
||||||
|
obs = build_obs((0.0, 0.0), 0.0, [(1.0, 1.0), (2.0, 2.0)], [True, True])
|
||||||
|
assert obs[19] == 0.0
|
||||||
|
# Aggregate fields (CoM, radius, std, vectors) should all be zero.
|
||||||
|
assert np.allclose(obs[4:12], 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_obs_dog_pose_normalised():
|
||||||
|
obs = build_obs((15.0, -15.0), math.pi / 2, [(0.0, 0.0)], [False])
|
||||||
|
assert math.isclose(obs[0], 1.0)
|
||||||
|
assert math.isclose(obs[1], -1.0)
|
||||||
|
assert math.isclose(obs[2], math.cos(math.pi / 2), abs_tol=1e-6)
|
||||||
|
assert math.isclose(obs[3], math.sin(math.pi / 2), abs_tol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_obs_order_invariance():
|
||||||
|
"""Sheep order in the input list must not affect the observation."""
|
||||||
|
sheep = [(3.0, 2.0), (-5.0, 1.0), (0.0, 8.0)]
|
||||||
|
p = [False] * 3
|
||||||
|
a = build_obs((0.0, 0.0), 0.0, sheep, p)
|
||||||
|
b = build_obs((0.0, 0.0), 0.0, list(reversed(sheep)), list(reversed(p)))
|
||||||
|
assert np.allclose(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_obs_count_field_normalised_by_n_max():
|
||||||
|
sheep = [(1.0, 1.0)] * 5
|
||||||
|
p = [False] * 5
|
||||||
|
obs = build_obs((0.0, 0.0), 0.0, sheep, p, n_max=10)
|
||||||
|
assert math.isclose(obs[19], 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_obs_polar_histogram_sums_to_one():
|
||||||
|
sheep = [(1.0, 0.0), (-1.0, 0.0), (0.0, 1.0), (0.0, -1.0)]
|
||||||
|
obs = build_obs((0.0, 0.0), 0.0, sheep, [False] * 4)
|
||||||
|
assert math.isclose(float(obs[20:28].sum()), 1.0, abs_tol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_obs_named_channels_closest_rearmost():
|
||||||
|
# Channels 28..29 = (closest_to_pen - dog) / 15
|
||||||
|
# Channels 30..31 = (rearmost - dog) / 15
|
||||||
|
pen_x, pen_y = 11.5, -15.0
|
||||||
|
near = (pen_x + 1.0, pen_y + 1.0)
|
||||||
|
far = (-10.0, 10.0)
|
||||||
|
obs = build_obs((0.0, 0.0), 0.0, [near, far], [False, False])
|
||||||
|
tol = 1e-5
|
||||||
|
assert math.isclose(obs[28], near[0] / 15.0, abs_tol=tol)
|
||||||
|
assert math.isclose(obs[29], near[1] / 15.0, abs_tol=tol)
|
||||||
|
assert math.isclose(obs[30], far[0] / 15.0, abs_tol=tol)
|
||||||
|
assert math.isclose(obs[31], far[1] / 15.0, abs_tol=tol)
|
||||||
|
|
||||||
|
|
||||||
|
def test_obs_pen_vector_zero_at_pen_entry():
|
||||||
|
obs = build_obs((11.5, -15.0), 0.0, [(0.0, 0.0)], [False])
|
||||||
|
assert math.isclose(obs[14], 0.0) # distance to pen
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
"""LiDAR simulation + perception pipeline + multi-target tracker."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from herding.perception.lidar_perception import (
|
||||||
|
STATIC_REJECT, detections_from_scan,
|
||||||
|
)
|
||||||
|
from herding.perception.lidar_sim import (
|
||||||
|
LIDAR_MAX_RANGE, LIDAR_N_RAYS, SHEEP_RADIUS, ray_angles, simulate_scan,
|
||||||
|
)
|
||||||
|
from herding.perception.sheep_tracker import (
|
||||||
|
FORGET_STEPS, GATE_M, MAX_ACTIVE_TRACKS, REACQUIRE_GATE_M,
|
||||||
|
REACQUIRE_MIN_AGE, SheepTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sim
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_simulate_scan_shape_and_dtype():
|
||||||
|
ranges = simulate_scan(0.0, 0.0, 0.0, [(5.0, 0.0)], noise=0.0)
|
||||||
|
assert ranges.shape == (LIDAR_N_RAYS,)
|
||||||
|
assert ranges.dtype == np.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_simulate_scan_no_sheep_far_from_walls():
|
||||||
|
# Dog at origin, no sheep, walls all ≥ 15 m away → all rays at max.
|
||||||
|
ranges = simulate_scan(0.0, 0.0, 0.0, [], noise=0.0)
|
||||||
|
# Walls (east/west at ±15) are beyond LIDAR_MAX_RANGE=12, so no hits.
|
||||||
|
assert (ranges == LIDAR_MAX_RANGE).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_simulate_scan_sheep_in_front_returns_centre_hit():
|
||||||
|
# Sheep dead ahead at 5 m. Centre ray should hit ~ 5 - SHEEP_RADIUS.
|
||||||
|
ranges = simulate_scan(0.0, 0.0, 0.0, [(5.0, 0.0)], noise=0.0)
|
||||||
|
centre = ranges[LIDAR_N_RAYS // 2]
|
||||||
|
assert math.isclose(float(centre), 5.0 - SHEEP_RADIUS, abs_tol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simulate_scan_sheep_behind_dog_not_hit():
|
||||||
|
ranges = simulate_scan(0.0, 0.0, 0.0, [(-5.0, 0.0)], noise=0.0)
|
||||||
|
assert (ranges == LIDAR_MAX_RANGE).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_simulate_scan_wall_hit():
|
||||||
|
# Dog 1 m south of the north wall, facing north → centre ray ≈ 1 m.
|
||||||
|
ranges = simulate_scan(0.0, 14.0, math.pi / 2, [], noise=0.0)
|
||||||
|
centre = ranges[LIDAR_N_RAYS // 2]
|
||||||
|
assert math.isclose(float(centre), 1.0, abs_tol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Perception
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_detections_recover_sheep_position():
|
||||||
|
sheep = [(5.0, 0.0), (3.0, 1.0)]
|
||||||
|
ranges = simulate_scan(0.0, 0.0, 0.0, sheep, noise=0.0)
|
||||||
|
det = detections_from_scan(ranges, 0.0, 0.0, 0.0)
|
||||||
|
assert len(det) == 2
|
||||||
|
# Centroid bias is corrected to within ~5 cm.
|
||||||
|
for truth in sheep:
|
||||||
|
assert any(math.hypot(d[0] - truth[0], d[1] - truth[1]) < 0.1
|
||||||
|
for d in det)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detections_filter_gate_post():
|
||||||
|
# An empty scene at the dog right next to a gate post produces no
|
||||||
|
# detections — the static-feature filter drops the post return.
|
||||||
|
ranges = simulate_scan(11.5, -10.0, -math.pi / 2, [], noise=0.0)
|
||||||
|
det = detections_from_scan(ranges, 11.5, -10.0, -math.pi / 2)
|
||||||
|
for cx, cy in det:
|
||||||
|
assert math.hypot(cx - 10.0, cy + 15.0) > STATIC_REJECT
|
||||||
|
assert math.hypot(cx - 13.0, cy + 15.0) > STATIC_REJECT
|
||||||
|
|
||||||
|
|
||||||
|
def test_detections_empty_scan_returns_nothing():
|
||||||
|
assert detections_from_scan(np.array([], dtype=np.float32),
|
||||||
|
0.0, 0.0, 0.0) == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tracker
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_tracker_creates_track_for_new_detection():
|
||||||
|
t = SheepTracker()
|
||||||
|
t.update([(5.0, 0.0)])
|
||||||
|
assert t.n_active() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_associates_close_detections():
|
||||||
|
"""A small movement within the gate keeps the same track."""
|
||||||
|
t = SheepTracker()
|
||||||
|
t.update([(5.0, 0.0)])
|
||||||
|
t.update([(5.5, 0.0)])
|
||||||
|
assert t.n_active() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_spawns_new_track_far_detection():
|
||||||
|
t = SheepTracker()
|
||||||
|
t.update([(5.0, 0.0)])
|
||||||
|
t.update([(-5.0, 0.0)]) # well outside the gate
|
||||||
|
assert t.n_active() == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_reacquisition_for_stale_track():
|
||||||
|
"""A stale track within the wider re-acquisition gate rebinds rather
|
||||||
|
than spawning a duplicate."""
|
||||||
|
t = SheepTracker()
|
||||||
|
t.update([(0.0, 0.0)])
|
||||||
|
# Let it go stale.
|
||||||
|
for _ in range(REACQUIRE_MIN_AGE):
|
||||||
|
t.update([])
|
||||||
|
# Re-emerges within REACQUIRE_GATE but outside the primary GATE.
|
||||||
|
offset = (GATE_M + REACQUIRE_GATE_M) / 2.0
|
||||||
|
t.update([(offset, 0.0)])
|
||||||
|
assert t.n_active() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_forgets_stale_tracks():
|
||||||
|
t = SheepTracker()
|
||||||
|
t.update([(0.0, 0.0)])
|
||||||
|
for _ in range(FORGET_STEPS + 1):
|
||||||
|
t.update([])
|
||||||
|
assert t.n_active() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_penned_position_promotes_track():
|
||||||
|
t = SheepTracker()
|
||||||
|
t.update([(11.5, -16.0)]) # spawn inside the pen column
|
||||||
|
# is_penned_position is True for this point.
|
||||||
|
assert t.n_penned() == 1
|
||||||
|
assert t.n_active() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_penned_tracks_persist():
|
||||||
|
t = SheepTracker()
|
||||||
|
t.update([(11.5, -16.0)])
|
||||||
|
for _ in range(FORGET_STEPS * 2):
|
||||||
|
t.update([])
|
||||||
|
# Penned tracks are not forgotten.
|
||||||
|
assert t.n_penned() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_caps_active_set():
|
||||||
|
t = SheepTracker()
|
||||||
|
# Spawn more than the cap, each well outside the others' gates.
|
||||||
|
for k in range(MAX_ACTIVE_TRACKS + 5):
|
||||||
|
t.update([(k * (GATE_M + 1.0), 0.0)])
|
||||||
|
assert t.n_active() <= MAX_ACTIVE_TRACKS
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_reset_clears_state():
|
||||||
|
t = SheepTracker()
|
||||||
|
t.update([(0.0, 0.0)])
|
||||||
|
t.reset()
|
||||||
|
assert t.n_active() == 0
|
||||||
|
assert t.step == 0
|
||||||
+17
-9
@@ -6,7 +6,7 @@ Two stages, strictly sequential:
|
|||||||
sim demos (Strömbom on tracker output, K=4 frame stack)
|
sim demos (Strömbom on tracker output, K=4 frame stack)
|
||||||
│
|
│
|
||||||
▼
|
▼
|
||||||
bc_pretrain.py ──► runs/bc (Strömbom-imitated MLP)
|
bc/pretrain.py ──► runs/bc (Strömbom-imitated MLP)
|
||||||
│
|
│
|
||||||
▼ KL-regularised PPO fine-tune
|
▼ KL-regularised PPO fine-tune
|
||||||
│
|
│
|
||||||
@@ -17,10 +17,13 @@ runs/rl (deployed `rl` mode — beats BC and Strömbom)
|
|||||||
|
|
||||||
```
|
```
|
||||||
herding_env.py — Gymnasium env (LiDAR raycast + tracker by default)
|
herding_env.py — Gymnasium env (LiDAR raycast + tracker by default)
|
||||||
bc_pretrain.py — MSE + cosine BC of (obs, action) demos into MlpPolicy
|
bc/pretrain.py — MSE + cosine BC of (obs, action) demos into MlpPolicy
|
||||||
train_ppo.py — KL-regularised PPO fine-tune of a BC checkpoint
|
rl/train.py — KL-regularised PPO fine-tune of a BC checkpoint
|
||||||
eval.py — multi-seed analytic / learned policy comparison
|
eval.py — multi-seed analytic / learned policy comparison
|
||||||
runs/ — checkpoints (whitelisted entries in top-level .gitignore)
|
runs/ — checkpoints (whitelisted entries in top-level .gitignore)
|
||||||
|
|
||||||
|
(Unit + integration tests live in the top-level ``tests/`` directory;
|
||||||
|
run with ``python -m pytest tests/``.)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Setup
|
## Setup
|
||||||
@@ -35,18 +38,23 @@ rollout collection, not gradient compute.
|
|||||||
|
|
||||||
## End-to-end pipeline
|
## End-to-end pipeline
|
||||||
|
|
||||||
|
The simplest way to run everything is the Makefile at the project
|
||||||
|
root: ``make`` does the full chain, ``make rl`` rebuilds whatever's
|
||||||
|
needed up to that point, etc. The individual stages below are kept
|
||||||
|
explicit for cases where you want to tune a single step.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 1. Sim demos with the active-scan + Strömbom teacher under LiDAR
|
# 1. Sim demos with the active-scan + Strömbom teacher under LiDAR
|
||||||
# perception. K=4 frame stack so the MLP has temporal context.
|
# perception. K=4 frame stack so the MLP has temporal context.
|
||||||
python -m tools.collect_demos --teacher strombom \
|
python -m training.bc.collect --teacher strombom \
|
||||||
--out training/demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
--out training/bc/demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
||||||
|
|
||||||
# 2. Behaviour-clone.
|
# 2. Behaviour-clone.
|
||||||
python -m training.bc_pretrain --demos training/demos.npz \
|
python -m training.bc.pretrain --demos training/bc/demos.npz \
|
||||||
--out training/runs/bc --epochs 60 --net-arch 512,512
|
--out training/runs/bc --epochs 60 --net-arch 512,512
|
||||||
|
|
||||||
# 3. KL-regularised PPO fine-tune of bc.
|
# 3. KL-regularised PPO fine-tune of bc.
|
||||||
python -m training.train_ppo \
|
python -m training.rl.train \
|
||||||
--bc training/runs/bc --out training/runs/rl \
|
--bc training/runs/bc --out training/runs/rl \
|
||||||
--total-timesteps 1000000
|
--total-timesteps 1000000
|
||||||
|
|
||||||
@@ -55,11 +63,11 @@ python -m training.eval --policy training/runs/rl \
|
|||||||
--max-flock 10 --max-steps 15000 --n-seeds 10
|
--max-flock 10 --max-steps 15000 --n-seeds 10
|
||||||
```
|
```
|
||||||
|
|
||||||
`bc_pretrain.py` saves the **best-val_cos** snapshot, not the final
|
`bc/pretrain.py` saves the **best-val_cos** snapshot, not the final
|
||||||
epoch — multi-modal teachers make training noisy and the last epoch is
|
epoch — multi-modal teachers make training noisy and the last epoch is
|
||||||
often worse than an earlier one.
|
often worse than an earlier one.
|
||||||
|
|
||||||
`train_ppo.py` loads BC weights into both a trainable policy and a
|
`rl/train.py` loads BC weights into both a trainable policy and a
|
||||||
frozen reference, fixes `log_std` small, and adds `β · KL(π‖π_ref)` to
|
frozen reference, fixes `log_std` small, and adds `β · KL(π‖π_ref)` to
|
||||||
the loss so the policy can only move within a trust region around BC.
|
the loss so the policy can only move within a trust region around BC.
|
||||||
See the file header for hyperparameter rationale.
|
See the file header for hyperparameter rationale.
|
||||||
|
|||||||
@@ -1,29 +1,23 @@
|
|||||||
"""Collect (obs, action) demonstrations from the sequential teacher.
|
"""Collect (obs, action) demonstrations from an analytic teacher.
|
||||||
|
|
||||||
Runs the sequential algorithm across a grid of (n_sheep, seed) combos
|
Runs the chosen teacher across a grid of ``(n_sheep, seed)`` combos at
|
||||||
at full difficulty, logs the (observation, action) pair every Nth step,
|
full difficulty, logs every Nth ``(obs, action)`` pair, and saves
|
||||||
and saves successful trajectories to a numpy ``.npz`` for behavior
|
successful trajectories to ``.npz`` for behaviour cloning. The teacher
|
||||||
cloning. Failed trajectories are dropped by default — we only want to
|
is wrapped in :class:`ActiveScanTeacher` by default so it operates on
|
||||||
teach the policy from good examples.
|
the same partial-obs view the student will have at deployment.
|
||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
|
|
||||||
python -m tools.collect_demos --out training/demos.npz
|
python -m training.bc.collect --teacher strombom \\
|
||||||
|
--out training/bc/demos.npz --frame-stack 4
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
|
||||||
if _PROJECT_ROOT not in sys.path:
|
|
||||||
sys.path.insert(0, _PROJECT_ROOT)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from herding.control.active_scan import ActiveScanTeacher
|
from herding.control.active_scan import ActiveScanTeacher
|
||||||
@@ -33,9 +27,6 @@ from herding.control.strombom import compute_action as strombom_action
|
|||||||
from training.herding_env import HerdingEnv
|
from training.herding_env import HerdingEnv
|
||||||
|
|
||||||
|
|
||||||
# Base analytic teachers (no scanning). The default at demo-collection
|
|
||||||
# time wraps these in ActiveScanTeacher, which under LiDAR makes the
|
|
||||||
# teacher operate on the same partial obs as the student.
|
|
||||||
TEACHERS = {
|
TEACHERS = {
|
||||||
"sequential": sequential_action,
|
"sequential": sequential_action,
|
||||||
"strombom": strombom_action,
|
"strombom": strombom_action,
|
||||||
@@ -48,13 +39,13 @@ def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int,
|
|||||||
difficulty=1.0, seed=seed, frame_stack=frame_stack)
|
difficulty=1.0, seed=seed, frame_stack=frame_stack)
|
||||||
obs, _ = env.reset(seed=seed)
|
obs, _ = env.reset(seed=seed)
|
||||||
obs_list, action_list = [], []
|
obs_list, action_list = [], []
|
||||||
# Active-scan wrapper: scan first, then run the base teacher on the
|
# Wrap the base teacher so it opens with a rotation and walks to
|
||||||
# tracker dict. Reset state per episode so the opening scan kicks in.
|
# centre when the tracker briefly empties — matches the student.
|
||||||
scan_teacher = ActiveScanTeacher(teacher_fn)
|
scan_teacher = ActiveScanTeacher(teacher_fn)
|
||||||
for step in range(max_steps):
|
for step in range(max_steps):
|
||||||
if privileged:
|
if privileged:
|
||||||
# Asymmetric "learning by cheating": teacher reads GT, student
|
# Asymmetric variant: teacher reads ground truth while the
|
||||||
# gets LiDAR obs. Kept available for ablation; default off.
|
# student keeps the LiDAR obs. Default off.
|
||||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||||
for i in range(env.n_sheep) if not env.sheep_penned[i]}
|
for i in range(env.n_sheep) if not env.sheep_penned[i]}
|
||||||
if not positions:
|
if not positions:
|
||||||
@@ -63,9 +54,6 @@ def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int,
|
|||||||
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
|
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Matched-perception teacher: it sees what the student sees
|
|
||||||
# (the tracker dict), with active scanning to fill the
|
|
||||||
# tracker before driving.
|
|
||||||
positions = env.perceived_positions()
|
positions = env.perceived_positions()
|
||||||
vx, vy, _mode = scan_teacher(
|
vx, vy, _mode = scan_teacher(
|
||||||
(env.dog_x, env.dog_y), env.dog_heading,
|
(env.dog_x, env.dog_y), env.dog_heading,
|
||||||
@@ -89,7 +77,7 @@ def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int,
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--out", default="training/demos.npz")
|
parser.add_argument("--out", default="training/bc/demos.npz")
|
||||||
parser.add_argument("--n-sheep-list", default="1,2,3,5,8,10")
|
parser.add_argument("--n-sheep-list", default="1,2,3,5,8,10")
|
||||||
parser.add_argument("--seeds-per-n", type=int, default=15)
|
parser.add_argument("--seeds-per-n", type=int, default=15)
|
||||||
parser.add_argument("--max-steps", type=int, default=30000)
|
parser.add_argument("--max-steps", type=int, default=30000)
|
||||||
@@ -101,13 +89,11 @@ def main():
|
|||||||
choices=list(TEACHERS.keys()),
|
choices=list(TEACHERS.keys()),
|
||||||
help="Which analytic teacher to demonstrate.")
|
help="Which analytic teacher to demonstrate.")
|
||||||
parser.add_argument("--frame-stack", type=int, default=1,
|
parser.add_argument("--frame-stack", type=int, default=1,
|
||||||
help="K — concatenate the last K env obs into a "
|
help="Concatenate the last K obs into a "
|
||||||
"single (32·K)-D vector. Lets a memoryless "
|
"(32·K)-D vector for the policy.")
|
||||||
"MLP recover temporal info under partial "
|
|
||||||
"LiDAR observability.")
|
|
||||||
parser.add_argument("--privileged", action="store_true",
|
parser.add_argument("--privileged", action="store_true",
|
||||||
help="Teacher reads ground truth (asymmetric BC). "
|
help="Teacher reads ground truth instead of "
|
||||||
"Default: matched-perception with active scan.")
|
"tracker output (asymmetric BC).")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
teacher_fn = TEACHERS[args.teacher]
|
teacher_fn = TEACHERS[args.teacher]
|
||||||
print(f"[demos] teacher: {args.teacher}")
|
print(f"[demos] teacher: {args.teacher}")
|
||||||
@@ -1,36 +1,27 @@
|
|||||||
"""Behavior cloning of an analytic teacher into an SB3-compatible policy.
|
"""Behaviour cloning of an analytic teacher into an SB3 MlpPolicy.
|
||||||
|
|
||||||
Trains the policy network (mean-action head) of an SB3 ``MlpPolicy``
|
Trains the mean-action head against ``(obs, action)`` demos from
|
||||||
to mimic the (obs, action) demonstrations produced by
|
``training.bc.collect`` using ``MSE + (1 − cos_sim)`` — the cosine
|
||||||
``tools.collect_demos``. The saved zip is loadable via ``PPO.load(...)``
|
term prevents collapse toward zero against unit-vector targets. The
|
||||||
and is what the Webots dog controller uses in ``HERDING_MODE=rl``.
|
best-by-val_cos snapshot is restored at the end of training because
|
||||||
|
multi-modal teachers make the last epoch unreliable.
|
||||||
|
|
||||||
Loss: MSE + (1 - cosine similarity). The cosine term is what stops
|
Output zip is loadable by ``PPO.load(...)`` and consumed by
|
||||||
the policy mean from collapsing toward zero against unit-vector
|
``HERDING_MODE=bc`` in the dog controller.
|
||||||
targets. Best-by-val_cos checkpoint is restored at the end of training
|
|
||||||
so noisy multi-modal teachers (e.g. Strömbom) don't lose progress when
|
|
||||||
the last epoch lands on a bad gradient step.
|
|
||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
|
|
||||||
python -m training.bc_pretrain \\
|
python -m training.bc.pretrain \\
|
||||||
--demos training/demos.npz \\
|
--demos training/bc/demos.npz \\
|
||||||
--out training/runs/bc
|
--out training/runs/bc
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
|
||||||
if _PROJECT_ROOT not in sys.path:
|
|
||||||
sys.path.insert(0, _PROJECT_ROOT)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -64,25 +55,21 @@ def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
|
|||||||
|
|
||||||
|
|
||||||
def policy_forward_mean(policy, obs_batch):
|
def policy_forward_mean(policy, obs_batch):
|
||||||
"""Return the policy's deterministic mean action for a batch.
|
"""Return the deterministic mean action for an obs batch.
|
||||||
|
|
||||||
SB3's ActorCriticPolicy doesn't expose this directly — it goes
|
SB3's ActorCriticPolicy routes ``forward`` through a Distribution
|
||||||
through a Distribution wrapper. We replicate the forward path:
|
wrapper; we replicate the underlying chain
|
||||||
extract_features → mlp_extractor → action_net.
|
``extract_features → mlp_extractor → action_net``.
|
||||||
"""
|
"""
|
||||||
features = policy.extract_features(obs_batch)
|
features = policy.extract_features(obs_batch)
|
||||||
if isinstance(features, tuple):
|
pi_features = features[0] if isinstance(features, tuple) else features
|
||||||
# SB3 ≥ 2.0 sometimes returns (pi_features, vf_features)
|
latent_pi, _ = policy.mlp_extractor(pi_features)
|
||||||
pi_features = features[0]
|
|
||||||
else:
|
|
||||||
pi_features = features
|
|
||||||
latent_pi, _latent_vf = policy.mlp_extractor(pi_features)
|
|
||||||
return policy.action_net(latent_pi)
|
return policy.action_net(latent_pi)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--demos", default="training/demos.npz")
|
parser.add_argument("--demos", default="training/bc/demos.npz")
|
||||||
parser.add_argument("--out", default="training/runs/bc")
|
parser.add_argument("--out", default="training/runs/bc")
|
||||||
parser.add_argument("--epochs", type=int, default=60)
|
parser.add_argument("--epochs", type=int, default=60)
|
||||||
parser.add_argument("--batch-size", type=int, default=256)
|
parser.add_argument("--batch-size", type=int, default=256)
|
||||||
@@ -92,12 +79,8 @@ def main():
|
|||||||
help="Comma-separated hidden layer widths.")
|
help="Comma-separated hidden layer widths.")
|
||||||
parser.add_argument("--log-std-init", type=float, default=0.5)
|
parser.add_argument("--log-std-init", type=float, default=0.5)
|
||||||
parser.add_argument("--cos-weight", type=float, default=1.0,
|
parser.add_argument("--cos-weight", type=float, default=1.0,
|
||||||
help="Weight on (1 - cosine similarity) loss term. "
|
help="Weight of the (1 - cosine_similarity) loss "
|
||||||
"MSE alone shrinks policy output toward zero "
|
"term; balances against MSE.")
|
||||||
"(zero-magnitude action minimises mean squared "
|
|
||||||
"error against ±1 targets); cos loss keeps "
|
|
||||||
"the action pointed correctly even at small "
|
|
||||||
"magnitudes.")
|
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--device", default="cpu")
|
parser.add_argument("--device", default="cpu")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -115,7 +98,6 @@ def main():
|
|||||||
if obs.size == 0:
|
if obs.size == 0:
|
||||||
raise RuntimeError("Empty demo file.")
|
raise RuntimeError("Empty demo file.")
|
||||||
|
|
||||||
# Action sanity check — sequential outputs unit vectors.
|
|
||||||
a_norms = np.linalg.norm(actions, axis=1)
|
a_norms = np.linalg.norm(actions, axis=1)
|
||||||
print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} "
|
print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} "
|
||||||
f"min={a_norms.min():.3f} max={a_norms.max():.3f}")
|
f"min={a_norms.min():.3f} max={a_norms.max():.3f}")
|
||||||
@@ -138,13 +120,11 @@ def main():
|
|||||||
batch_size=args.batch_size, shuffle=False,
|
batch_size=args.batch_size, shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Build model ---
|
|
||||||
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
|
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
|
||||||
net_arch_vf = net_arch_pi[:]
|
net_arch_vf = net_arch_pi[:]
|
||||||
# Auto-detect frame stacking from the demo file so a stacked-obs
|
# Frame stack is inferred from the demo obs dim.
|
||||||
# demo trains a stacked-obs policy without an extra CLI flag.
|
|
||||||
obs_dim = obs.shape[1]
|
obs_dim = obs.shape[1]
|
||||||
from herding.obs import OBS_DIM as _SINGLE
|
from herding.perception.obs import OBS_DIM as _SINGLE
|
||||||
if obs_dim % _SINGLE != 0:
|
if obs_dim % _SINGLE != 0:
|
||||||
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
|
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
|
||||||
frame_stack = obs_dim // _SINGLE
|
frame_stack = obs_dim // _SINGLE
|
||||||
@@ -161,10 +141,7 @@ def main():
|
|||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
best_val = float("inf")
|
best_val = float("inf")
|
||||||
best_cos = -1.0
|
best_cos = -1.0
|
||||||
# Snapshot the best-by-val_cos policy weights and restore at the end —
|
best_state = None # restored at the end so noisy last epochs don't win
|
||||||
# training is noisy on multi-modal teachers (e.g. Strömbom collect/drive),
|
|
||||||
# so the last epoch is often worse than an earlier one.
|
|
||||||
best_state = None
|
|
||||||
|
|
||||||
def combined_loss(pred, target):
|
def combined_loss(pred, target):
|
||||||
mse = nn.functional.mse_loss(pred, target)
|
mse = nn.functional.mse_loss(pred, target)
|
||||||
@@ -205,8 +182,6 @@ def main():
|
|||||||
val_total += nn.functional.mse_loss(
|
val_total += nn.functional.mse_loss(
|
||||||
mean_action, act_batch, reduction="sum",
|
mean_action, act_batch, reduction="sum",
|
||||||
).item()
|
).item()
|
||||||
# Cosine similarity in action space — useful sanity for
|
|
||||||
# "is the policy pointing the same way as the teacher?".
|
|
||||||
m_norm = mean_action.norm(dim=1).clamp_min(1e-6)
|
m_norm = mean_action.norm(dim=1).clamp_min(1e-6)
|
||||||
a_norm = act_batch.norm(dim=1).clamp_min(1e-6)
|
a_norm = act_batch.norm(dim=1).clamp_min(1e-6)
|
||||||
cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm)
|
cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm)
|
||||||
+26
-38
@@ -1,27 +1,19 @@
|
|||||||
"""Evaluate a trained PPO policy (or the Strömbom baseline) on the env.
|
"""Env-side evaluation of analytic or learned policies.
|
||||||
|
|
||||||
Reports success rate and time-to-pen across a fixed seed grid for each
|
Reports success rate, mean steps and mean penned per flock size for
|
||||||
flock size 1..MAX_SHEEP. Used to produce the M5 quantitative comparison
|
``n_sheep ∈ 1..max_flock`` across ``--n-seeds`` seeds each.
|
||||||
table mentioned in plan.md.
|
|
||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
|
|
||||||
python -m training.eval --policy training/runs/latest/best
|
python -m training.eval --policy training/runs/rl --n-seeds 10
|
||||||
python -m training.eval --policy strombom
|
python -m training.eval --policy strombom
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from statistics import mean, stdev
|
from statistics import mean
|
||||||
|
|
||||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
|
||||||
if _PROJECT_ROOT not in sys.path:
|
|
||||||
sys.path.insert(0, _PROJECT_ROOT)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -33,40 +25,38 @@ from training.herding_env import HerdingEnv
|
|||||||
|
|
||||||
def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
|
def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
success = False
|
|
||||||
for t in range(max_steps):
|
for t in range(max_steps):
|
||||||
action = predict_fn(env, obs)
|
action = predict_fn(env, obs)
|
||||||
obs, _r, terminated, truncated, info = env.step(action)
|
obs, _r, terminated, truncated, info = env.step(action)
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
success = bool(info.get("is_success", False))
|
return {
|
||||||
return {"success": success, "steps": info.get("steps", t + 1),
|
"success": bool(info.get("is_success", False)),
|
||||||
"n_penned": info.get("n_penned", 0)}
|
"steps": info.get("steps", t + 1),
|
||||||
return {"success": False, "steps": max_steps, "n_penned": int(env.sheep_penned.sum())}
|
"n_penned": info.get("n_penned", 0),
|
||||||
|
}
|
||||||
|
return {"success": False, "steps": max_steps,
|
||||||
|
"n_penned": int(env.sheep_penned.sum())}
|
||||||
|
|
||||||
|
|
||||||
def make_analytic_predictor(action_fn):
|
def make_analytic_predictor(action_fn):
|
||||||
|
"""Wrap an analytic teacher so it runs on the env's exposed
|
||||||
|
perception (tracker in LiDAR mode, GT in privileged mode)."""
|
||||||
def _predict(env, _obs):
|
def _predict(env, _obs):
|
||||||
# Use whatever perception the env exposes — tracker output in
|
|
||||||
# LiDAR mode, ground truth in privileged mode. This makes
|
|
||||||
# evaluation honest: the analytic teacher sees what the
|
|
||||||
# deployed controller would see.
|
|
||||||
positions = env.perceived_positions()
|
positions = env.perceived_positions()
|
||||||
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
||||||
return np.array([vx, vy], dtype=np.float32)
|
return np.array([vx, vy], dtype=np.float32)
|
||||||
return _predict
|
return _predict
|
||||||
|
|
||||||
|
|
||||||
# Backwards-compat alias.
|
|
||||||
def make_strombom_predictor():
|
def make_strombom_predictor():
|
||||||
return make_analytic_predictor(strombom_action)
|
return make_analytic_predictor(strombom_action)
|
||||||
|
|
||||||
|
|
||||||
def make_policy_predictor(model, vecnorm):
|
def make_policy_predictor(model, vecnorm):
|
||||||
def _predict(_env, obs):
|
def _predict(_env, obs):
|
||||||
if vecnorm is not None:
|
|
||||||
obs_b = vecnorm.normalize_obs(np.asarray(obs, dtype=np.float32).reshape(1, -1))
|
|
||||||
else:
|
|
||||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||||
|
if vecnorm is not None:
|
||||||
|
obs_b = vecnorm.normalize_obs(obs_b)
|
||||||
action, _ = model.predict(obs_b, deterministic=True)
|
action, _ = model.predict(obs_b, deterministic=True)
|
||||||
return action[0]
|
return action[0]
|
||||||
return _predict
|
return _predict
|
||||||
@@ -75,16 +65,17 @@ def make_policy_predictor(model, vecnorm):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--policy", required=True,
|
parser.add_argument("--policy", required=True,
|
||||||
help="Either 'strombom' or path to an SB3 run directory.")
|
help="'strombom', 'sequential', or path to a "
|
||||||
|
"policy directory / zip.")
|
||||||
parser.add_argument("--n-seeds", type=int, default=10)
|
parser.add_argument("--n-seeds", type=int, default=10)
|
||||||
parser.add_argument("--max-steps", type=int, default=5000)
|
parser.add_argument("--max-steps", type=int, default=5000)
|
||||||
parser.add_argument("--max-flock", type=int, default=MAX_SHEEP)
|
parser.add_argument("--max-flock", type=int, default=MAX_SHEEP)
|
||||||
# 1.0 = deployment distribution (sheep anywhere in field).
|
parser.add_argument("--difficulty", type=float, default=1.0,
|
||||||
# Lower values use the training-curriculum spawn band (sheep near gate).
|
help="0 = sheep spawn near the gate (easy); "
|
||||||
parser.add_argument("--difficulty", type=float, default=1.0)
|
"1 = full field (deployment distribution).")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
frame_stack = 1 # default; analytic predictors don't use stacked obs
|
frame_stack = 1
|
||||||
if args.policy == "strombom":
|
if args.policy == "strombom":
|
||||||
predict = make_analytic_predictor(strombom_action)
|
predict = make_analytic_predictor(strombom_action)
|
||||||
elif args.policy == "sequential":
|
elif args.policy == "sequential":
|
||||||
@@ -92,23 +83,20 @@ def main():
|
|||||||
else:
|
else:
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
run = Path(args.policy)
|
run = Path(args.policy)
|
||||||
# Resolve to a zip: directory of checkpoints, or a direct zip path.
|
|
||||||
if run.is_file():
|
if run.is_file():
|
||||||
zip_path = run
|
zip_path = run
|
||||||
else:
|
else:
|
||||||
for name in ("best_model.zip", "policy.zip", "final.zip"):
|
for name in ("policy.zip", "final.zip"):
|
||||||
if (run / name).exists():
|
if (run / name).exists():
|
||||||
zip_path = run / name
|
zip_path = run / name
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"No checkpoint found in {run} (tried best_model.zip, "
|
f"No checkpoint found in {run} "
|
||||||
f"policy.zip, final.zip)"
|
f"(tried policy.zip, final.zip)"
|
||||||
)
|
)
|
||||||
model = PPO.load(str(zip_path), device="auto")
|
model = PPO.load(str(zip_path), device="auto")
|
||||||
# Auto-detect frame stacking from the policy's expected obs dim,
|
from herding.perception.obs import OBS_DIM as _SINGLE
|
||||||
# so eval runs with whatever stacking the policy was trained on.
|
|
||||||
from herding.obs import OBS_DIM as _SINGLE
|
|
||||||
policy_obs_dim = int(model.observation_space.shape[0])
|
policy_obs_dim = int(model.observation_space.shape[0])
|
||||||
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
|
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
|
||||||
frame_stack = policy_obs_dim // _SINGLE
|
frame_stack = policy_obs_dim // _SINGLE
|
||||||
|
|||||||
+67
-173
@@ -1,61 +1,30 @@
|
|||||||
"""Gymnasium environment for the shepherd-dog herding task.
|
"""Gymnasium environment for the shepherd-dog herding task.
|
||||||
|
|
||||||
Single-agent: the agent is the dog. Sheep are environment-controlled
|
Single-agent: the dog is the policy; sheep are env-controlled flocking
|
||||||
flocking agents whose dynamics are imported verbatim from
|
agents (``herding.world.flocking_sim``). Differential-drive kinematics
|
||||||
``herding.flocking_sim`` so a policy trained here transfers to Webots
|
match the proto specs (``herding.world.diffdrive``) so a policy trained
|
||||||
without re-tuning. Differential-drive kinematics for both dog and sheep
|
here transfers to Webots without re-tuning.
|
||||||
match the proto specs (wheel radius, base, max wheel ω) via
|
|
||||||
``herding.diffdrive``.
|
|
||||||
|
|
||||||
Action space
|
* **Action**: ``Box(-1, 1, (2,))`` — desired ``(vx, vy)`` intent.
|
||||||
------------
|
* **Observation**: ``Box(-inf, inf, (32·K,))`` from ``herding.perception.obs.build_obs``
|
||||||
Box(-1, 1, (2,)) — the dog's desired (vx, vy) velocity *intent*. This
|
with optional frame stacking K (concatenated oldest → newest).
|
||||||
matches the high-level action representation the Webots controller
|
* **Reset**: ``options["n_sheep"]`` overrides flock size; otherwise
|
||||||
already uses; the env converts (vx, vy) → wheel speeds with the same
|
sampled uniformly from ``[1, max_n_sheep]``.
|
||||||
formula.
|
* **Reward**: dense shaping (per-sheep distance progress, time
|
||||||
|
penalty, Strömbom-imitation cosine bonus) + sparse pen/done jackpots.
|
||||||
Observation space
|
Weights live as class attributes on :class:`HerdingEnv`.
|
||||||
-----------------
|
|
||||||
Box(-inf, inf, (28,)) — the order-invariant feature vector built by
|
|
||||||
``herding.obs.build_obs``. See ``herding/obs.py`` for the layout.
|
|
||||||
|
|
||||||
Reset
|
|
||||||
-----
|
|
||||||
``options["n_sheep"]`` (1..MAX_SHEEP) overrides the default flock size
|
|
||||||
for the next episode. If absent, flock size is sampled uniformly from
|
|
||||||
[1, max_n_sheep] each reset, where ``max_n_sheep`` can be raised over
|
|
||||||
training time by an outer callback.
|
|
||||||
|
|
||||||
Reward
|
|
||||||
------
|
|
||||||
Sparse + shaping (see :func:`HerdingEnv._compute_reward` for weights).
|
|
||||||
|
|
||||||
+2.0 per newly penned sheep
|
|
||||||
+0.5 · ΔCoM-distance-to-pen (positive when CoM moves closer)
|
|
||||||
+0.2 · ΔFlock-radius (positive when flock tightens)
|
|
||||||
-0.005 per step (encourages speed)
|
|
||||||
- wall and collision penalties
|
|
||||||
+10.0 terminal bonus when all sheep penned
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
# Make herding/ importable when run from anywhere.
|
|
||||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
|
||||||
if _PROJECT_ROOT not in sys.path:
|
|
||||||
sys.path.insert(0, _PROJECT_ROOT)
|
|
||||||
|
|
||||||
from herding.world.diffdrive import (
|
from herding.world.diffdrive import (
|
||||||
heading_speed_to_wheels, kinematics_step, velocity_to_wheels,
|
heading_speed_to_wheels, kinematics_step, velocity_to_wheels,
|
||||||
)
|
)
|
||||||
@@ -71,7 +40,7 @@ from herding.world.geometry import (
|
|||||||
)
|
)
|
||||||
from herding.perception.lidar_perception import detections_from_scan
|
from herding.perception.lidar_perception import detections_from_scan
|
||||||
from herding.perception.lidar_sim import simulate_scan
|
from herding.perception.lidar_sim import simulate_scan
|
||||||
from herding.obs import OBS_DIM, build_obs
|
from herding.perception.obs import OBS_DIM, build_obs
|
||||||
from herding.perception.sheep_tracker import SheepTracker
|
from herding.perception.sheep_tracker import SheepTracker
|
||||||
from herding.control.strombom import compute_action as strombom_action
|
from herding.control.strombom import compute_action as strombom_action
|
||||||
|
|
||||||
@@ -85,45 +54,23 @@ class HerdingEnv(gym.Env):
|
|||||||
|
|
||||||
metadata = {"render_modes": []}
|
metadata = {"render_modes": []}
|
||||||
|
|
||||||
# Reward shaping weights. Re-tuned after the first run got stuck at
|
# Reward weights. Sparse jackpots (W_PEN_DELTA, W_DONE) dominate;
|
||||||
# 0% success: progress reward must dominate the time penalty by a
|
# dense shaping (W_PROGRESS on Δ mean-distance-to-pen) provides the
|
||||||
# large margin, and the pen-event bonus must be big enough that PPO's
|
# gradient; W_IMITATE adds a small cosine bonus toward the analytic
|
||||||
# advantage estimator can credit-assign across the long path that
|
# teacher's action; W_TIME is a per-step penalty (0 by default).
|
||||||
# leads to it. Per-step shaping is bounded by the clamps inside
|
|
||||||
# _compute_reward.
|
|
||||||
# Drastically simplified after two runs got stuck farming a position
|
|
||||||
# bonus instead of penning sheep. Reward now is essentially:
|
|
||||||
# • huge jackpot for actually penning sheep (+100 per pen, +500 done)
|
|
||||||
# • small dense gradient: per-sheep mean distance to pen
|
|
||||||
# No position shaping (gameable), no compactness shaping (gameable),
|
|
||||||
# no engagement bonus (gameable). The terminal per-unpenned penalty
|
|
||||||
# forbids "good enough" partial herds.
|
|
||||||
# We have a working analytic baseline (Strömbom, 100 % on easy mode).
|
|
||||||
# Use it as a teacher: per-step bonus proportional to the cosine
|
|
||||||
# similarity between the policy's action and what Strömbom would do
|
|
||||||
# at the same state. This drags the policy out of "do nothing" local
|
|
||||||
# optima without locking it to the teacher — PPO can still find
|
|
||||||
# improvements over Strömbom because pen jackpots dominate.
|
|
||||||
W_PEN_DELTA = 100.0
|
W_PEN_DELTA = 100.0
|
||||||
W_PROGRESS = 20.0
|
W_PROGRESS = 20.0
|
||||||
W_IMITATE = 0.5 # per-step max ±0.5 (action cosine sim, [-1, 1])
|
W_IMITATE = 0.5
|
||||||
W_TIME = 0.0
|
W_TIME = 0.0
|
||||||
W_WALL = 0.0
|
W_WALL = 0.0
|
||||||
W_COLLISION = 0.0
|
W_COLLISION = 0.0
|
||||||
W_DONE = 500.0
|
W_DONE = 500.0
|
||||||
|
|
||||||
# Action smoothing during training: 0 = none. The Webots controller
|
# In-env action EMA. 0 = none; the Webots controller applies its own
|
||||||
# still applies its own EMA at inference for actuator stability, so
|
# EMA at inference, so the policy needn't learn smoothness.
|
||||||
# the policy doesn't need to learn smoothness explicitly.
|
|
||||||
ACTION_SMOOTH = 0.0
|
ACTION_SMOOTH = 0.0
|
||||||
|
|
||||||
# Episode budget. ~80 s of sim time at dt=0.016. The new external-pen
|
|
||||||
# layout has paths up to ~28 m from spawn to pen entry; at sheep flee
|
|
||||||
# speed ~0.4 m/s, that's 70 s minimum. 3000 steps (48 s) was leaving
|
|
||||||
# the dog with no margin for collect-then-drive on multi-sheep cases.
|
|
||||||
DEFAULT_MAX_STEPS = 5000
|
DEFAULT_MAX_STEPS = 5000
|
||||||
|
|
||||||
# Distance under which the dog is considered "colliding" with a sheep.
|
|
||||||
COLLISION_DIST = 0.30
|
COLLISION_DIST = 0.30
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -137,19 +84,15 @@ class HerdingEnv(gym.Env):
|
|||||||
frame_stack: int = 1,
|
frame_stack: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# When True (default), the obs and the imitation-reward teacher
|
# ``use_lidar=True`` (default): obs and imitation-reward teacher
|
||||||
# see only LiDAR-perceived sheep positions through a tracker —
|
# see only LiDAR-perceived positions via a tracker, matching the
|
||||||
# matching what the Webots controller has access to. When False,
|
# Webots controller. ``False`` exposes ground truth for ablation.
|
||||||
# both consume ground-truth positions (legacy "privileged" mode,
|
|
||||||
# kept for ablation).
|
|
||||||
self._use_lidar = bool(use_lidar)
|
self._use_lidar = bool(use_lidar)
|
||||||
self._tracker = SheepTracker() if self._use_lidar else None
|
self._tracker = SheepTracker() if self._use_lidar else None
|
||||||
self._np_rng_lidar: Optional[np.random.Generator] = None
|
self._np_rng_lidar: Optional[np.random.Generator] = None
|
||||||
|
|
||||||
# Frame stacking: the policy receives the last K single-frame
|
# Frame stacking: the policy receives the last K obs concatenated,
|
||||||
# observations concatenated. Lets a memoryless MLP integrate
|
# giving a memoryless MLP temporal context. K=1 → single frame.
|
||||||
# information across time, partly compensating for the limited
|
|
||||||
# LiDAR FOV. K=1 reproduces the legacy single-frame obs.
|
|
||||||
self._frame_stack = max(1, int(frame_stack))
|
self._frame_stack = max(1, int(frame_stack))
|
||||||
self._frame_buffer: list[np.ndarray] = []
|
self._frame_buffer: list[np.ndarray] = []
|
||||||
self.action_space = spaces.Box(-1.0, 1.0, shape=(2,), dtype=np.float32)
|
self.action_space = spaces.Box(-1.0, 1.0, shape=(2,), dtype=np.float32)
|
||||||
@@ -159,18 +102,16 @@ class HerdingEnv(gym.Env):
|
|||||||
shape=(OBS_DIM * self._frame_stack,), dtype=np.float32,
|
shape=(OBS_DIM * self._frame_stack,), dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If n_sheep is None, env will sample uniformly from [1, max_n_sheep]
|
# n_sheep=None → sample uniformly from [1, max_n_sheep] each reset.
|
||||||
# on every reset — this is the default for curriculum-free training.
|
|
||||||
self._fixed_n_sheep = n_sheep
|
self._fixed_n_sheep = n_sheep
|
||||||
self._max_n_sheep = max_n_sheep
|
self._max_n_sheep = max_n_sheep
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
# difficulty ∈ [0, 1]: 0 = sheep spawn next to the gate (easy),
|
# difficulty ∈ [0, 1]: 0 = sheep spawn near the gate (easy);
|
||||||
# 1 = sheep spawn anywhere in the field (hard, the deployment
|
# 1 = sheep spawn anywhere in the field (deployment distribution).
|
||||||
# distribution). Curriculum bumps this from 0 → 1 over training.
|
|
||||||
self._difficulty = float(difficulty)
|
self._difficulty = float(difficulty)
|
||||||
self._initial_seed = seed
|
self._initial_seed = seed
|
||||||
|
|
||||||
# State (initialized in reset)
|
# State (initialised in reset)
|
||||||
self.dog_x = self.dog_y = self.dog_heading = 0.0
|
self.dog_x = self.dog_y = self.dog_heading = 0.0
|
||||||
self.sheep_x = np.zeros(0, dtype=np.float32)
|
self.sheep_x = np.zeros(0, dtype=np.float32)
|
||||||
self.sheep_y = np.zeros(0, dtype=np.float32)
|
self.sheep_y = np.zeros(0, dtype=np.float32)
|
||||||
@@ -186,12 +127,10 @@ class HerdingEnv(gym.Env):
|
|||||||
self.prev_d_pen = 0.0
|
self.prev_d_pen = 0.0
|
||||||
self.prev_radius = 0.0
|
self.prev_radius = 0.0
|
||||||
|
|
||||||
# Env-owned RNG for the flocking wander-jitter, seeded fresh on each
|
# Env-owned RNG for wander jitter, re-seeded from np_random on reset.
|
||||||
# reset so determinism is preserved without touching the global
|
|
||||||
# random module.
|
|
||||||
self._py_rng = random.Random()
|
self._py_rng = random.Random()
|
||||||
|
|
||||||
# ---- public knobs (used by curriculum callback) ----
|
# --- Public knobs ---
|
||||||
def set_max_n_sheep(self, value: int) -> None:
|
def set_max_n_sheep(self, value: int) -> None:
|
||||||
self._max_n_sheep = int(np.clip(value, 1, MAX_SHEEP))
|
self._max_n_sheep = int(np.clip(value, 1, MAX_SHEEP))
|
||||||
|
|
||||||
@@ -199,22 +138,18 @@ class HerdingEnv(gym.Env):
|
|||||||
self._difficulty = float(np.clip(value, 0.0, 1.0))
|
self._difficulty = float(np.clip(value, 0.0, 1.0))
|
||||||
|
|
||||||
def set_imitate_weight(self, value: float) -> None:
|
def set_imitate_weight(self, value: float) -> None:
|
||||||
"""Override W_IMITATE (instance-level) — used to disable the
|
"""Override the instance W_IMITATE — used to disable Strömbom
|
||||||
Strömbom imitation reward during BC fine-tuning, when the policy
|
imitation during PPO fine-tune."""
|
||||||
already mimics a stronger teacher (sequential)."""
|
|
||||||
self.W_IMITATE = float(value)
|
self.W_IMITATE = float(value)
|
||||||
|
|
||||||
def set_time_weight(self, value: float) -> None:
|
def set_time_weight(self, value: float) -> None:
|
||||||
"""Override W_TIME (instance-level). Default 0.0; a small
|
"""Override the instance W_TIME — set negative to penalise step
|
||||||
negative value (e.g. -0.1) adds a per-step penalty that
|
count and encourage faster time-to-pen during PPO fine-tune."""
|
||||||
explicitly rewards fast time-to-pen during PPO fine-tune."""
|
|
||||||
self.W_TIME = float(value)
|
self.W_TIME = float(value)
|
||||||
|
|
||||||
# ---- gym API ----
|
# --- gym API ---
|
||||||
def reset(self, *, seed=None, options=None):
|
def reset(self, *, seed=None, options=None):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
# Re-seed the flocking RNG from np_random so flocking jitter is
|
|
||||||
# reproducible alongside everything else the env samples.
|
|
||||||
self._py_rng.seed(int(self.np_random.integers(0, 2**31 - 1)))
|
self._py_rng.seed(int(self.np_random.integers(0, 2**31 - 1)))
|
||||||
opts = options or {}
|
opts = options or {}
|
||||||
|
|
||||||
@@ -230,28 +165,26 @@ class HerdingEnv(gym.Env):
|
|||||||
self.dog_y = float(self.np_random.uniform(-2.5, 2.5))
|
self.dog_y = float(self.np_random.uniform(-2.5, 2.5))
|
||||||
self.dog_heading = float(self.np_random.uniform(-math.pi, math.pi))
|
self.dog_heading = float(self.np_random.uniform(-math.pi, math.pi))
|
||||||
|
|
||||||
# Sheep spawn region scales with difficulty:
|
# Sheep spawn region linearly interpolates with difficulty:
|
||||||
# 0.0 → narrow box just north of the gate (x ∈ [7, 14], y ∈ [-12, -6])
|
# 0 → small box near the gate, 1 → full field.
|
||||||
# 1.0 → full field (x ∈ [-13, 13], y ∈ [-12, 13])
|
|
||||||
# Linear interpolation between the two for intermediate values.
|
|
||||||
d = self._difficulty
|
d = self._difficulty
|
||||||
sx_lo = 7.0 - d * 20.0 # → -13 at d=1
|
sx_lo = 7.0 - d * 20.0
|
||||||
sx_hi = 14.0 - d * 1.0 # → 13 at d=1
|
sx_hi = 14.0 - d * 1.0
|
||||||
sy_lo = -12.0 + d * 0.0 # → -12 at d=1
|
sy_lo = -12.0 + d * 0.0
|
||||||
sy_hi = -6.0 + d * 19.0 # → 13 at d=1
|
sy_hi = -6.0 + d * 19.0
|
||||||
|
|
||||||
sxs, sys_, shs, sws = [], [], [], []
|
sxs, sys_, shs, sws = [], [], [], []
|
||||||
for _ in range(self.n_sheep):
|
for _ in range(self.n_sheep):
|
||||||
for _try in range(100):
|
for _try in range(100):
|
||||||
sx = float(self.np_random.uniform(sx_lo, sx_hi))
|
sx = float(self.np_random.uniform(sx_lo, sx_hi))
|
||||||
sy = float(self.np_random.uniform(sy_lo, sy_hi))
|
sy = float(self.np_random.uniform(sy_lo, sy_hi))
|
||||||
# Reject too close to dog or to other sheep.
|
# Reject if too close to the dog or another sheep, or
|
||||||
|
# already in the gate column (would start "penned").
|
||||||
if math.hypot(sx - self.dog_x, sy - self.dog_y) < 3.0:
|
if math.hypot(sx - self.dog_x, sy - self.dog_y) < 3.0:
|
||||||
continue
|
continue
|
||||||
if any(math.hypot(sx - x, sy - y) < 1.5
|
if any(math.hypot(sx - x, sy - y) < 1.5
|
||||||
for x, y in zip(sxs, sys_)):
|
for x, y in zip(sxs, sys_)):
|
||||||
continue
|
continue
|
||||||
# Reject inside the gate column already (they'd start "penned").
|
|
||||||
if PEN_X[0] <= sx <= PEN_X[1] and sy < -8.0:
|
if PEN_X[0] <= sx <= PEN_X[1] and sy < -8.0:
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
@@ -275,10 +208,8 @@ class HerdingEnv(gym.Env):
|
|||||||
self._tracker.reset()
|
self._tracker.reset()
|
||||||
self._np_rng_lidar = np.random.default_rng(
|
self._np_rng_lidar = np.random.default_rng(
|
||||||
int(self.np_random.integers(0, 2**31 - 1)))
|
int(self.np_random.integers(0, 2**31 - 1)))
|
||||||
# Prime the tracker with one scan so the first obs isn't empty.
|
|
||||||
self._update_tracker()
|
self._update_tracker()
|
||||||
|
|
||||||
# Clear the frame stack — the next _build_obs will repopulate.
|
|
||||||
self._frame_buffer = []
|
self._frame_buffer = []
|
||||||
|
|
||||||
obs = self._build_obs()
|
obs = self._build_obs()
|
||||||
@@ -288,7 +219,6 @@ class HerdingEnv(gym.Env):
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
action = np.clip(np.asarray(action, dtype=np.float32), -1.0, 1.0)
|
action = np.clip(np.asarray(action, dtype=np.float32), -1.0, 1.0)
|
||||||
|
|
||||||
# EMA smoothing — the Webots controller does this too.
|
|
||||||
self.smoothed_action = (
|
self.smoothed_action = (
|
||||||
self.ACTION_SMOOTH * self.prev_action
|
self.ACTION_SMOOTH * self.prev_action
|
||||||
+ (1.0 - self.ACTION_SMOOTH) * action
|
+ (1.0 - self.ACTION_SMOOTH) * action
|
||||||
@@ -296,12 +226,11 @@ class HerdingEnv(gym.Env):
|
|||||||
self.prev_action = self.smoothed_action.copy()
|
self.prev_action = self.smoothed_action.copy()
|
||||||
vx, vy = float(self.smoothed_action[0]), float(self.smoothed_action[1])
|
vx, vy = float(self.smoothed_action[0]), float(self.smoothed_action[1])
|
||||||
|
|
||||||
# Safety supervisor mirrored from the controller — keeps the dog
|
# Safety supervisor — dog stays north of the gate.
|
||||||
# north of the gate so the policy can't strand itself in the pen.
|
|
||||||
if self.dog_y < DOG_SOUTH_LIMIT and vy < 0.0:
|
if self.dog_y < DOG_SOUTH_LIMIT and vy < 0.0:
|
||||||
vx, vy = 0.0, 1.0
|
vx, vy = 0.0, 1.0
|
||||||
|
|
||||||
# --- Step the dog ---
|
# Step the dog.
|
||||||
wL, wR = velocity_to_wheels(
|
wL, wR = velocity_to_wheels(
|
||||||
vx, vy, self.dog_heading,
|
vx, vy, self.dog_heading,
|
||||||
max_linear=DOG_MAX_LINEAR,
|
max_linear=DOG_MAX_LINEAR,
|
||||||
@@ -313,27 +242,22 @@ class HerdingEnv(gym.Env):
|
|||||||
self.dog_x, self.dog_y, self.dog_heading,
|
self.dog_x, self.dog_y, self.dog_heading,
|
||||||
wL, wR, DOG_WHEEL_RADIUS, DOG_WHEEL_BASE, WEBOTS_DT,
|
wL, wR, DOG_WHEEL_RADIUS, DOG_WHEEL_BASE, WEBOTS_DT,
|
||||||
)
|
)
|
||||||
# Clip dog to field bounds and out of pen — same as the Webots stone walls.
|
|
||||||
self.dog_x = float(np.clip(self.dog_x, FIELD_X[0] + 0.3, FIELD_X[1] - 0.3))
|
self.dog_x = float(np.clip(self.dog_x, FIELD_X[0] + 0.3, FIELD_X[1] - 0.3))
|
||||||
self.dog_y = float(np.clip(self.dog_y, DOG_SOUTH_LIMIT, FIELD_Y[1] - 0.3))
|
self.dog_y = float(np.clip(self.dog_y, DOG_SOUTH_LIMIT, FIELD_Y[1] - 0.3))
|
||||||
|
|
||||||
# --- Step each sheep ---
|
# Step sheep and update penned flags (GT-based).
|
||||||
for i in range(self.n_sheep):
|
for i in range(self.n_sheep):
|
||||||
self._step_one_sheep(i)
|
self._step_one_sheep(i)
|
||||||
|
|
||||||
# --- Update penned state ---
|
|
||||||
for i in range(self.n_sheep):
|
for i in range(self.n_sheep):
|
||||||
if (not self.sheep_penned[i]
|
if (not self.sheep_penned[i]
|
||||||
and is_penned_position(self.sheep_x[i], self.sheep_y[i])):
|
and is_penned_position(self.sheep_x[i], self.sheep_y[i])):
|
||||||
self.sheep_penned[i] = True
|
self.sheep_penned[i] = True
|
||||||
|
|
||||||
# --- Run LiDAR perception on this step's state (after sheep have
|
# LiDAR perception runs after sheep move; feeds the obs and the
|
||||||
# moved). Updates the tracker that obs and the imitation-
|
# imitation reward. Reward/termination still use GT.
|
||||||
# reward teacher consume. Reward / termination still use GT. ---
|
|
||||||
if self._tracker is not None:
|
if self._tracker is not None:
|
||||||
self._update_tracker()
|
self._update_tracker()
|
||||||
|
|
||||||
# --- Reward, termination ---
|
|
||||||
d_pen, radius = self._flock_metrics()
|
d_pen, radius = self._flock_metrics()
|
||||||
reward = self._compute_reward(d_pen, radius, action=action)
|
reward = self._compute_reward(d_pen, radius, action=action)
|
||||||
self.prev_d_pen = d_pen
|
self.prev_d_pen = d_pen
|
||||||
@@ -346,12 +270,6 @@ class HerdingEnv(gym.Env):
|
|||||||
truncated = self.steps >= self.max_steps
|
truncated = self.steps >= self.max_steps
|
||||||
if all_penned:
|
if all_penned:
|
||||||
reward += self.W_DONE
|
reward += self.W_DONE
|
||||||
# No timeout penalty: a per-unpenned penalty made "do nothing"
|
|
||||||
# strictly preferable to noisy-random under reward-progress shaping
|
|
||||||
# (random sometimes pushes sheep away → negative progress, then
|
|
||||||
# always ate the timeout penalty), which collapsed exploration to
|
|
||||||
# tiny actions. The pen jackpot alone provides the directional
|
|
||||||
# signal once exploration is wide enough to find it.
|
|
||||||
|
|
||||||
obs = self._build_obs()
|
obs = self._build_obs()
|
||||||
info = {
|
info = {
|
||||||
@@ -362,7 +280,7 @@ class HerdingEnv(gym.Env):
|
|||||||
}
|
}
|
||||||
return obs, float(reward), terminated, truncated, info
|
return obs, float(reward), terminated, truncated, info
|
||||||
|
|
||||||
# ---- internals ----
|
# --- Internals ---
|
||||||
def _step_one_sheep(self, i: int) -> None:
|
def _step_one_sheep(self, i: int) -> None:
|
||||||
x, y = float(self.sheep_x[i]), float(self.sheep_y[i])
|
x, y = float(self.sheep_x[i]), float(self.sheep_y[i])
|
||||||
peers = [(float(self.sheep_x[j]), float(self.sheep_y[j]))
|
peers = [(float(self.sheep_x[j]), float(self.sheep_y[j]))
|
||||||
@@ -386,8 +304,7 @@ class HerdingEnv(gym.Env):
|
|||||||
SHEEP_WHEEL_RADIUS, SHEEP_WHEEL_BASE, WEBOTS_DT,
|
SHEEP_WHEEL_RADIUS, SHEEP_WHEEL_BASE, WEBOTS_DT,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wall clipping — matches Webots stone walls, except in the gate column
|
# Wall clipping (south wall absent inside the gate column).
|
||||||
# where the south wall is absent.
|
|
||||||
nx = float(np.clip(nx, FIELD_X[0] + 0.2, FIELD_X[1] - 0.2))
|
nx = float(np.clip(nx, FIELD_X[0] + 0.2, FIELD_X[1] - 0.2))
|
||||||
in_gate_col = PEN_X[0] <= nx <= PEN_X[1]
|
in_gate_col = PEN_X[0] <= nx <= PEN_X[1]
|
||||||
if in_gate_col:
|
if in_gate_col:
|
||||||
@@ -400,12 +317,11 @@ class HerdingEnv(gym.Env):
|
|||||||
self.sheep_h[i] = nh
|
self.sheep_h[i] = nh
|
||||||
|
|
||||||
def _flock_metrics(self):
|
def _flock_metrics(self):
|
||||||
"""(per-sheep mean distance to pen entry, max-radius).
|
"""Return (per-sheep mean distance to pen, max radius from CoM).
|
||||||
|
|
||||||
Using the per-sheep mean instead of CoM-distance ensures stragglers
|
The per-sheep mean (not CoM distance) makes the progress signal
|
||||||
keep contributing to the progress signal — the dog can't game the
|
sensitive to stragglers: the dog can't game it by herding most of
|
||||||
shaping by herding the bulk of the flock and abandoning one
|
the flock and abandoning one outlier.
|
||||||
outlier (CoM moves toward pen, but mean-distance doesn't).
|
|
||||||
"""
|
"""
|
||||||
active_mask = ~self.sheep_penned
|
active_mask = ~self.sheep_penned
|
||||||
if not active_mask.any():
|
if not active_mask.any():
|
||||||
@@ -422,24 +338,14 @@ class HerdingEnv(gym.Env):
|
|||||||
return d_pen, radius
|
return d_pen, radius
|
||||||
|
|
||||||
def _compute_reward(self, d_pen: float, radius: float, action=None) -> float:
|
def _compute_reward(self, d_pen: float, radius: float, action=None) -> float:
|
||||||
"""Sparse + per-sheep distance shaping + Strömbom imitation.
|
"""Sparse pen jackpot + dense progress shaping + Strömbom imitation."""
|
||||||
|
|
||||||
d_pen is the *mean* distance over active sheep, so progress only
|
|
||||||
accrues when ALL active sheep get closer to the pen on average —
|
|
||||||
the dog can't farm it by herding one sheep while ignoring others.
|
|
||||||
|
|
||||||
The imitation term is computed by querying Strömbom for the
|
|
||||||
recommended action at the *current* (post-step) state and
|
|
||||||
rewarding cosine similarity with what the policy actually did.
|
|
||||||
"""
|
|
||||||
n_penned = int(self.sheep_penned.sum())
|
n_penned = int(self.sheep_penned.sum())
|
||||||
delta_pen = n_penned - self.prev_n_penned
|
delta_pen = n_penned - self.prev_n_penned
|
||||||
|
|
||||||
d_progress = max(-5.0, min(5.0, self.prev_d_pen - d_pen))
|
d_progress = max(-5.0, min(5.0, self.prev_d_pen - d_pen))
|
||||||
r = self.W_PEN_DELTA * delta_pen + self.W_PROGRESS * d_progress
|
r = (self.W_PEN_DELTA * delta_pen
|
||||||
# Per-step time penalty (0 by default). When negative, encourages
|
+ self.W_PROGRESS * d_progress
|
||||||
# the policy to finish quickly — used during PPO fine-tune.
|
+ self.W_TIME)
|
||||||
r += self.W_TIME
|
|
||||||
|
|
||||||
if action is not None and self.W_IMITATE > 0.0:
|
if action is not None and self.W_IMITATE > 0.0:
|
||||||
positions = self._perceived_positions()
|
positions = self._perceived_positions()
|
||||||
@@ -457,10 +363,7 @@ class HerdingEnv(gym.Env):
|
|||||||
|
|
||||||
def _build_single_obs(self) -> np.ndarray:
|
def _build_single_obs(self) -> np.ndarray:
|
||||||
if self._tracker is not None:
|
if self._tracker is not None:
|
||||||
# Obs sees only the tracker's active set; penned tracks are
|
# LiDAR mode: the obs sees only the tracker's active set.
|
||||||
# intentionally excluded (matches the prior receiver-based
|
|
||||||
# behaviour where penned sheep stopped contributing to the
|
|
||||||
# symbolic obs).
|
|
||||||
active = self._tracker.get_positions()
|
active = self._tracker.get_positions()
|
||||||
sheep_xy_list = list(active.values())
|
sheep_xy_list = list(active.values())
|
||||||
sheep_penned_list = [False] * len(sheep_xy_list)
|
sheep_penned_list = [False] * len(sheep_xy_list)
|
||||||
@@ -477,22 +380,18 @@ class HerdingEnv(gym.Env):
|
|||||||
single = self._build_single_obs()
|
single = self._build_single_obs()
|
||||||
if self._frame_stack <= 1:
|
if self._frame_stack <= 1:
|
||||||
return single
|
return single
|
||||||
# On a fresh reset the buffer is empty — duplicate the first
|
# On reset the buffer is empty — pad with copies of the first frame.
|
||||||
# frame so the stack is always full-length.
|
|
||||||
if not self._frame_buffer:
|
if not self._frame_buffer:
|
||||||
self._frame_buffer = [single.copy() for _ in range(self._frame_stack)]
|
self._frame_buffer = [single.copy() for _ in range(self._frame_stack)]
|
||||||
else:
|
else:
|
||||||
self._frame_buffer.append(single)
|
self._frame_buffer.append(single)
|
||||||
if len(self._frame_buffer) > self._frame_stack:
|
if len(self._frame_buffer) > self._frame_stack:
|
||||||
self._frame_buffer = self._frame_buffer[-self._frame_stack:]
|
self._frame_buffer = self._frame_buffer[-self._frame_stack:]
|
||||||
# Concatenate oldest → newest.
|
|
||||||
return np.concatenate(self._frame_buffer, axis=0).astype(np.float32)
|
return np.concatenate(self._frame_buffer, axis=0).astype(np.float32)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# --- LiDAR perception ---
|
||||||
# LiDAR perception helpers
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
def _all_sheep_xy(self) -> list[tuple[float, float]]:
|
def _all_sheep_xy(self) -> list[tuple[float, float]]:
|
||||||
"""Every sheep, including penned ones (the LiDAR sees them)."""
|
"""Every sheep, including penned (the LiDAR sees them)."""
|
||||||
return [(float(self.sheep_x[i]), float(self.sheep_y[i]))
|
return [(float(self.sheep_x[i]), float(self.sheep_y[i]))
|
||||||
for i in range(self.n_sheep)]
|
for i in range(self.n_sheep)]
|
||||||
|
|
||||||
@@ -508,19 +407,14 @@ class HerdingEnv(gym.Env):
|
|||||||
self._tracker.update(detections)
|
self._tracker.update(detections)
|
||||||
|
|
||||||
def perceived_positions(self) -> dict[str, tuple[float, float]]:
|
def perceived_positions(self) -> dict[str, tuple[float, float]]:
|
||||||
"""Public accessor — what the controller would 'see' this step.
|
"""What the controller would "see" this step: tracker output in
|
||||||
|
LiDAR mode, ground truth in privileged mode. Used by demo
|
||||||
LiDAR mode → the tracker's active set.
|
collection and analytic-policy eval so the teacher runs on the
|
||||||
Privileged mode → ground-truth active sheep.
|
same perception the deployed controller has.
|
||||||
|
|
||||||
Used by ``training.eval`` and ``tools.collect_demos`` so analytic
|
|
||||||
teachers run on the same perception the deployed controller has.
|
|
||||||
"""
|
"""
|
||||||
if self._tracker is not None:
|
if self._tracker is not None:
|
||||||
return self._tracker.get_positions()
|
return self._tracker.get_positions()
|
||||||
return {f"s{i}": (float(self.sheep_x[i]), float(self.sheep_y[i]))
|
return {f"s{i}": (float(self.sheep_x[i]), float(self.sheep_y[i]))
|
||||||
for i in range(self.n_sheep) if not self.sheep_penned[i]}
|
for i in range(self.n_sheep) if not self.sheep_penned[i]}
|
||||||
|
|
||||||
# Internal alias so the imitation reward path doesn't need to know
|
|
||||||
# which mode it's in.
|
|
||||||
_perceived_positions = perceived_positions
|
_perceived_positions = perceived_positions
|
||||||
|
|||||||
@@ -6,3 +6,4 @@ numpy>=1.24
|
|||||||
pyyaml>=6.0
|
pyyaml>=6.0
|
||||||
tensorboard>=2.14
|
tensorboard>=2.14
|
||||||
tqdm>=4.66
|
tqdm>=4.66
|
||||||
|
pytest>=8.0
|
||||||
|
|||||||
@@ -1,30 +1,17 @@
|
|||||||
"""KL-regularised PPO fine-tune of a behaviour-cloned policy.
|
"""KL-regularised PPO fine-tune of a behaviour-cloned policy.
|
||||||
|
|
||||||
The PPO-from-scratch and unregularised PPO-fine-tune-of-BC versions
|
The trainable policy is initialised from ``runs/bc/policy.zip``. A
|
||||||
we tried earlier failed for the standard reasons (sparse pen reward,
|
frozen copy of the same weights becomes the reference; each PPO loss
|
||||||
long horizons, exploration noise destroying BC weights). The fix is
|
gets an extra ``β · KL(π ‖ π_ref)`` term so the policy can only move
|
||||||
to anchor the policy to its BC initialisation with a KL penalty in
|
within a trust region around BC. ``log_std`` is fixed small to keep
|
||||||
the loss — the policy is free to refine the BC mean within a
|
exploration tight.
|
||||||
trust-region-like ball around the reference, and the dense-enough
|
|
||||||
per-step reward signal does the rest.
|
|
||||||
|
|
||||||
Pipeline
|
Output: ``runs/rl/policy.zip`` — same SB3 format as the BC checkpoint,
|
||||||
--------
|
loadable by ``HERDING_MODE=rl`` in the dog controller.
|
||||||
1. Load ``bc`` weights into both the trainable policy and a frozen
|
|
||||||
reference ``ref_policy``.
|
|
||||||
2. Initialise the policy's log_std to a small fixed value (≈ −1.5)
|
|
||||||
and disable its gradient — exploration noise stays small so PPO
|
|
||||||
updates don't blow up the BC mean before reward can stabilise.
|
|
||||||
3. Override ``PPO.train()`` to add ``β · KL(π ‖ π_ref)`` to the loss
|
|
||||||
each minibatch.
|
|
||||||
4. Train for ~1–3 M timesteps with a low LR (5e-5).
|
|
||||||
|
|
||||||
Output: ``runs/rl/policy.zip`` — same SB3 format as bc, loadable
|
|
||||||
by the dog controller's ``HERDING_MODE=rl`` path.
|
|
||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
|
|
||||||
python -m training.train_ppo \\
|
python -m training.rl.train \\
|
||||||
--bc training/runs/bc \\
|
--bc training/runs/bc \\
|
||||||
--out training/runs/rl \\
|
--out training/runs/rl \\
|
||||||
--total-timesteps 2000000
|
--total-timesteps 2000000
|
||||||
@@ -33,15 +20,8 @@ Usage::
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
|
||||||
if _PROJECT_ROOT not in sys.path:
|
|
||||||
sys.path.insert(0, _PROJECT_ROOT)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -50,7 +30,7 @@ from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
|
|||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||||
|
|
||||||
from herding.obs import OBS_DIM
|
from herding.perception.obs import OBS_DIM
|
||||||
from training.herding_env import HerdingEnv
|
from training.herding_env import HerdingEnv
|
||||||
|
|
||||||
|
|
||||||
@@ -73,15 +53,12 @@ def _make_env(rank: int, seed: int, frame_stack: int):
|
|||||||
class KLPPO(PPO):
|
class KLPPO(PPO):
|
||||||
"""PPO with an extra KL-to-reference penalty in the policy loss.
|
"""PPO with an extra KL-to-reference penalty in the policy loss.
|
||||||
|
|
||||||
Subclasses SB3's PPO and overrides ``train()`` only to add a single
|
Overrides only ``train()``; rollout buffer, clipped surrogate, value
|
||||||
line for the KL term — everything else (rollout buffer, clipped
|
loss and entropy bonus are unchanged from stock SB3 PPO.
|
||||||
surrogate, value loss, entropy bonus) is unchanged.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs):
|
def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
# ref_policy is set after construction (caller can build it
|
|
||||||
# from the BC checkpoint once `self.policy` exists).
|
|
||||||
self.ref_policy = ref_policy
|
self.ref_policy = ref_policy
|
||||||
if self.ref_policy is not None:
|
if self.ref_policy is not None:
|
||||||
self.ref_policy.set_training_mode(False)
|
self.ref_policy.set_training_mode(False)
|
||||||
@@ -90,9 +67,8 @@ class KLPPO(PPO):
|
|||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
|
|
||||||
def train(self) -> None:
|
def train(self) -> None:
|
||||||
# Copied from stable_baselines3.ppo.PPO.train (v2.x), with the
|
# Stock SB3 PPO.train() structure with the KL-to-reference term
|
||||||
# KL-to-reference term added. Keeping the structure intact so
|
# added inside the inner minibatch loop.
|
||||||
# behavioural parity with stock PPO is obvious.
|
|
||||||
self.policy.set_training_mode(True)
|
self.policy.set_training_mode(True)
|
||||||
self._update_learning_rate(self.policy.optimizer)
|
self._update_learning_rate(self.policy.optimizer)
|
||||||
clip_range = self.clip_range(self._current_progress_remaining)
|
clip_range = self.clip_range(self._current_progress_remaining)
|
||||||
@@ -139,12 +115,8 @@ class KLPPO(PPO):
|
|||||||
entropy_loss = -th.mean(entropy)
|
entropy_loss = -th.mean(entropy)
|
||||||
entropy_losses.append(entropy_loss.item())
|
entropy_losses.append(entropy_loss.item())
|
||||||
|
|
||||||
# --- KL-to-reference term ----------------------------
|
# KL-to-reference: closed-form KL between two diagonal
|
||||||
# Both policies are diagonal Gaussian (ActorCriticPolicy).
|
# Gaussians, summed over the action axis, mean over batch.
|
||||||
# KL(π ‖ π_ref) per-action-dim; sum over the action axis
|
|
||||||
# to get total KL per sample, then mean over batch.
|
|
||||||
# Computed on the rollout's observations so the penalty
|
|
||||||
# reflects what the agent actually saw.
|
|
||||||
if self.ref_policy is None:
|
if self.ref_policy is None:
|
||||||
raise RuntimeError("KLPPO.train called without ref_policy")
|
raise RuntimeError("KLPPO.train called without ref_policy")
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
@@ -153,7 +125,6 @@ class KLPPO(PPO):
|
|||||||
kl_div = th.distributions.kl.kl_divergence(
|
kl_div = th.distributions.kl.kl_divergence(
|
||||||
pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean()
|
pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean()
|
||||||
kl_losses.append(kl_div.item())
|
kl_losses.append(kl_div.item())
|
||||||
# ----------------------------------------------------
|
|
||||||
|
|
||||||
loss = (policy_loss
|
loss = (policy_loss
|
||||||
+ self.ent_coef * entropy_loss
|
+ self.ent_coef * entropy_loss
|
||||||
@@ -192,7 +163,6 @@ class KLPPO(PPO):
|
|||||||
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||||||
|
|
||||||
def _explained_variance(self) -> float:
|
def _explained_variance(self) -> float:
|
||||||
# SB3 doesn't expose this as a method; replicate the computation.
|
|
||||||
y_pred = self.rollout_buffer.values.flatten()
|
y_pred = self.rollout_buffer.values.flatten()
|
||||||
y_true = self.rollout_buffer.returns.flatten()
|
y_true = self.rollout_buffer.returns.flatten()
|
||||||
var_y = np.var(y_true)
|
var_y = np.var(y_true)
|
||||||
@@ -206,50 +176,41 @@ class KLPPO(PPO):
|
|||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--bc", default="training/runs/bc",
|
parser.add_argument("--bc", default="training/runs/bc",
|
||||||
help="Directory containing the BC initialisation (policy.zip).")
|
help="Directory containing the BC initialisation.")
|
||||||
parser.add_argument("--out", default="training/runs/rl",
|
parser.add_argument("--out", default="training/runs/rl",
|
||||||
help="Where to save the fine-tuned policy.")
|
help="Where to save the fine-tuned policy.")
|
||||||
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
|
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
|
||||||
parser.add_argument("--n-envs", type=int, default=8)
|
parser.add_argument("--n-envs", type=int, default=8)
|
||||||
parser.add_argument("--learning-rate", type=float, default=5e-5,
|
parser.add_argument("--learning-rate", type=float, default=5e-5)
|
||||||
help="Low LR keeps PPO close to the BC mean.")
|
|
||||||
parser.add_argument("--kl-coef", type=float, default=0.05,
|
parser.add_argument("--kl-coef", type=float, default=0.05,
|
||||||
help="KL-to-reference penalty coefficient.")
|
help="Coefficient of the KL-to-reference penalty.")
|
||||||
parser.add_argument("--log-std", type=float, default=-1.5,
|
parser.add_argument("--log-std", type=float, default=-1.5,
|
||||||
help="Initial (and frozen) log_std. σ ≈ exp(-1.5) ≈ 0.22.")
|
help="Initial (and frozen) log_std for exploration.")
|
||||||
parser.add_argument("--freeze-log-std", action="store_true", default=True,
|
parser.add_argument("--freeze-log-std", action="store_true", default=True)
|
||||||
help="Keep log_std fixed; only the policy mean updates.")
|
parser.add_argument("--n-steps", type=int, default=2048)
|
||||||
parser.add_argument("--n-steps", type=int, default=2048,
|
|
||||||
help="Steps per rollout per env.")
|
|
||||||
parser.add_argument("--batch-size", type=int, default=256)
|
parser.add_argument("--batch-size", type=int, default=256)
|
||||||
parser.add_argument("--n-epochs", type=int, default=10)
|
parser.add_argument("--n-epochs", type=int, default=10)
|
||||||
parser.add_argument("--gamma", type=float, default=0.995)
|
parser.add_argument("--gamma", type=float, default=0.995)
|
||||||
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||||||
parser.add_argument("--clip-range", type=float, default=0.1,
|
parser.add_argument("--clip-range", type=float, default=0.1)
|
||||||
help="Tight clip range — keep updates conservative.")
|
|
||||||
parser.add_argument("--ent-coef", type=float, default=0.0)
|
parser.add_argument("--ent-coef", type=float, default=0.0)
|
||||||
parser.add_argument("--target-kl", type=float, default=0.02,
|
parser.add_argument("--target-kl", type=float, default=0.02,
|
||||||
help="SB3's per-batch KL early stop; safety belt.")
|
help="SB3 per-batch KL early-stop guard.")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--device", default="cpu")
|
parser.add_argument("--device", default="cpu")
|
||||||
parser.add_argument("--imitate-weight", type=float, default=None,
|
parser.add_argument("--imitate-weight", type=float, default=None,
|
||||||
help="Override env.W_IMITATE for this training "
|
help="Override env.W_IMITATE (e.g. 0.0 to drop "
|
||||||
"run. Set to 0.0 to drop the Strömbom "
|
"Strömbom imitation during fine-tune).")
|
||||||
"cosine-imitation reward — useful during "
|
|
||||||
"PPO refinement where you want reward, "
|
|
||||||
"not teacher imitation, to drive updates.")
|
|
||||||
parser.add_argument("--time-weight", type=float, default=None,
|
parser.add_argument("--time-weight", type=float, default=None,
|
||||||
help="Override env.W_TIME. Default env value is "
|
help="Override env.W_TIME (e.g. -0.1 for a "
|
||||||
"0.0; setting e.g. -0.1 adds a small per-"
|
"per-step time penalty).")
|
||||||
"step penalty that explicitly rewards "
|
|
||||||
"fast time-to-pen.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
bc_zip = Path(args.bc) / "policy.zip"
|
bc_zip = Path(args.bc) / "policy.zip"
|
||||||
if not bc_zip.exists():
|
if not bc_zip.exists():
|
||||||
raise SystemExit(
|
raise SystemExit(
|
||||||
f"BC checkpoint not found at {bc_zip}. Train bc first with "
|
f"BC checkpoint not found at {bc_zip}. Train bc first with "
|
||||||
f"`python -m training.bc_pretrain`."
|
f"`python -m training.bc.pretrain`."
|
||||||
)
|
)
|
||||||
|
|
||||||
out = Path(args.out)
|
out = Path(args.out)
|
||||||
@@ -257,7 +218,7 @@ def main() -> None:
|
|||||||
(out / "checkpoints").mkdir(exist_ok=True)
|
(out / "checkpoints").mkdir(exist_ok=True)
|
||||||
(out / "best").mkdir(exist_ok=True)
|
(out / "best").mkdir(exist_ok=True)
|
||||||
|
|
||||||
# --- Inspect BC obs dim → infer frame_stack ---
|
# Infer frame_stack from the BC checkpoint's obs space.
|
||||||
ref_only = PPO.load(str(bc_zip), device=args.device)
|
ref_only = PPO.load(str(bc_zip), device=args.device)
|
||||||
obs_dim = int(ref_only.observation_space.shape[0])
|
obs_dim = int(ref_only.observation_space.shape[0])
|
||||||
if obs_dim % OBS_DIM != 0:
|
if obs_dim % OBS_DIM != 0:
|
||||||
@@ -265,12 +226,11 @@ def main() -> None:
|
|||||||
frame_stack = obs_dim // OBS_DIM
|
frame_stack = obs_dim // OBS_DIM
|
||||||
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
|
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
|
||||||
|
|
||||||
# --- Vectorised envs (match BC obs space) ---
|
|
||||||
env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)]
|
env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)]
|
||||||
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
||||||
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
|
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
|
||||||
|
|
||||||
# --- Apply reward-shaping overrides to every env instance ---
|
# Reward-shaping overrides (broadcast to every env instance).
|
||||||
def _broadcast(method: str, value):
|
def _broadcast(method: str, value):
|
||||||
for v in (venv, eval_venv):
|
for v in (venv, eval_venv):
|
||||||
try:
|
try:
|
||||||
@@ -284,10 +244,8 @@ def main() -> None:
|
|||||||
_broadcast("set_time_weight", args.time_weight)
|
_broadcast("set_time_weight", args.time_weight)
|
||||||
print(f"[rl] W_TIME overridden to {args.time_weight}")
|
print(f"[rl] W_TIME overridden to {args.time_weight}")
|
||||||
|
|
||||||
# --- Trainable policy: load BC weights, then bolt onto PPO ---
|
# Build a fresh KLPPO at the right obs/action shape, then copy BC
|
||||||
# Trick: instantiate a PPO with the right env (so the policy
|
# weights into both the trainable policy and the frozen reference.
|
||||||
# network is constructed at the correct obs/action shape), then
|
|
||||||
# copy BC weights into it.
|
|
||||||
model = KLPPO(
|
model = KLPPO(
|
||||||
"MlpPolicy", venv,
|
"MlpPolicy", venv,
|
||||||
ref_policy=None, # filled in below
|
ref_policy=None, # filled in below
|
||||||
@@ -311,15 +269,11 @@ def main() -> None:
|
|||||||
tensorboard_log=str(out / "tb"),
|
tensorboard_log=str(out / "tb"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Load BC weights into both `model.policy` and `ref_policy` ---
|
# strict=False — the BC value head wasn't trained; PPO trains it.
|
||||||
bc_state = ref_only.policy.state_dict()
|
bc_state = ref_only.policy.state_dict()
|
||||||
# Strict=False because the value head may not have been trained in
|
|
||||||
# BC — that's fine, PPO will train it from scratch.
|
|
||||||
missing, unexpected = model.policy.load_state_dict(bc_state, strict=False)
|
missing, unexpected = model.policy.load_state_dict(bc_state, strict=False)
|
||||||
print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}")
|
print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}")
|
||||||
|
|
||||||
# Build a separate reference policy with identical architecture and
|
|
||||||
# the BC weights, frozen.
|
|
||||||
ref_policy = type(model.policy)(
|
ref_policy = type(model.policy)(
|
||||||
observation_space=model.observation_space,
|
observation_space=model.observation_space,
|
||||||
action_space=model.action_space,
|
action_space=model.action_space,
|
||||||
@@ -333,11 +287,8 @@ def main() -> None:
|
|||||||
for p in model.ref_policy.parameters():
|
for p in model.ref_policy.parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
# Align both policies' log_std. BC was trained with log_std≈0.5
|
# Force both policies to the same log_std so the KL term measures
|
||||||
# (σ≈1.65), which would make the KL term huge from a std mismatch
|
# mean drift only, not a std mismatch carried over from BC.
|
||||||
# rather than the mean drift we actually care about. Force both to
|
|
||||||
# the same small value so KL measures only how far the policy mean
|
|
||||||
# has drifted from the BC mean.
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
model.policy.log_std.fill_(args.log_std)
|
model.policy.log_std.fill_(args.log_std)
|
||||||
model.ref_policy.log_std.fill_(args.log_std)
|
model.ref_policy.log_std.fill_(args.log_std)
|
||||||
@@ -345,15 +296,18 @@ def main() -> None:
|
|||||||
model.policy.log_std.requires_grad = False
|
model.policy.log_std.requires_grad = False
|
||||||
print(f"[rl] log_std frozen at {args.log_std} (σ ≈ {np.exp(args.log_std):.3f})")
|
print(f"[rl] log_std frozen at {args.log_std} (σ ≈ {np.exp(args.log_std):.3f})")
|
||||||
|
|
||||||
# --- Callbacks ---
|
|
||||||
ckpt_cb = CheckpointCallback(
|
ckpt_cb = CheckpointCallback(
|
||||||
save_freq=max(1, 50_000 // args.n_envs),
|
save_freq=max(1, 50_000 // args.n_envs),
|
||||||
save_path=str(out / "checkpoints"),
|
save_path=str(out / "checkpoints"),
|
||||||
name_prefix="ppo",
|
name_prefix="ppo",
|
||||||
)
|
)
|
||||||
|
# EvalCallback writes <save_path>/best_model.zip on every new best
|
||||||
|
# eval reward. We send it straight to ``out/`` and rename to
|
||||||
|
# ``policy.zip`` after training so the deployed file lives at the
|
||||||
|
# canonical path.
|
||||||
eval_cb = EvalCallback(
|
eval_cb = EvalCallback(
|
||||||
eval_venv,
|
eval_venv,
|
||||||
best_model_save_path=str(out / "best"),
|
best_model_save_path=str(out),
|
||||||
log_path=str(out / "evals"),
|
log_path=str(out / "evals"),
|
||||||
eval_freq=max(1, 20_000 // args.n_envs),
|
eval_freq=max(1, 20_000 // args.n_envs),
|
||||||
n_eval_episodes=5,
|
n_eval_episodes=5,
|
||||||
@@ -365,9 +319,23 @@ def main() -> None:
|
|||||||
model.learn(total_timesteps=args.total_timesteps,
|
model.learn(total_timesteps=args.total_timesteps,
|
||||||
callback=[ckpt_cb, eval_cb], progress_bar=True)
|
callback=[ckpt_cb, eval_cb], progress_bar=True)
|
||||||
|
|
||||||
# --- Save final checkpoint in the SB3 zip the controller expects ---
|
# Save the end-of-training state for debugging convergence behaviour.
|
||||||
model.save(out / "policy.zip")
|
model.save(out / "final.zip")
|
||||||
print(f"[rl] saved fine-tuned policy → {out/'policy.zip'}")
|
|
||||||
|
# Promote the EvalCallback's best-by-eval-reward snapshot to the
|
||||||
|
# canonical ``policy.zip`` (what the controller loads). Fall back
|
||||||
|
# to the final state if eval never recorded a "best".
|
||||||
|
import shutil
|
||||||
|
best_zip = out / "best_model.zip"
|
||||||
|
policy_zip = out / "policy.zip"
|
||||||
|
if best_zip.exists():
|
||||||
|
if policy_zip.exists():
|
||||||
|
policy_zip.unlink()
|
||||||
|
best_zip.rename(policy_zip)
|
||||||
|
print(f"[rl] best snapshot → {policy_zip} (final state kept at {out/'final.zip'})")
|
||||||
|
else:
|
||||||
|
shutil.copy(out / "final.zip", policy_zip)
|
||||||
|
print(f"[rl] no best snapshot recorded; using final → {policy_zip}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Reference in New Issue
Block a user