RL training ready to test
This commit is contained in:
@@ -0,0 +1,143 @@
|
|||||||
|
"""
|
||||||
|
Evaluation script for a trained herding policy.
|
||||||
|
|
||||||
|
Runs N episodes and reports the three project metrics:
|
||||||
|
1. Success rate — fraction of episodes where all sheep are penned
|
||||||
|
2. Time-to-pen — mean steps across successful episodes (per sheep)
|
||||||
|
3. Flock dispersion — mean pairwise distance among active sheep, averaged
|
||||||
|
over all timesteps (lower = tighter herding)
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
python evaluate.py --model runs/ppo_herding/best_model/best_model.zip \
|
||||||
|
--vecnorm runs/ppo_herding/vecnorm.pkl \
|
||||||
|
--n-sheep 5 --episodes 100
|
||||||
|
|
||||||
|
Add --render to watch the first episode in a matplotlib window.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||||
|
|
||||||
|
from herding_env import HerdingEnv
|
||||||
|
|
||||||
|
|
||||||
|
def make_single_env(n_sheep: int, max_steps: int, render_mode: str = None):
|
||||||
|
def _init():
|
||||||
|
return HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||||
|
render_mode=render_mode)
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def pairwise_mean(positions: np.ndarray, n_active: int) -> float:
|
||||||
|
"""Mean pairwise distance among the first n_active sheep."""
|
||||||
|
if n_active < 2:
|
||||||
|
return 0.0
|
||||||
|
pts = positions[:n_active]
|
||||||
|
dists = []
|
||||||
|
for i in range(n_active):
|
||||||
|
for j in range(i + 1, n_active):
|
||||||
|
dists.append(float(np.linalg.norm(pts[i] - pts[j])))
|
||||||
|
return float(np.mean(dists))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--model", required=True,
|
||||||
|
help="Path to saved model .zip")
|
||||||
|
p.add_argument("--vecnorm", default=None,
|
||||||
|
help="Path to VecNormalize stats .pkl (optional)")
|
||||||
|
p.add_argument("--n-sheep", type=int, default=1)
|
||||||
|
p.add_argument("--episodes", type=int, default=50)
|
||||||
|
p.add_argument("--max-steps", type=int, default=2000)
|
||||||
|
p.add_argument("--render", action="store_true",
|
||||||
|
help="Render first episode in matplotlib")
|
||||||
|
p.add_argument("--seed", type=int, default=42)
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
render_mode = "human" if args.render else None
|
||||||
|
raw_env = DummyVecEnv([make_single_env(args.n_sheep, args.max_steps,
|
||||||
|
render_mode)])
|
||||||
|
if args.vecnorm:
|
||||||
|
env = VecNormalize.load(args.vecnorm, raw_env)
|
||||||
|
env.training = False
|
||||||
|
env.norm_reward = False
|
||||||
|
else:
|
||||||
|
env = raw_env
|
||||||
|
|
||||||
|
model = PPO.load(args.model, env=env)
|
||||||
|
|
||||||
|
successes = []
|
||||||
|
steps_to_pen = [] # steps for successful episodes
|
||||||
|
dispersions = [] # per-episode mean flock dispersion
|
||||||
|
|
||||||
|
for ep in range(args.episodes):
|
||||||
|
obs = env.reset()
|
||||||
|
done = False
|
||||||
|
ep_steps = 0
|
||||||
|
ep_dispersion = []
|
||||||
|
first_ep = ep == 0
|
||||||
|
|
||||||
|
while not done:
|
||||||
|
action, _ = model.predict(obs, deterministic=True)
|
||||||
|
obs, _, dones, infos = env.step(action)
|
||||||
|
done = dones[0]
|
||||||
|
ep_steps += 1
|
||||||
|
|
||||||
|
# Access the underlying HerdingEnv for dispersion calculation
|
||||||
|
inner = env.envs[0] if hasattr(env, "envs") else env.venv.envs[0]
|
||||||
|
if not inner.penned[:inner.n_sheep].all():
|
||||||
|
ep_dispersion.append(
|
||||||
|
pairwise_mean(inner.sheep_pos, inner.n_sheep)
|
||||||
|
)
|
||||||
|
|
||||||
|
if first_ep and render_mode == "human":
|
||||||
|
pass # render() is called inside step()
|
||||||
|
|
||||||
|
info = infos[0]
|
||||||
|
n_penned = info.get("n_penned", 0)
|
||||||
|
n_sheep = info.get("n_sheep", args.n_sheep)
|
||||||
|
success = n_penned == n_sheep
|
||||||
|
|
||||||
|
successes.append(int(success))
|
||||||
|
if success:
|
||||||
|
steps_to_pen.append(ep_steps / n_sheep)
|
||||||
|
if ep_dispersion:
|
||||||
|
dispersions.append(float(np.mean(ep_dispersion)))
|
||||||
|
|
||||||
|
if (ep + 1) % 10 == 0:
|
||||||
|
print(f" Episode {ep + 1:>4}/{args.episodes} "
|
||||||
|
f"success={int(success)} steps={ep_steps}")
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
# Report
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
success_rate = float(np.mean(successes))
|
||||||
|
mean_ttp = float(np.mean(steps_to_pen)) if steps_to_pen else float("nan")
|
||||||
|
mean_disp = float(np.mean(dispersions)) if dispersions else float("nan")
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print(f" Model : {args.model}")
|
||||||
|
print(f" Sheep : {args.n_sheep}")
|
||||||
|
print(f" Episodes : {args.episodes}")
|
||||||
|
print("-" * 50)
|
||||||
|
print(f" Success rate : {success_rate * 100:.1f}%"
|
||||||
|
f" ({sum(successes)}/{args.episodes})")
|
||||||
|
print(f" Time-to-pen : {mean_ttp:.1f} steps/sheep"
|
||||||
|
f" (successful episodes only)")
|
||||||
|
print(f" Flock dispersion: {mean_disp:.2f} m"
|
||||||
|
f" (mean pairwise distance while active)")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,319 @@
|
|||||||
|
"""
|
||||||
|
2D herding environment for PPO training (Gymnasium-compatible).
|
||||||
|
|
||||||
|
The dog agent (action: 2D velocity vector) must herd n_sheep into the
|
||||||
|
quarantine pen. Sheep dynamics mirror the Webots controller exactly:
|
||||||
|
flee (quadratic ramp), separation (inverse-distance), cohesion, wall
|
||||||
|
avoidance, and wander.
|
||||||
|
|
||||||
|
Coordinate system matches the Webots world file:
|
||||||
|
field : x ∈ [-15, 15], y ∈ [-15, 15]
|
||||||
|
pen : x ∈ [10, 13], y ∈ [-15, -8] (SE corner, open north)
|
||||||
|
|
||||||
|
Observation is always sized for MAX_SHEEP (currently 5) regardless of
|
||||||
|
how many sheep are active. Inactive slots are pre-penned at the pen
|
||||||
|
centre with flag=1. This keeps the model input dimension fixed across
|
||||||
|
curriculum stages so VecNormalize statistics are preserved throughout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
|
||||||
|
class HerdingEnv(gym.Env):
|
||||||
|
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
# World constants — must match Webots world file
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
MAX_SHEEP = 5
|
||||||
|
FIELD = 15.0 # half-size; positions ∈ [-FIELD, FIELD]
|
||||||
|
PEN_X = (10.0, 13.0) # quarantine pen x bounds
|
||||||
|
PEN_Y = (-15.0, -8.0) # quarantine pen y bounds
|
||||||
|
PEN_CENTER = np.array([11.5, -11.5], dtype=np.float32)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
# Dynamics — calibrated to match Webots robot specs
|
||||||
|
# wheel radius 0.031 m; sheep FLEE_SPEED 20 rad/s → 0.62 m/s
|
||||||
|
# wheel radius 0.038 m; dog maxVelocity 70 rad/s → 2.66 m/s
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
DOG_SPEED = 2.5 # m/s
|
||||||
|
SHEEP_FLEE_V = 0.65 # m/s
|
||||||
|
SHEEP_WANDER_V = 0.20 # m/s
|
||||||
|
DT = 0.1 # seconds per step
|
||||||
|
|
||||||
|
# Boid parameters — identical to sheep.py
|
||||||
|
FLEE_DIST = 7.0
|
||||||
|
SEPARATION_DIST = 2.5
|
||||||
|
COHESION_DIST = 8.0
|
||||||
|
WALL_MARGIN = 3.5
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
# Reward weights
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
W_APPROACH = 0.3 # dense: dog distance to nearest active sheep
|
||||||
|
W_SHAPING = 0.5 # dense: mean sheep distance to pen (was 0.01)
|
||||||
|
W_PEN_BONUS = 5.0 # sparse: per sheep successfully penned
|
||||||
|
W_COMPLETE = 20.0 # bonus when ALL active sheep are penned
|
||||||
|
W_STEP_COST = 0.002 # penalty per step (encourages efficiency)
|
||||||
|
|
||||||
|
def __init__(self, n_sheep: int = 1, max_steps: int = 2000,
|
||||||
|
render_mode: str = None):
|
||||||
|
super().__init__()
|
||||||
|
assert 1 <= n_sheep <= self.MAX_SHEEP
|
||||||
|
self.n_sheep = n_sheep
|
||||||
|
self.max_steps = max_steps
|
||||||
|
self.render_mode = render_mode
|
||||||
|
|
||||||
|
# Observation: dog(x,y) + MAX_SHEEP×sheep(x,y) + MAX_SHEEP×penned
|
||||||
|
# Fixed size across all curriculum stages.
|
||||||
|
obs_dim = 2 + 2 * self.MAX_SHEEP + self.MAX_SHEEP
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-1.0, high=1.0, shape=(obs_dim,), dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Action: desired velocity (vx, vy) ∈ [-1, 1]², scaled by DOG_SPEED
|
||||||
|
self.action_space = spaces.Box(
|
||||||
|
low=-1.0, high=1.0, shape=(2,), dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Runtime state (populated by reset)
|
||||||
|
self._step_count = 0
|
||||||
|
self._prev_penned = 0
|
||||||
|
self.dog_pos = np.zeros(2, dtype=np.float32)
|
||||||
|
self.sheep_pos = np.zeros((self.MAX_SHEEP, 2), dtype=np.float32)
|
||||||
|
self.penned = np.ones(self.MAX_SHEEP, dtype=bool)
|
||||||
|
self.wander_ang = np.zeros(self.MAX_SHEEP, dtype=np.float32)
|
||||||
|
|
||||||
|
self._fig = None # lazy matplotlib figure
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Curriculum interface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def set_n_sheep(self, n: int):
|
||||||
|
"""Advance curriculum difficulty; takes effect on next reset()."""
|
||||||
|
assert 1 <= n <= self.MAX_SHEEP
|
||||||
|
self.n_sheep = n
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Gymnasium API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
super().reset(seed=seed)
|
||||||
|
self._step_count = 0
|
||||||
|
self._prev_penned = 0
|
||||||
|
|
||||||
|
# Dog: random start in the open field (not near the pen)
|
||||||
|
self.dog_pos = self.np_random.uniform(-8.0, 5.0, size=(2,)).astype(np.float32)
|
||||||
|
|
||||||
|
# Active sheep (0 .. n_sheep-1): random non-pen positions
|
||||||
|
self.sheep_pos[:] = self.PEN_CENTER # default all to pen centre
|
||||||
|
self.penned[:] = True
|
||||||
|
|
||||||
|
placed = 0
|
||||||
|
while placed < self.n_sheep:
|
||||||
|
p = self.np_random.uniform(-12.0, 12.0, size=(2,)).astype(np.float32)
|
||||||
|
if not self._in_pen(p):
|
||||||
|
self.sheep_pos[placed] = p
|
||||||
|
self.penned[placed] = False
|
||||||
|
placed += 1
|
||||||
|
|
||||||
|
# Inactive slots (n_sheep .. MAX_SHEEP-1): already at pen centre, penned=True
|
||||||
|
|
||||||
|
self.wander_ang = self.np_random.uniform(
|
||||||
|
-np.pi, np.pi, size=(self.MAX_SHEEP,)
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
return self._obs(), {}
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self._step_count += 1
|
||||||
|
|
||||||
|
# Move dog — clip each axis independently so the agent can idle
|
||||||
|
act = np.clip(np.asarray(action, dtype=np.float32), -1.0, 1.0)
|
||||||
|
self.dog_pos = np.clip(
|
||||||
|
self.dog_pos + act * self.DOG_SPEED * self.DT,
|
||||||
|
-self.FIELD, self.FIELD
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step sheep dynamics
|
||||||
|
for i in range(self.n_sheep):
|
||||||
|
if self.penned[i]:
|
||||||
|
continue
|
||||||
|
self.sheep_pos[i] = self._step_sheep(i)
|
||||||
|
if self._in_pen(self.sheep_pos[i]):
|
||||||
|
self.penned[i] = True
|
||||||
|
|
||||||
|
n_penned = int(self.penned[:self.n_sheep].sum())
|
||||||
|
newly_penned = n_penned - self._prev_penned
|
||||||
|
self._prev_penned = n_penned
|
||||||
|
|
||||||
|
reward = self._reward(n_penned, newly_penned)
|
||||||
|
terminated = n_penned == self.n_sheep
|
||||||
|
truncated = self._step_count >= self.max_steps
|
||||||
|
info = {"n_penned": n_penned, "n_sheep": self.n_sheep}
|
||||||
|
|
||||||
|
if self.render_mode == "human":
|
||||||
|
self.render()
|
||||||
|
|
||||||
|
return self._obs(), float(reward), terminated, truncated, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.patches as mpatches
|
||||||
|
|
||||||
|
if self._fig is None:
|
||||||
|
plt.ion()
|
||||||
|
self._fig, self._ax = plt.subplots(figsize=(6, 6))
|
||||||
|
|
||||||
|
ax = self._ax
|
||||||
|
ax.clear()
|
||||||
|
ax.set_xlim(-16, 16)
|
||||||
|
ax.set_ylim(-16, 16)
|
||||||
|
ax.set_aspect("equal")
|
||||||
|
ax.set_facecolor("#dcedc8")
|
||||||
|
|
||||||
|
# Field boundary
|
||||||
|
ax.add_patch(mpatches.Rectangle(
|
||||||
|
(-15, -15), 30, 30, fill=False, edgecolor="#795548", linewidth=2
|
||||||
|
))
|
||||||
|
# Pen
|
||||||
|
pw = self.PEN_X[1] - self.PEN_X[0]
|
||||||
|
ph = self.PEN_Y[1] - self.PEN_Y[0]
|
||||||
|
ax.add_patch(mpatches.Rectangle(
|
||||||
|
(self.PEN_X[0], self.PEN_Y[0]), pw, ph,
|
||||||
|
facecolor="#ffe082", edgecolor="#795548", linewidth=2
|
||||||
|
))
|
||||||
|
ax.text(11.5, -11.5, "pen", ha="center", va="center",
|
||||||
|
fontsize=8, color="#795548")
|
||||||
|
|
||||||
|
# Sheep
|
||||||
|
for i in range(self.MAX_SHEEP):
|
||||||
|
if i >= self.n_sheep:
|
||||||
|
continue # inactive slot — not shown
|
||||||
|
color = "deeppink" if self.penned[i] else "white"
|
||||||
|
ax.plot(*self.sheep_pos[i], "o", color=color, markersize=11,
|
||||||
|
markeredgecolor="#555", markeredgewidth=1.5)
|
||||||
|
|
||||||
|
# Dog
|
||||||
|
ax.plot(*self.dog_pos, "s", color="#4e342e", markersize=13,
|
||||||
|
markeredgecolor="black", markeredgewidth=1.5)
|
||||||
|
|
||||||
|
ax.set_title(
|
||||||
|
f"step {self._step_count} | "
|
||||||
|
f"penned {int(self.penned[:self.n_sheep].sum())}/{self.n_sheep}",
|
||||||
|
fontsize=11
|
||||||
|
)
|
||||||
|
self._fig.canvas.draw()
|
||||||
|
self._fig.canvas.flush_events()
|
||||||
|
plt.pause(0.001)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self._fig is not None:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
plt.close(self._fig)
|
||||||
|
self._fig = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internals
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _in_pen(self, pos: np.ndarray) -> bool:
|
||||||
|
return (self.PEN_X[0] < pos[0] < self.PEN_X[1] and
|
||||||
|
self.PEN_Y[0] < pos[1] < self.PEN_Y[1])
|
||||||
|
|
||||||
|
def _obs(self) -> np.ndarray:
|
||||||
|
scale = 1.0 / self.FIELD
|
||||||
|
return np.concatenate([
|
||||||
|
self.dog_pos * scale, # 2
|
||||||
|
(self.sheep_pos * scale).flatten(), # 2 * MAX_SHEEP
|
||||||
|
self.penned.astype(np.float32), # MAX_SHEEP
|
||||||
|
]).astype(np.float32)
|
||||||
|
|
||||||
|
def _reward(self, n_penned: int, newly_penned: int) -> float:
|
||||||
|
active_mask = ~self.penned[:self.n_sheep]
|
||||||
|
if active_mask.any():
|
||||||
|
active_pos = self.sheep_pos[:self.n_sheep][active_mask]
|
||||||
|
|
||||||
|
# Sheep-to-pen shaping: encourages moving sheep toward pen
|
||||||
|
dists_pen = np.linalg.norm(active_pos - self.PEN_CENTER, axis=1)
|
||||||
|
shaping = -(dists_pen.mean() / (2 * self.FIELD)) # ∈ [-1, 0]
|
||||||
|
|
||||||
|
# Dog-to-nearest-sheep approach: incentivises the dog to stay
|
||||||
|
# within flee range (FLEE_DIST=7m) rather than wandering away
|
||||||
|
dists_dog = np.linalg.norm(active_pos - self.dog_pos, axis=1)
|
||||||
|
approach = -(dists_dog.min() / (2 * self.FIELD)) # ∈ [-1, 0]
|
||||||
|
else:
|
||||||
|
shaping = approach = 0.0
|
||||||
|
|
||||||
|
reward = shaping * self.W_SHAPING
|
||||||
|
reward += approach * self.W_APPROACH
|
||||||
|
reward += newly_penned * self.W_PEN_BONUS
|
||||||
|
reward -= self.W_STEP_COST
|
||||||
|
if n_penned == self.n_sheep:
|
||||||
|
reward += self.W_COMPLETE
|
||||||
|
return reward
|
||||||
|
|
||||||
|
def _step_sheep(self, i: int) -> np.ndarray:
|
||||||
|
"""Apply one timestep of boid dynamics to sheep i."""
|
||||||
|
pos = self.sheep_pos[i].copy()
|
||||||
|
fx, fy = 0.0, 0.0
|
||||||
|
fleeing = False
|
||||||
|
|
||||||
|
# Flee from dog — quadratic ramp (mirrors sheep.py)
|
||||||
|
diff = self.dog_pos - pos
|
||||||
|
dist = float(np.linalg.norm(diff))
|
||||||
|
if 0.01 < dist < self.FLEE_DIST:
|
||||||
|
t = 1.0 - dist / self.FLEE_DIST
|
||||||
|
s = t * t * 5.0
|
||||||
|
fx -= (diff[0] / dist) * s
|
||||||
|
fy -= (diff[1] / dist) * s
|
||||||
|
fleeing = True
|
||||||
|
|
||||||
|
# Separation (inverse-distance) + Cohesion
|
||||||
|
cx, cy, cn = 0.0, 0.0, 0
|
||||||
|
for j in range(self.n_sheep):
|
||||||
|
if j == i or self.penned[j]:
|
||||||
|
continue
|
||||||
|
dv = self.sheep_pos[j] - pos
|
||||||
|
dj = float(np.linalg.norm(dv))
|
||||||
|
if 0.3 < dj < self.COHESION_DIST:
|
||||||
|
cx += self.sheep_pos[j][0]
|
||||||
|
cy += self.sheep_pos[j][1]
|
||||||
|
cn += 1
|
||||||
|
if 0.05 < dj < self.SEPARATION_DIST:
|
||||||
|
push = (self.SEPARATION_DIST - dj) / dj
|
||||||
|
fx -= (dv[0] / dj) * push * 2.5
|
||||||
|
fy -= (dv[1] / dj) * push * 2.5
|
||||||
|
if cn > 0:
|
||||||
|
w = 0.08 if fleeing else 0.15
|
||||||
|
fx += (cx / cn - pos[0]) * w
|
||||||
|
fy += (cy / cn - pos[1]) * w
|
||||||
|
|
||||||
|
# Wall avoidance
|
||||||
|
m, F = self.WALL_MARGIN, self.FIELD
|
||||||
|
if pos[0] < -F + m: fx += ((-F + m - pos[0]) / m) * 6.0
|
||||||
|
if pos[0] > F - m: fx -= ((pos[0] - (F - m)) / m) * 6.0
|
||||||
|
if pos[1] < -F + m: fy += ((-F + m - pos[1]) / m) * 6.0
|
||||||
|
if pos[1] > F - m: fy -= ((pos[1] - (F - m)) / m) * 6.0
|
||||||
|
|
||||||
|
# Wander — suppressed while fleeing
|
||||||
|
if not fleeing:
|
||||||
|
if self.np_random.random() < 0.02:
|
||||||
|
self.wander_ang[i] += float(self.np_random.uniform(-0.6, 0.6))
|
||||||
|
fx += float(np.cos(self.wander_ang[i])) * 0.5
|
||||||
|
fy += float(np.sin(self.wander_ang[i])) * 0.5
|
||||||
|
|
||||||
|
# Integrate
|
||||||
|
force = np.array([fx, fy])
|
||||||
|
mag = float(np.linalg.norm(force))
|
||||||
|
if mag > 0.01:
|
||||||
|
top_speed = self.SHEEP_FLEE_V if fleeing else self.SHEEP_WANDER_V
|
||||||
|
speed = min(top_speed, mag * 0.3)
|
||||||
|
pos = np.clip(pos + (force / mag) * speed * self.DT,
|
||||||
|
-self.FIELD, self.FIELD)
|
||||||
|
|
||||||
|
return pos.astype(np.float32)
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
gymnasium>=0.29
|
||||||
|
stable-baselines3>=2.3
|
||||||
|
torch>=2.2
|
||||||
|
numpy>=1.26
|
||||||
|
matplotlib>=3.8
|
||||||
|
tensorboard>=2.16
|
||||||
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
PPO training script for the herding task.
|
||||||
|
|
||||||
|
Usage examples
|
||||||
|
--------------
|
||||||
|
# Start fresh with curriculum (1 → 5 sheep):
|
||||||
|
python train.py --curriculum
|
||||||
|
|
||||||
|
# Resume from checkpoint, skip directly to 3 sheep:
|
||||||
|
python train.py --resume runs/ppo_herding/ckpt_200000_steps.zip --n-sheep 3
|
||||||
|
|
||||||
|
# Quick smoke-test (no curriculum, single env):
|
||||||
|
python train.py --n-envs 1 --total-steps 50000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.callbacks import (
|
||||||
|
BaseCallback,
|
||||||
|
CallbackList,
|
||||||
|
CheckpointCallback,
|
||||||
|
EvalCallback,
|
||||||
|
)
|
||||||
|
from stable_baselines3.common.vec_env import SubprocVecEnv, VecNormalize
|
||||||
|
|
||||||
|
from herding_env import HerdingEnv
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Curriculum callback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class CurriculumCallback(BaseCallback):
|
||||||
|
"""
|
||||||
|
Advances the curriculum (number of active sheep) when the rolling mean
|
||||||
|
episode success rate exceeds a threshold.
|
||||||
|
|
||||||
|
Success = episode terminated (all sheep penned) rather than truncated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
THRESHOLD = 0.75 # success rate to graduate
|
||||||
|
WINDOW = 100 # episodes to average over
|
||||||
|
MIN_EPISODES = 50 # don't graduate before seeing this many episodes
|
||||||
|
|
||||||
|
def __init__(self, start_sheep: int, max_sheep: int, verbose: int = 1):
|
||||||
|
super().__init__(verbose)
|
||||||
|
self.max_sheep = max_sheep
|
||||||
|
self._successes = []
|
||||||
|
self._cur_sheep = start_sheep
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
for info, done in zip(self.locals["infos"], self.locals["dones"]):
|
||||||
|
if done:
|
||||||
|
truncated = info.get("TimeLimit.truncated", False)
|
||||||
|
self._successes.append(0 if truncated else 1)
|
||||||
|
if len(self._successes) > self.WINDOW:
|
||||||
|
self._successes.pop(0)
|
||||||
|
|
||||||
|
if (self._cur_sheep < self.max_sheep
|
||||||
|
and len(self._successes) >= self.MIN_EPISODES
|
||||||
|
and np.mean(self._successes) >= self.THRESHOLD):
|
||||||
|
self._cur_sheep += 1
|
||||||
|
self.training_env.env_method("set_n_sheep", self._cur_sheep)
|
||||||
|
self._successes.clear()
|
||||||
|
if self.verbose:
|
||||||
|
print(f"\n[Curriculum] Advanced to {self._cur_sheep} sheep "
|
||||||
|
f"at step {self.num_timesteps}\n")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Environment factory
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def make_env(n_sheep: int, seed: int, max_steps: int):
|
||||||
|
def _init():
|
||||||
|
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
|
||||||
|
env.reset(seed=seed)
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--n-sheep", type=int, default=1,
|
||||||
|
help="Starting number of sheep (or fixed count if no curriculum)")
|
||||||
|
p.add_argument("--max-sheep", type=int, default=5,
|
||||||
|
help="Maximum sheep for curriculum (ignored without --curriculum)")
|
||||||
|
p.add_argument("--n-envs", type=int, default=8,
|
||||||
|
help="Number of parallel environments")
|
||||||
|
p.add_argument("--total-steps", type=int, default=5_000_000,
|
||||||
|
help="Total environment steps to train for")
|
||||||
|
p.add_argument("--max-steps", type=int, default=2000,
|
||||||
|
help="Episode step limit inside each env")
|
||||||
|
p.add_argument("--curriculum", action="store_true",
|
||||||
|
help="Enable automatic curriculum advancement")
|
||||||
|
p.add_argument("--resume", type=str, default=None,
|
||||||
|
help="Path to a .zip checkpoint to resume training from")
|
||||||
|
p.add_argument("--run-dir", type=str, default="runs/ppo_herding",
|
||||||
|
help="Output directory for checkpoints and logs")
|
||||||
|
p.add_argument("--save-freq", type=int, default=100_000,
|
||||||
|
help="Checkpoint every N steps (per-env, not total)")
|
||||||
|
p.add_argument("--eval-freq", type=int, default=50_000,
|
||||||
|
help="Evaluate every N steps")
|
||||||
|
p.add_argument("--eval-eps", type=int, default=20,
|
||||||
|
help="Episodes per evaluation run")
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
os.makedirs(args.run_dir, exist_ok=True)
|
||||||
|
ckpt_dir = os.path.join(args.run_dir, "checkpoints")
|
||||||
|
best_dir = os.path.join(args.run_dir, "best_model")
|
||||||
|
norm_path = os.path.join(args.run_dir, "vecnorm.pkl")
|
||||||
|
os.makedirs(ckpt_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Training envs
|
||||||
|
train_env = SubprocVecEnv([
|
||||||
|
make_env(args.n_sheep, seed=i, max_steps=args.max_steps)
|
||||||
|
for i in range(args.n_envs)
|
||||||
|
])
|
||||||
|
if args.resume and os.path.exists(norm_path):
|
||||||
|
train_env = VecNormalize.load(norm_path, train_env)
|
||||||
|
train_env.training = True
|
||||||
|
train_env.norm_reward = True
|
||||||
|
else:
|
||||||
|
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True,
|
||||||
|
clip_obs=10.0)
|
||||||
|
|
||||||
|
# Eval env (no reward normalisation, deterministic)
|
||||||
|
eval_env = SubprocVecEnv([
|
||||||
|
make_env(args.n_sheep, seed=1000 + i, max_steps=args.max_steps)
|
||||||
|
for i in range(2)
|
||||||
|
])
|
||||||
|
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False,
|
||||||
|
clip_obs=10.0, training=False)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
checkpoint_cb = CheckpointCallback(
|
||||||
|
save_freq=max(args.save_freq // args.n_envs, 1),
|
||||||
|
save_path=ckpt_dir,
|
||||||
|
name_prefix="ckpt",
|
||||||
|
save_vecnormalize=True,
|
||||||
|
)
|
||||||
|
eval_cb = EvalCallback(
|
||||||
|
eval_env,
|
||||||
|
best_model_save_path=best_dir,
|
||||||
|
log_path=args.run_dir,
|
||||||
|
eval_freq=max(args.eval_freq // args.n_envs, 1),
|
||||||
|
n_eval_episodes=args.eval_eps,
|
||||||
|
deterministic=True,
|
||||||
|
verbose=1,
|
||||||
|
)
|
||||||
|
callbacks = [checkpoint_cb, eval_cb]
|
||||||
|
if args.curriculum:
|
||||||
|
callbacks.append(CurriculumCallback(start_sheep=args.n_sheep,
|
||||||
|
max_sheep=args.max_sheep))
|
||||||
|
callback_list = CallbackList(callbacks)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
ppo_kwargs = dict(
|
||||||
|
policy = "MlpPolicy",
|
||||||
|
env = train_env,
|
||||||
|
learning_rate = 3e-4,
|
||||||
|
n_steps = 2048,
|
||||||
|
batch_size = 256,
|
||||||
|
n_epochs = 10,
|
||||||
|
gamma = 0.995,
|
||||||
|
gae_lambda = 0.95,
|
||||||
|
clip_range = 0.2,
|
||||||
|
ent_coef = 0.005,
|
||||||
|
vf_coef = 0.5,
|
||||||
|
max_grad_norm = 0.5,
|
||||||
|
policy_kwargs = dict(net_arch=[256, 256]),
|
||||||
|
tensorboard_log = args.run_dir,
|
||||||
|
verbose = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.resume:
|
||||||
|
print(f"Resuming from {args.resume}")
|
||||||
|
model = PPO.load(args.resume, env=train_env, **{
|
||||||
|
k: v for k, v in ppo_kwargs.items()
|
||||||
|
if k not in ("policy", "env")
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
model = PPO(**ppo_kwargs)
|
||||||
|
|
||||||
|
model.learn(
|
||||||
|
total_timesteps=args.total_steps,
|
||||||
|
callback=callback_list,
|
||||||
|
reset_num_timesteps=args.resume is None,
|
||||||
|
tb_log_name="ppo",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save final artefacts
|
||||||
|
model.save(os.path.join(args.run_dir, "final_model"))
|
||||||
|
train_env.save(norm_path)
|
||||||
|
print(f"\nTraining complete. Artefacts saved to {args.run_dir}/")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
+3
-2
@@ -364,6 +364,7 @@ Solid {
|
|||||||
# ==================== SCARECROW (east side, outside fence) ====================
|
# ==================== SCARECROW (east side, outside fence) ====================
|
||||||
Solid {
|
Solid {
|
||||||
translation 20 -10 0
|
translation 20 -10 0
|
||||||
|
rotation 0 0 1 2.61799
|
||||||
children [
|
children [
|
||||||
Transform { translation 0 0 1.22 children [ Shape { appearance USE TRUNK geometry Cylinder { height 2.44 radius 0.045 subdivision 8 } } ] }
|
Transform { translation 0 0 1.22 children [ Shape { appearance USE TRUNK geometry Cylinder { height 2.44 radius 0.045 subdivision 8 } } ] }
|
||||||
Transform { translation 0 0 2.02 rotation 1 0 0 1.5708 children [ Shape { appearance USE TRUNK geometry Cylinder { height 1.60 radius 0.032 subdivision 8 } } ] }
|
Transform { translation 0 0 2.02 rotation 1 0 0 1.5708 children [ Shape { appearance USE TRUNK geometry Cylinder { height 1.60 radius 0.032 subdivision 8 } } ] }
|
||||||
@@ -391,12 +392,12 @@ Solid {
|
|||||||
|
|
||||||
# ==================== HAY BALES (near barn) ====================
|
# ==================== HAY BALES (near barn) ====================
|
||||||
Solid { translation 25.75 13.76 0.62 children [ Transform { rotation 1 0 0 1.5708 children [ Shape { appearance USE HAY geometry Cylinder { height 1.30 radius 0.62 subdivision 14 } } ] } ] boundingObject Box { size 1.30 1.24 1.24 } }
|
Solid { translation 25.75 13.76 0.62 children [ Transform { rotation 1 0 0 1.5708 children [ Shape { appearance USE HAY geometry Cylinder { height 1.30 radius 0.62 subdivision 14 } } ] } ] boundingObject Box { size 1.30 1.24 1.24 } }
|
||||||
Solid { translation 24.34 12.32 0.62 children [ Transform { rotation 1 0 0 1.5708 children [ Shape { appearance USE HAY geometry Cylinder { height 1.30 radius 0.62 subdivision 14 } } ] } ] boundingObject Box { size 1.30 1.24 1.24 } }
|
Solid { translation 24.34 12.32 0.62 rotation -1 0 0 1.5708 children [ Transform { rotation 1 0 0 1.5708 children [ Shape { appearance USE HAY geometry Cylinder { height 1.30 radius 0.62 subdivision 14 } } ] } ] boundingObject Box { size 1.30 1.24 1.24 } }
|
||||||
Solid { translation 24.28 13.79 0.62 children [ Transform { rotation 1 0 0 1.5708 children [ Shape { appearance USE HAY geometry Cylinder { height 1.30 radius 0.62 subdivision 14 } } ] } ] boundingObject Box { size 1.30 1.24 1.24 } }
|
Solid { translation 24.28 13.79 0.62 children [ Transform { rotation 1 0 0 1.5708 children [ Shape { appearance USE HAY geometry Cylinder { height 1.30 radius 0.62 subdivision 14 } } ] } ] boundingObject Box { size 1.30 1.24 1.24 } }
|
||||||
|
|
||||||
# ==================== TRACTOR (near barn) ====================
|
# ==================== TRACTOR (near barn) ====================
|
||||||
Solid {
|
Solid {
|
||||||
translation 17 19 0
|
translation 17 19 0.18
|
||||||
rotation 0 0 1 1.9
|
rotation 0 0 1 1.9
|
||||||
children [
|
children [
|
||||||
# Chassis
|
# Chassis
|
||||||
|
|||||||
Reference in New Issue
Block a user