Cleanup and new approach
This commit is contained in:
+469
-354
@@ -1,414 +1,529 @@
|
||||
"""
|
||||
PPO training script for the herding task.
|
||||
PPO training for the herding task with curriculum learning.
|
||||
|
||||
Usage examples
|
||||
--------------
|
||||
# Proper 5-sheep curriculum, 1 M steps per stage:
|
||||
python train.py --curriculum --steps-per-stage 1000000 --total-steps 5000000
|
||||
Trains from scratch through a 1→max_sheep curriculum, evaluates after each
|
||||
stage, and auto-generates trajectory/timeseries plots plus a summary chart.
|
||||
|
||||
# Success-rate curriculum (advances when 70 % success over 100 episodes):
|
||||
python train.py --curriculum --threshold 0.70
|
||||
Usage
|
||||
-----
|
||||
python train.py # defaults from config.json
|
||||
python train.py --config my_config.json --max-sheep 5
|
||||
python train.py --max-sheep 3 --steps-per-stage 1000000
|
||||
|
||||
# Resume from checkpoint at stage 3:
|
||||
python train.py --resume runs/ppo_herding/ckpt_3000000_steps.zip --n-sheep 3 \
|
||||
--curriculum --steps-per-stage 1000000 --total-steps 5000000
|
||||
|
||||
# Quick smoke-test:
|
||||
python train.py --n-envs 1 --total-steps 50000
|
||||
Outputs (in runs/<timestamp>/):
|
||||
config.json resolved config
|
||||
final_model.zip trained PPO model
|
||||
vecnorm.pkl VecNormalize statistics
|
||||
stage_results.json per-stage evaluation metrics
|
||||
success_rate.png summary bar chart
|
||||
eval/ trajectory & timeseries plots per sheep count
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
from matplotlib.collections import LineCollection
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import (
|
||||
BaseCallback,
|
||||
CallbackList,
|
||||
CheckpointCallback,
|
||||
EvalCallback,
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.vec_env import (
|
||||
DummyVecEnv,
|
||||
SubprocVecEnv,
|
||||
VecNormalize,
|
||||
)
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
|
||||
|
||||
# ── Colours ──────────────────────────────────────────────────────────────────
|
||||
|
||||
SHEEP_COLORS = [
|
||||
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00",
|
||||
"#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62",
|
||||
]
|
||||
DOG_COLOR = "#4e342e"
|
||||
|
||||
|
||||
# ── Callbacks ────────────────────────────────────────────────────────────────
|
||||
|
||||
class ProgressCallback(BaseCallback):
|
||||
"""One-line progress summary every `freq` env steps."""
|
||||
|
||||
def __init__(self, stage_label: str, freq: int = 100_000):
|
||||
super().__init__()
|
||||
self.stage_label = stage_label
|
||||
self.freq = freq
|
||||
self._last = 0
|
||||
self._ep_returns = []
|
||||
self._ep_success = []
|
||||
self._total_eps = 0
|
||||
self._total_success = 0
|
||||
self._cur_ret = None
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
rewards = self.locals.get("rewards")
|
||||
dones = self.locals.get("dones")
|
||||
infos = self.locals.get("infos", [])
|
||||
if rewards is None or dones is None:
|
||||
return True
|
||||
if self._cur_ret is None or len(self._cur_ret) != len(rewards):
|
||||
self._cur_ret = np.zeros(len(rewards), dtype=np.float64)
|
||||
self._cur_ret += np.asarray(rewards, dtype=np.float64)
|
||||
for i, d in enumerate(dones):
|
||||
if not d:
|
||||
continue
|
||||
self._ep_returns.append(float(self._cur_ret[i]))
|
||||
info = infos[i] if i < len(infos) else {}
|
||||
success = int(info.get("n_penned", 0) == info.get("n_sheep", -1))
|
||||
self._ep_success.append(success)
|
||||
self._total_eps += 1
|
||||
self._total_success += success
|
||||
self._cur_ret[i] = 0.0
|
||||
if len(self._ep_returns) > 50:
|
||||
self._ep_returns.pop(0)
|
||||
self._ep_success.pop(0)
|
||||
if self.num_timesteps - self._last >= self.freq:
|
||||
self._last = self.num_timesteps
|
||||
n = len(self._ep_returns)
|
||||
mean_r = float(np.mean(self._ep_returns)) if n else float("nan")
|
||||
win_sr = float(np.mean(self._ep_success)) if n else float("nan")
|
||||
cum_sr = (self._total_success / self._total_eps
|
||||
if self._total_eps else float("nan"))
|
||||
print(f" ... [{self.stage_label} | "
|
||||
f"{self.num_timesteps:>7,} steps | "
|
||||
f"ret(last {n})={mean_r:+.2f} "
|
||||
f"win_sr={win_sr*100:.0f}% cum_sr={cum_sr*100:.0f}%]",
|
||||
flush=True)
|
||||
return True
|
||||
|
||||
|
||||
# ── Environment factory ──────────────────────────────────────────────────────
|
||||
|
||||
def make_env(n_sheep, seed, max_steps, reward_cfg=None):
|
||||
def _init():
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||
reward_cfg=reward_cfg)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
|
||||
# ── Failure-mode classification ──────────────────────────────────────────────
|
||||
|
||||
COMPACT_RADIUS = 5.0
|
||||
|
||||
|
||||
def _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success):
|
||||
if success:
|
||||
def _classify(ep_radii, ep_com_dists, n_penned, n_sheep):
|
||||
if n_penned == n_sheep:
|
||||
return "SUCCESS"
|
||||
if min(ep_radius) > COMPACT_RADIUS:
|
||||
if min(ep_radii) > COMPACT_RADIUS:
|
||||
return "NEVER_COMPACT"
|
||||
first = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
|
||||
if min(ep_com_dist[first:]) > 3.0:
|
||||
first = next(i for i, r in enumerate(ep_radii) if r <= COMPACT_RADIUS)
|
||||
if min(ep_com_dists[first:]) > 3.0:
|
||||
return "COMPACT_CANT_DRIVE"
|
||||
if n_penned == 0:
|
||||
return "DROVE_NO_SHEEP"
|
||||
return f"PARTIAL_{n_penned}of{n_sheep}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Curriculum callback
|
||||
# ---------------------------------------------------------------------------
|
||||
# ── Evaluation ───────────────────────────────────────────────────────────────
|
||||
|
||||
class CurriculumCallback(BaseCallback):
|
||||
"""
|
||||
Advances n_sheep on both training and eval envs.
|
||||
def evaluate(model, vn_template, n_sheep, n_episodes, max_steps,
|
||||
reward_cfg=None):
|
||||
"""Evaluate at a given sheep count; returns metrics dict."""
|
||||
raw = DummyVecEnv([make_env(n_sheep, 9999, max_steps, reward_cfg)])
|
||||
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
vn.obs_rms = deepcopy(vn_template.obs_rms)
|
||||
vn.ret_rms = deepcopy(vn_template.ret_rms)
|
||||
|
||||
Two modes (mutually exclusive):
|
||||
steps_per_stage — advance every N environment steps regardless of
|
||||
success rate (recommended for reliability).
|
||||
threshold — advance when rolling success rate exceeds this value
|
||||
(requires the policy to actually reach the threshold).
|
||||
"""
|
||||
successes = 0
|
||||
ep_lens = []
|
||||
min_pen_list = []
|
||||
action_mags = []
|
||||
failure_counts = {}
|
||||
rc_sums = {}
|
||||
rc_n = 0
|
||||
|
||||
def __init__(self, start_sheep: int, max_sheep: int,
|
||||
eval_env=None,
|
||||
steps_per_stage: int = None,
|
||||
threshold: float = 0.75,
|
||||
window: int = 100,
|
||||
min_episodes: int = 50,
|
||||
verbose: int = 1):
|
||||
super().__init__(verbose)
|
||||
self.max_sheep = max_sheep
|
||||
self.eval_env = eval_env
|
||||
self.steps_per_stage = steps_per_stage
|
||||
self.threshold = threshold
|
||||
self.window = window
|
||||
self.min_episodes = min_episodes
|
||||
self._cur_sheep = start_sheep
|
||||
self._successes = []
|
||||
self._stage_start = 0
|
||||
for _ in range(n_episodes):
|
||||
obs = vn.reset()
|
||||
done = False
|
||||
steps = 0
|
||||
min_pen = float("inf")
|
||||
mags = []
|
||||
ep_radii = []
|
||||
ep_com_dists = []
|
||||
while not done:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, _, dones, infos = vn.step(action)
|
||||
done = dones[0]
|
||||
inner = vn.envs[0]
|
||||
com, radius, _ = inner._flock_stats()
|
||||
min_pen = min(min_pen, float(np.linalg.norm(com - inner.PEN_CENTER)))
|
||||
mags.append(float(np.linalg.norm(action[0])))
|
||||
ep_radii.append(radius)
|
||||
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
|
||||
steps += 1
|
||||
rc = infos[0].get("rcomps")
|
||||
if rc:
|
||||
for k, v in rc.items():
|
||||
rc_sums[k] = rc_sums.get(k, 0.0) + v
|
||||
rc_n += 1
|
||||
n_penned = infos[0].get("n_penned", 0)
|
||||
success = n_penned == n_sheep
|
||||
successes += int(success)
|
||||
ep_lens.append(steps)
|
||||
min_pen_list.append(min_pen)
|
||||
action_mags.extend(mags)
|
||||
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
|
||||
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
||||
|
||||
def _advance(self):
|
||||
prev_sheep = self._cur_sheep
|
||||
recent_sr = (np.mean(self._successes) if self._successes else float("nan"))
|
||||
if self.verbose:
|
||||
print(f"\n[Curriculum] leaving stage n_sheep={prev_sheep} "
|
||||
f"after {self.num_timesteps - self._stage_start:,} steps "
|
||||
f"| training success rate (last {len(self._successes)} eps) = "
|
||||
f"{recent_sr*100:.0f}%")
|
||||
self._cur_sheep += 1
|
||||
self.training_env.env_method("set_n_sheep", self._cur_sheep)
|
||||
if self.eval_env is not None:
|
||||
self.eval_env.env_method("set_n_sheep", self._cur_sheep)
|
||||
self._stage_start = self.num_timesteps
|
||||
self._successes.clear()
|
||||
if self.verbose:
|
||||
print(f"[Curriculum] → {self._cur_sheep} sheep "
|
||||
f"at step {self.num_timesteps:,}\n")
|
||||
vn.close()
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self._cur_sheep >= self.max_sheep:
|
||||
return True
|
||||
|
||||
# Always track training-side success (success = sheep all penned, not truncated)
|
||||
for info, done in zip(self.locals["infos"], self.locals["dones"]):
|
||||
if done:
|
||||
npen = info.get("n_penned", 0)
|
||||
nshp = info.get("n_sheep", self._cur_sheep)
|
||||
self._successes.append(1 if npen == nshp else 0)
|
||||
if len(self._successes) > self.window:
|
||||
self._successes.pop(0)
|
||||
|
||||
if self.steps_per_stage is not None:
|
||||
if self.num_timesteps - self._stage_start >= self.steps_per_stage:
|
||||
self._advance()
|
||||
else:
|
||||
if (len(self._successes) >= self.min_episodes
|
||||
and np.mean(self._successes) >= self.threshold):
|
||||
self._advance()
|
||||
|
||||
return True
|
||||
result = {
|
||||
"sr": successes / n_episodes,
|
||||
"mean_len": float(np.mean(ep_lens)),
|
||||
"mean_min_pen": float(np.mean(min_pen_list)),
|
||||
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
|
||||
"failure_modes": failure_counts,
|
||||
}
|
||||
if rc_n > 0:
|
||||
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Diagnostic callback — failure-mode breakdown every diag_freq steps
|
||||
# ---------------------------------------------------------------------------
|
||||
# ── Visualization helpers ────────────────────────────────────────────────────
|
||||
|
||||
class DiagnosticCallback(BaseCallback):
|
||||
"""
|
||||
Every diag_freq env steps: spin up a temporary eval env, run n_episodes
|
||||
deterministic episodes, and print a failure-mode breakdown.
|
||||
Aborts training (returns False) if the dominant failure mode hasn't
|
||||
changed after two consecutive checks at the same n_sheep — a sign that
|
||||
training has stalled and further steps are wasted.
|
||||
"""
|
||||
|
||||
def __init__(self, diag_freq: int = 500_000, n_episodes: int = 20,
|
||||
max_steps: int = 2000, abort_on_stall: bool = True,
|
||||
verbose: int = 1):
|
||||
super().__init__(verbose)
|
||||
self.diag_freq = diag_freq
|
||||
self.n_episodes = n_episodes
|
||||
self.max_steps = max_steps
|
||||
self.abort_on_stall = abort_on_stall
|
||||
self._last_diag = 0
|
||||
self._prev_dominant = None # (n_sheep, mode) from last check
|
||||
self._stall_count = 0
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.num_timesteps - self._last_diag < self.diag_freq:
|
||||
return True
|
||||
self._last_diag = self.num_timesteps
|
||||
|
||||
n_sheep = self.training_env.get_attr("n_sheep")[0]
|
||||
|
||||
# Build a temporary single-env with copied VecNorm stats
|
||||
raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep,
|
||||
max_steps=self.max_steps)])
|
||||
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
vn.obs_rms = deepcopy(self.training_env.obs_rms)
|
||||
vn.ret_rms = deepcopy(self.training_env.ret_rms)
|
||||
|
||||
failure_counts = {}
|
||||
successes = 0
|
||||
all_action_mags = []
|
||||
ep_min_radii = []
|
||||
ep_min_dog_com = [] # closest the dog ever got to flock COM
|
||||
ep_min_pen_dists = [] # closest COM ever got to pen
|
||||
rcomp_sums = {"progress":0.0,"alignment":0.0,"pen_bonus":0.0,
|
||||
"step_cost":0.0,"complete":0.0}
|
||||
rcomp_n = 0
|
||||
|
||||
for _ in range(self.n_episodes):
|
||||
obs = vn.reset()
|
||||
done = False
|
||||
ep_radius, ep_com_dist, ep_dog_com = [], [], []
|
||||
ep_actions = []
|
||||
n_penned = 0
|
||||
|
||||
while not done:
|
||||
action, _ = self.model.predict(obs, deterministic=True)
|
||||
obs, _, dones, infos = vn.step(action)
|
||||
done = dones[0]
|
||||
inner = vn.envs[0]
|
||||
com, radius, _ = inner._flock_stats()
|
||||
ep_radius.append(radius)
|
||||
ep_com_dist.append(
|
||||
float(np.linalg.norm(com - inner.PEN_CENTER))
|
||||
)
|
||||
ep_dog_com.append(
|
||||
float(np.linalg.norm(inner.dog_pos - com))
|
||||
)
|
||||
ep_actions.append(float(np.linalg.norm(action[0])))
|
||||
rc = infos[0].get("rcomps")
|
||||
if rc is not None:
|
||||
for k in rcomp_sums: rcomp_sums[k] += rc[k]
|
||||
rcomp_n += 1
|
||||
|
||||
n_penned = infos[0].get("n_penned", 0)
|
||||
success = n_penned == n_sheep
|
||||
successes += int(success)
|
||||
mode = _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success)
|
||||
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
||||
all_action_mags.extend(ep_actions)
|
||||
ep_min_radii.append(min(ep_radius))
|
||||
ep_min_dog_com.append(min(ep_dog_com))
|
||||
ep_min_pen_dists.append(min(ep_com_dist))
|
||||
|
||||
vn.close()
|
||||
|
||||
success_rate = successes / self.n_episodes
|
||||
dominant = max(failure_counts, key=failure_counts.get)
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n[Diag @ {self.num_timesteps:,} | n_sheep={n_sheep} | "
|
||||
f"success={success_rate*100:.0f}%]")
|
||||
for m, c in sorted(failure_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {m:<26} {c}/{self.n_episodes}")
|
||||
mean_act = float(np.mean(all_action_mags)) if all_action_mags else 0.0
|
||||
p10 = float(np.percentile(all_action_mags, 10)) if all_action_mags else 0.0
|
||||
p90 = float(np.percentile(all_action_mags, 90)) if all_action_mags else 0.0
|
||||
print(f" action_mag mean={mean_act:.3f} p10={p10:.3f} p90={p90:.3f} "
|
||||
f"(0=stopped, 1=full speed)")
|
||||
print(f" min_flock_radius mean={np.mean(ep_min_radii):.2f}m "
|
||||
f"best={np.min(ep_min_radii):.2f}m (target <5m to compact)")
|
||||
print(f" min_dog_to_com mean={np.mean(ep_min_dog_com):.2f}m "
|
||||
f"best={np.min(ep_min_dog_com):.2f}m (FLEE_DIST=7m)")
|
||||
print(f" min_com_to_pen mean={np.mean(ep_min_pen_dists):.2f}m "
|
||||
f"best={np.min(ep_min_pen_dists):.2f}m")
|
||||
if rcomp_n > 0:
|
||||
print(f" reward/step (mean): " + " ".join(
|
||||
f"{k}={rcomp_sums[k]/rcomp_n:+.4f}" for k in
|
||||
("progress","alignment","pen_bonus","step_cost","complete")
|
||||
))
|
||||
|
||||
# Stall detection — disabled when --no-stall-abort or when we've never
|
||||
# seen any stage succeed (we want full visibility into what's happening).
|
||||
key = (n_sheep, dominant)
|
||||
if key == self._prev_dominant and dominant != "SUCCESS":
|
||||
self._stall_count += 1
|
||||
if (self.abort_on_stall and self._stall_count >= 5
|
||||
and self.num_timesteps >= 3_000_000):
|
||||
print(f"\n[Diag] STALL DETECTED — '{dominant}' on {n_sheep} sheep "
|
||||
f"for {self._stall_count} consecutive checks. "
|
||||
f"Aborting training early.")
|
||||
return False
|
||||
else:
|
||||
self._stall_count = 0
|
||||
self._prev_dominant = key
|
||||
|
||||
return True
|
||||
def _draw_field(ax):
|
||||
ax.set_xlim(-16, 16)
|
||||
ax.set_ylim(-16, 16)
|
||||
ax.set_aspect("equal")
|
||||
ax.set_facecolor("#dcedc8")
|
||||
ax.add_patch(mpatches.Rectangle((-15, -15), 30, 30,
|
||||
fill=False, edgecolor="#795548", lw=2))
|
||||
ax.add_patch(mpatches.Rectangle((10, -15), 3, 7,
|
||||
facecolor="#ffe082", edgecolor="#795548", lw=2))
|
||||
ax.text(11.5, -11.5, "pen", ha="center", va="center",
|
||||
fontsize=8, color="#795548")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_env(n_sheep: int, seed: int, max_steps: int, random_n_sheep: bool = False):
|
||||
def _init():
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||
random_n_sheep=random_n_sheep)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
return _init
|
||||
def _faded_path(ax, xs, ys, color, lw=1.5, label=None):
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return
|
||||
points = np.array([xs, ys]).T.reshape(-1, 1, 2)
|
||||
segs = np.concatenate([points[:-1], points[1:]], axis=1)
|
||||
alphas = np.linspace(0.15, 1.0, len(segs))
|
||||
colors = [(*matplotlib.colors.to_rgb(color), a) for a in alphas]
|
||||
ax.add_collection(LineCollection(segs, colors=colors, linewidth=lw))
|
||||
if label:
|
||||
ax.plot([], [], color=color, lw=lw, label=label)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
def run_and_record(model, vn_template, n_sheep, max_steps,
|
||||
reward_cfg=None, seed=42):
|
||||
"""Run one deterministic episode and return full history."""
|
||||
raw = DummyVecEnv([make_env(n_sheep, seed, max_steps, reward_cfg)])
|
||||
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
vn.obs_rms = deepcopy(vn_template.obs_rms)
|
||||
vn.ret_rms = deepcopy(vn_template.ret_rms)
|
||||
|
||||
obs = vn.reset()
|
||||
inner = vn.envs[0]
|
||||
done = False
|
||||
|
||||
dog_xs, dog_ys = [], []
|
||||
sheep_xs = [[] for _ in range(n_sheep)]
|
||||
sheep_ys = [[] for _ in range(n_sheep)]
|
||||
radii = []
|
||||
pen_dists = [[] for _ in range(n_sheep)]
|
||||
action_mags = []
|
||||
rewards = []
|
||||
penned_at = [None] * n_sheep
|
||||
step = 0
|
||||
|
||||
while not done:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, reward, dones, infos = vn.step(action)
|
||||
done = dones[0]
|
||||
step += 1
|
||||
|
||||
dog_xs.append(float(inner.dog_pos[0]))
|
||||
dog_ys.append(float(inner.dog_pos[1]))
|
||||
com, radius, _ = inner._flock_stats()
|
||||
radii.append(radius)
|
||||
rewards.append(float(reward[0]))
|
||||
action_mags.append(float(np.linalg.norm(action[0])))
|
||||
|
||||
for i in range(n_sheep):
|
||||
sheep_xs[i].append(float(inner.sheep_pos[i][0]))
|
||||
sheep_ys[i].append(float(inner.sheep_pos[i][1]))
|
||||
pen_dists[i].append(
|
||||
float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
|
||||
if inner.penned[i] and penned_at[i] is None:
|
||||
penned_at[i] = step
|
||||
|
||||
n_penned = infos[0].get("n_penned", 0)
|
||||
vn.close()
|
||||
|
||||
return dict(
|
||||
dog_xs=dog_xs, dog_ys=dog_ys,
|
||||
sheep_xs=sheep_xs, sheep_ys=sheep_ys,
|
||||
radii=radii, pen_dists=pen_dists,
|
||||
action_mags=action_mags, rewards=rewards,
|
||||
penned_at=penned_at,
|
||||
n_penned=n_penned, n_sheep=n_sheep,
|
||||
success=n_penned == n_sheep, steps=step,
|
||||
)
|
||||
|
||||
|
||||
def plot_trajectory(hist, out_path):
|
||||
fig, ax = plt.subplots(figsize=(7, 7))
|
||||
_draw_field(ax)
|
||||
for i in range(hist["n_sheep"]):
|
||||
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||
xs, ys = hist["sheep_xs"][i], hist["sheep_ys"][i]
|
||||
_faded_path(ax, xs, ys, c, lw=1.2, label=f"sheep {i+1}")
|
||||
ax.plot(xs[0], ys[0], "o", color=c, ms=7, zorder=4)
|
||||
end = hist["penned_at"][i] if hist["penned_at"][i] is not None else -1
|
||||
ax.plot(xs[end], ys[end], "*", color=c, ms=11, zorder=5)
|
||||
_faded_path(ax, hist["dog_xs"], hist["dog_ys"], DOG_COLOR, lw=2.0,
|
||||
label="dog")
|
||||
ax.plot(hist["dog_xs"][0], hist["dog_ys"][0], "s", color=DOG_COLOR,
|
||||
ms=10, zorder=5)
|
||||
ax.plot(hist["dog_xs"][-1], hist["dog_ys"][-1], "D", color=DOG_COLOR,
|
||||
ms=10, zorder=5)
|
||||
result = ("SUCCESS" if hist["success"]
|
||||
else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})")
|
||||
ax.set_title(f"n={hist['n_sheep']} {result} {hist['steps']} steps",
|
||||
fontsize=12)
|
||||
ax.legend(loc="upper left", fontsize=8)
|
||||
plt.tight_layout()
|
||||
fig.savefig(out_path, dpi=120)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_timeseries(hist, out_path):
|
||||
t = np.arange(hist["steps"])
|
||||
fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
|
||||
|
||||
axes[0].plot(t, hist["radii"], color="steelblue")
|
||||
axes[0].axhline(5.0, color="orange", ls="--", lw=1, label="compact (5m)")
|
||||
axes[0].set_ylabel("flock radius (m)")
|
||||
axes[0].legend(fontsize=8)
|
||||
axes[0].set_title("Flock radius")
|
||||
|
||||
for i in range(hist["n_sheep"]):
|
||||
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||
axes[1].plot(t, hist["pen_dists"][i], color=c, lw=1,
|
||||
label=f"sheep {i+1}")
|
||||
if hist["penned_at"][i] is not None:
|
||||
axes[1].axvline(hist["penned_at"][i], color=c, ls=":", lw=1)
|
||||
axes[1].set_ylabel("dist to pen (m)")
|
||||
axes[1].legend(fontsize=7, ncol=min(hist["n_sheep"], 5))
|
||||
axes[1].set_title("Per-sheep distance to pen")
|
||||
|
||||
axes[2].plot(t, hist["action_mags"], color="tomato", lw=1)
|
||||
axes[2].axhline(1.0, color="gray", ls="--", lw=1, label="max")
|
||||
axes[2].set_ylabel("action ||(vx,vy)||")
|
||||
axes[2].set_ylim(0, 1.5)
|
||||
axes[2].set_title("Dog action magnitude")
|
||||
axes[2].legend(fontsize=8)
|
||||
|
||||
axes[3].plot(t, hist["rewards"], color="purple", lw=1, alpha=0.7)
|
||||
axes[3].axhline(0, color="black", lw=0.5)
|
||||
axes[3].set_ylabel("reward")
|
||||
axes[3].set_xlabel("step")
|
||||
axes[3].set_title("Reward per step")
|
||||
|
||||
result = ("SUCCESS" if hist["success"]
|
||||
else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})")
|
||||
fig.suptitle(f"n_sheep={hist['n_sheep']} {result} {hist['steps']} steps",
|
||||
fontsize=13)
|
||||
plt.tight_layout()
|
||||
fig.savefig(out_path, dpi=120)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_success_rate(stage_results, out_path):
|
||||
fig, ax = plt.subplots(figsize=(8, 4))
|
||||
ns = [r["n_sheep"] for r in stage_results]
|
||||
srs = [r["sr"] * 100 for r in stage_results]
|
||||
bars = ax.bar(ns, srs, color="steelblue", edgecolor="white")
|
||||
ax.set_xlabel("Sheep count")
|
||||
ax.set_ylabel("Success rate (%)")
|
||||
ax.set_ylim(0, 105)
|
||||
ax.axhline(90, color="orange", ls="--", lw=1, label="90% target")
|
||||
for bar, sr in zip(bars, srs):
|
||||
ax.text(bar.get_x() + bar.get_width() / 2,
|
||||
bar.get_height() + 1, f"{sr:.0f}%",
|
||||
ha="center", fontsize=9)
|
||||
ax.legend()
|
||||
ax.set_title("Evaluation success rate per sheep count")
|
||||
plt.tight_layout()
|
||||
fig.savefig(out_path, dpi=120)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# ── CLI ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"W_PER_SHEEP": 2.0,
|
||||
"W_ALIGN": 0.05,
|
||||
"W_PEN_BONUS": 10.0,
|
||||
"W_COMPLETE": 100.0,
|
||||
"W_STEP_COST": 0.02,
|
||||
"W_COMPACT": 0.0,
|
||||
"W_WALL_TOUCH": 0.15,
|
||||
"WALL_TOUCH_BUFFER": 0.8,
|
||||
"ALIGN_SHAPE": "standoff",
|
||||
"ALIGN_GATED": True,
|
||||
"ENTRY_AWARE": False,
|
||||
"ent_coef": 0.02,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--n-sheep", type=int, default=1,
|
||||
help="Starting sheep count")
|
||||
p.add_argument("--max-sheep", type=int, default=5,
|
||||
help="Final sheep count for curriculum")
|
||||
p.add_argument("--n-envs", type=int, default=8,
|
||||
help="Parallel training environments")
|
||||
p.add_argument("--total-steps", type=int, default=5_000_000)
|
||||
p.add_argument("--max-steps", type=int, default=2000,
|
||||
help="Episode step limit")
|
||||
p.add_argument("--curriculum", action="store_true",
|
||||
help="Enable curriculum advancement")
|
||||
p.add_argument("--steps-per-stage", type=int, default=None,
|
||||
help="Advance curriculum every N steps (overrides --threshold)")
|
||||
p.add_argument("--threshold", type=float, default=0.75,
|
||||
help="Success-rate threshold to advance (used without --steps-per-stage)")
|
||||
p.add_argument("--resume", type=str, default=None,
|
||||
help="Checkpoint .zip to resume from")
|
||||
p.add_argument("--run-dir", type=str, default="runs/ppo_herding")
|
||||
p.add_argument("--save-freq", type=int, default=100_000)
|
||||
p.add_argument("--eval-freq", type=int, default=50_000)
|
||||
p.add_argument("--eval-eps", type=int, default=20)
|
||||
p.add_argument("--diag-freq", type=int, default=500_000,
|
||||
help="Run failure-mode diagnostics every N env steps")
|
||||
p.add_argument("--no-stall-abort", action="store_true",
|
||||
help="Disable early-abort on stall — run full --total-steps "
|
||||
"for diagnostics")
|
||||
p.add_argument("--mixed", action="store_true",
|
||||
help="Randomise n_sheep each episode (consolidation pass, "
|
||||
"use with --resume after curriculum training)")
|
||||
p = argparse.ArgumentParser(
|
||||
description="PPO training for herding task with curriculum learning")
|
||||
p.add_argument("--config", type=str, default=None,
|
||||
help="JSON config file (reward weights + ent_coef)")
|
||||
p.add_argument("--max-sheep", type=int, default=10)
|
||||
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
|
||||
p.add_argument("--n-envs", type=int, default=8)
|
||||
p.add_argument("--max-steps", type=int, default=2500)
|
||||
p.add_argument("--eval-episodes", type=int, default=30)
|
||||
p.add_argument("--run-dir", type=str, default=None)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
os.makedirs(args.run_dir, exist_ok=True)
|
||||
ckpt_dir = os.path.join(args.run_dir, "checkpoints")
|
||||
best_dir = os.path.join(args.run_dir, "best_model")
|
||||
norm_path = os.path.join(args.run_dir, "vecnorm.pkl")
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
|
||||
# Load config
|
||||
cfg = dict(DEFAULT_CONFIG)
|
||||
if args.config:
|
||||
with open(args.config) as f:
|
||||
cfg.update(json.load(f))
|
||||
|
||||
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
|
||||
|
||||
# Run directory
|
||||
run_dir = args.run_dir or os.path.join(
|
||||
"runs", time.strftime("%Y%m%d_%H%M%S"))
|
||||
eval_dir = os.path.join(run_dir, "eval")
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
with open(os.path.join(run_dir, "config.json"), "w") as f:
|
||||
json.dump(cfg, f, indent=2)
|
||||
|
||||
print(f"Config: {cfg}")
|
||||
print(f"Run dir: {run_dir}")
|
||||
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
|
||||
f"{args.steps_per_stage:,} steps/stage\n")
|
||||
|
||||
# Training envs
|
||||
train_env = SubprocVecEnv([
|
||||
make_env(args.n_sheep, seed=i, max_steps=args.max_steps,
|
||||
random_n_sheep=args.mixed)
|
||||
make_env(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
|
||||
for i in range(args.n_envs)
|
||||
])
|
||||
if args.resume and os.path.exists(norm_path):
|
||||
train_env = VecNormalize.load(norm_path, train_env)
|
||||
train_env.training = True
|
||||
train_env.norm_reward = True
|
||||
else:
|
||||
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True,
|
||||
clip_obs=10.0)
|
||||
|
||||
# Eval env — starts at same difficulty, advances with curriculum callback
|
||||
eval_env = SubprocVecEnv([
|
||||
make_env(args.n_sheep, seed=1000 + i, max_steps=args.max_steps)
|
||||
for i in range(2)
|
||||
])
|
||||
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False,
|
||||
clip_obs=10.0, training=False)
|
||||
|
||||
# Callbacks
|
||||
checkpoint_cb = CheckpointCallback(
|
||||
save_freq=max(args.save_freq // args.n_envs, 1),
|
||||
save_path=ckpt_dir,
|
||||
name_prefix="ckpt",
|
||||
save_vecnormalize=True,
|
||||
)
|
||||
eval_cb = EvalCallback(
|
||||
eval_env,
|
||||
best_model_save_path=best_dir,
|
||||
log_path=args.run_dir,
|
||||
eval_freq=max(args.eval_freq // args.n_envs, 1),
|
||||
n_eval_episodes=args.eval_eps,
|
||||
deterministic=True,
|
||||
verbose=1,
|
||||
)
|
||||
diag_cb = DiagnosticCallback(
|
||||
diag_freq=args.diag_freq,
|
||||
n_episodes=20,
|
||||
max_steps=args.max_steps,
|
||||
abort_on_stall=not args.no_stall_abort,
|
||||
)
|
||||
callbacks = [checkpoint_cb, eval_cb, diag_cb]
|
||||
|
||||
if args.curriculum:
|
||||
cur_cb = CurriculumCallback(
|
||||
start_sheep=args.n_sheep,
|
||||
max_sheep=args.max_sheep,
|
||||
eval_env=eval_env,
|
||||
steps_per_stage=args.steps_per_stage,
|
||||
threshold=args.threshold,
|
||||
)
|
||||
callbacks.append(cur_cb)
|
||||
|
||||
callback_list = CallbackList(callbacks)
|
||||
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True,
|
||||
clip_obs=10.0)
|
||||
|
||||
# Model
|
||||
ppo_kwargs = dict(
|
||||
policy = "MlpPolicy",
|
||||
env = train_env,
|
||||
learning_rate = 3e-4,
|
||||
n_steps = 2048,
|
||||
batch_size = 256,
|
||||
n_epochs = 10,
|
||||
gamma = 0.995,
|
||||
gae_lambda = 0.95,
|
||||
clip_range = 0.2,
|
||||
ent_coef = 0.01,
|
||||
vf_coef = 0.5,
|
||||
max_grad_norm = 0.5,
|
||||
policy_kwargs = dict(net_arch=[256, 256]),
|
||||
tensorboard_log = args.run_dir,
|
||||
verbose = 1,
|
||||
model = PPO(
|
||||
"MlpPolicy", vn,
|
||||
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
|
||||
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
|
||||
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
|
||||
policy_kwargs=dict(net_arch=[256, 256]),
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
if args.resume:
|
||||
print(f"Resuming from {args.resume}")
|
||||
model = PPO.load(args.resume, env=train_env, **{
|
||||
k: v for k, v in ppo_kwargs.items()
|
||||
if k not in ("policy", "env")
|
||||
})
|
||||
else:
|
||||
model = PPO(**ppo_kwargs)
|
||||
# Curriculum training
|
||||
stage_results = []
|
||||
t0 = time.time()
|
||||
|
||||
model.learn(
|
||||
total_timesteps=args.total_steps,
|
||||
callback=callback_list,
|
||||
reset_num_timesteps=args.resume is None,
|
||||
tb_log_name="ppo",
|
||||
)
|
||||
try:
|
||||
for n in range(1, args.max_sheep + 1):
|
||||
if n > 1:
|
||||
vn.env_method("set_n_sheep", n)
|
||||
|
||||
model.save(os.path.join(args.run_dir, "final_model"))
|
||||
train_env.save(norm_path)
|
||||
print(f"\nTraining complete. Artefacts saved to {args.run_dir}/")
|
||||
print(f"\n[Stage n_sheep={n}] training {args.steps_per_stage:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=args.steps_per_stage,
|
||||
reset_num_timesteps=(n == 1),
|
||||
callback=ProgressCallback(f"{n} sheep", freq=100_000),
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
|
||||
r = evaluate(model, vn, n, args.eval_episodes, args.max_steps, rcfg)
|
||||
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
|
||||
f"mean_len={r['mean_len']:.0f} "
|
||||
f"mean_min_pen={r['mean_min_pen']:.1f}m "
|
||||
f"mean_act={r['mean_act']:.2f}")
|
||||
|
||||
# Failure-mode breakdown
|
||||
if r["failure_modes"]:
|
||||
modes = " ".join(
|
||||
f"{k}={v}" for k, v in sorted(
|
||||
r["failure_modes"].items(), key=lambda x: -x[1]))
|
||||
print(f" failure modes: {modes}")
|
||||
|
||||
# Reward breakdown
|
||||
if "reward_per_step" in r:
|
||||
rps = r["reward_per_step"]
|
||||
print(f" reward/step: " + " ".join(
|
||||
f"{k}={v:+.4f}" for k, v in rps.items()))
|
||||
|
||||
# Episode visualization
|
||||
hist = run_and_record(model, vn, n, args.max_steps, rcfg,
|
||||
seed=1000 + n)
|
||||
tag = "success" if hist["success"] else "fail"
|
||||
plot_trajectory(
|
||||
hist,
|
||||
os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
|
||||
plot_timeseries(
|
||||
hist,
|
||||
os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
|
||||
|
||||
r["n_sheep"] = n
|
||||
stage_results.append(r)
|
||||
|
||||
# Save artefacts
|
||||
model.save(os.path.join(run_dir, "final_model"))
|
||||
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
|
||||
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
|
||||
json.dump(stage_results, f, indent=2)
|
||||
|
||||
finally:
|
||||
try:
|
||||
vn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Summary
|
||||
elapsed = (time.time() - t0) / 60
|
||||
print("\n" + "=" * 70)
|
||||
print(" TRAINING SUMMARY")
|
||||
print("=" * 70)
|
||||
for r in stage_results:
|
||||
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
|
||||
f"len={r['mean_len']:>5.0f} min_pen={r['mean_min_pen']:>5.1f}m "
|
||||
f"act={r['mean_act']:.2f}")
|
||||
print(f"\n Total time: {elapsed:.1f} min")
|
||||
print(f" Artefacts: {run_dir}/")
|
||||
|
||||
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
|
||||
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user