Checkpoint 7

This commit is contained in:
Johnny Fernandes
2026-05-11 12:21:51 +01:00
parent fce0e0c786
commit a01a5c9cef
34 changed files with 1266 additions and 1038 deletions
+1 -1
View File
@@ -1,4 +1,4 @@
"""Backwards-compat shim — flocking logic now lives in ``herding.flocking_sim``.
"""Backwards-compat shim — flocking logic now lives in ``herding.world.flocking_sim``.
Kept so any external reference still resolves.
"""
+19 -36
View File
@@ -1,14 +1,13 @@
"""Sheep flocking controller (Webots).
Each sheep broadcasts its GPS position every 3 steps on channel 1 and
listens for the dog and peer sheep positions. The behavioural step is
delegated to ``herding.flocking_sim.compute_heading_speed`` so the
training environment and Webots run identical sheep dynamics.
Each sheep emits its GPS position every 3 steps and listens for the
dog's position and peer-sheep positions. The behavioural step is
delegated to :func:`herding.world.flocking_sim.compute_heading_speed`
so the env and Webots use identical sheep dynamics.
Pen behaviour: a sheep latches to ``penned`` the first time it crosses
the south-wall gate plane into the gate corridor. Once latched it turns
pink (via the exposed ``woolColor`` PROTO field) and the force stack
switches to in-pen containment.
A sheep latches penned the first time it crosses the gate plane south;
the wool turns pink (via the exposed ``woolColor`` PROTO field) and
the dynamics switch to in-pen containment.
"""
import math
@@ -32,10 +31,7 @@ from herding.world.geometry import (
)
# ---------------------------------------------------------------------------
# Device setup
# ---------------------------------------------------------------------------
# --- Devices ---
robot = Supervisor()
timestep = int(robot.getBasicTimeStep())
name = robot.getName()
@@ -55,14 +51,10 @@ receiver = robot.getDevice("receiver"); receiver.enable(timestep)
emitter = robot.getDevice("emitter")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
# --- Helpers ---
def bearing():
# Compass returns north direction in sensor frame; for this Z-up world
# with north = +Y, atan2(n[0], n[1]) gives the standard math angle
# (0 = east, π/2 = north) matching atan2(fy, fx) used for headings.
"""World-frame heading (0 = east, π/2 = north)."""
n = compass.getValues()
return math.atan2(n[0], n[1])
@@ -76,45 +68,36 @@ def drive(heading, speed_motor):
def paint_pink():
# woolColor is declared as a PROTO field with IS binding to the DEF WOOL
# PBRAppearance baseColor; setting it propagates to every USE WOOL shape.
"""Switch the sheep's wool to pink via the exposed PROTO field."""
self_node.getField("woolColor").setSFColor([1.0, 0.55, 0.72])
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
# --- State ---
wander_angle = random.uniform(-math.pi, math.pi)
step_count = 0
dog_x, dog_y = None, None
peers = {} # name → (x, y), one entry per neighbour, cleared every 30 steps
peers = {} # name → (x, y); periodically pruned
penned = False
# Stuck detection: differential-drive sheep can pin against a wall and need
# a forced reverse-and-rotate to escape. If displacement < STUCK_DIST for
# STUCK_STEPS consecutive steps, drive toward field centre.
# Safety net for differential-drive sheep pinned against a wall.
_prev_x, _prev_y = None, None
_stuck_count = 0
STUCK_STEPS = 20
STUCK_DIST = 0.05
# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------
# --- Main loop ---
while robot.step(timestep) != -1:
step_count += 1
pos = gps.getValues()
x, y = pos[0], pos[1]
# Pen entry: one-way latch. Penned sheep get pink wool and switch behaviour.
if not penned and is_penned_position(x, y):
penned = True
paint_pink()
# Refresh peer table — clear before receiving so fresh data is never lost.
# Stale peers get dropped periodically so a peer that's gone silent
# doesn't permanently distort the local CoM.
if step_count % 30 == 0:
peers.clear()
while receiver.getQueueLength() > 0:
@@ -132,12 +115,12 @@ while robot.step(timestep) != -1:
wander_angle=wander_angle,
)
# Stuck detection — safety net for differential-drive wall pinning.
# Stuck-against-wall recovery: drive toward the field centre.
if _prev_x is not None:
moved = math.hypot(x - _prev_x, y - _prev_y)
_stuck_count = _stuck_count + 1 if moved < STUCK_DIST else 0
if _stuck_count >= STUCK_STEPS:
heading = math.atan2(-y, -x) # always points away from the boundary
heading = math.atan2(-y, -x)
speed = MAX_SPEED
_stuck_count = 0
_prev_x, _prev_y = x, y
+19 -33
View File
@@ -1,18 +1,13 @@
"""Lazy loader for the SB3 PPO policy used by the dog controller.
"""Lazy SB3 policy loader for the dog controller.
Importing stable-baselines3 inside the Webots Python interpreter is only
needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This
loader keeps SB3 out of the import path until you actually ask for the RL
policy, so users without SB3 installed can still run the Strömbom
baseline.
SB3 is imported only when a learned policy is actually requested,
so the analytic modes can run on installs without stable-baselines3
or torch.
The policy + VecNormalize statistics are saved together by
``training/train_ppo.py``:
runs/<name>/best/best_model.zip # SB3 PPO checkpoint
runs/<name>/best/vecnormalize.pkl # observation-normaliser stats
Pass either the directory or the explicit zip path.
The handle auto-detects frame stacking from the policy's expected
observation dimension: if it's a multiple of the single-frame
``OBS_DIM``, an internal buffer of the last K frames is maintained
and concatenated on each ``predict`` call.
"""
import os
@@ -20,20 +15,12 @@ from pathlib import Path
class PolicyHandle:
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
``predict(obs)`` without thinking about either.
Frame stacking is auto-detected from the policy's expected obs dim:
if it's a multiple of the single-frame ``OBS_DIM``, the handle keeps
a deque of the last K frames and concatenates them on each predict.
"""
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``."""
def __init__(self, model, vecnorm):
self.model = model
self.vecnorm = vecnorm
# Lazy import to avoid forcing herding/* into the import path
# when SB3 isn't being used.
from herding.obs import OBS_DIM
from herding.perception.obs import OBS_DIM
policy_dim = int(model.observation_space.shape[0])
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
self.frame_stack = policy_dim // OBS_DIM
@@ -46,7 +33,7 @@ class PolicyHandle:
import numpy as np
single = np.asarray(obs, dtype=np.float32).reshape(-1)
if single.shape[0] != self._single_dim:
# Caller already passed a stacked obs — use as-is.
# Caller passed an already-stacked obs.
stacked = single
elif self.frame_stack > 1:
if not self._buffer:
@@ -67,18 +54,19 @@ class PolicyHandle:
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
"""Load a PPO model (and optional VecNormalize) from disk.
"""Load a policy zip (+ optional VecNormalize pickle) from disk.
``model_path`` may be the .zip checkpoint or a directory containing
``best_model.zip`` (and optionally ``vecnormalize.pkl``).
``model_path`` may be a ``.zip`` file or a directory; in the
latter case ``policy.zip`` is preferred, with ``final.zip`` as
a fallback for partially-completed RL runs.
"""
p = Path(model_path)
if p.is_dir():
zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"]
zip_candidates = [p / "policy.zip", p / "final.zip"]
zip_path = next((z for z in zip_candidates if z.exists()), None)
if zip_path is None:
raise FileNotFoundError(
f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)"
f"No policy zip in {p} (looked for policy.zip, final.zip)"
)
if vecnorm_path is None:
vn = p / "vecnormalize.pkl"
@@ -87,15 +75,13 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
else:
zip_path = p
# Imports deferred so the Strömbom path doesn't require SB3.
# Deferred imports so the analytic path doesn't require SB3.
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
model = PPO.load(str(zip_path), device="auto")
vecnorm = None
if vecnorm_path and os.path.exists(vecnorm_path):
# VecNormalize.load needs a venv to attach to; we only need its stats
# at inference, so we reconstruct the wrapper manually.
import pickle
with open(vecnorm_path, "rb") as f:
vecnorm = pickle.load(f)
+1 -1
View File
@@ -57,7 +57,7 @@ from herding.control.active_scan import ActiveScanTeacher
from herding.control.modulation import modulate_speed_near_sheep
from herding.control.sequential import compute_action as sequential_action
from herding.control.strombom import compute_action as strombom_action
from herding.obs import build_obs
from herding.perception.obs import build_obs
from herding.perception.lidar_perception import detections_from_scan
from herding.perception.sheep_tracker import SheepTracker
from herding.world.diffdrive import velocity_to_wheels