Shepherd Dog RL
This commit is contained in:
Binary file not shown.
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Shepherd Dog RL controller — runs a trained SB3 PPO policy inside Webots.
|
||||
|
||||
Setup
|
||||
-----
|
||||
1. Copy your trained files into this directory:
|
||||
controllers/shepherd_dog_rl/best_model.zip
|
||||
controllers/shepherd_dog_rl/vecnorm.pkl
|
||||
|
||||
2. In field.wbt, set the ShepherdDog robot's controller field to
|
||||
"shepherd_dog_rl". You can do this in the Webots GUI:
|
||||
click the robot → Controller → shepherd_dog_rl
|
||||
|
||||
3. Optional: set controllerArgs to ["5"] (number of sheep) if it differs
|
||||
from the default of 5.
|
||||
|
||||
The controller reads GPS (dog position) and Receiver (sheep broadcasts),
|
||||
builds the same 13-dim flock observation the training env used, normalises
|
||||
it with the saved VecNormalize stats, and converts the (vx, vy) policy
|
||||
output into differential wheel speeds.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
import struct
|
||||
import numpy as np
|
||||
|
||||
# ── make training code importable ───────────────────────────────────────────
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_TRAINING = os.path.join(_HERE, "..", "..", "training")
|
||||
sys.path.insert(0, _TRAINING)
|
||||
|
||||
from controller import Robot
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||
from herding_env import HerdingEnv
|
||||
|
||||
# ── constants (must match herding_env.py) ───────────────────────────────────
|
||||
FIELD = 15.0
|
||||
PEN_CENTER = np.array([11.5, -11.5], dtype=np.float32)
|
||||
PEN_X = (10.0, 13.0)
|
||||
PEN_Y = (-15.0, -8.0)
|
||||
DOG_SPEED = 2.5 # m/s
|
||||
WHEEL_R = 0.038 # wheel radius (metres) — from ShepherdDog.proto
|
||||
K_TURN = 4.0 # heading-error gain (rad/s per rad)
|
||||
EAR_AMPLITUDE = 0.35
|
||||
EAR_RATE = 8.0
|
||||
|
||||
# ── model paths ─────────────────────────────────────────────────────────────
|
||||
MODEL_PATH = os.path.join(_HERE, "best_model.zip")
|
||||
VECNORM_PATH = os.path.join(_HERE, "vecnorm.pkl")
|
||||
|
||||
|
||||
def norm_angle(a: float) -> float:
|
||||
while a > math.pi: a -= 2 * math.pi
|
||||
while a < -math.pi: a += 2 * math.pi
|
||||
return a
|
||||
|
||||
|
||||
def in_pen(x: float, y: float) -> bool:
|
||||
return PEN_X[0] < x < PEN_X[1] and PEN_Y[0] < y < PEN_Y[1]
|
||||
|
||||
|
||||
def build_obs(dog_pos: np.ndarray,
|
||||
sheep_dict: dict,
|
||||
n_sheep: int) -> np.ndarray:
|
||||
"""
|
||||
Build the 13-dim flock observation — identical to HerdingEnv._obs().
|
||||
|
||||
sheep_dict: {name: (x, y)} for ALL known sheep (penned or not).
|
||||
"""
|
||||
D = 2 * FIELD
|
||||
|
||||
# Split active vs penned
|
||||
active_pos = np.array(
|
||||
[v for v in sheep_dict.values() if not in_pen(*v)],
|
||||
dtype=np.float32
|
||||
)
|
||||
n_active = len(active_pos)
|
||||
|
||||
if n_active > 0:
|
||||
com = active_pos.mean(axis=0)
|
||||
d_from_com = np.linalg.norm(active_pos - com, axis=1)
|
||||
radius = float(d_from_com.max())
|
||||
mean_disp = float(d_from_com.mean())
|
||||
far = active_pos[int(np.argmax(d_from_com))]
|
||||
else:
|
||||
com = PEN_CENTER.copy()
|
||||
radius = mean_disp = 0.0
|
||||
far = PEN_CENTER.copy()
|
||||
|
||||
frac_active = n_active / max(n_sheep, 1)
|
||||
|
||||
return np.array([
|
||||
dog_pos[0] / FIELD, dog_pos[1] / FIELD,
|
||||
(com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D,
|
||||
(far[0] - dog_pos[0]) / D, (far[1] - dog_pos[1]) / D,
|
||||
(PEN_CENTER[0] - com[0]) / D, (PEN_CENTER[1] - com[1]) / D,
|
||||
(PEN_CENTER[0] - far[0]) / D, (PEN_CENTER[1] - far[1]) / D,
|
||||
radius / D,
|
||||
mean_disp / D,
|
||||
frac_active,
|
||||
], dtype=np.float32)
|
||||
|
||||
|
||||
# ── Webots setup ─────────────────────────────────────────────────────────────
|
||||
robot = Robot()
|
||||
timestep = int(robot.getBasicTimeStep())
|
||||
|
||||
# Drive motors
|
||||
left_motor = robot.getDevice("left wheel motor")
|
||||
right_motor = robot.getDevice("right wheel motor")
|
||||
left_motor.setPosition(float("inf"))
|
||||
right_motor.setPosition(float("inf"))
|
||||
left_motor.setVelocity(0.0)
|
||||
right_motor.setVelocity(0.0)
|
||||
MOTOR_MAX = left_motor.getMaxVelocity()
|
||||
|
||||
# Sensors
|
||||
gps = robot.getDevice("gps"); gps.enable(timestep)
|
||||
compass = robot.getDevice("compass"); compass.enable(timestep)
|
||||
receiver = robot.getDevice("receiver"); receiver.enable(timestep)
|
||||
emitter = robot.getDevice("emitter")
|
||||
|
||||
# Cosmetic
|
||||
left_ear = robot.getDevice("left ear motor")
|
||||
right_ear = robot.getDevice("right ear motor")
|
||||
left_ear.setPosition(float("inf")); right_ear.setPosition(float("inf"))
|
||||
left_ear.setVelocity(0.0); right_ear.setVelocity(0.0)
|
||||
ear_phase = 0.0
|
||||
|
||||
# Number of sheep (from controllerArgs or default)
|
||||
try:
|
||||
n_sheep = int(sys.argv[1])
|
||||
except (IndexError, ValueError):
|
||||
n_sheep = 5
|
||||
|
||||
# ── Load model ───────────────────────────────────────────────────────────────
|
||||
print(f"[RL dog] Loading model from {MODEL_PATH}")
|
||||
print(f"[RL dog] Loading vecnorm from {VECNORM_PATH}")
|
||||
|
||||
dummy_env = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep)])
|
||||
vecnorm = VecNormalize.load(VECNORM_PATH, dummy_env)
|
||||
vecnorm.training = False
|
||||
vecnorm.norm_reward = False
|
||||
|
||||
model = PPO.load(MODEL_PATH)
|
||||
print(f"[RL dog] Model loaded — running with n_sheep={n_sheep}")
|
||||
|
||||
# ── Runtime state ─────────────────────────────────────────────────────────────
|
||||
sheep_positions: dict = {} # {name: (x, y)} — updated every step from receiver
|
||||
step_count = 0
|
||||
|
||||
|
||||
def bearing() -> float:
|
||||
"""Current robot heading in world frame (radians)."""
|
||||
n = compass.getValues()
|
||||
return math.atan2(n[0], n[1])
|
||||
|
||||
|
||||
def drive(action_vx: float, action_vy: float) -> None:
|
||||
"""Convert (vx, vy) policy action to differential wheel speeds."""
|
||||
speed_ms = math.sqrt(action_vx ** 2 + action_vy ** 2) * DOG_SPEED
|
||||
if speed_ms < 0.05:
|
||||
left_motor.setVelocity(0.0)
|
||||
right_motor.setVelocity(0.0)
|
||||
return
|
||||
|
||||
target_heading = math.atan2(action_vy, action_vx)
|
||||
err = norm_angle(target_heading - bearing())
|
||||
|
||||
fwd_ms = speed_ms * max(0.0, math.cos(err))
|
||||
fwd_rad = fwd_ms / WHEEL_R
|
||||
turn = K_TURN * err # rad/s correction
|
||||
|
||||
l = max(-MOTOR_MAX, min(MOTOR_MAX, fwd_rad - turn))
|
||||
r = max(-MOTOR_MAX, min(MOTOR_MAX, fwd_rad + turn))
|
||||
left_motor.setVelocity(l)
|
||||
right_motor.setVelocity(r)
|
||||
|
||||
|
||||
# ── Main loop ─────────────────────────────────────────────────────────────────
|
||||
while robot.step(timestep) != -1:
|
||||
step_count += 1
|
||||
|
||||
# 1. Drain receiver — update sheep position table
|
||||
while receiver.getQueueLength() > 0:
|
||||
try:
|
||||
msg = receiver.getString()
|
||||
parts = msg.split(":")
|
||||
if parts[0] == "sheep" and len(parts) == 4:
|
||||
sheep_positions[parts[1]] = (float(parts[2]), float(parts[3]))
|
||||
except Exception:
|
||||
pass
|
||||
receiver.nextPacket()
|
||||
|
||||
# 2. Dog GPS
|
||||
gps_vals = gps.getValues()
|
||||
dog_pos = np.array([gps_vals[0], gps_vals[1]], dtype=np.float32)
|
||||
|
||||
# 3. Build and normalise observation
|
||||
raw_obs = build_obs(dog_pos, sheep_positions, n_sheep)
|
||||
obs_norm = vecnorm.normalize_obs(raw_obs[np.newaxis]) # (1, 13)
|
||||
|
||||
# 4. Policy inference
|
||||
action, _ = model.predict(obs_norm, deterministic=True)
|
||||
vx, vy = float(action[0][0]), float(action[0][1])
|
||||
|
||||
# 5. Drive
|
||||
drive(vx, vy)
|
||||
|
||||
# 6. Broadcast dog position so sheep can compute flee forces
|
||||
emitter.send(f"dog:{dog_pos[0]:.4f}:{dog_pos[1]:.4f}")
|
||||
|
||||
# 7. Ear animation
|
||||
ear_phase += 0.12
|
||||
ep = EAR_AMPLITUDE * math.sin(ear_phase)
|
||||
left_ear.setVelocity(EAR_RATE); right_ear.setVelocity(EAR_RATE)
|
||||
left_ear.setPosition( ep); right_ear.setPosition(-ep)
|
||||
|
||||
# Periodic status
|
||||
if step_count % 100 == 0:
|
||||
n_in_pen = sum(1 for x, y in sheep_positions.values() if in_pen(x, y))
|
||||
print(f"[RL dog] step={step_count} known_sheep={len(sheep_positions)}"
|
||||
f" penned={n_in_pen}/{n_sheep}"
|
||||
f" action=({vx:.2f}, {vy:.2f})")
|
||||
Binary file not shown.
Reference in New Issue
Block a user