RL training ready to test
This commit is contained in:
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Evaluation script for a trained herding policy.
|
||||
|
||||
Runs N episodes and reports the three project metrics:
|
||||
1. Success rate — fraction of episodes where all sheep are penned
|
||||
2. Time-to-pen — mean steps across successful episodes (per sheep)
|
||||
3. Flock dispersion — mean pairwise distance among active sheep, averaged
|
||||
over all timesteps (lower = tighter herding)
|
||||
|
||||
Usage
|
||||
-----
|
||||
python evaluate.py --model runs/ppo_herding/best_model/best_model.zip \
|
||||
--vecnorm runs/ppo_herding/vecnorm.pkl \
|
||||
--n-sheep 5 --episodes 100
|
||||
|
||||
Add --render to watch the first episode in a matplotlib window.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
|
||||
|
||||
def make_single_env(n_sheep: int, max_steps: int, render_mode: str = None):
|
||||
def _init():
|
||||
return HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||
render_mode=render_mode)
|
||||
return _init
|
||||
|
||||
|
||||
def pairwise_mean(positions: np.ndarray, n_active: int) -> float:
|
||||
"""Mean pairwise distance among the first n_active sheep."""
|
||||
if n_active < 2:
|
||||
return 0.0
|
||||
pts = positions[:n_active]
|
||||
dists = []
|
||||
for i in range(n_active):
|
||||
for j in range(i + 1, n_active):
|
||||
dists.append(float(np.linalg.norm(pts[i] - pts[j])))
|
||||
return float(np.mean(dists))
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model", required=True,
|
||||
help="Path to saved model .zip")
|
||||
p.add_argument("--vecnorm", default=None,
|
||||
help="Path to VecNormalize stats .pkl (optional)")
|
||||
p.add_argument("--n-sheep", type=int, default=1)
|
||||
p.add_argument("--episodes", type=int, default=50)
|
||||
p.add_argument("--max-steps", type=int, default=2000)
|
||||
p.add_argument("--render", action="store_true",
|
||||
help="Render first episode in matplotlib")
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
render_mode = "human" if args.render else None
|
||||
raw_env = DummyVecEnv([make_single_env(args.n_sheep, args.max_steps,
|
||||
render_mode)])
|
||||
if args.vecnorm:
|
||||
env = VecNormalize.load(args.vecnorm, raw_env)
|
||||
env.training = False
|
||||
env.norm_reward = False
|
||||
else:
|
||||
env = raw_env
|
||||
|
||||
model = PPO.load(args.model, env=env)
|
||||
|
||||
successes = []
|
||||
steps_to_pen = [] # steps for successful episodes
|
||||
dispersions = [] # per-episode mean flock dispersion
|
||||
|
||||
for ep in range(args.episodes):
|
||||
obs = env.reset()
|
||||
done = False
|
||||
ep_steps = 0
|
||||
ep_dispersion = []
|
||||
first_ep = ep == 0
|
||||
|
||||
while not done:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, _, dones, infos = env.step(action)
|
||||
done = dones[0]
|
||||
ep_steps += 1
|
||||
|
||||
# Access the underlying HerdingEnv for dispersion calculation
|
||||
inner = env.envs[0] if hasattr(env, "envs") else env.venv.envs[0]
|
||||
if not inner.penned[:inner.n_sheep].all():
|
||||
ep_dispersion.append(
|
||||
pairwise_mean(inner.sheep_pos, inner.n_sheep)
|
||||
)
|
||||
|
||||
if first_ep and render_mode == "human":
|
||||
pass # render() is called inside step()
|
||||
|
||||
info = infos[0]
|
||||
n_penned = info.get("n_penned", 0)
|
||||
n_sheep = info.get("n_sheep", args.n_sheep)
|
||||
success = n_penned == n_sheep
|
||||
|
||||
successes.append(int(success))
|
||||
if success:
|
||||
steps_to_pen.append(ep_steps / n_sheep)
|
||||
if ep_dispersion:
|
||||
dispersions.append(float(np.mean(ep_dispersion)))
|
||||
|
||||
if (ep + 1) % 10 == 0:
|
||||
print(f" Episode {ep + 1:>4}/{args.episodes} "
|
||||
f"success={int(success)} steps={ep_steps}")
|
||||
|
||||
env.close()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Report
|
||||
# -----------------------------------------------------------------------
|
||||
success_rate = float(np.mean(successes))
|
||||
mean_ttp = float(np.mean(steps_to_pen)) if steps_to_pen else float("nan")
|
||||
mean_disp = float(np.mean(dispersions)) if dispersions else float("nan")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print(f" Model : {args.model}")
|
||||
print(f" Sheep : {args.n_sheep}")
|
||||
print(f" Episodes : {args.episodes}")
|
||||
print("-" * 50)
|
||||
print(f" Success rate : {success_rate * 100:.1f}%"
|
||||
f" ({sum(successes)}/{args.episodes})")
|
||||
print(f" Time-to-pen : {mean_ttp:.1f} steps/sheep"
|
||||
f" (successful episodes only)")
|
||||
print(f" Flock dispersion: {mean_disp:.2f} m"
|
||||
f" (mean pairwise distance while active)")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
2D herding environment for PPO training (Gymnasium-compatible).
|
||||
|
||||
The dog agent (action: 2D velocity vector) must herd n_sheep into the
|
||||
quarantine pen. Sheep dynamics mirror the Webots controller exactly:
|
||||
flee (quadratic ramp), separation (inverse-distance), cohesion, wall
|
||||
avoidance, and wander.
|
||||
|
||||
Coordinate system matches the Webots world file:
|
||||
field : x ∈ [-15, 15], y ∈ [-15, 15]
|
||||
pen : x ∈ [10, 13], y ∈ [-15, -8] (SE corner, open north)
|
||||
|
||||
Observation is always sized for MAX_SHEEP (currently 5) regardless of
|
||||
how many sheep are active. Inactive slots are pre-penned at the pen
|
||||
centre with flag=1. This keeps the model input dimension fixed across
|
||||
curriculum stages so VecNormalize statistics are preserved throughout.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
class HerdingEnv(gym.Env):
|
||||
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# World constants — must match Webots world file
|
||||
# -----------------------------------------------------------------------
|
||||
MAX_SHEEP = 5
|
||||
FIELD = 15.0 # half-size; positions ∈ [-FIELD, FIELD]
|
||||
PEN_X = (10.0, 13.0) # quarantine pen x bounds
|
||||
PEN_Y = (-15.0, -8.0) # quarantine pen y bounds
|
||||
PEN_CENTER = np.array([11.5, -11.5], dtype=np.float32)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Dynamics — calibrated to match Webots robot specs
|
||||
# wheel radius 0.031 m; sheep FLEE_SPEED 20 rad/s → 0.62 m/s
|
||||
# wheel radius 0.038 m; dog maxVelocity 70 rad/s → 2.66 m/s
|
||||
# -----------------------------------------------------------------------
|
||||
DOG_SPEED = 2.5 # m/s
|
||||
SHEEP_FLEE_V = 0.65 # m/s
|
||||
SHEEP_WANDER_V = 0.20 # m/s
|
||||
DT = 0.1 # seconds per step
|
||||
|
||||
# Boid parameters — identical to sheep.py
|
||||
FLEE_DIST = 7.0
|
||||
SEPARATION_DIST = 2.5
|
||||
COHESION_DIST = 8.0
|
||||
WALL_MARGIN = 3.5
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Reward weights
|
||||
# -----------------------------------------------------------------------
|
||||
W_APPROACH = 0.3 # dense: dog distance to nearest active sheep
|
||||
W_SHAPING = 0.5 # dense: mean sheep distance to pen (was 0.01)
|
||||
W_PEN_BONUS = 5.0 # sparse: per sheep successfully penned
|
||||
W_COMPLETE = 20.0 # bonus when ALL active sheep are penned
|
||||
W_STEP_COST = 0.002 # penalty per step (encourages efficiency)
|
||||
|
||||
def __init__(self, n_sheep: int = 1, max_steps: int = 2000,
|
||||
render_mode: str = None):
|
||||
super().__init__()
|
||||
assert 1 <= n_sheep <= self.MAX_SHEEP
|
||||
self.n_sheep = n_sheep
|
||||
self.max_steps = max_steps
|
||||
self.render_mode = render_mode
|
||||
|
||||
# Observation: dog(x,y) + MAX_SHEEP×sheep(x,y) + MAX_SHEEP×penned
|
||||
# Fixed size across all curriculum stages.
|
||||
obs_dim = 2 + 2 * self.MAX_SHEEP + self.MAX_SHEEP
|
||||
self.observation_space = spaces.Box(
|
||||
low=-1.0, high=1.0, shape=(obs_dim,), dtype=np.float32
|
||||
)
|
||||
|
||||
# Action: desired velocity (vx, vy) ∈ [-1, 1]², scaled by DOG_SPEED
|
||||
self.action_space = spaces.Box(
|
||||
low=-1.0, high=1.0, shape=(2,), dtype=np.float32
|
||||
)
|
||||
|
||||
# Runtime state (populated by reset)
|
||||
self._step_count = 0
|
||||
self._prev_penned = 0
|
||||
self.dog_pos = np.zeros(2, dtype=np.float32)
|
||||
self.sheep_pos = np.zeros((self.MAX_SHEEP, 2), dtype=np.float32)
|
||||
self.penned = np.ones(self.MAX_SHEEP, dtype=bool)
|
||||
self.wander_ang = np.zeros(self.MAX_SHEEP, dtype=np.float32)
|
||||
|
||||
self._fig = None # lazy matplotlib figure
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Curriculum interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_n_sheep(self, n: int):
|
||||
"""Advance curriculum difficulty; takes effect on next reset()."""
|
||||
assert 1 <= n <= self.MAX_SHEEP
|
||||
self.n_sheep = n
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Gymnasium API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
self._step_count = 0
|
||||
self._prev_penned = 0
|
||||
|
||||
# Dog: random start in the open field (not near the pen)
|
||||
self.dog_pos = self.np_random.uniform(-8.0, 5.0, size=(2,)).astype(np.float32)
|
||||
|
||||
# Active sheep (0 .. n_sheep-1): random non-pen positions
|
||||
self.sheep_pos[:] = self.PEN_CENTER # default all to pen centre
|
||||
self.penned[:] = True
|
||||
|
||||
placed = 0
|
||||
while placed < self.n_sheep:
|
||||
p = self.np_random.uniform(-12.0, 12.0, size=(2,)).astype(np.float32)
|
||||
if not self._in_pen(p):
|
||||
self.sheep_pos[placed] = p
|
||||
self.penned[placed] = False
|
||||
placed += 1
|
||||
|
||||
# Inactive slots (n_sheep .. MAX_SHEEP-1): already at pen centre, penned=True
|
||||
|
||||
self.wander_ang = self.np_random.uniform(
|
||||
-np.pi, np.pi, size=(self.MAX_SHEEP,)
|
||||
).astype(np.float32)
|
||||
|
||||
return self._obs(), {}
|
||||
|
||||
def step(self, action):
|
||||
self._step_count += 1
|
||||
|
||||
# Move dog — clip each axis independently so the agent can idle
|
||||
act = np.clip(np.asarray(action, dtype=np.float32), -1.0, 1.0)
|
||||
self.dog_pos = np.clip(
|
||||
self.dog_pos + act * self.DOG_SPEED * self.DT,
|
||||
-self.FIELD, self.FIELD
|
||||
)
|
||||
|
||||
# Step sheep dynamics
|
||||
for i in range(self.n_sheep):
|
||||
if self.penned[i]:
|
||||
continue
|
||||
self.sheep_pos[i] = self._step_sheep(i)
|
||||
if self._in_pen(self.sheep_pos[i]):
|
||||
self.penned[i] = True
|
||||
|
||||
n_penned = int(self.penned[:self.n_sheep].sum())
|
||||
newly_penned = n_penned - self._prev_penned
|
||||
self._prev_penned = n_penned
|
||||
|
||||
reward = self._reward(n_penned, newly_penned)
|
||||
terminated = n_penned == self.n_sheep
|
||||
truncated = self._step_count >= self.max_steps
|
||||
info = {"n_penned": n_penned, "n_sheep": self.n_sheep}
|
||||
|
||||
if self.render_mode == "human":
|
||||
self.render()
|
||||
|
||||
return self._obs(), float(reward), terminated, truncated, info
|
||||
|
||||
def render(self):
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
if self._fig is None:
|
||||
plt.ion()
|
||||
self._fig, self._ax = plt.subplots(figsize=(6, 6))
|
||||
|
||||
ax = self._ax
|
||||
ax.clear()
|
||||
ax.set_xlim(-16, 16)
|
||||
ax.set_ylim(-16, 16)
|
||||
ax.set_aspect("equal")
|
||||
ax.set_facecolor("#dcedc8")
|
||||
|
||||
# Field boundary
|
||||
ax.add_patch(mpatches.Rectangle(
|
||||
(-15, -15), 30, 30, fill=False, edgecolor="#795548", linewidth=2
|
||||
))
|
||||
# Pen
|
||||
pw = self.PEN_X[1] - self.PEN_X[0]
|
||||
ph = self.PEN_Y[1] - self.PEN_Y[0]
|
||||
ax.add_patch(mpatches.Rectangle(
|
||||
(self.PEN_X[0], self.PEN_Y[0]), pw, ph,
|
||||
facecolor="#ffe082", edgecolor="#795548", linewidth=2
|
||||
))
|
||||
ax.text(11.5, -11.5, "pen", ha="center", va="center",
|
||||
fontsize=8, color="#795548")
|
||||
|
||||
# Sheep
|
||||
for i in range(self.MAX_SHEEP):
|
||||
if i >= self.n_sheep:
|
||||
continue # inactive slot — not shown
|
||||
color = "deeppink" if self.penned[i] else "white"
|
||||
ax.plot(*self.sheep_pos[i], "o", color=color, markersize=11,
|
||||
markeredgecolor="#555", markeredgewidth=1.5)
|
||||
|
||||
# Dog
|
||||
ax.plot(*self.dog_pos, "s", color="#4e342e", markersize=13,
|
||||
markeredgecolor="black", markeredgewidth=1.5)
|
||||
|
||||
ax.set_title(
|
||||
f"step {self._step_count} | "
|
||||
f"penned {int(self.penned[:self.n_sheep].sum())}/{self.n_sheep}",
|
||||
fontsize=11
|
||||
)
|
||||
self._fig.canvas.draw()
|
||||
self._fig.canvas.flush_events()
|
||||
plt.pause(0.001)
|
||||
|
||||
def close(self):
|
||||
if self._fig is not None:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.close(self._fig)
|
||||
self._fig = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _in_pen(self, pos: np.ndarray) -> bool:
|
||||
return (self.PEN_X[0] < pos[0] < self.PEN_X[1] and
|
||||
self.PEN_Y[0] < pos[1] < self.PEN_Y[1])
|
||||
|
||||
def _obs(self) -> np.ndarray:
|
||||
scale = 1.0 / self.FIELD
|
||||
return np.concatenate([
|
||||
self.dog_pos * scale, # 2
|
||||
(self.sheep_pos * scale).flatten(), # 2 * MAX_SHEEP
|
||||
self.penned.astype(np.float32), # MAX_SHEEP
|
||||
]).astype(np.float32)
|
||||
|
||||
def _reward(self, n_penned: int, newly_penned: int) -> float:
|
||||
active_mask = ~self.penned[:self.n_sheep]
|
||||
if active_mask.any():
|
||||
active_pos = self.sheep_pos[:self.n_sheep][active_mask]
|
||||
|
||||
# Sheep-to-pen shaping: encourages moving sheep toward pen
|
||||
dists_pen = np.linalg.norm(active_pos - self.PEN_CENTER, axis=1)
|
||||
shaping = -(dists_pen.mean() / (2 * self.FIELD)) # ∈ [-1, 0]
|
||||
|
||||
# Dog-to-nearest-sheep approach: incentivises the dog to stay
|
||||
# within flee range (FLEE_DIST=7m) rather than wandering away
|
||||
dists_dog = np.linalg.norm(active_pos - self.dog_pos, axis=1)
|
||||
approach = -(dists_dog.min() / (2 * self.FIELD)) # ∈ [-1, 0]
|
||||
else:
|
||||
shaping = approach = 0.0
|
||||
|
||||
reward = shaping * self.W_SHAPING
|
||||
reward += approach * self.W_APPROACH
|
||||
reward += newly_penned * self.W_PEN_BONUS
|
||||
reward -= self.W_STEP_COST
|
||||
if n_penned == self.n_sheep:
|
||||
reward += self.W_COMPLETE
|
||||
return reward
|
||||
|
||||
def _step_sheep(self, i: int) -> np.ndarray:
|
||||
"""Apply one timestep of boid dynamics to sheep i."""
|
||||
pos = self.sheep_pos[i].copy()
|
||||
fx, fy = 0.0, 0.0
|
||||
fleeing = False
|
||||
|
||||
# Flee from dog — quadratic ramp (mirrors sheep.py)
|
||||
diff = self.dog_pos - pos
|
||||
dist = float(np.linalg.norm(diff))
|
||||
if 0.01 < dist < self.FLEE_DIST:
|
||||
t = 1.0 - dist / self.FLEE_DIST
|
||||
s = t * t * 5.0
|
||||
fx -= (diff[0] / dist) * s
|
||||
fy -= (diff[1] / dist) * s
|
||||
fleeing = True
|
||||
|
||||
# Separation (inverse-distance) + Cohesion
|
||||
cx, cy, cn = 0.0, 0.0, 0
|
||||
for j in range(self.n_sheep):
|
||||
if j == i or self.penned[j]:
|
||||
continue
|
||||
dv = self.sheep_pos[j] - pos
|
||||
dj = float(np.linalg.norm(dv))
|
||||
if 0.3 < dj < self.COHESION_DIST:
|
||||
cx += self.sheep_pos[j][0]
|
||||
cy += self.sheep_pos[j][1]
|
||||
cn += 1
|
||||
if 0.05 < dj < self.SEPARATION_DIST:
|
||||
push = (self.SEPARATION_DIST - dj) / dj
|
||||
fx -= (dv[0] / dj) * push * 2.5
|
||||
fy -= (dv[1] / dj) * push * 2.5
|
||||
if cn > 0:
|
||||
w = 0.08 if fleeing else 0.15
|
||||
fx += (cx / cn - pos[0]) * w
|
||||
fy += (cy / cn - pos[1]) * w
|
||||
|
||||
# Wall avoidance
|
||||
m, F = self.WALL_MARGIN, self.FIELD
|
||||
if pos[0] < -F + m: fx += ((-F + m - pos[0]) / m) * 6.0
|
||||
if pos[0] > F - m: fx -= ((pos[0] - (F - m)) / m) * 6.0
|
||||
if pos[1] < -F + m: fy += ((-F + m - pos[1]) / m) * 6.0
|
||||
if pos[1] > F - m: fy -= ((pos[1] - (F - m)) / m) * 6.0
|
||||
|
||||
# Wander — suppressed while fleeing
|
||||
if not fleeing:
|
||||
if self.np_random.random() < 0.02:
|
||||
self.wander_ang[i] += float(self.np_random.uniform(-0.6, 0.6))
|
||||
fx += float(np.cos(self.wander_ang[i])) * 0.5
|
||||
fy += float(np.sin(self.wander_ang[i])) * 0.5
|
||||
|
||||
# Integrate
|
||||
force = np.array([fx, fy])
|
||||
mag = float(np.linalg.norm(force))
|
||||
if mag > 0.01:
|
||||
top_speed = self.SHEEP_FLEE_V if fleeing else self.SHEEP_WANDER_V
|
||||
speed = min(top_speed, mag * 0.3)
|
||||
pos = np.clip(pos + (force / mag) * speed * self.DT,
|
||||
-self.FIELD, self.FIELD)
|
||||
|
||||
return pos.astype(np.float32)
|
||||
@@ -0,0 +1,6 @@
|
||||
gymnasium>=0.29
|
||||
stable-baselines3>=2.3
|
||||
torch>=2.2
|
||||
numpy>=1.26
|
||||
matplotlib>=3.8
|
||||
tensorboard>=2.16
|
||||
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
PPO training script for the herding task.
|
||||
|
||||
Usage examples
|
||||
--------------
|
||||
# Start fresh with curriculum (1 → 5 sheep):
|
||||
python train.py --curriculum
|
||||
|
||||
# Resume from checkpoint, skip directly to 3 sheep:
|
||||
python train.py --resume runs/ppo_herding/ckpt_200000_steps.zip --n-sheep 3
|
||||
|
||||
# Quick smoke-test (no curriculum, single env):
|
||||
python train.py --n-envs 1 --total-steps 50000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import (
|
||||
BaseCallback,
|
||||
CallbackList,
|
||||
CheckpointCallback,
|
||||
EvalCallback,
|
||||
)
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv, VecNormalize
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Curriculum callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CurriculumCallback(BaseCallback):
|
||||
"""
|
||||
Advances the curriculum (number of active sheep) when the rolling mean
|
||||
episode success rate exceeds a threshold.
|
||||
|
||||
Success = episode terminated (all sheep penned) rather than truncated.
|
||||
"""
|
||||
|
||||
THRESHOLD = 0.75 # success rate to graduate
|
||||
WINDOW = 100 # episodes to average over
|
||||
MIN_EPISODES = 50 # don't graduate before seeing this many episodes
|
||||
|
||||
def __init__(self, start_sheep: int, max_sheep: int, verbose: int = 1):
|
||||
super().__init__(verbose)
|
||||
self.max_sheep = max_sheep
|
||||
self._successes = []
|
||||
self._cur_sheep = start_sheep
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
for info, done in zip(self.locals["infos"], self.locals["dones"]):
|
||||
if done:
|
||||
truncated = info.get("TimeLimit.truncated", False)
|
||||
self._successes.append(0 if truncated else 1)
|
||||
if len(self._successes) > self.WINDOW:
|
||||
self._successes.pop(0)
|
||||
|
||||
if (self._cur_sheep < self.max_sheep
|
||||
and len(self._successes) >= self.MIN_EPISODES
|
||||
and np.mean(self._successes) >= self.THRESHOLD):
|
||||
self._cur_sheep += 1
|
||||
self.training_env.env_method("set_n_sheep", self._cur_sheep)
|
||||
self._successes.clear()
|
||||
if self.verbose:
|
||||
print(f"\n[Curriculum] Advanced to {self._cur_sheep} sheep "
|
||||
f"at step {self.num_timesteps}\n")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_env(n_sheep: int, seed: int, max_steps: int):
|
||||
def _init():
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--n-sheep", type=int, default=1,
|
||||
help="Starting number of sheep (or fixed count if no curriculum)")
|
||||
p.add_argument("--max-sheep", type=int, default=5,
|
||||
help="Maximum sheep for curriculum (ignored without --curriculum)")
|
||||
p.add_argument("--n-envs", type=int, default=8,
|
||||
help="Number of parallel environments")
|
||||
p.add_argument("--total-steps", type=int, default=5_000_000,
|
||||
help="Total environment steps to train for")
|
||||
p.add_argument("--max-steps", type=int, default=2000,
|
||||
help="Episode step limit inside each env")
|
||||
p.add_argument("--curriculum", action="store_true",
|
||||
help="Enable automatic curriculum advancement")
|
||||
p.add_argument("--resume", type=str, default=None,
|
||||
help="Path to a .zip checkpoint to resume training from")
|
||||
p.add_argument("--run-dir", type=str, default="runs/ppo_herding",
|
||||
help="Output directory for checkpoints and logs")
|
||||
p.add_argument("--save-freq", type=int, default=100_000,
|
||||
help="Checkpoint every N steps (per-env, not total)")
|
||||
p.add_argument("--eval-freq", type=int, default=50_000,
|
||||
help="Evaluate every N steps")
|
||||
p.add_argument("--eval-eps", type=int, default=20,
|
||||
help="Episodes per evaluation run")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
os.makedirs(args.run_dir, exist_ok=True)
|
||||
ckpt_dir = os.path.join(args.run_dir, "checkpoints")
|
||||
best_dir = os.path.join(args.run_dir, "best_model")
|
||||
norm_path = os.path.join(args.run_dir, "vecnorm.pkl")
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
|
||||
# Training envs
|
||||
train_env = SubprocVecEnv([
|
||||
make_env(args.n_sheep, seed=i, max_steps=args.max_steps)
|
||||
for i in range(args.n_envs)
|
||||
])
|
||||
if args.resume and os.path.exists(norm_path):
|
||||
train_env = VecNormalize.load(norm_path, train_env)
|
||||
train_env.training = True
|
||||
train_env.norm_reward = True
|
||||
else:
|
||||
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True,
|
||||
clip_obs=10.0)
|
||||
|
||||
# Eval env (no reward normalisation, deterministic)
|
||||
eval_env = SubprocVecEnv([
|
||||
make_env(args.n_sheep, seed=1000 + i, max_steps=args.max_steps)
|
||||
for i in range(2)
|
||||
])
|
||||
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False,
|
||||
clip_obs=10.0, training=False)
|
||||
|
||||
# Callbacks
|
||||
checkpoint_cb = CheckpointCallback(
|
||||
save_freq=max(args.save_freq // args.n_envs, 1),
|
||||
save_path=ckpt_dir,
|
||||
name_prefix="ckpt",
|
||||
save_vecnormalize=True,
|
||||
)
|
||||
eval_cb = EvalCallback(
|
||||
eval_env,
|
||||
best_model_save_path=best_dir,
|
||||
log_path=args.run_dir,
|
||||
eval_freq=max(args.eval_freq // args.n_envs, 1),
|
||||
n_eval_episodes=args.eval_eps,
|
||||
deterministic=True,
|
||||
verbose=1,
|
||||
)
|
||||
callbacks = [checkpoint_cb, eval_cb]
|
||||
if args.curriculum:
|
||||
callbacks.append(CurriculumCallback(start_sheep=args.n_sheep,
|
||||
max_sheep=args.max_sheep))
|
||||
callback_list = CallbackList(callbacks)
|
||||
|
||||
# Model
|
||||
ppo_kwargs = dict(
|
||||
policy = "MlpPolicy",
|
||||
env = train_env,
|
||||
learning_rate = 3e-4,
|
||||
n_steps = 2048,
|
||||
batch_size = 256,
|
||||
n_epochs = 10,
|
||||
gamma = 0.995,
|
||||
gae_lambda = 0.95,
|
||||
clip_range = 0.2,
|
||||
ent_coef = 0.005,
|
||||
vf_coef = 0.5,
|
||||
max_grad_norm = 0.5,
|
||||
policy_kwargs = dict(net_arch=[256, 256]),
|
||||
tensorboard_log = args.run_dir,
|
||||
verbose = 1,
|
||||
)
|
||||
|
||||
if args.resume:
|
||||
print(f"Resuming from {args.resume}")
|
||||
model = PPO.load(args.resume, env=train_env, **{
|
||||
k: v for k, v in ppo_kwargs.items()
|
||||
if k not in ("policy", "env")
|
||||
})
|
||||
else:
|
||||
model = PPO(**ppo_kwargs)
|
||||
|
||||
model.learn(
|
||||
total_timesteps=args.total_steps,
|
||||
callback=callback_list,
|
||||
reset_num_timesteps=args.resume is None,
|
||||
tb_log_name="ppo",
|
||||
)
|
||||
|
||||
# Save final artefacts
|
||||
model.save(os.path.join(args.run_dir, "final_model"))
|
||||
train_env.save(norm_path)
|
||||
print(f"\nTraining complete. Artefacts saved to {args.run_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+3
-2
@@ -364,6 +364,7 @@ Solid {
|
||||
# ==================== SCARECROW (east side, outside fence) ====================
|
||||
Solid {
|
||||
translation 20 -10 0
|
||||
rotation 0 0 1 2.61799
|
||||
children [
|
||||
Transform { translation 0 0 1.22 children [ Shape { appearance USE TRUNK geometry Cylinder { height 2.44 radius 0.045 subdivision 8 } } ] }
|
||||
Transform { translation 0 0 2.02 rotation 1 0 0 1.5708 children [ Shape { appearance USE TRUNK geometry Cylinder { height 1.60 radius 0.032 subdivision 8 } } ] }
|
||||
@@ -391,12 +392,12 @@ Solid {
|
||||
|
||||
# ==================== HAY BALES (near barn) ====================
|
||||
Solid { translation 25.75 13.76 0.62 children [ Transform { rotation 1 0 0 1.5708 children [ Shape { appearance USE HAY geometry Cylinder { height 1.30 radius 0.62 subdivision 14 } } ] } ] boundingObject Box { size 1.30 1.24 1.24 } }
|
||||
Solid { translation 24.34 12.32 0.62 children [ Transform { rotation 1 0 0 1.5708 children [ Shape { appearance USE HAY geometry Cylinder { height 1.30 radius 0.62 subdivision 14 } } ] } ] boundingObject Box { size 1.30 1.24 1.24 } }
|
||||
Solid { translation 24.34 12.32 0.62 rotation -1 0 0 1.5708 children [ Transform { rotation 1 0 0 1.5708 children [ Shape { appearance USE HAY geometry Cylinder { height 1.30 radius 0.62 subdivision 14 } } ] } ] boundingObject Box { size 1.30 1.24 1.24 } }
|
||||
Solid { translation 24.28 13.79 0.62 children [ Transform { rotation 1 0 0 1.5708 children [ Shape { appearance USE HAY geometry Cylinder { height 1.30 radius 0.62 subdivision 14 } } ] } ] boundingObject Box { size 1.30 1.24 1.24 } }
|
||||
|
||||
# ==================== TRACTOR (near barn) ====================
|
||||
Solid {
|
||||
translation 17 19 0
|
||||
translation 17 19 0.18
|
||||
rotation 0 0 1 1.9
|
||||
children [
|
||||
# Chassis
|
||||
|
||||
Reference in New Issue
Block a user