Checkpoint 2
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
"""Lazy loader for the SB3 PPO policy used by the dog controller.
|
||||
|
||||
Importing stable-baselines3 inside the Webots Python interpreter is only
|
||||
needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This
|
||||
loader keeps SB3 out of the import path until you actually ask for the RL
|
||||
policy, so users without SB3 installed can still run the Strömbom
|
||||
baseline.
|
||||
|
||||
The policy + VecNormalize statistics are saved together by
|
||||
``training/train_ppo.py``:
|
||||
|
||||
runs/<name>/best/best_model.zip # SB3 PPO checkpoint
|
||||
runs/<name>/best/vecnormalize.pkl # observation-normaliser stats
|
||||
|
||||
Pass either the directory or the explicit zip path.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class PolicyHandle:
|
||||
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
|
||||
``predict(obs)`` without thinking about either."""
|
||||
|
||||
def __init__(self, model, vecnorm):
|
||||
self.model = model
|
||||
self.vecnorm = vecnorm
|
||||
|
||||
def predict(self, obs):
|
||||
# VecNormalize expects a batched obs of shape (n_envs, obs_dim).
|
||||
if self.vecnorm is not None:
|
||||
import numpy as np
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
obs_b = self.vecnorm.normalize_obs(obs_b)
|
||||
else:
|
||||
import numpy as np
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
action, _ = self.model.predict(obs_b, deterministic=True)
|
||||
return action[0]
|
||||
|
||||
|
||||
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
||||
"""Load a PPO model (and optional VecNormalize) from disk.
|
||||
|
||||
``model_path`` may be the .zip checkpoint or a directory containing
|
||||
``best_model.zip`` (and optionally ``vecnormalize.pkl``).
|
||||
"""
|
||||
p = Path(model_path)
|
||||
if p.is_dir():
|
||||
zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"]
|
||||
zip_path = next((z for z in zip_candidates if z.exists()), None)
|
||||
if zip_path is None:
|
||||
raise FileNotFoundError(
|
||||
f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)"
|
||||
)
|
||||
if vecnorm_path is None:
|
||||
vn = p / "vecnormalize.pkl"
|
||||
if vn.exists():
|
||||
vecnorm_path = str(vn)
|
||||
else:
|
||||
zip_path = p
|
||||
|
||||
# Imports deferred so the Strömbom path doesn't require SB3.
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
vecnorm = None
|
||||
if vecnorm_path and os.path.exists(vecnorm_path):
|
||||
# VecNormalize.load needs a venv to attach to; we only need its stats
|
||||
# at inference, so we reconstruct the wrapper manually.
|
||||
import pickle
|
||||
with open(vecnorm_path, "rb") as f:
|
||||
vecnorm = pickle.load(f)
|
||||
vecnorm.training = False
|
||||
vecnorm.norm_reward = False
|
||||
return PolicyHandle(model=model, vecnorm=vecnorm)
|
||||
@@ -1,88 +1,283 @@
|
||||
"""
|
||||
Shepherd Dog controller (Webots, manual keyboard control).
|
||||
"""Shepherd Dog controller (Webots).
|
||||
|
||||
WASD / arrow keys drive the robot. +/- adjust speed in 10 % increments.
|
||||
GPS position is broadcast every step on channel 1 so sheep controllers
|
||||
can compute flee forces. Ears wag continuously via sinusoidal position
|
||||
targets — purely cosmetic.
|
||||
Runs in one of two modes selected by the ``HERDING_MODE`` environment
|
||||
variable:
|
||||
|
||||
HERDING_MODE=rl → load an SB3 PPO policy from
|
||||
HERDING_POLICY_DIR (default
|
||||
training/runs/latest/best) and use its
|
||||
(vx, vy) action each step.
|
||||
HERDING_MODE=strombom → use the analytic Strömbom collect/drive
|
||||
heuristic. This is the fallback if the RL
|
||||
policy can't be loaded (e.g. SB3 not
|
||||
installed in the Webots Python env, or no
|
||||
checkpoint yet).
|
||||
|
||||
Both modes share the same low-level differential-drive controller
|
||||
(``herding.diffdrive.velocity_to_wheels`` + clamped forward speed), so
|
||||
switching modes does not retune the actuation layer.
|
||||
|
||||
A safety supervisor enforces the "dog stays out of the pen" invariant:
|
||||
if the action would push the dog past ``DOG_SOUTH_LIMIT`` it is
|
||||
overridden with a north-driving correction. This is a hard guarantee
|
||||
the policy cannot escape.
|
||||
"""
|
||||
|
||||
import math
|
||||
from controller import Robot, Keyboard
|
||||
import os
|
||||
import sys
|
||||
|
||||
robot = Robot()
|
||||
# --- Make the shared herding/ package importable from this controller dir ---
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..", ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
from controller import Robot
|
||||
|
||||
from herding.diffdrive import velocity_to_wheels
|
||||
from herding.geometry import (
|
||||
DOG_MAX_LINEAR, DOG_MAX_WHEEL_OMEGA,
|
||||
DOG_SOUTH_LIMIT, DOG_WHEEL_RADIUS,
|
||||
PEN_ENTRY,
|
||||
)
|
||||
from herding.obs import build_obs
|
||||
from herding.sequential import compute_action_debug as sequential_action_debug
|
||||
from herding.strombom import compute_action_debug as strombom_action_debug
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode selection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_runtime_config():
|
||||
"""Read mode + policy_dir overrides from a runtime config file.
|
||||
|
||||
Webots strips HERDING_* env vars in some configurations, so the
|
||||
launcher writes a tiny ``herding_runtime.cfg`` (key=value lines)
|
||||
in the project root and the controller reads it here. Env vars
|
||||
win if both are present; the file is the fallback.
|
||||
"""
|
||||
cfg_path = os.path.join(_PROJECT_ROOT, "herding_runtime.cfg")
|
||||
if not os.path.exists(cfg_path):
|
||||
return {}
|
||||
out = {}
|
||||
try:
|
||||
with open(cfg_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
k, _, v = line.partition("=")
|
||||
out[k.strip().upper()] = v.strip()
|
||||
except OSError:
|
||||
return {}
|
||||
return out
|
||||
|
||||
|
||||
_runtime_cfg = _load_runtime_config()
|
||||
MODE = (os.environ.get("HERDING_MODE")
|
||||
or _runtime_cfg.get("HERDING_MODE")
|
||||
or "rl").lower()
|
||||
|
||||
|
||||
def _resolve_policy_dir() -> str:
|
||||
"""Where to look for the trained policy.
|
||||
|
||||
Priority:
|
||||
1. HERDING_POLICY_DIR env var (if set and points to a real dir)
|
||||
2. training/runs/bc_pretrained/ (BC-only checkpoint)
|
||||
3. training/runs/bc_ppo/best/ (PPO fine-tuned best)
|
||||
4. training/runs/latest/best/ (legacy default)
|
||||
"""
|
||||
env_dir = (os.environ.get("HERDING_POLICY_DIR")
|
||||
or _runtime_cfg.get("HERDING_POLICY_DIR"))
|
||||
if env_dir and os.path.isdir(env_dir):
|
||||
return env_dir
|
||||
candidates = [
|
||||
os.path.join(_PROJECT_ROOT, "training", "runs", "bc_pretrained"),
|
||||
os.path.join(_PROJECT_ROOT, "training", "runs", "bc_ppo", "best"),
|
||||
os.path.join(_PROJECT_ROOT, "training", "runs", "latest", "best"),
|
||||
]
|
||||
for c in candidates:
|
||||
if os.path.isdir(c):
|
||||
return c
|
||||
# Last resort — return env var anyway so error message is informative.
|
||||
return env_dir or candidates[0]
|
||||
|
||||
|
||||
POLICY_DIR = _resolve_policy_dir()
|
||||
|
||||
policy_handle = None
|
||||
if MODE == "rl":
|
||||
print(f"[dog] HERDING_MODE={MODE} HERDING_POLICY_DIR(env)="
|
||||
f"{os.environ.get('HERDING_POLICY_DIR', '<unset>')}")
|
||||
print(f"[dog] resolved POLICY_DIR={POLICY_DIR} exists="
|
||||
f"{os.path.isdir(POLICY_DIR)}")
|
||||
if os.path.isdir(POLICY_DIR):
|
||||
try:
|
||||
entries = sorted(os.listdir(POLICY_DIR))
|
||||
except OSError:
|
||||
entries = []
|
||||
print(f"[dog] dir contents: {entries}")
|
||||
try:
|
||||
from policy_loader import load as _load_policy
|
||||
policy_handle = _load_policy(POLICY_DIR)
|
||||
print(f"[dog] RL policy loaded from {POLICY_DIR}")
|
||||
except Exception as exc:
|
||||
print(f"[dog] RL policy load failed ({exc!r}); falling back to Strömbom.")
|
||||
MODE = "strombom"
|
||||
if MODE not in ("rl", "strombom", "sequential"):
|
||||
print(f"[dog] unknown HERDING_MODE={MODE!r}; defaulting to strombom.")
|
||||
MODE = "strombom"
|
||||
print(f"[dog] running in mode={MODE}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Action smoothing + safety supervisor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ACTION_SMOOTH = 0.35
|
||||
prev_action = (0.0, 0.0)
|
||||
|
||||
|
||||
def safety_clamp(vx: float, vy: float, dog_x: float, dog_y: float) -> tuple:
|
||||
"""If the dog is near the south barrier and the action would push it
|
||||
further south, override with a northward action. Hard invariant: the
|
||||
dog never enters the pen."""
|
||||
if dog_y < DOG_SOUTH_LIMIT and vy < 0.0:
|
||||
return (0.0, 1.0)
|
||||
if dog_y < DOG_SOUTH_LIMIT + 0.5 and vy < -0.2:
|
||||
return (vx * 0.5, max(0.0, vy + 0.5))
|
||||
return (vx, vy)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Driving
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def drive(vx: float, vy: float, left_motor, right_motor, compass, motor_max: float):
|
||||
if math.hypot(vx, vy) < 1e-3:
|
||||
left_motor.setVelocity(0.0)
|
||||
right_motor.setVelocity(0.0)
|
||||
return
|
||||
n = compass.getValues()
|
||||
h = math.atan2(n[0], n[1])
|
||||
left, right = velocity_to_wheels(
|
||||
vx, vy, h,
|
||||
max_linear=DOG_MAX_LINEAR,
|
||||
wheel_radius=DOG_WHEEL_RADIUS,
|
||||
max_wheel_omega=motor_max,
|
||||
k_turn=4.0,
|
||||
)
|
||||
left_motor.setVelocity(left)
|
||||
right_motor.setVelocity(right)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Webots devices
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
robot = Robot()
|
||||
timestep = int(robot.getBasicTimeStep())
|
||||
|
||||
left_motor = robot.getDevice("left wheel motor")
|
||||
left_motor = robot.getDevice("left wheel motor")
|
||||
right_motor = robot.getDevice("right wheel motor")
|
||||
left_motor.setPosition(float("inf"))
|
||||
right_motor.setPosition(float("inf"))
|
||||
left_motor.setVelocity(0.0)
|
||||
right_motor.setVelocity(0.0)
|
||||
MOTOR_MAX = min(left_motor.getMaxVelocity(), DOG_MAX_WHEEL_OMEGA)
|
||||
|
||||
lidar = robot.getDevice("lidar")
|
||||
lidar.enable(timestep)
|
||||
lidar.enablePointCloud()
|
||||
|
||||
gps = robot.getDevice("gps"); gps.enable(timestep)
|
||||
compass = robot.getDevice("compass"); compass.enable(timestep)
|
||||
emitter = robot.getDevice("emitter")
|
||||
gps = robot.getDevice("gps"); gps.enable(timestep)
|
||||
compass = robot.getDevice("compass"); compass.enable(timestep)
|
||||
receiver = robot.getDevice("receiver"); receiver.enable(timestep)
|
||||
emitter = robot.getDevice("emitter")
|
||||
|
||||
left_ear = robot.getDevice("left ear motor")
|
||||
# Cosmetic ear motors — ignored by control logic but keep them animated.
|
||||
left_ear = robot.getDevice("left ear motor")
|
||||
right_ear = robot.getDevice("right ear motor")
|
||||
left_ear.setPosition(float("inf"))
|
||||
right_ear.setPosition(float("inf"))
|
||||
left_ear.setVelocity(0.0)
|
||||
right_ear.setVelocity(0.0)
|
||||
ear_phase = 0.0
|
||||
EAR_AMPLITUDE = 0.35
|
||||
EAR_RATE = 8.0
|
||||
|
||||
keyboard = robot.getKeyboard()
|
||||
keyboard.enable(timestep)
|
||||
|
||||
MOTOR_MAX = left_motor.getMaxVelocity()
|
||||
speed_level = 0.5 # fraction of MOTOR_MAX; adjusted by +/-
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
EAR_AMPLITUDE = 0.35 # rad, peak ear deflection
|
||||
EAR_RATE = 8.0 # rad/s, how fast the ears are driven
|
||||
ear_phase = 0.0
|
||||
# {name: (x, y)} — kept across all sheep ever heard from. Sheep that drift
|
||||
# into the pen are tracked by ``penned`` so observations and Strömbom
|
||||
# agree on which ones still need herding.
|
||||
sheep_positions: dict = {}
|
||||
penned_set: set = set()
|
||||
step_count = 0
|
||||
|
||||
from herding.geometry import is_penned_position
|
||||
|
||||
while robot.step(timestep) != -1:
|
||||
speed = MOTOR_MAX * speed_level
|
||||
turn = speed * 0.6 # differential turn radius
|
||||
step_count += 1
|
||||
|
||||
left_vel = 0.0
|
||||
right_vel = 0.0
|
||||
key = keyboard.getKey()
|
||||
while key > 0:
|
||||
if key in (ord('W'), Keyboard.UP):
|
||||
left_vel = speed
|
||||
right_vel = speed
|
||||
elif key in (ord('S'), Keyboard.DOWN):
|
||||
left_vel = -speed
|
||||
right_vel = -speed
|
||||
elif key in (ord('A'), Keyboard.LEFT):
|
||||
left_vel = -turn
|
||||
right_vel = turn
|
||||
elif key in (ord('D'), Keyboard.RIGHT):
|
||||
left_vel = turn
|
||||
right_vel = -turn
|
||||
elif key in (ord('+'), ord('=')):
|
||||
speed_level = min(1.0, speed_level + 0.1)
|
||||
print(f"Speed: {speed_level:.0%} ({MOTOR_MAX * speed_level:.1f} rad/s)")
|
||||
elif key in (ord('-'), ord('_')):
|
||||
speed_level = max(0.1, speed_level - 0.1)
|
||||
print(f"Speed: {speed_level:.0%} ({MOTOR_MAX * speed_level:.1f} rad/s)")
|
||||
key = keyboard.getKey()
|
||||
|
||||
left_motor.setVelocity(left_vel)
|
||||
right_motor.setVelocity(right_vel)
|
||||
while receiver.getQueueLength() > 0:
|
||||
msg = receiver.getString()
|
||||
receiver.nextPacket()
|
||||
parts = msg.split(":")
|
||||
if len(parts) == 4 and parts[0] == "sheep":
|
||||
try:
|
||||
x, y = float(parts[2]), float(parts[3])
|
||||
except ValueError:
|
||||
continue
|
||||
sheep_positions[parts[1]] = (x, y)
|
||||
if parts[1] not in penned_set and is_penned_position(x, y):
|
||||
penned_set.add(parts[1])
|
||||
|
||||
pos = gps.getValues()
|
||||
emitter.send(f"dog:{pos[0]}:{pos[1]}")
|
||||
dog_xy = (pos[0], pos[1])
|
||||
n = compass.getValues()
|
||||
dog_heading = math.atan2(n[0], n[1])
|
||||
|
||||
# ---- Action selection ----
|
||||
if MODE == "rl" and policy_handle is not None:
|
||||
sheep_xy_list = list(sheep_positions.values())
|
||||
sheep_names = list(sheep_positions.keys())
|
||||
sheep_penned_list = [s in penned_set for s in sheep_names]
|
||||
obs = build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list)
|
||||
action = policy_handle.predict(obs)
|
||||
vx, vy = float(action[0]), float(action[1])
|
||||
elif MODE == "sequential":
|
||||
vx, vy, _mode_str, _dbg = sequential_action_debug(
|
||||
dog_xy, sheep_positions, PEN_ENTRY,
|
||||
)
|
||||
else:
|
||||
# Strömbom (canonical baseline).
|
||||
vx, vy, _mode_str, _dbg = strombom_action_debug(
|
||||
dog_xy, sheep_positions, PEN_ENTRY,
|
||||
)
|
||||
|
||||
# EMA smoothing — reduces oscillation from policy or Strömbom flips.
|
||||
vx = ACTION_SMOOTH * prev_action[0] + (1.0 - ACTION_SMOOTH) * vx
|
||||
vy = ACTION_SMOOTH * prev_action[1] + (1.0 - ACTION_SMOOTH) * vy
|
||||
|
||||
# Safety: dog must never enter the pen.
|
||||
vx, vy = safety_clamp(vx, vy, dog_xy[0], dog_xy[1])
|
||||
prev_action = (vx, vy)
|
||||
|
||||
drive(vx, vy, left_motor, right_motor, compass, MOTOR_MAX)
|
||||
emitter.send(f"dog:{dog_xy[0]:.4f}:{dog_xy[1]:.4f}")
|
||||
|
||||
# Cosmetic ear wiggle — purely visual.
|
||||
ear_phase += 0.12
|
||||
ear_pos = EAR_AMPLITUDE * math.sin(ear_phase)
|
||||
left_ear.setVelocity(EAR_RATE)
|
||||
right_ear.setVelocity(EAR_RATE)
|
||||
left_ear.setPosition( ear_pos)
|
||||
left_ear.setPosition(ear_pos)
|
||||
right_ear.setPosition(-ear_pos)
|
||||
|
||||
if step_count % 200 == 0:
|
||||
n_active = sum(1 for s in sheep_positions if s not in penned_set)
|
||||
print(f"[dog mode={MODE}] step={step_count} known={len(sheep_positions)} "
|
||||
f"penned={len(penned_set)} active={n_active} action=({vx:+.2f}, {vy:+.2f})")
|
||||
|
||||
Reference in New Issue
Block a user