Checkpoint 2

This commit is contained in:
Johnny Fernandes
2026-05-07 22:00:10 +01:00
parent 90aa3bbcb4
commit 1bb9415414
37 changed files with 3068 additions and 2912 deletions
+78
View File
@@ -0,0 +1,78 @@
"""Lazy loader for the SB3 PPO policy used by the dog controller.
Importing stable-baselines3 inside the Webots Python interpreter is only
needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This
loader keeps SB3 out of the import path until you actually ask for the RL
policy, so users without SB3 installed can still run the Strömbom
baseline.
The policy + VecNormalize statistics are saved together by
``training/train_ppo.py``:
runs/<name>/best/best_model.zip # SB3 PPO checkpoint
runs/<name>/best/vecnormalize.pkl # observation-normaliser stats
Pass either the directory or the explicit zip path.
"""
import os
from pathlib import Path
class PolicyHandle:
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
``predict(obs)`` without thinking about either."""
def __init__(self, model, vecnorm):
self.model = model
self.vecnorm = vecnorm
def predict(self, obs):
# VecNormalize expects a batched obs of shape (n_envs, obs_dim).
if self.vecnorm is not None:
import numpy as np
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
obs_b = self.vecnorm.normalize_obs(obs_b)
else:
import numpy as np
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
action, _ = self.model.predict(obs_b, deterministic=True)
return action[0]
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
"""Load a PPO model (and optional VecNormalize) from disk.
``model_path`` may be the .zip checkpoint or a directory containing
``best_model.zip`` (and optionally ``vecnormalize.pkl``).
"""
p = Path(model_path)
if p.is_dir():
zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"]
zip_path = next((z for z in zip_candidates if z.exists()), None)
if zip_path is None:
raise FileNotFoundError(
f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)"
)
if vecnorm_path is None:
vn = p / "vecnormalize.pkl"
if vn.exists():
vecnorm_path = str(vn)
else:
zip_path = p
# Imports deferred so the Strömbom path doesn't require SB3.
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
model = PPO.load(str(zip_path), device="auto")
vecnorm = None
if vecnorm_path and os.path.exists(vecnorm_path):
# VecNormalize.load needs a venv to attach to; we only need its stats
# at inference, so we reconstruct the wrapper manually.
import pickle
with open(vecnorm_path, "rb") as f:
vecnorm = pickle.load(f)
vecnorm.training = False
vecnorm.norm_reward = False
return PolicyHandle(model=model, vecnorm=vecnorm)
+249 -54
View File
@@ -1,88 +1,283 @@
"""
Shepherd Dog controller (Webots, manual keyboard control).
"""Shepherd Dog controller (Webots).
WASD / arrow keys drive the robot. +/- adjust speed in 10 % increments.
GPS position is broadcast every step on channel 1 so sheep controllers
can compute flee forces. Ears wag continuously via sinusoidal position
targets — purely cosmetic.
Runs in one of two modes selected by the ``HERDING_MODE`` environment
variable:
HERDING_MODE=rl → load an SB3 PPO policy from
HERDING_POLICY_DIR (default
training/runs/latest/best) and use its
(vx, vy) action each step.
HERDING_MODE=strombom → use the analytic Strömbom collect/drive
heuristic. This is the fallback if the RL
policy can't be loaded (e.g. SB3 not
installed in the Webots Python env, or no
checkpoint yet).
Both modes share the same low-level differential-drive controller
(``herding.diffdrive.velocity_to_wheels`` + clamped forward speed), so
switching modes does not retune the actuation layer.
A safety supervisor enforces the "dog stays out of the pen" invariant:
if the action would push the dog past ``DOG_SOUTH_LIMIT`` it is
overridden with a north-driving correction. This is a hard guarantee
the policy cannot escape.
"""
import math
from controller import Robot, Keyboard
import os
import sys
robot = Robot()
# --- Make the shared herding/ package importable from this controller dir ---
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..", ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from controller import Robot
from herding.diffdrive import velocity_to_wheels
from herding.geometry import (
DOG_MAX_LINEAR, DOG_MAX_WHEEL_OMEGA,
DOG_SOUTH_LIMIT, DOG_WHEEL_RADIUS,
PEN_ENTRY,
)
from herding.obs import build_obs
from herding.sequential import compute_action_debug as sequential_action_debug
from herding.strombom import compute_action_debug as strombom_action_debug
# ---------------------------------------------------------------------------
# Mode selection
# ---------------------------------------------------------------------------
def _load_runtime_config():
"""Read mode + policy_dir overrides from a runtime config file.
Webots strips HERDING_* env vars in some configurations, so the
launcher writes a tiny ``herding_runtime.cfg`` (key=value lines)
in the project root and the controller reads it here. Env vars
win if both are present; the file is the fallback.
"""
cfg_path = os.path.join(_PROJECT_ROOT, "herding_runtime.cfg")
if not os.path.exists(cfg_path):
return {}
out = {}
try:
with open(cfg_path) as f:
for line in f:
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
k, _, v = line.partition("=")
out[k.strip().upper()] = v.strip()
except OSError:
return {}
return out
_runtime_cfg = _load_runtime_config()
MODE = (os.environ.get("HERDING_MODE")
or _runtime_cfg.get("HERDING_MODE")
or "rl").lower()
def _resolve_policy_dir() -> str:
"""Where to look for the trained policy.
Priority:
1. HERDING_POLICY_DIR env var (if set and points to a real dir)
2. training/runs/bc_pretrained/ (BC-only checkpoint)
3. training/runs/bc_ppo/best/ (PPO fine-tuned best)
4. training/runs/latest/best/ (legacy default)
"""
env_dir = (os.environ.get("HERDING_POLICY_DIR")
or _runtime_cfg.get("HERDING_POLICY_DIR"))
if env_dir and os.path.isdir(env_dir):
return env_dir
candidates = [
os.path.join(_PROJECT_ROOT, "training", "runs", "bc_pretrained"),
os.path.join(_PROJECT_ROOT, "training", "runs", "bc_ppo", "best"),
os.path.join(_PROJECT_ROOT, "training", "runs", "latest", "best"),
]
for c in candidates:
if os.path.isdir(c):
return c
# Last resort — return env var anyway so error message is informative.
return env_dir or candidates[0]
POLICY_DIR = _resolve_policy_dir()
policy_handle = None
if MODE == "rl":
print(f"[dog] HERDING_MODE={MODE} HERDING_POLICY_DIR(env)="
f"{os.environ.get('HERDING_POLICY_DIR', '<unset>')}")
print(f"[dog] resolved POLICY_DIR={POLICY_DIR} exists="
f"{os.path.isdir(POLICY_DIR)}")
if os.path.isdir(POLICY_DIR):
try:
entries = sorted(os.listdir(POLICY_DIR))
except OSError:
entries = []
print(f"[dog] dir contents: {entries}")
try:
from policy_loader import load as _load_policy
policy_handle = _load_policy(POLICY_DIR)
print(f"[dog] RL policy loaded from {POLICY_DIR}")
except Exception as exc:
print(f"[dog] RL policy load failed ({exc!r}); falling back to Strömbom.")
MODE = "strombom"
if MODE not in ("rl", "strombom", "sequential"):
print(f"[dog] unknown HERDING_MODE={MODE!r}; defaulting to strombom.")
MODE = "strombom"
print(f"[dog] running in mode={MODE}")
# ---------------------------------------------------------------------------
# Action smoothing + safety supervisor
# ---------------------------------------------------------------------------
ACTION_SMOOTH = 0.35
prev_action = (0.0, 0.0)
def safety_clamp(vx: float, vy: float, dog_x: float, dog_y: float) -> tuple:
"""If the dog is near the south barrier and the action would push it
further south, override with a northward action. Hard invariant: the
dog never enters the pen."""
if dog_y < DOG_SOUTH_LIMIT and vy < 0.0:
return (0.0, 1.0)
if dog_y < DOG_SOUTH_LIMIT + 0.5 and vy < -0.2:
return (vx * 0.5, max(0.0, vy + 0.5))
return (vx, vy)
# ---------------------------------------------------------------------------
# Driving
# ---------------------------------------------------------------------------
def drive(vx: float, vy: float, left_motor, right_motor, compass, motor_max: float):
if math.hypot(vx, vy) < 1e-3:
left_motor.setVelocity(0.0)
right_motor.setVelocity(0.0)
return
n = compass.getValues()
h = math.atan2(n[0], n[1])
left, right = velocity_to_wheels(
vx, vy, h,
max_linear=DOG_MAX_LINEAR,
wheel_radius=DOG_WHEEL_RADIUS,
max_wheel_omega=motor_max,
k_turn=4.0,
)
left_motor.setVelocity(left)
right_motor.setVelocity(right)
# ---------------------------------------------------------------------------
# Webots devices
# ---------------------------------------------------------------------------
robot = Robot()
timestep = int(robot.getBasicTimeStep())
left_motor = robot.getDevice("left wheel motor")
left_motor = robot.getDevice("left wheel motor")
right_motor = robot.getDevice("right wheel motor")
left_motor.setPosition(float("inf"))
right_motor.setPosition(float("inf"))
left_motor.setVelocity(0.0)
right_motor.setVelocity(0.0)
MOTOR_MAX = min(left_motor.getMaxVelocity(), DOG_MAX_WHEEL_OMEGA)
lidar = robot.getDevice("lidar")
lidar.enable(timestep)
lidar.enablePointCloud()
gps = robot.getDevice("gps"); gps.enable(timestep)
compass = robot.getDevice("compass"); compass.enable(timestep)
emitter = robot.getDevice("emitter")
gps = robot.getDevice("gps"); gps.enable(timestep)
compass = robot.getDevice("compass"); compass.enable(timestep)
receiver = robot.getDevice("receiver"); receiver.enable(timestep)
emitter = robot.getDevice("emitter")
left_ear = robot.getDevice("left ear motor")
# Cosmetic ear motors — ignored by control logic but keep them animated.
left_ear = robot.getDevice("left ear motor")
right_ear = robot.getDevice("right ear motor")
left_ear.setPosition(float("inf"))
right_ear.setPosition(float("inf"))
left_ear.setVelocity(0.0)
right_ear.setVelocity(0.0)
ear_phase = 0.0
EAR_AMPLITUDE = 0.35
EAR_RATE = 8.0
keyboard = robot.getKeyboard()
keyboard.enable(timestep)
MOTOR_MAX = left_motor.getMaxVelocity()
speed_level = 0.5 # fraction of MOTOR_MAX; adjusted by +/-
# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------
EAR_AMPLITUDE = 0.35 # rad, peak ear deflection
EAR_RATE = 8.0 # rad/s, how fast the ears are driven
ear_phase = 0.0
# {name: (x, y)} — kept across all sheep ever heard from. Sheep that drift
# into the pen are tracked by ``penned`` so observations and Strömbom
# agree on which ones still need herding.
sheep_positions: dict = {}
penned_set: set = set()
step_count = 0
from herding.geometry import is_penned_position
while robot.step(timestep) != -1:
speed = MOTOR_MAX * speed_level
turn = speed * 0.6 # differential turn radius
step_count += 1
left_vel = 0.0
right_vel = 0.0
key = keyboard.getKey()
while key > 0:
if key in (ord('W'), Keyboard.UP):
left_vel = speed
right_vel = speed
elif key in (ord('S'), Keyboard.DOWN):
left_vel = -speed
right_vel = -speed
elif key in (ord('A'), Keyboard.LEFT):
left_vel = -turn
right_vel = turn
elif key in (ord('D'), Keyboard.RIGHT):
left_vel = turn
right_vel = -turn
elif key in (ord('+'), ord('=')):
speed_level = min(1.0, speed_level + 0.1)
print(f"Speed: {speed_level:.0%} ({MOTOR_MAX * speed_level:.1f} rad/s)")
elif key in (ord('-'), ord('_')):
speed_level = max(0.1, speed_level - 0.1)
print(f"Speed: {speed_level:.0%} ({MOTOR_MAX * speed_level:.1f} rad/s)")
key = keyboard.getKey()
left_motor.setVelocity(left_vel)
right_motor.setVelocity(right_vel)
while receiver.getQueueLength() > 0:
msg = receiver.getString()
receiver.nextPacket()
parts = msg.split(":")
if len(parts) == 4 and parts[0] == "sheep":
try:
x, y = float(parts[2]), float(parts[3])
except ValueError:
continue
sheep_positions[parts[1]] = (x, y)
if parts[1] not in penned_set and is_penned_position(x, y):
penned_set.add(parts[1])
pos = gps.getValues()
emitter.send(f"dog:{pos[0]}:{pos[1]}")
dog_xy = (pos[0], pos[1])
n = compass.getValues()
dog_heading = math.atan2(n[0], n[1])
# ---- Action selection ----
if MODE == "rl" and policy_handle is not None:
sheep_xy_list = list(sheep_positions.values())
sheep_names = list(sheep_positions.keys())
sheep_penned_list = [s in penned_set for s in sheep_names]
obs = build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list)
action = policy_handle.predict(obs)
vx, vy = float(action[0]), float(action[1])
elif MODE == "sequential":
vx, vy, _mode_str, _dbg = sequential_action_debug(
dog_xy, sheep_positions, PEN_ENTRY,
)
else:
# Strömbom (canonical baseline).
vx, vy, _mode_str, _dbg = strombom_action_debug(
dog_xy, sheep_positions, PEN_ENTRY,
)
# EMA smoothing — reduces oscillation from policy or Strömbom flips.
vx = ACTION_SMOOTH * prev_action[0] + (1.0 - ACTION_SMOOTH) * vx
vy = ACTION_SMOOTH * prev_action[1] + (1.0 - ACTION_SMOOTH) * vy
# Safety: dog must never enter the pen.
vx, vy = safety_clamp(vx, vy, dog_xy[0], dog_xy[1])
prev_action = (vx, vy)
drive(vx, vy, left_motor, right_motor, compass, MOTOR_MAX)
emitter.send(f"dog:{dog_xy[0]:.4f}:{dog_xy[1]:.4f}")
# Cosmetic ear wiggle — purely visual.
ear_phase += 0.12
ear_pos = EAR_AMPLITUDE * math.sin(ear_phase)
left_ear.setVelocity(EAR_RATE)
right_ear.setVelocity(EAR_RATE)
left_ear.setPosition( ear_pos)
left_ear.setPosition(ear_pos)
right_ear.setPosition(-ear_pos)
if step_count % 200 == 0:
n_active = sum(1 for s in sheep_positions if s not in penned_set)
print(f"[dog mode={MODE}] step={step_count} known={len(sheep_positions)} "
f"penned={len(penned_set)} active={n_active} action=({vx:+.2f}, {vy:+.2f})")