Checkpoint 7
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Backwards-compat shim — flocking logic now lives in ``herding.flocking_sim``.
|
||||
"""Backwards-compat shim — flocking logic now lives in ``herding.world.flocking_sim``.
|
||||
|
||||
Kept so any external reference still resolves.
|
||||
"""
|
||||
|
||||
+19
-36
@@ -1,14 +1,13 @@
|
||||
"""Sheep flocking controller (Webots).
|
||||
|
||||
Each sheep broadcasts its GPS position every 3 steps on channel 1 and
|
||||
listens for the dog and peer sheep positions. The behavioural step is
|
||||
delegated to ``herding.flocking_sim.compute_heading_speed`` so the
|
||||
training environment and Webots run identical sheep dynamics.
|
||||
Each sheep emits its GPS position every 3 steps and listens for the
|
||||
dog's position and peer-sheep positions. The behavioural step is
|
||||
delegated to :func:`herding.world.flocking_sim.compute_heading_speed`
|
||||
so the env and Webots use identical sheep dynamics.
|
||||
|
||||
Pen behaviour: a sheep latches to ``penned`` the first time it crosses
|
||||
the south-wall gate plane into the gate corridor. Once latched it turns
|
||||
pink (via the exposed ``woolColor`` PROTO field) and the force stack
|
||||
switches to in-pen containment.
|
||||
A sheep latches penned the first time it crosses the gate plane south;
|
||||
the wool turns pink (via the exposed ``woolColor`` PROTO field) and
|
||||
the dynamics switch to in-pen containment.
|
||||
"""
|
||||
|
||||
import math
|
||||
@@ -32,10 +31,7 @@ from herding.world.geometry import (
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# --- Devices ---
|
||||
robot = Supervisor()
|
||||
timestep = int(robot.getBasicTimeStep())
|
||||
name = robot.getName()
|
||||
@@ -55,14 +51,10 @@ receiver = robot.getDevice("receiver"); receiver.enable(timestep)
|
||||
emitter = robot.getDevice("emitter")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
# --- Helpers ---
|
||||
|
||||
def bearing():
|
||||
# Compass returns north direction in sensor frame; for this Z-up world
|
||||
# with north = +Y, atan2(n[0], n[1]) gives the standard math angle
|
||||
# (0 = east, π/2 = north) matching atan2(fy, fx) used for headings.
|
||||
"""World-frame heading (0 = east, π/2 = north)."""
|
||||
n = compass.getValues()
|
||||
return math.atan2(n[0], n[1])
|
||||
|
||||
@@ -76,45 +68,36 @@ def drive(heading, speed_motor):
|
||||
|
||||
|
||||
def paint_pink():
|
||||
# woolColor is declared as a PROTO field with IS binding to the DEF WOOL
|
||||
# PBRAppearance baseColor; setting it propagates to every USE WOOL shape.
|
||||
"""Switch the sheep's wool to pink via the exposed PROTO field."""
|
||||
self_node.getField("woolColor").setSFColor([1.0, 0.55, 0.72])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# --- State ---
|
||||
wander_angle = random.uniform(-math.pi, math.pi)
|
||||
step_count = 0
|
||||
dog_x, dog_y = None, None
|
||||
peers = {} # name → (x, y), one entry per neighbour, cleared every 30 steps
|
||||
peers = {} # name → (x, y); periodically pruned
|
||||
penned = False
|
||||
|
||||
# Stuck detection: differential-drive sheep can pin against a wall and need
|
||||
# a forced reverse-and-rotate to escape. If displacement < STUCK_DIST for
|
||||
# STUCK_STEPS consecutive steps, drive toward field centre.
|
||||
# Safety net for differential-drive sheep pinned against a wall.
|
||||
_prev_x, _prev_y = None, None
|
||||
_stuck_count = 0
|
||||
STUCK_STEPS = 20
|
||||
STUCK_DIST = 0.05
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# --- Main loop ---
|
||||
while robot.step(timestep) != -1:
|
||||
step_count += 1
|
||||
pos = gps.getValues()
|
||||
x, y = pos[0], pos[1]
|
||||
|
||||
# Pen entry: one-way latch. Penned sheep get pink wool and switch behaviour.
|
||||
if not penned and is_penned_position(x, y):
|
||||
penned = True
|
||||
paint_pink()
|
||||
|
||||
# Refresh peer table — clear before receiving so fresh data is never lost.
|
||||
# Stale peers get dropped periodically so a peer that's gone silent
|
||||
# doesn't permanently distort the local CoM.
|
||||
if step_count % 30 == 0:
|
||||
peers.clear()
|
||||
while receiver.getQueueLength() > 0:
|
||||
@@ -132,12 +115,12 @@ while robot.step(timestep) != -1:
|
||||
wander_angle=wander_angle,
|
||||
)
|
||||
|
||||
# Stuck detection — safety net for differential-drive wall pinning.
|
||||
# Stuck-against-wall recovery: drive toward the field centre.
|
||||
if _prev_x is not None:
|
||||
moved = math.hypot(x - _prev_x, y - _prev_y)
|
||||
_stuck_count = _stuck_count + 1 if moved < STUCK_DIST else 0
|
||||
if _stuck_count >= STUCK_STEPS:
|
||||
heading = math.atan2(-y, -x) # always points away from the boundary
|
||||
heading = math.atan2(-y, -x)
|
||||
speed = MAX_SPEED
|
||||
_stuck_count = 0
|
||||
_prev_x, _prev_y = x, y
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
"""Lazy loader for the SB3 PPO policy used by the dog controller.
|
||||
"""Lazy SB3 policy loader for the dog controller.
|
||||
|
||||
Importing stable-baselines3 inside the Webots Python interpreter is only
|
||||
needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This
|
||||
loader keeps SB3 out of the import path until you actually ask for the RL
|
||||
policy, so users without SB3 installed can still run the Strömbom
|
||||
baseline.
|
||||
SB3 is imported only when a learned policy is actually requested,
|
||||
so the analytic modes can run on installs without stable-baselines3
|
||||
or torch.
|
||||
|
||||
The policy + VecNormalize statistics are saved together by
|
||||
``training/train_ppo.py``:
|
||||
|
||||
runs/<name>/best/best_model.zip # SB3 PPO checkpoint
|
||||
runs/<name>/best/vecnormalize.pkl # observation-normaliser stats
|
||||
|
||||
Pass either the directory or the explicit zip path.
|
||||
The handle auto-detects frame stacking from the policy's expected
|
||||
observation dimension: if it's a multiple of the single-frame
|
||||
``OBS_DIM``, an internal buffer of the last K frames is maintained
|
||||
and concatenated on each ``predict`` call.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -20,20 +15,12 @@ from pathlib import Path
|
||||
|
||||
|
||||
class PolicyHandle:
|
||||
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
|
||||
``predict(obs)`` without thinking about either.
|
||||
|
||||
Frame stacking is auto-detected from the policy's expected obs dim:
|
||||
if it's a multiple of the single-frame ``OBS_DIM``, the handle keeps
|
||||
a deque of the last K frames and concatenates them on each predict.
|
||||
"""
|
||||
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``."""
|
||||
|
||||
def __init__(self, model, vecnorm):
|
||||
self.model = model
|
||||
self.vecnorm = vecnorm
|
||||
# Lazy import to avoid forcing herding/* into the import path
|
||||
# when SB3 isn't being used.
|
||||
from herding.obs import OBS_DIM
|
||||
from herding.perception.obs import OBS_DIM
|
||||
policy_dim = int(model.observation_space.shape[0])
|
||||
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
||||
self.frame_stack = policy_dim // OBS_DIM
|
||||
@@ -46,7 +33,7 @@ class PolicyHandle:
|
||||
import numpy as np
|
||||
single = np.asarray(obs, dtype=np.float32).reshape(-1)
|
||||
if single.shape[0] != self._single_dim:
|
||||
# Caller already passed a stacked obs — use as-is.
|
||||
# Caller passed an already-stacked obs.
|
||||
stacked = single
|
||||
elif self.frame_stack > 1:
|
||||
if not self._buffer:
|
||||
@@ -67,18 +54,19 @@ class PolicyHandle:
|
||||
|
||||
|
||||
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
||||
"""Load a PPO model (and optional VecNormalize) from disk.
|
||||
"""Load a policy zip (+ optional VecNormalize pickle) from disk.
|
||||
|
||||
``model_path`` may be the .zip checkpoint or a directory containing
|
||||
``best_model.zip`` (and optionally ``vecnormalize.pkl``).
|
||||
``model_path`` may be a ``.zip`` file or a directory; in the
|
||||
latter case ``policy.zip`` is preferred, with ``final.zip`` as
|
||||
a fallback for partially-completed RL runs.
|
||||
"""
|
||||
p = Path(model_path)
|
||||
if p.is_dir():
|
||||
zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"]
|
||||
zip_candidates = [p / "policy.zip", p / "final.zip"]
|
||||
zip_path = next((z for z in zip_candidates if z.exists()), None)
|
||||
if zip_path is None:
|
||||
raise FileNotFoundError(
|
||||
f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)"
|
||||
f"No policy zip in {p} (looked for policy.zip, final.zip)"
|
||||
)
|
||||
if vecnorm_path is None:
|
||||
vn = p / "vecnormalize.pkl"
|
||||
@@ -87,15 +75,13 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
||||
else:
|
||||
zip_path = p
|
||||
|
||||
# Imports deferred so the Strömbom path doesn't require SB3.
|
||||
# Deferred imports so the analytic path doesn't require SB3.
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
|
||||
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
vecnorm = None
|
||||
if vecnorm_path and os.path.exists(vecnorm_path):
|
||||
# VecNormalize.load needs a venv to attach to; we only need its stats
|
||||
# at inference, so we reconstruct the wrapper manually.
|
||||
import pickle
|
||||
with open(vecnorm_path, "rb") as f:
|
||||
vecnorm = pickle.load(f)
|
||||
|
||||
@@ -57,7 +57,7 @@ from herding.control.active_scan import ActiveScanTeacher
|
||||
from herding.control.modulation import modulate_speed_near_sheep
|
||||
from herding.control.sequential import compute_action as sequential_action
|
||||
from herding.control.strombom import compute_action as strombom_action
|
||||
from herding.obs import build_obs
|
||||
from herding.perception.obs import build_obs
|
||||
from herding.perception.lidar_perception import detections_from_scan
|
||||
from herding.perception.sheep_tracker import SheepTracker
|
||||
from herding.world.diffdrive import velocity_to_wheels
|
||||
|
||||
Reference in New Issue
Block a user