Cleanup and new approach

This commit is contained in:
Johnny Fernandes
2026-04-26 01:50:01 +01:00
parent b031473758
commit 61f8a7db15
139 changed files with 510 additions and 16170 deletions
+469 -354
View File
@@ -1,414 +1,529 @@
"""
PPO training script for the herding task.
PPO training for the herding task with curriculum learning.
Usage examples
--------------
# Proper 5-sheep curriculum, 1 M steps per stage:
python train.py --curriculum --steps-per-stage 1000000 --total-steps 5000000
Trains from scratch through a 1→max_sheep curriculum, evaluates after each
stage, and auto-generates trajectory/timeseries plots plus a summary chart.
# Success-rate curriculum (advances when 70 % success over 100 episodes):
python train.py --curriculum --threshold 0.70
Usage
-----
python train.py # defaults from config.json
python train.py --config my_config.json --max-sheep 5
python train.py --max-sheep 3 --steps-per-stage 1000000
# Resume from checkpoint at stage 3:
python train.py --resume runs/ppo_herding/ckpt_3000000_steps.zip --n-sheep 3 \
--curriculum --steps-per-stage 1000000 --total-steps 5000000
# Quick smoke-test:
python train.py --n-envs 1 --total-steps 50000
Outputs (in runs/<timestamp>/):
config.json resolved config
final_model.zip trained PPO model
vecnorm.pkl VecNormalize statistics
stage_results.json per-stage evaluation metrics
success_rate.png summary bar chart
eval/ trajectory & timeseries plots per sheep count
"""
import argparse
import json
import os
import time
from copy import deepcopy
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from matplotlib.collections import LineCollection
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import (
BaseCallback,
CallbackList,
CheckpointCallback,
EvalCallback,
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import (
DummyVecEnv,
SubprocVecEnv,
VecNormalize,
)
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
# ── Colours ──────────────────────────────────────────────────────────────────
SHEEP_COLORS = [
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00",
"#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62",
]
DOG_COLOR = "#4e342e"
# ── Callbacks ────────────────────────────────────────────────────────────────
class ProgressCallback(BaseCallback):
"""One-line progress summary every `freq` env steps."""
def __init__(self, stage_label: str, freq: int = 100_000):
super().__init__()
self.stage_label = stage_label
self.freq = freq
self._last = 0
self._ep_returns = []
self._ep_success = []
self._total_eps = 0
self._total_success = 0
self._cur_ret = None
def _on_step(self) -> bool:
rewards = self.locals.get("rewards")
dones = self.locals.get("dones")
infos = self.locals.get("infos", [])
if rewards is None or dones is None:
return True
if self._cur_ret is None or len(self._cur_ret) != len(rewards):
self._cur_ret = np.zeros(len(rewards), dtype=np.float64)
self._cur_ret += np.asarray(rewards, dtype=np.float64)
for i, d in enumerate(dones):
if not d:
continue
self._ep_returns.append(float(self._cur_ret[i]))
info = infos[i] if i < len(infos) else {}
success = int(info.get("n_penned", 0) == info.get("n_sheep", -1))
self._ep_success.append(success)
self._total_eps += 1
self._total_success += success
self._cur_ret[i] = 0.0
if len(self._ep_returns) > 50:
self._ep_returns.pop(0)
self._ep_success.pop(0)
if self.num_timesteps - self._last >= self.freq:
self._last = self.num_timesteps
n = len(self._ep_returns)
mean_r = float(np.mean(self._ep_returns)) if n else float("nan")
win_sr = float(np.mean(self._ep_success)) if n else float("nan")
cum_sr = (self._total_success / self._total_eps
if self._total_eps else float("nan"))
print(f" ... [{self.stage_label} | "
f"{self.num_timesteps:>7,} steps | "
f"ret(last {n})={mean_r:+.2f} "
f"win_sr={win_sr*100:.0f}% cum_sr={cum_sr*100:.0f}%]",
flush=True)
return True
# ── Environment factory ──────────────────────────────────────────────────────
def make_env(n_sheep, seed, max_steps, reward_cfg=None):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
reward_cfg=reward_cfg)
env.reset(seed=seed)
return env
return _init
# ── Failure-mode classification ──────────────────────────────────────────────
COMPACT_RADIUS = 5.0
def _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success):
if success:
def _classify(ep_radii, ep_com_dists, n_penned, n_sheep):
if n_penned == n_sheep:
return "SUCCESS"
if min(ep_radius) > COMPACT_RADIUS:
if min(ep_radii) > COMPACT_RADIUS:
return "NEVER_COMPACT"
first = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
if min(ep_com_dist[first:]) > 3.0:
first = next(i for i, r in enumerate(ep_radii) if r <= COMPACT_RADIUS)
if min(ep_com_dists[first:]) > 3.0:
return "COMPACT_CANT_DRIVE"
if n_penned == 0:
return "DROVE_NO_SHEEP"
return f"PARTIAL_{n_penned}of{n_sheep}"
# ---------------------------------------------------------------------------
# Curriculum callback
# ---------------------------------------------------------------------------
# ── Evaluation ───────────────────────────────────────────────────────────────
class CurriculumCallback(BaseCallback):
"""
Advances n_sheep on both training and eval envs.
def evaluate(model, vn_template, n_sheep, n_episodes, max_steps,
reward_cfg=None):
"""Evaluate at a given sheep count; returns metrics dict."""
raw = DummyVecEnv([make_env(n_sheep, 9999, max_steps, reward_cfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
Two modes (mutually exclusive):
steps_per_stage — advance every N environment steps regardless of
success rate (recommended for reliability).
threshold — advance when rolling success rate exceeds this value
(requires the policy to actually reach the threshold).
"""
successes = 0
ep_lens = []
min_pen_list = []
action_mags = []
failure_counts = {}
rc_sums = {}
rc_n = 0
def __init__(self, start_sheep: int, max_sheep: int,
eval_env=None,
steps_per_stage: int = None,
threshold: float = 0.75,
window: int = 100,
min_episodes: int = 50,
verbose: int = 1):
super().__init__(verbose)
self.max_sheep = max_sheep
self.eval_env = eval_env
self.steps_per_stage = steps_per_stage
self.threshold = threshold
self.window = window
self.min_episodes = min_episodes
self._cur_sheep = start_sheep
self._successes = []
self._stage_start = 0
for _ in range(n_episodes):
obs = vn.reset()
done = False
steps = 0
min_pen = float("inf")
mags = []
ep_radii = []
ep_com_dists = []
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = vn.step(action)
done = dones[0]
inner = vn.envs[0]
com, radius, _ = inner._flock_stats()
min_pen = min(min_pen, float(np.linalg.norm(com - inner.PEN_CENTER)))
mags.append(float(np.linalg.norm(action[0])))
ep_radii.append(radius)
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
steps += 1
rc = infos[0].get("rcomps")
if rc:
for k, v in rc.items():
rc_sums[k] = rc_sums.get(k, 0.0) + v
rc_n += 1
n_penned = infos[0].get("n_penned", 0)
success = n_penned == n_sheep
successes += int(success)
ep_lens.append(steps)
min_pen_list.append(min_pen)
action_mags.extend(mags)
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
def _advance(self):
prev_sheep = self._cur_sheep
recent_sr = (np.mean(self._successes) if self._successes else float("nan"))
if self.verbose:
print(f"\n[Curriculum] leaving stage n_sheep={prev_sheep} "
f"after {self.num_timesteps - self._stage_start:,} steps "
f"| training success rate (last {len(self._successes)} eps) = "
f"{recent_sr*100:.0f}%")
self._cur_sheep += 1
self.training_env.env_method("set_n_sheep", self._cur_sheep)
if self.eval_env is not None:
self.eval_env.env_method("set_n_sheep", self._cur_sheep)
self._stage_start = self.num_timesteps
self._successes.clear()
if self.verbose:
print(f"[Curriculum] → {self._cur_sheep} sheep "
f"at step {self.num_timesteps:,}\n")
vn.close()
def _on_step(self) -> bool:
if self._cur_sheep >= self.max_sheep:
return True
# Always track training-side success (success = sheep all penned, not truncated)
for info, done in zip(self.locals["infos"], self.locals["dones"]):
if done:
npen = info.get("n_penned", 0)
nshp = info.get("n_sheep", self._cur_sheep)
self._successes.append(1 if npen == nshp else 0)
if len(self._successes) > self.window:
self._successes.pop(0)
if self.steps_per_stage is not None:
if self.num_timesteps - self._stage_start >= self.steps_per_stage:
self._advance()
else:
if (len(self._successes) >= self.min_episodes
and np.mean(self._successes) >= self.threshold):
self._advance()
return True
result = {
"sr": successes / n_episodes,
"mean_len": float(np.mean(ep_lens)),
"mean_min_pen": float(np.mean(min_pen_list)),
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
"failure_modes": failure_counts,
}
if rc_n > 0:
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
return result
# ---------------------------------------------------------------------------
# Diagnostic callback — failure-mode breakdown every diag_freq steps
# ---------------------------------------------------------------------------
# ── Visualization helpers ────────────────────────────────────────────────────
class DiagnosticCallback(BaseCallback):
"""
Every diag_freq env steps: spin up a temporary eval env, run n_episodes
deterministic episodes, and print a failure-mode breakdown.
Aborts training (returns False) if the dominant failure mode hasn't
changed after two consecutive checks at the same n_sheep — a sign that
training has stalled and further steps are wasted.
"""
def __init__(self, diag_freq: int = 500_000, n_episodes: int = 20,
max_steps: int = 2000, abort_on_stall: bool = True,
verbose: int = 1):
super().__init__(verbose)
self.diag_freq = diag_freq
self.n_episodes = n_episodes
self.max_steps = max_steps
self.abort_on_stall = abort_on_stall
self._last_diag = 0
self._prev_dominant = None # (n_sheep, mode) from last check
self._stall_count = 0
def _on_step(self) -> bool:
if self.num_timesteps - self._last_diag < self.diag_freq:
return True
self._last_diag = self.num_timesteps
n_sheep = self.training_env.get_attr("n_sheep")[0]
# Build a temporary single-env with copied VecNorm stats
raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep,
max_steps=self.max_steps)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(self.training_env.obs_rms)
vn.ret_rms = deepcopy(self.training_env.ret_rms)
failure_counts = {}
successes = 0
all_action_mags = []
ep_min_radii = []
ep_min_dog_com = [] # closest the dog ever got to flock COM
ep_min_pen_dists = [] # closest COM ever got to pen
rcomp_sums = {"progress":0.0,"alignment":0.0,"pen_bonus":0.0,
"step_cost":0.0,"complete":0.0}
rcomp_n = 0
for _ in range(self.n_episodes):
obs = vn.reset()
done = False
ep_radius, ep_com_dist, ep_dog_com = [], [], []
ep_actions = []
n_penned = 0
while not done:
action, _ = self.model.predict(obs, deterministic=True)
obs, _, dones, infos = vn.step(action)
done = dones[0]
inner = vn.envs[0]
com, radius, _ = inner._flock_stats()
ep_radius.append(radius)
ep_com_dist.append(
float(np.linalg.norm(com - inner.PEN_CENTER))
)
ep_dog_com.append(
float(np.linalg.norm(inner.dog_pos - com))
)
ep_actions.append(float(np.linalg.norm(action[0])))
rc = infos[0].get("rcomps")
if rc is not None:
for k in rcomp_sums: rcomp_sums[k] += rc[k]
rcomp_n += 1
n_penned = infos[0].get("n_penned", 0)
success = n_penned == n_sheep
successes += int(success)
mode = _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
all_action_mags.extend(ep_actions)
ep_min_radii.append(min(ep_radius))
ep_min_dog_com.append(min(ep_dog_com))
ep_min_pen_dists.append(min(ep_com_dist))
vn.close()
success_rate = successes / self.n_episodes
dominant = max(failure_counts, key=failure_counts.get)
if self.verbose:
print(f"\n[Diag @ {self.num_timesteps:,} | n_sheep={n_sheep} | "
f"success={success_rate*100:.0f}%]")
for m, c in sorted(failure_counts.items(), key=lambda x: -x[1]):
print(f" {m:<26} {c}/{self.n_episodes}")
mean_act = float(np.mean(all_action_mags)) if all_action_mags else 0.0
p10 = float(np.percentile(all_action_mags, 10)) if all_action_mags else 0.0
p90 = float(np.percentile(all_action_mags, 90)) if all_action_mags else 0.0
print(f" action_mag mean={mean_act:.3f} p10={p10:.3f} p90={p90:.3f} "
f"(0=stopped, 1=full speed)")
print(f" min_flock_radius mean={np.mean(ep_min_radii):.2f}m "
f"best={np.min(ep_min_radii):.2f}m (target <5m to compact)")
print(f" min_dog_to_com mean={np.mean(ep_min_dog_com):.2f}m "
f"best={np.min(ep_min_dog_com):.2f}m (FLEE_DIST=7m)")
print(f" min_com_to_pen mean={np.mean(ep_min_pen_dists):.2f}m "
f"best={np.min(ep_min_pen_dists):.2f}m")
if rcomp_n > 0:
print(f" reward/step (mean): " + " ".join(
f"{k}={rcomp_sums[k]/rcomp_n:+.4f}" for k in
("progress","alignment","pen_bonus","step_cost","complete")
))
# Stall detection — disabled when --no-stall-abort or when we've never
# seen any stage succeed (we want full visibility into what's happening).
key = (n_sheep, dominant)
if key == self._prev_dominant and dominant != "SUCCESS":
self._stall_count += 1
if (self.abort_on_stall and self._stall_count >= 5
and self.num_timesteps >= 3_000_000):
print(f"\n[Diag] STALL DETECTED — '{dominant}' on {n_sheep} sheep "
f"for {self._stall_count} consecutive checks. "
f"Aborting training early.")
return False
else:
self._stall_count = 0
self._prev_dominant = key
return True
def _draw_field(ax):
ax.set_xlim(-16, 16)
ax.set_ylim(-16, 16)
ax.set_aspect("equal")
ax.set_facecolor("#dcedc8")
ax.add_patch(mpatches.Rectangle((-15, -15), 30, 30,
fill=False, edgecolor="#795548", lw=2))
ax.add_patch(mpatches.Rectangle((10, -15), 3, 7,
facecolor="#ffe082", edgecolor="#795548", lw=2))
ax.text(11.5, -11.5, "pen", ha="center", va="center",
fontsize=8, color="#795548")
# ---------------------------------------------------------------------------
# Environment factory
# ---------------------------------------------------------------------------
def make_env(n_sheep: int, seed: int, max_steps: int, random_n_sheep: bool = False):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
random_n_sheep=random_n_sheep)
env.reset(seed=seed)
return env
return _init
def _faded_path(ax, xs, ys, color, lw=1.5, label=None):
n = len(xs)
if n < 2:
return
points = np.array([xs, ys]).T.reshape(-1, 1, 2)
segs = np.concatenate([points[:-1], points[1:]], axis=1)
alphas = np.linspace(0.15, 1.0, len(segs))
colors = [(*matplotlib.colors.to_rgb(color), a) for a in alphas]
ax.add_collection(LineCollection(segs, colors=colors, linewidth=lw))
if label:
ax.plot([], [], color=color, lw=lw, label=label)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def run_and_record(model, vn_template, n_sheep, max_steps,
reward_cfg=None, seed=42):
"""Run one deterministic episode and return full history."""
raw = DummyVecEnv([make_env(n_sheep, seed, max_steps, reward_cfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
obs = vn.reset()
inner = vn.envs[0]
done = False
dog_xs, dog_ys = [], []
sheep_xs = [[] for _ in range(n_sheep)]
sheep_ys = [[] for _ in range(n_sheep)]
radii = []
pen_dists = [[] for _ in range(n_sheep)]
action_mags = []
rewards = []
penned_at = [None] * n_sheep
step = 0
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, dones, infos = vn.step(action)
done = dones[0]
step += 1
dog_xs.append(float(inner.dog_pos[0]))
dog_ys.append(float(inner.dog_pos[1]))
com, radius, _ = inner._flock_stats()
radii.append(radius)
rewards.append(float(reward[0]))
action_mags.append(float(np.linalg.norm(action[0])))
for i in range(n_sheep):
sheep_xs[i].append(float(inner.sheep_pos[i][0]))
sheep_ys[i].append(float(inner.sheep_pos[i][1]))
pen_dists[i].append(
float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
if inner.penned[i] and penned_at[i] is None:
penned_at[i] = step
n_penned = infos[0].get("n_penned", 0)
vn.close()
return dict(
dog_xs=dog_xs, dog_ys=dog_ys,
sheep_xs=sheep_xs, sheep_ys=sheep_ys,
radii=radii, pen_dists=pen_dists,
action_mags=action_mags, rewards=rewards,
penned_at=penned_at,
n_penned=n_penned, n_sheep=n_sheep,
success=n_penned == n_sheep, steps=step,
)
def plot_trajectory(hist, out_path):
fig, ax = plt.subplots(figsize=(7, 7))
_draw_field(ax)
for i in range(hist["n_sheep"]):
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
xs, ys = hist["sheep_xs"][i], hist["sheep_ys"][i]
_faded_path(ax, xs, ys, c, lw=1.2, label=f"sheep {i+1}")
ax.plot(xs[0], ys[0], "o", color=c, ms=7, zorder=4)
end = hist["penned_at"][i] if hist["penned_at"][i] is not None else -1
ax.plot(xs[end], ys[end], "*", color=c, ms=11, zorder=5)
_faded_path(ax, hist["dog_xs"], hist["dog_ys"], DOG_COLOR, lw=2.0,
label="dog")
ax.plot(hist["dog_xs"][0], hist["dog_ys"][0], "s", color=DOG_COLOR,
ms=10, zorder=5)
ax.plot(hist["dog_xs"][-1], hist["dog_ys"][-1], "D", color=DOG_COLOR,
ms=10, zorder=5)
result = ("SUCCESS" if hist["success"]
else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})")
ax.set_title(f"n={hist['n_sheep']} {result} {hist['steps']} steps",
fontsize=12)
ax.legend(loc="upper left", fontsize=8)
plt.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
def plot_timeseries(hist, out_path):
t = np.arange(hist["steps"])
fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
axes[0].plot(t, hist["radii"], color="steelblue")
axes[0].axhline(5.0, color="orange", ls="--", lw=1, label="compact (5m)")
axes[0].set_ylabel("flock radius (m)")
axes[0].legend(fontsize=8)
axes[0].set_title("Flock radius")
for i in range(hist["n_sheep"]):
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
axes[1].plot(t, hist["pen_dists"][i], color=c, lw=1,
label=f"sheep {i+1}")
if hist["penned_at"][i] is not None:
axes[1].axvline(hist["penned_at"][i], color=c, ls=":", lw=1)
axes[1].set_ylabel("dist to pen (m)")
axes[1].legend(fontsize=7, ncol=min(hist["n_sheep"], 5))
axes[1].set_title("Per-sheep distance to pen")
axes[2].plot(t, hist["action_mags"], color="tomato", lw=1)
axes[2].axhline(1.0, color="gray", ls="--", lw=1, label="max")
axes[2].set_ylabel("action ||(vx,vy)||")
axes[2].set_ylim(0, 1.5)
axes[2].set_title("Dog action magnitude")
axes[2].legend(fontsize=8)
axes[3].plot(t, hist["rewards"], color="purple", lw=1, alpha=0.7)
axes[3].axhline(0, color="black", lw=0.5)
axes[3].set_ylabel("reward")
axes[3].set_xlabel("step")
axes[3].set_title("Reward per step")
result = ("SUCCESS" if hist["success"]
else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})")
fig.suptitle(f"n_sheep={hist['n_sheep']} {result} {hist['steps']} steps",
fontsize=13)
plt.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
def plot_success_rate(stage_results, out_path):
fig, ax = plt.subplots(figsize=(8, 4))
ns = [r["n_sheep"] for r in stage_results]
srs = [r["sr"] * 100 for r in stage_results]
bars = ax.bar(ns, srs, color="steelblue", edgecolor="white")
ax.set_xlabel("Sheep count")
ax.set_ylabel("Success rate (%)")
ax.set_ylim(0, 105)
ax.axhline(90, color="orange", ls="--", lw=1, label="90% target")
for bar, sr in zip(bars, srs):
ax.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() + 1, f"{sr:.0f}%",
ha="center", fontsize=9)
ax.legend()
ax.set_title("Evaluation success rate per sheep count")
plt.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
# ── CLI ──────────────────────────────────────────────────────────────────────
DEFAULT_CONFIG = {
"W_PER_SHEEP": 2.0,
"W_ALIGN": 0.05,
"W_PEN_BONUS": 10.0,
"W_COMPLETE": 100.0,
"W_STEP_COST": 0.02,
"W_COMPACT": 0.0,
"W_WALL_TOUCH": 0.15,
"WALL_TOUCH_BUFFER": 0.8,
"ALIGN_SHAPE": "standoff",
"ALIGN_GATED": True,
"ENTRY_AWARE": False,
"ent_coef": 0.02,
}
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--n-sheep", type=int, default=1,
help="Starting sheep count")
p.add_argument("--max-sheep", type=int, default=5,
help="Final sheep count for curriculum")
p.add_argument("--n-envs", type=int, default=8,
help="Parallel training environments")
p.add_argument("--total-steps", type=int, default=5_000_000)
p.add_argument("--max-steps", type=int, default=2000,
help="Episode step limit")
p.add_argument("--curriculum", action="store_true",
help="Enable curriculum advancement")
p.add_argument("--steps-per-stage", type=int, default=None,
help="Advance curriculum every N steps (overrides --threshold)")
p.add_argument("--threshold", type=float, default=0.75,
help="Success-rate threshold to advance (used without --steps-per-stage)")
p.add_argument("--resume", type=str, default=None,
help="Checkpoint .zip to resume from")
p.add_argument("--run-dir", type=str, default="runs/ppo_herding")
p.add_argument("--save-freq", type=int, default=100_000)
p.add_argument("--eval-freq", type=int, default=50_000)
p.add_argument("--eval-eps", type=int, default=20)
p.add_argument("--diag-freq", type=int, default=500_000,
help="Run failure-mode diagnostics every N env steps")
p.add_argument("--no-stall-abort", action="store_true",
help="Disable early-abort on stall — run full --total-steps "
"for diagnostics")
p.add_argument("--mixed", action="store_true",
help="Randomise n_sheep each episode (consolidation pass, "
"use with --resume after curriculum training)")
p = argparse.ArgumentParser(
description="PPO training for herding task with curriculum learning")
p.add_argument("--config", type=str, default=None,
help="JSON config file (reward weights + ent_coef)")
p.add_argument("--max-sheep", type=int, default=10)
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
p.add_argument("--n-envs", type=int, default=8)
p.add_argument("--max-steps", type=int, default=2500)
p.add_argument("--eval-episodes", type=int, default=30)
p.add_argument("--run-dir", type=str, default=None)
return p.parse_args()
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
os.makedirs(args.run_dir, exist_ok=True)
ckpt_dir = os.path.join(args.run_dir, "checkpoints")
best_dir = os.path.join(args.run_dir, "best_model")
norm_path = os.path.join(args.run_dir, "vecnorm.pkl")
os.makedirs(ckpt_dir, exist_ok=True)
# Load config
cfg = dict(DEFAULT_CONFIG)
if args.config:
with open(args.config) as f:
cfg.update(json.load(f))
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
# Run directory
run_dir = args.run_dir or os.path.join(
"runs", time.strftime("%Y%m%d_%H%M%S"))
eval_dir = os.path.join(run_dir, "eval")
os.makedirs(eval_dir, exist_ok=True)
with open(os.path.join(run_dir, "config.json"), "w") as f:
json.dump(cfg, f, indent=2)
print(f"Config: {cfg}")
print(f"Run dir: {run_dir}")
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
f"{args.steps_per_stage:,} steps/stage\n")
# Training envs
train_env = SubprocVecEnv([
make_env(args.n_sheep, seed=i, max_steps=args.max_steps,
random_n_sheep=args.mixed)
make_env(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
for i in range(args.n_envs)
])
if args.resume and os.path.exists(norm_path):
train_env = VecNormalize.load(norm_path, train_env)
train_env.training = True
train_env.norm_reward = True
else:
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True,
clip_obs=10.0)
# Eval env — starts at same difficulty, advances with curriculum callback
eval_env = SubprocVecEnv([
make_env(args.n_sheep, seed=1000 + i, max_steps=args.max_steps)
for i in range(2)
])
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False,
clip_obs=10.0, training=False)
# Callbacks
checkpoint_cb = CheckpointCallback(
save_freq=max(args.save_freq // args.n_envs, 1),
save_path=ckpt_dir,
name_prefix="ckpt",
save_vecnormalize=True,
)
eval_cb = EvalCallback(
eval_env,
best_model_save_path=best_dir,
log_path=args.run_dir,
eval_freq=max(args.eval_freq // args.n_envs, 1),
n_eval_episodes=args.eval_eps,
deterministic=True,
verbose=1,
)
diag_cb = DiagnosticCallback(
diag_freq=args.diag_freq,
n_episodes=20,
max_steps=args.max_steps,
abort_on_stall=not args.no_stall_abort,
)
callbacks = [checkpoint_cb, eval_cb, diag_cb]
if args.curriculum:
cur_cb = CurriculumCallback(
start_sheep=args.n_sheep,
max_sheep=args.max_sheep,
eval_env=eval_env,
steps_per_stage=args.steps_per_stage,
threshold=args.threshold,
)
callbacks.append(cur_cb)
callback_list = CallbackList(callbacks)
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True,
clip_obs=10.0)
# Model
ppo_kwargs = dict(
policy = "MlpPolicy",
env = train_env,
learning_rate = 3e-4,
n_steps = 2048,
batch_size = 256,
n_epochs = 10,
gamma = 0.995,
gae_lambda = 0.95,
clip_range = 0.2,
ent_coef = 0.01,
vf_coef = 0.5,
max_grad_norm = 0.5,
policy_kwargs = dict(net_arch=[256, 256]),
tensorboard_log = args.run_dir,
verbose = 1,
model = PPO(
"MlpPolicy", vn,
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
policy_kwargs=dict(net_arch=[256, 256]),
verbose=0,
)
if args.resume:
print(f"Resuming from {args.resume}")
model = PPO.load(args.resume, env=train_env, **{
k: v for k, v in ppo_kwargs.items()
if k not in ("policy", "env")
})
else:
model = PPO(**ppo_kwargs)
# Curriculum training
stage_results = []
t0 = time.time()
model.learn(
total_timesteps=args.total_steps,
callback=callback_list,
reset_num_timesteps=args.resume is None,
tb_log_name="ppo",
)
try:
for n in range(1, args.max_sheep + 1):
if n > 1:
vn.env_method("set_n_sheep", n)
model.save(os.path.join(args.run_dir, "final_model"))
train_env.save(norm_path)
print(f"\nTraining complete. Artefacts saved to {args.run_dir}/")
print(f"\n[Stage n_sheep={n}] training {args.steps_per_stage:,} steps")
model.learn(
total_timesteps=args.steps_per_stage,
reset_num_timesteps=(n == 1),
callback=ProgressCallback(f"{n} sheep", freq=100_000),
)
# Evaluate
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
r = evaluate(model, vn, n, args.eval_episodes, args.max_steps, rcfg)
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
f"mean_len={r['mean_len']:.0f} "
f"mean_min_pen={r['mean_min_pen']:.1f}m "
f"mean_act={r['mean_act']:.2f}")
# Failure-mode breakdown
if r["failure_modes"]:
modes = " ".join(
f"{k}={v}" for k, v in sorted(
r["failure_modes"].items(), key=lambda x: -x[1]))
print(f" failure modes: {modes}")
# Reward breakdown
if "reward_per_step" in r:
rps = r["reward_per_step"]
print(f" reward/step: " + " ".join(
f"{k}={v:+.4f}" for k, v in rps.items()))
# Episode visualization
hist = run_and_record(model, vn, n, args.max_steps, rcfg,
seed=1000 + n)
tag = "success" if hist["success"] else "fail"
plot_trajectory(
hist,
os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
plot_timeseries(
hist,
os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
r["n_sheep"] = n
stage_results.append(r)
# Save artefacts
model.save(os.path.join(run_dir, "final_model"))
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
json.dump(stage_results, f, indent=2)
finally:
try:
vn.close()
except Exception:
pass
# Summary
elapsed = (time.time() - t0) / 60
print("\n" + "=" * 70)
print(" TRAINING SUMMARY")
print("=" * 70)
for r in stage_results:
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
f"len={r['mean_len']:>5.0f} min_pen={r['mean_min_pen']:>5.1f}m "
f"act={r['mean_act']:.2f}")
print(f"\n Total time: {elapsed:.1f} min")
print(f" Artefacts: {run_dir}/")
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
if __name__ == "__main__":