Files
TIR_PROJ/controllers/shepherd_dog_rl/shepherd_dog_rl.py
T
Johnny Fernandes 1af7d03ce2 Mimic webots physics
2026-04-26 18:22:26 +01:00

286 lines
11 KiB
Python

"""
Shepherd Dog RL controller — runs a trained SB3 PPO policy inside Webots.
Setup
-----
1. Copy your trained files into this directory:
controllers/shepherd_dog_rl/final_model.zip
controllers/shepherd_dog_rl/vecnorm.pkl
2. In field.wbt, set the ShepherdDog robot's controller field to
"shepherd_dog_rl". You can do this in the Webots GUI:
click the robot → Controller → shepherd_dog_rl
3. Optional: set controllerArgs to ["5"] (number of sheep) if it differs
from the default of 5.
The controller reads GPS (dog position) and Receiver (sheep broadcasts),
builds the same 16-dim flock observation the training env used, normalises
it with the saved VecNormalize stats, and converts the (vx, vy) policy
output into differential wheel speeds.
Debug logging
-------------
Set env var DOG_DEBUG=1 to write a per-step CSV (dog pos, sheep positions,
raw obs, normalised obs, action) to debug.csv alongside this script. Use
plot_debug.py to render trajectories from it.
"""
import sys
import os
import math
import struct
import numpy as np
# ── make training code importable ───────────────────────────────────────────
_HERE = os.path.dirname(os.path.abspath(__file__))
_TRAINING = os.path.join(_HERE, "..", "..", "training")
sys.path.insert(0, _TRAINING)
from controller import Robot
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
# ── constants (must match herding_env.py) ───────────────────────────────────
FIELD = 15.0
PEN_CENTER = np.array([11.5, -11.5], dtype=np.float32)
PEN_X = (10.0, 13.0)
PEN_Y = (-15.0, -8.0)
DOG_SPEED = 2.5 # m/s
WHEEL_R = 0.038 # wheel radius (metres) — from ShepherdDog.proto
K_TURN = 4.0 # heading-error gain (rad/s per rad)
EAR_AMPLITUDE = 0.35
EAR_RATE = 8.0
# ── model paths ─────────────────────────────────────────────────────────────
MODEL_PATH = os.path.join(_HERE, "final_model.zip")
VECNORM_PATH = os.path.join(_HERE, "vecnorm.pkl")
DEBUG_CSV = os.path.join(_HERE, "debug.csv")
DEBUG_ENABLED = True # set False to disable debug.csv logging
# ── action smoothing ─────────────────────────────────────────────────────────
# EMA on policy output to suppress the rapid oscillation (vx/vy flipping
# between -1 and +1 every step) that stalls the physical dog. 0 = no
# smoothing (raw policy), 1 = frozen. 0.3 keeps ~30% of previous action.
ACTION_SMOOTH = 0.3
prev_action = np.zeros(2, dtype=np.float32)
def norm_angle(a: float) -> float:
while a > math.pi: a -= 2 * math.pi
while a < -math.pi: a += 2 * math.pi
return a
def in_pen(x: float, y: float) -> bool:
return PEN_X[0] < x < PEN_X[1] and PEN_Y[0] < y < PEN_Y[1]
def build_obs(dog_pos: np.ndarray,
sheep_dict: dict,
n_sheep: int,
dog_heading: float = 0.0) -> np.ndarray:
"""
Build the 18-dim flock observation — identical to HerdingEnv._obs().
sheep_dict: {name: (x, y)} for ALL known sheep (penned or not).
dog_heading: dog's current world-frame heading in radians.
"""
D = 2 * FIELD
# Split active vs penned
active_pos = np.array(
[v for v in sheep_dict.values() if not in_pen(*v)],
dtype=np.float32
)
n_active = len(active_pos)
if n_active > 0:
com = active_pos.mean(axis=0)
d_from_com = np.linalg.norm(active_pos - com, axis=1)
sorted_idx = np.argsort(d_from_com)[::-1]
radius = float(d_from_com[sorted_idx[0]])
def nth(n):
return active_pos[sorted_idx[n]] if len(sorted_idx) > n else com
far1, far2, far3 = nth(0), nth(1), nth(2)
else:
com = PEN_CENTER.copy()
radius = 0.0
far1 = far2 = far3 = PEN_CENTER.copy()
frac_active = n_active / max(n_sheep, 1)
return np.array([
dog_pos[0] / FIELD, dog_pos[1] / FIELD,
(com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D,
(far1[0] - com[0]) / D, (far1[1] - com[1]) / D,
(far2[0] - com[0]) / D, (far2[1] - com[1]) / D,
(far3[0] - com[0]) / D, (far3[1] - com[1]) / D,
(PEN_CENTER[0] - com[0]) / D, (PEN_CENTER[1] - com[1]) / D,
(PEN_CENTER[0] - far1[0]) / D, (PEN_CENTER[1] - far1[1]) / D,
radius / D,
frac_active,
math.cos(dog_heading), math.sin(dog_heading),
], dtype=np.float32)
# ── Webots setup ─────────────────────────────────────────────────────────────
robot = Robot()
timestep = int(robot.getBasicTimeStep())
# Drive motors
left_motor = robot.getDevice("left wheel motor")
right_motor = robot.getDevice("right wheel motor")
left_motor.setPosition(float("inf"))
right_motor.setPosition(float("inf"))
left_motor.setVelocity(0.0)
right_motor.setVelocity(0.0)
MOTOR_MAX = left_motor.getMaxVelocity()
# Sensors
gps = robot.getDevice("gps"); gps.enable(timestep)
compass = robot.getDevice("compass"); compass.enable(timestep)
receiver = robot.getDevice("receiver"); receiver.enable(timestep)
emitter = robot.getDevice("emitter")
# Cosmetic
left_ear = robot.getDevice("left ear motor")
right_ear = robot.getDevice("right ear motor")
left_ear.setPosition(float("inf")); right_ear.setPosition(float("inf"))
left_ear.setVelocity(0.0); right_ear.setVelocity(0.0)
ear_phase = 0.0
# Number of sheep (from controllerArgs or default)
try:
n_sheep = int(sys.argv[1])
except (IndexError, ValueError):
n_sheep = 3
# ── Load model ───────────────────────────────────────────────────────────────
print(f"[RL dog] Loading model from {MODEL_PATH}")
print(f"[RL dog] Loading vecnorm from {VECNORM_PATH}")
dummy_env = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep)])
vecnorm = VecNormalize.load(VECNORM_PATH, dummy_env)
vecnorm.training = False
vecnorm.norm_reward = False
model = PPO.load(MODEL_PATH, device="cpu")
print(f"[RL dog] Model loaded — running with n_sheep={n_sheep}")
# ── Runtime state ─────────────────────────────────────────────────────────────
sheep_positions: dict = {} # {name: (x, y)} — updated every step from receiver
step_count = 0
# Debug CSV — written every step when DOG_DEBUG=1
debug_file = None
if DEBUG_ENABLED:
import csv
debug_file = open(DEBUG_CSV, "w", newline="")
debug_writer = csv.writer(debug_file)
debug_writer.writerow([
"step", "dog_x", "dog_y", "heading",
"sheep_xs", "sheep_ys", "n_active", "n_penned",
"raw_obs", "norm_obs", "vx", "vy",
])
print(f"[RL dog] DEBUG logging to {DEBUG_CSV}")
def bearing() -> float:
"""Current robot heading in world frame (radians)."""
n = compass.getValues()
return math.atan2(n[0], n[1])
def drive(action_vx: float, action_vy: float) -> None:
"""Convert (vx, vy) policy action to differential wheel speeds."""
speed_ms = math.sqrt(action_vx ** 2 + action_vy ** 2) * DOG_SPEED
if speed_ms < 0.05:
left_motor.setVelocity(0.0)
right_motor.setVelocity(0.0)
return
target_heading = math.atan2(action_vy, action_vx)
err = norm_angle(target_heading - bearing())
fwd_ms = speed_ms * max(0.0, math.cos(err))
fwd_rad = fwd_ms / WHEEL_R
turn = K_TURN * err # rad/s correction
l = max(-MOTOR_MAX, min(MOTOR_MAX, fwd_rad - turn))
r = max(-MOTOR_MAX, min(MOTOR_MAX, fwd_rad + turn))
left_motor.setVelocity(l)
right_motor.setVelocity(r)
# ── Main loop ─────────────────────────────────────────────────────────────────
while robot.step(timestep) != -1:
step_count += 1
# 1. Drain receiver — update sheep position table
while receiver.getQueueLength() > 0:
try:
msg = receiver.getString()
parts = msg.split(":")
if parts[0] == "sheep" and len(parts) == 4:
sheep_positions[parts[1]] = (float(parts[2]), float(parts[3]))
except Exception:
pass
receiver.nextPacket()
# 2. Dog GPS
gps_vals = gps.getValues()
dog_pos = np.array([gps_vals[0], gps_vals[1]], dtype=np.float32)
# 3. Build and normalise observation (heading from compass)
raw_obs = build_obs(dog_pos, sheep_positions, n_sheep,
dog_heading=bearing())
obs_norm = vecnorm.normalize_obs(raw_obs[np.newaxis]) # (1, 13)
# 4. Policy inference + smoothing
action, _ = model.predict(obs_norm, deterministic=True)
raw_a = np.array([float(action[0][0]), float(action[0][1])], dtype=np.float32)
if ACTION_SMOOTH > 0:
smoothed = ACTION_SMOOTH * prev_action + (1.0 - ACTION_SMOOTH) * raw_a
prev_action[:] = smoothed
vx, vy = float(smoothed[0]), float(smoothed[1])
else:
vx, vy = float(raw_a[0]), float(raw_a[1])
# 5. Drive
drive(vx, vy)
# 6. Broadcast dog position so sheep can compute flee forces
emitter.send(f"dog:{dog_pos[0]:.4f}:{dog_pos[1]:.4f}")
# 7. Ear animation
ear_phase += 0.12
ep = EAR_AMPLITUDE * math.sin(ear_phase)
left_ear.setVelocity(EAR_RATE); right_ear.setVelocity(EAR_RATE)
left_ear.setPosition( ep); right_ear.setPosition(-ep)
# Periodic status
if step_count % 100 == 0:
n_in_pen = sum(1 for x, y in sheep_positions.values() if in_pen(x, y))
print(f"[RL dog] step={step_count} known_sheep={len(sheep_positions)}"
f" penned={n_in_pen}/{n_sheep} dog=({dog_pos[0]:.2f},{dog_pos[1]:.2f})"
f" action=({vx:.2f}, {vy:.2f})")
# Debug CSV row
if debug_file is not None:
n_active = sum(1 for x, y in sheep_positions.values() if not in_pen(x, y))
n_in_pen = len(sheep_positions) - n_active
debug_writer.writerow([
step_count, f"{dog_pos[0]:.4f}", f"{dog_pos[1]:.4f}",
f"{bearing():.4f}",
";".join(f"{v[0]:.3f}" for v in sheep_positions.values()),
";".join(f"{v[1]:.3f}" for v in sheep_positions.values()),
n_active, n_in_pen,
";".join(f"{x:.4f}" for x in raw_obs),
";".join(f"{x:.4f}" for x in obs_norm[0]),
f"{vx:.4f}", f"{vy:.4f}",
])
if step_count % 200 == 0:
debug_file.flush()