329 lines
12 KiB
Python
329 lines
12 KiB
Python
"""Shepherd Dog controller (Webots).
|
|
|
|
Mode is selected by ``HERDING_MODE`` (env var, or via the
|
|
``herding_runtime.cfg`` file the launcher writes since Webots strips
|
|
env vars on some setups):
|
|
|
|
strombom → canonical Strömbom (2014) collect/drive heuristic
|
|
wrapped in ActiveScanTeacher (opening rotation +
|
|
walk-to-centre when the tracker briefly empties).
|
|
sequential → single-target "pin-and-push", same wrapper.
|
|
bc → behaviour-cloned MLP, trained on Strömbom demos.
|
|
Default policy: training/runs/bc/policy.zip.
|
|
rl → KL-regularised PPO fine-tune of bc. Same obs/action
|
|
space as bc; refines time-to-pen via reward while
|
|
staying anchored to bc.
|
|
Default policy: training/runs/rl/policy.zip.
|
|
|
|
Sheep perception
|
|
----------------
|
|
The dog perceives sheep through its **front-mounted 140° LiDAR**
|
|
(``protos/ShepherdDog.proto``: 180 rays, 12 m max range). Each step:
|
|
|
|
1. Reads ``lidar.getRangeImage()``.
|
|
2. Runs ``herding.perception.lidar_perception.detections_from_scan``
|
|
to cluster returns into world-frame ``(x, y)`` sheep estimates.
|
|
3. Folds those into a ``SheepTracker`` which maintains last-seen
|
|
positions for sheep currently out of FOV and latches "penned"
|
|
once a track crosses the gate plane south.
|
|
|
|
Sheep ``emitter`` messages are read **for diagnostic logging only**
|
|
(GT_penned counter + auto-finish sentinel); they are never used to
|
|
drive the policy. Perception for control comes entirely from LiDAR.
|
|
|
|
Auto-finish
|
|
-----------
|
|
When the dog observes (via GT, read off the receiver) that all sheep
|
|
are penned, it writes ``training/.run_done`` and the launcher
|
|
(``tools/run_webots.sh``) detects it and closes Webots. This keeps
|
|
batch evaluation runs bounded.
|
|
"""
|
|
|
|
import math
|
|
import os
|
|
import sys
|
|
|
|
# --- Make the shared herding/ package importable from this controller dir ---
|
|
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..", ".."))
|
|
if _PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, _PROJECT_ROOT)
|
|
|
|
import numpy as np
|
|
|
|
from controller import Robot
|
|
|
|
from herding.control.active_scan import ActiveScanTeacher
|
|
from herding.control.modulation import modulate_speed_near_sheep
|
|
from herding.control.sequential import compute_action as sequential_action
|
|
from herding.control.strombom import compute_action as strombom_action
|
|
from herding.perception.obs import build_obs
|
|
from herding.perception.lidar_perception import detections_from_scan
|
|
from herding.perception.sheep_tracker import SheepTracker
|
|
from herding.world.diffdrive import velocity_to_wheels
|
|
from herding.world.geometry import (
|
|
DOG_MAX_LINEAR, DOG_MAX_WHEEL_OMEGA,
|
|
DOG_SOUTH_LIMIT, DOG_WHEEL_RADIUS,
|
|
PEN_ENTRY, is_penned_position,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mode + policy resolution
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _load_runtime_config():
|
|
"""Read mode + policy_dir overrides from a runtime config file.
|
|
|
|
Webots strips HERDING_* env vars in some configurations, so the
|
|
launcher writes a tiny ``herding_runtime.cfg`` (key=value lines)
|
|
in the project root and the controller reads it here. Env vars
|
|
win if both are present; the file is the fallback.
|
|
"""
|
|
cfg_path = os.path.join(_PROJECT_ROOT, "herding_runtime.cfg")
|
|
if not os.path.exists(cfg_path):
|
|
return {}
|
|
out = {}
|
|
try:
|
|
with open(cfg_path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line or line.startswith("#") or "=" not in line:
|
|
continue
|
|
k, _, v = line.partition("=")
|
|
out[k.strip().upper()] = v.strip()
|
|
except OSError:
|
|
return {}
|
|
return out
|
|
|
|
|
|
_runtime_cfg = _load_runtime_config()
|
|
MODE = (os.environ.get("HERDING_MODE")
|
|
or _runtime_cfg.get("HERDING_MODE")
|
|
or "bc").lower()
|
|
|
|
|
|
def _resolve_policy_dir(mode: str) -> str:
|
|
"""Where to look for the trained policy for the given mode.
|
|
|
|
Priority:
|
|
1. HERDING_POLICY_DIR env var or runtime-cfg entry, if it points
|
|
to a real directory.
|
|
2. Mode-specific default:
|
|
bc → training/runs/bc (Strömbom-imitated MLP)
|
|
rl → training/runs/rl (KL-PPO fine-tune of bc)
|
|
3. Fall back to bc.
|
|
All checkpoints are frame-stacked K = 4; ``policy_loader`` reads
|
|
the stacking factor from the policy's observation space.
|
|
"""
|
|
env_dir = (os.environ.get("HERDING_POLICY_DIR")
|
|
or _runtime_cfg.get("HERDING_POLICY_DIR"))
|
|
if env_dir and os.path.isdir(env_dir):
|
|
return env_dir
|
|
mode_default = {
|
|
"bc": os.path.join(_PROJECT_ROOT, "training", "runs", "bc"),
|
|
"rl": os.path.join(_PROJECT_ROOT, "training", "runs", "rl"),
|
|
}
|
|
primary = mode_default.get(mode, mode_default["bc"])
|
|
if os.path.isdir(primary):
|
|
return primary
|
|
fallback = mode_default["bc"]
|
|
if os.path.isdir(fallback):
|
|
return fallback
|
|
return env_dir or primary
|
|
|
|
|
|
_VALID_MODES = ("bc", "rl", "strombom", "sequential")
|
|
if MODE not in _VALID_MODES:
|
|
print(f"[dog] unknown HERDING_MODE={MODE!r}; defaulting to strombom.")
|
|
MODE = "strombom"
|
|
|
|
POLICY_DIR = _resolve_policy_dir(MODE)
|
|
policy_handle = None
|
|
if MODE in ("bc", "rl"):
|
|
print(f"[dog] resolved POLICY_DIR={POLICY_DIR} exists={os.path.isdir(POLICY_DIR)}")
|
|
try:
|
|
from policy_loader import load as _load_policy
|
|
policy_handle = _load_policy(POLICY_DIR)
|
|
print(f"[dog] policy loaded from {POLICY_DIR}")
|
|
except Exception as exc:
|
|
print(f"[dog] policy load failed ({exc!r}); falling back to strombom.")
|
|
MODE = "strombom"
|
|
print(f"[dog] running in mode={MODE}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Control parameters
|
|
# ---------------------------------------------------------------------------
|
|
|
|
ACTION_SMOOTH = 0.55 # EMA on (vx, vy) — kills frame-to-frame jitter
|
|
RUN_DONE_FILE = os.path.join(_PROJECT_ROOT, "training", ".run_done")
|
|
|
|
|
|
def safety_clamp(vx: float, vy: float, dog_x: float, dog_y: float) -> tuple:
|
|
"""If the dog is near the south barrier and the action would push it
|
|
further south, override with a northward action. Hard invariant: the
|
|
dog never enters the pen."""
|
|
if dog_y < DOG_SOUTH_LIMIT and vy < 0.0:
|
|
return (0.0, 1.0)
|
|
if dog_y < DOG_SOUTH_LIMIT + 0.5 and vy < -0.2:
|
|
return (vx * 0.5, max(0.0, vy + 0.5))
|
|
return (vx, vy)
|
|
|
|
|
|
def drive(vx: float, vy: float, left_motor, right_motor, compass, motor_max: float):
|
|
if math.hypot(vx, vy) < 1e-3:
|
|
left_motor.setVelocity(0.0)
|
|
right_motor.setVelocity(0.0)
|
|
return
|
|
n = compass.getValues()
|
|
h = math.atan2(n[0], n[1])
|
|
left, right = velocity_to_wheels(
|
|
vx, vy, h,
|
|
max_linear=DOG_MAX_LINEAR,
|
|
wheel_radius=DOG_WHEEL_RADIUS,
|
|
max_wheel_omega=motor_max,
|
|
k_turn=4.0,
|
|
)
|
|
left_motor.setVelocity(left)
|
|
right_motor.setVelocity(right)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Webots devices
|
|
# ---------------------------------------------------------------------------
|
|
|
|
robot = Robot()
|
|
timestep = int(robot.getBasicTimeStep())
|
|
|
|
left_motor = robot.getDevice("left wheel motor")
|
|
right_motor = robot.getDevice("right wheel motor")
|
|
left_motor.setPosition(float("inf"))
|
|
right_motor.setPosition(float("inf"))
|
|
left_motor.setVelocity(0.0)
|
|
right_motor.setVelocity(0.0)
|
|
MOTOR_MAX = min(left_motor.getMaxVelocity(), DOG_MAX_WHEEL_OMEGA)
|
|
|
|
gps = robot.getDevice("gps"); gps.enable(timestep)
|
|
compass = robot.getDevice("compass"); compass.enable(timestep)
|
|
receiver = robot.getDevice("receiver"); receiver.enable(timestep)
|
|
emitter = robot.getDevice("emitter")
|
|
lidar = robot.getDevice("lidar"); lidar.enable(timestep)
|
|
|
|
tracker = SheepTracker()
|
|
|
|
# Cosmetic ear motors — animated; not used by control.
|
|
left_ear = robot.getDevice("left ear motor")
|
|
right_ear = robot.getDevice("right ear motor")
|
|
left_ear.setPosition(float("inf"))
|
|
right_ear.setPosition(float("inf"))
|
|
left_ear.setVelocity(0.0)
|
|
right_ear.setVelocity(0.0)
|
|
ear_phase = 0.0
|
|
EAR_AMPLITUDE = 0.35
|
|
EAR_RATE = 8.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main loop
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Analytic-teacher wrapper (instantiated lazily so RL/BC modes don't pay
|
|
# the import-time cost). Each gets the same ActiveScanTeacher treatment:
|
|
# rotate-on-empty, walk-to-centre, near-sheep speed modulation.
|
|
analytic_teacher = None
|
|
if MODE in ("strombom", "sequential"):
|
|
base_fn = strombom_action if MODE == "strombom" else sequential_action
|
|
analytic_teacher = ActiveScanTeacher(base_fn)
|
|
|
|
# GT positions from sheep emitters — used **only** for the auto-finish
|
|
# sentinel and the GT_penned diagnostic line. Never fed into control.
|
|
_gt_sheep: dict = {}
|
|
_run_done = False
|
|
|
|
prev_action = (0.0, 0.0)
|
|
step_count = 0
|
|
|
|
while robot.step(timestep) != -1:
|
|
step_count += 1
|
|
|
|
# Drain sheep emitter messages → GT (diagnostic only).
|
|
while receiver.getQueueLength() > 0:
|
|
msg = receiver.getString()
|
|
receiver.nextPacket()
|
|
parts = msg.split(":")
|
|
if len(parts) == 4 and parts[0] == "sheep":
|
|
try:
|
|
_gt_sheep[parts[1]] = (float(parts[2]), float(parts[3]))
|
|
except ValueError:
|
|
pass
|
|
|
|
pos = gps.getValues()
|
|
dog_xy = (pos[0], pos[1])
|
|
n = compass.getValues()
|
|
dog_heading = math.atan2(n[0], n[1])
|
|
|
|
# ---- LiDAR perception → tracker → active sheep positions ----
|
|
ranges = np.asarray(lidar.getRangeImage(), dtype=np.float32)
|
|
detections = detections_from_scan(ranges, dog_xy[0], dog_xy[1], dog_heading)
|
|
sheep_positions = tracker.update(detections)
|
|
|
|
sheep_xy_list = list(sheep_positions.values())
|
|
sheep_penned_list = [False] * len(sheep_xy_list)
|
|
single_obs = build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list)
|
|
|
|
# ---- Action selection ----
|
|
if MODE in ("bc", "rl") and policy_handle is not None:
|
|
action = policy_handle.predict(single_obs)
|
|
vx, vy = float(action[0]), float(action[1])
|
|
else:
|
|
vx, vy, _mode_str = analytic_teacher(
|
|
dog_xy, dog_heading, sheep_positions, PEN_ENTRY,
|
|
)
|
|
|
|
# Near-sheep speed modulation (shared by every mode).
|
|
vx, vy = modulate_speed_near_sheep(vx, vy, dog_xy, sheep_positions)
|
|
|
|
# EMA smoothing — kills frame-to-frame action jitter.
|
|
vx = ACTION_SMOOTH * prev_action[0] + (1.0 - ACTION_SMOOTH) * vx
|
|
vy = ACTION_SMOOTH * prev_action[1] + (1.0 - ACTION_SMOOTH) * vy
|
|
|
|
# Safety: dog must never enter the pen.
|
|
vx, vy = safety_clamp(vx, vy, dog_xy[0], dog_xy[1])
|
|
prev_action = (vx, vy)
|
|
|
|
drive(vx, vy, left_motor, right_motor, compass, MOTOR_MAX)
|
|
emitter.send(f"dog:{dog_xy[0]:.4f}:{dog_xy[1]:.4f}")
|
|
|
|
# Cosmetic ear wiggle.
|
|
ear_phase += 0.12
|
|
ear_pos = EAR_AMPLITUDE * math.sin(ear_phase)
|
|
left_ear.setVelocity(EAR_RATE)
|
|
right_ear.setVelocity(EAR_RATE)
|
|
left_ear.setPosition(ear_pos)
|
|
right_ear.setPosition(-ear_pos)
|
|
|
|
# Auto-finish: when all GT sheep are penned, write the sentinel.
|
|
# The launcher polls for it and closes Webots so batch evals don't
|
|
# hang after the task is done. Bounded by `_gt_sheep` so we don't
|
|
# fire during the first few steps while the receiver fills.
|
|
if _gt_sheep and not _run_done:
|
|
gt_active = sum(1 for x, y in _gt_sheep.values()
|
|
if not is_penned_position(x, y))
|
|
if gt_active == 0:
|
|
os.makedirs(os.path.dirname(RUN_DONE_FILE), exist_ok=True)
|
|
open(RUN_DONE_FILE, "w").close()
|
|
_run_done = True
|
|
print(f"[dog] all {len(_gt_sheep)} sheep penned at step "
|
|
f"{step_count} — wrote sentinel, launcher will close Webots")
|
|
|
|
if step_count % 200 == 0:
|
|
gt_penned = sum(1 for x, y in _gt_sheep.values()
|
|
if is_penned_position(x, y))
|
|
gt_total = len(_gt_sheep)
|
|
print(f"[dog mode={MODE}] step={step_count} "
|
|
f"GT_penned={gt_penned}/{gt_total} "
|
|
f"tracks_active={tracker.n_active()} "
|
|
f"tracks_penned={tracker.n_penned()} "
|
|
f"detections={len(detections)} action=({vx:+.2f}, {vy:+.2f})")
|