Files
TIR_PROJ/controllers/shepherd_dog_rl/shepherd_dog_rl.py
T
2026-04-24 17:49:42 +01:00

231 lines
8.4 KiB
Python

"""
Shepherd Dog RL controller — runs a trained SB3 PPO policy inside Webots.
Setup
-----
1. Copy your trained files into this directory:
controllers/shepherd_dog_rl/best_model.zip
controllers/shepherd_dog_rl/vecnorm.pkl
2. In field.wbt, set the ShepherdDog robot's controller field to
"shepherd_dog_rl". You can do this in the Webots GUI:
click the robot → Controller → shepherd_dog_rl
3. Optional: set controllerArgs to ["5"] (number of sheep) if it differs
from the default of 5.
The controller reads GPS (dog position) and Receiver (sheep broadcasts),
builds the same 13-dim flock observation the training env used, normalises
it with the saved VecNormalize stats, and converts the (vx, vy) policy
output into differential wheel speeds.
"""
import sys
import os
import math
import struct
import numpy as np
# ── make training code importable ───────────────────────────────────────────
_HERE = os.path.dirname(os.path.abspath(__file__))
_TRAINING = os.path.join(_HERE, "..", "..", "training")
sys.path.insert(0, _TRAINING)
from controller import Robot
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
# ── constants (must match herding_env.py) ───────────────────────────────────
FIELD = 15.0
PEN_CENTER = np.array([11.5, -11.5], dtype=np.float32)
PEN_X = (10.0, 13.0)
PEN_Y = (-15.0, -8.0)
DOG_SPEED = 2.5 # m/s
WHEEL_R = 0.038 # wheel radius (metres) — from ShepherdDog.proto
K_TURN = 4.0 # heading-error gain (rad/s per rad)
EAR_AMPLITUDE = 0.35
EAR_RATE = 8.0
# ── model paths ─────────────────────────────────────────────────────────────
MODEL_PATH = os.path.join(_HERE, "best_model.zip")
VECNORM_PATH = os.path.join(_HERE, "vecnorm.pkl")
def norm_angle(a: float) -> float:
while a > math.pi: a -= 2 * math.pi
while a < -math.pi: a += 2 * math.pi
return a
def in_pen(x: float, y: float) -> bool:
return PEN_X[0] < x < PEN_X[1] and PEN_Y[0] < y < PEN_Y[1]
def build_obs(dog_pos: np.ndarray,
sheep_dict: dict,
n_sheep: int) -> np.ndarray:
"""
Build the 13-dim flock observation — identical to HerdingEnv._obs().
sheep_dict: {name: (x, y)} for ALL known sheep (penned or not).
"""
D = 2 * FIELD
# Split active vs penned
active_pos = np.array(
[v for v in sheep_dict.values() if not in_pen(*v)],
dtype=np.float32
)
n_active = len(active_pos)
if n_active > 0:
com = active_pos.mean(axis=0)
d_from_com = np.linalg.norm(active_pos - com, axis=1)
sorted_idx = np.argsort(d_from_com)[::-1]
radius = float(d_from_com[sorted_idx[0]])
def nth(n):
return active_pos[sorted_idx[n]] if len(sorted_idx) > n else com
far1, far2, far3 = nth(0), nth(1), nth(2)
else:
com = PEN_CENTER.copy()
radius = 0.0
far1 = far2 = far3 = PEN_CENTER.copy()
frac_active = n_active / max(n_sheep, 1)
return np.array([
dog_pos[0] / FIELD, dog_pos[1] / FIELD,
(com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D,
(far1[0] - com[0]) / D, (far1[1] - com[1]) / D,
(far2[0] - com[0]) / D, (far2[1] - com[1]) / D,
(far3[0] - com[0]) / D, (far3[1] - com[1]) / D,
(PEN_CENTER[0] - com[0]) / D, (PEN_CENTER[1] - com[1]) / D,
(PEN_CENTER[0] - far1[0]) / D, (PEN_CENTER[1] - far1[1]) / D,
radius / D,
frac_active,
], dtype=np.float32)
# ── Webots setup ─────────────────────────────────────────────────────────────
robot = Robot()
timestep = int(robot.getBasicTimeStep())
# Drive motors
left_motor = robot.getDevice("left wheel motor")
right_motor = robot.getDevice("right wheel motor")
left_motor.setPosition(float("inf"))
right_motor.setPosition(float("inf"))
left_motor.setVelocity(0.0)
right_motor.setVelocity(0.0)
MOTOR_MAX = left_motor.getMaxVelocity()
# Sensors
gps = robot.getDevice("gps"); gps.enable(timestep)
compass = robot.getDevice("compass"); compass.enable(timestep)
receiver = robot.getDevice("receiver"); receiver.enable(timestep)
emitter = robot.getDevice("emitter")
# Cosmetic
left_ear = robot.getDevice("left ear motor")
right_ear = robot.getDevice("right ear motor")
left_ear.setPosition(float("inf")); right_ear.setPosition(float("inf"))
left_ear.setVelocity(0.0); right_ear.setVelocity(0.0)
ear_phase = 0.0
# Number of sheep (from controllerArgs or default)
try:
n_sheep = int(sys.argv[1])
except (IndexError, ValueError):
n_sheep = 5
# ── Load model ───────────────────────────────────────────────────────────────
print(f"[RL dog] Loading model from {MODEL_PATH}")
print(f"[RL dog] Loading vecnorm from {VECNORM_PATH}")
dummy_env = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep)])
vecnorm = VecNormalize.load(VECNORM_PATH, dummy_env)
vecnorm.training = False
vecnorm.norm_reward = False
model = PPO.load(MODEL_PATH)
print(f"[RL dog] Model loaded — running with n_sheep={n_sheep}")
# ── Runtime state ─────────────────────────────────────────────────────────────
sheep_positions: dict = {} # {name: (x, y)} — updated every step from receiver
step_count = 0
def bearing() -> float:
"""Current robot heading in world frame (radians)."""
n = compass.getValues()
return math.atan2(n[0], n[1])
def drive(action_vx: float, action_vy: float) -> None:
"""Convert (vx, vy) policy action to differential wheel speeds."""
speed_ms = math.sqrt(action_vx ** 2 + action_vy ** 2) * DOG_SPEED
if speed_ms < 0.05:
left_motor.setVelocity(0.0)
right_motor.setVelocity(0.0)
return
target_heading = math.atan2(action_vy, action_vx)
err = norm_angle(target_heading - bearing())
fwd_ms = speed_ms * max(0.0, math.cos(err))
fwd_rad = fwd_ms / WHEEL_R
turn = K_TURN * err # rad/s correction
l = max(-MOTOR_MAX, min(MOTOR_MAX, fwd_rad - turn))
r = max(-MOTOR_MAX, min(MOTOR_MAX, fwd_rad + turn))
left_motor.setVelocity(l)
right_motor.setVelocity(r)
# ── Main loop ─────────────────────────────────────────────────────────────────
while robot.step(timestep) != -1:
step_count += 1
# 1. Drain receiver — update sheep position table
while receiver.getQueueLength() > 0:
try:
msg = receiver.getString()
parts = msg.split(":")
if parts[0] == "sheep" and len(parts) == 4:
sheep_positions[parts[1]] = (float(parts[2]), float(parts[3]))
except Exception:
pass
receiver.nextPacket()
# 2. Dog GPS
gps_vals = gps.getValues()
dog_pos = np.array([gps_vals[0], gps_vals[1]], dtype=np.float32)
# 3. Build and normalise observation
raw_obs = build_obs(dog_pos, sheep_positions, n_sheep)
obs_norm = vecnorm.normalize_obs(raw_obs[np.newaxis]) # (1, 13)
# 4. Policy inference
action, _ = model.predict(obs_norm, deterministic=True)
vx, vy = float(action[0][0]), float(action[0][1])
# 5. Drive
drive(vx, vy)
# 6. Broadcast dog position so sheep can compute flee forces
emitter.send(f"dog:{dog_pos[0]:.4f}:{dog_pos[1]:.4f}")
# 7. Ear animation
ear_phase += 0.12
ep = EAR_AMPLITUDE * math.sin(ear_phase)
left_ear.setVelocity(EAR_RATE); right_ear.setVelocity(EAR_RATE)
left_ear.setPosition( ep); right_ear.setPosition(-ep)
# Periodic status
if step_count % 100 == 0:
n_in_pen = sum(1 for x, y in sheep_positions.values() if in_pen(x, y))
print(f"[RL dog] step={step_count} known_sheep={len(sheep_positions)}"
f" penned={n_in_pen}/{n_sheep}"
f" action=({vx:.2f}, {vy:.2f})")