Files
TIR_PROJ/training/train.py
T
2026-04-26 01:50:01 +01:00

531 lines
19 KiB
Python

"""
PPO training for the herding task with curriculum learning.
Trains from scratch through a 1→max_sheep curriculum, evaluates after each
stage, and auto-generates trajectory/timeseries plots plus a summary chart.
Usage
-----
python train.py # defaults from config.json
python train.py --config my_config.json --max-sheep 5
python train.py --max-sheep 3 --steps-per-stage 1000000
Outputs (in runs/<timestamp>/):
config.json resolved config
final_model.zip trained PPO model
vecnorm.pkl VecNormalize statistics
stage_results.json per-stage evaluation metrics
success_rate.png summary bar chart
eval/ trajectory & timeseries plots per sheep count
"""
import argparse
import json
import os
import time
from copy import deepcopy
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from matplotlib.collections import LineCollection
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import (
DummyVecEnv,
SubprocVecEnv,
VecNormalize,
)
from herding_env import HerdingEnv
# ── Colours ──────────────────────────────────────────────────────────────────
SHEEP_COLORS = [
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00",
"#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62",
]
DOG_COLOR = "#4e342e"
# ── Callbacks ────────────────────────────────────────────────────────────────
class ProgressCallback(BaseCallback):
"""One-line progress summary every `freq` env steps."""
def __init__(self, stage_label: str, freq: int = 100_000):
super().__init__()
self.stage_label = stage_label
self.freq = freq
self._last = 0
self._ep_returns = []
self._ep_success = []
self._total_eps = 0
self._total_success = 0
self._cur_ret = None
def _on_step(self) -> bool:
rewards = self.locals.get("rewards")
dones = self.locals.get("dones")
infos = self.locals.get("infos", [])
if rewards is None or dones is None:
return True
if self._cur_ret is None or len(self._cur_ret) != len(rewards):
self._cur_ret = np.zeros(len(rewards), dtype=np.float64)
self._cur_ret += np.asarray(rewards, dtype=np.float64)
for i, d in enumerate(dones):
if not d:
continue
self._ep_returns.append(float(self._cur_ret[i]))
info = infos[i] if i < len(infos) else {}
success = int(info.get("n_penned", 0) == info.get("n_sheep", -1))
self._ep_success.append(success)
self._total_eps += 1
self._total_success += success
self._cur_ret[i] = 0.0
if len(self._ep_returns) > 50:
self._ep_returns.pop(0)
self._ep_success.pop(0)
if self.num_timesteps - self._last >= self.freq:
self._last = self.num_timesteps
n = len(self._ep_returns)
mean_r = float(np.mean(self._ep_returns)) if n else float("nan")
win_sr = float(np.mean(self._ep_success)) if n else float("nan")
cum_sr = (self._total_success / self._total_eps
if self._total_eps else float("nan"))
print(f" ... [{self.stage_label} | "
f"{self.num_timesteps:>7,} steps | "
f"ret(last {n})={mean_r:+.2f} "
f"win_sr={win_sr*100:.0f}% cum_sr={cum_sr*100:.0f}%]",
flush=True)
return True
# ── Environment factory ──────────────────────────────────────────────────────
def make_env(n_sheep, seed, max_steps, reward_cfg=None):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
reward_cfg=reward_cfg)
env.reset(seed=seed)
return env
return _init
# ── Failure-mode classification ──────────────────────────────────────────────
COMPACT_RADIUS = 5.0
def _classify(ep_radii, ep_com_dists, n_penned, n_sheep):
if n_penned == n_sheep:
return "SUCCESS"
if min(ep_radii) > COMPACT_RADIUS:
return "NEVER_COMPACT"
first = next(i for i, r in enumerate(ep_radii) if r <= COMPACT_RADIUS)
if min(ep_com_dists[first:]) > 3.0:
return "COMPACT_CANT_DRIVE"
if n_penned == 0:
return "DROVE_NO_SHEEP"
return f"PARTIAL_{n_penned}of{n_sheep}"
# ── Evaluation ───────────────────────────────────────────────────────────────
def evaluate(model, vn_template, n_sheep, n_episodes, max_steps,
reward_cfg=None):
"""Evaluate at a given sheep count; returns metrics dict."""
raw = DummyVecEnv([make_env(n_sheep, 9999, max_steps, reward_cfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
successes = 0
ep_lens = []
min_pen_list = []
action_mags = []
failure_counts = {}
rc_sums = {}
rc_n = 0
for _ in range(n_episodes):
obs = vn.reset()
done = False
steps = 0
min_pen = float("inf")
mags = []
ep_radii = []
ep_com_dists = []
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = vn.step(action)
done = dones[0]
inner = vn.envs[0]
com, radius, _ = inner._flock_stats()
min_pen = min(min_pen, float(np.linalg.norm(com - inner.PEN_CENTER)))
mags.append(float(np.linalg.norm(action[0])))
ep_radii.append(radius)
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
steps += 1
rc = infos[0].get("rcomps")
if rc:
for k, v in rc.items():
rc_sums[k] = rc_sums.get(k, 0.0) + v
rc_n += 1
n_penned = infos[0].get("n_penned", 0)
success = n_penned == n_sheep
successes += int(success)
ep_lens.append(steps)
min_pen_list.append(min_pen)
action_mags.extend(mags)
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
vn.close()
result = {
"sr": successes / n_episodes,
"mean_len": float(np.mean(ep_lens)),
"mean_min_pen": float(np.mean(min_pen_list)),
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
"failure_modes": failure_counts,
}
if rc_n > 0:
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
return result
# ── Visualization helpers ────────────────────────────────────────────────────
def _draw_field(ax):
ax.set_xlim(-16, 16)
ax.set_ylim(-16, 16)
ax.set_aspect("equal")
ax.set_facecolor("#dcedc8")
ax.add_patch(mpatches.Rectangle((-15, -15), 30, 30,
fill=False, edgecolor="#795548", lw=2))
ax.add_patch(mpatches.Rectangle((10, -15), 3, 7,
facecolor="#ffe082", edgecolor="#795548", lw=2))
ax.text(11.5, -11.5, "pen", ha="center", va="center",
fontsize=8, color="#795548")
def _faded_path(ax, xs, ys, color, lw=1.5, label=None):
n = len(xs)
if n < 2:
return
points = np.array([xs, ys]).T.reshape(-1, 1, 2)
segs = np.concatenate([points[:-1], points[1:]], axis=1)
alphas = np.linspace(0.15, 1.0, len(segs))
colors = [(*matplotlib.colors.to_rgb(color), a) for a in alphas]
ax.add_collection(LineCollection(segs, colors=colors, linewidth=lw))
if label:
ax.plot([], [], color=color, lw=lw, label=label)
def run_and_record(model, vn_template, n_sheep, max_steps,
reward_cfg=None, seed=42):
"""Run one deterministic episode and return full history."""
raw = DummyVecEnv([make_env(n_sheep, seed, max_steps, reward_cfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
obs = vn.reset()
inner = vn.envs[0]
done = False
dog_xs, dog_ys = [], []
sheep_xs = [[] for _ in range(n_sheep)]
sheep_ys = [[] for _ in range(n_sheep)]
radii = []
pen_dists = [[] for _ in range(n_sheep)]
action_mags = []
rewards = []
penned_at = [None] * n_sheep
step = 0
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, dones, infos = vn.step(action)
done = dones[0]
step += 1
dog_xs.append(float(inner.dog_pos[0]))
dog_ys.append(float(inner.dog_pos[1]))
com, radius, _ = inner._flock_stats()
radii.append(radius)
rewards.append(float(reward[0]))
action_mags.append(float(np.linalg.norm(action[0])))
for i in range(n_sheep):
sheep_xs[i].append(float(inner.sheep_pos[i][0]))
sheep_ys[i].append(float(inner.sheep_pos[i][1]))
pen_dists[i].append(
float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
if inner.penned[i] and penned_at[i] is None:
penned_at[i] = step
n_penned = infos[0].get("n_penned", 0)
vn.close()
return dict(
dog_xs=dog_xs, dog_ys=dog_ys,
sheep_xs=sheep_xs, sheep_ys=sheep_ys,
radii=radii, pen_dists=pen_dists,
action_mags=action_mags, rewards=rewards,
penned_at=penned_at,
n_penned=n_penned, n_sheep=n_sheep,
success=n_penned == n_sheep, steps=step,
)
def plot_trajectory(hist, out_path):
fig, ax = plt.subplots(figsize=(7, 7))
_draw_field(ax)
for i in range(hist["n_sheep"]):
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
xs, ys = hist["sheep_xs"][i], hist["sheep_ys"][i]
_faded_path(ax, xs, ys, c, lw=1.2, label=f"sheep {i+1}")
ax.plot(xs[0], ys[0], "o", color=c, ms=7, zorder=4)
end = hist["penned_at"][i] if hist["penned_at"][i] is not None else -1
ax.plot(xs[end], ys[end], "*", color=c, ms=11, zorder=5)
_faded_path(ax, hist["dog_xs"], hist["dog_ys"], DOG_COLOR, lw=2.0,
label="dog")
ax.plot(hist["dog_xs"][0], hist["dog_ys"][0], "s", color=DOG_COLOR,
ms=10, zorder=5)
ax.plot(hist["dog_xs"][-1], hist["dog_ys"][-1], "D", color=DOG_COLOR,
ms=10, zorder=5)
result = ("SUCCESS" if hist["success"]
else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})")
ax.set_title(f"n={hist['n_sheep']} {result} {hist['steps']} steps",
fontsize=12)
ax.legend(loc="upper left", fontsize=8)
plt.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
def plot_timeseries(hist, out_path):
t = np.arange(hist["steps"])
fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
axes[0].plot(t, hist["radii"], color="steelblue")
axes[0].axhline(5.0, color="orange", ls="--", lw=1, label="compact (5m)")
axes[0].set_ylabel("flock radius (m)")
axes[0].legend(fontsize=8)
axes[0].set_title("Flock radius")
for i in range(hist["n_sheep"]):
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
axes[1].plot(t, hist["pen_dists"][i], color=c, lw=1,
label=f"sheep {i+1}")
if hist["penned_at"][i] is not None:
axes[1].axvline(hist["penned_at"][i], color=c, ls=":", lw=1)
axes[1].set_ylabel("dist to pen (m)")
axes[1].legend(fontsize=7, ncol=min(hist["n_sheep"], 5))
axes[1].set_title("Per-sheep distance to pen")
axes[2].plot(t, hist["action_mags"], color="tomato", lw=1)
axes[2].axhline(1.0, color="gray", ls="--", lw=1, label="max")
axes[2].set_ylabel("action ||(vx,vy)||")
axes[2].set_ylim(0, 1.5)
axes[2].set_title("Dog action magnitude")
axes[2].legend(fontsize=8)
axes[3].plot(t, hist["rewards"], color="purple", lw=1, alpha=0.7)
axes[3].axhline(0, color="black", lw=0.5)
axes[3].set_ylabel("reward")
axes[3].set_xlabel("step")
axes[3].set_title("Reward per step")
result = ("SUCCESS" if hist["success"]
else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})")
fig.suptitle(f"n_sheep={hist['n_sheep']} {result} {hist['steps']} steps",
fontsize=13)
plt.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
def plot_success_rate(stage_results, out_path):
fig, ax = plt.subplots(figsize=(8, 4))
ns = [r["n_sheep"] for r in stage_results]
srs = [r["sr"] * 100 for r in stage_results]
bars = ax.bar(ns, srs, color="steelblue", edgecolor="white")
ax.set_xlabel("Sheep count")
ax.set_ylabel("Success rate (%)")
ax.set_ylim(0, 105)
ax.axhline(90, color="orange", ls="--", lw=1, label="90% target")
for bar, sr in zip(bars, srs):
ax.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() + 1, f"{sr:.0f}%",
ha="center", fontsize=9)
ax.legend()
ax.set_title("Evaluation success rate per sheep count")
plt.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
# ── CLI ──────────────────────────────────────────────────────────────────────
DEFAULT_CONFIG = {
"W_PER_SHEEP": 2.0,
"W_ALIGN": 0.05,
"W_PEN_BONUS": 10.0,
"W_COMPLETE": 100.0,
"W_STEP_COST": 0.02,
"W_COMPACT": 0.0,
"W_WALL_TOUCH": 0.15,
"WALL_TOUCH_BUFFER": 0.8,
"ALIGN_SHAPE": "standoff",
"ALIGN_GATED": True,
"ENTRY_AWARE": False,
"ent_coef": 0.02,
}
def parse_args():
p = argparse.ArgumentParser(
description="PPO training for herding task with curriculum learning")
p.add_argument("--config", type=str, default=None,
help="JSON config file (reward weights + ent_coef)")
p.add_argument("--max-sheep", type=int, default=10)
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
p.add_argument("--n-envs", type=int, default=8)
p.add_argument("--max-steps", type=int, default=2500)
p.add_argument("--eval-episodes", type=int, default=30)
p.add_argument("--run-dir", type=str, default=None)
return p.parse_args()
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
# Load config
cfg = dict(DEFAULT_CONFIG)
if args.config:
with open(args.config) as f:
cfg.update(json.load(f))
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
# Run directory
run_dir = args.run_dir or os.path.join(
"runs", time.strftime("%Y%m%d_%H%M%S"))
eval_dir = os.path.join(run_dir, "eval")
os.makedirs(eval_dir, exist_ok=True)
with open(os.path.join(run_dir, "config.json"), "w") as f:
json.dump(cfg, f, indent=2)
print(f"Config: {cfg}")
print(f"Run dir: {run_dir}")
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
f"{args.steps_per_stage:,} steps/stage\n")
# Training envs
train_env = SubprocVecEnv([
make_env(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
for i in range(args.n_envs)
])
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True,
clip_obs=10.0)
# Model
model = PPO(
"MlpPolicy", vn,
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
policy_kwargs=dict(net_arch=[256, 256]),
verbose=0,
)
# Curriculum training
stage_results = []
t0 = time.time()
try:
for n in range(1, args.max_sheep + 1):
if n > 1:
vn.env_method("set_n_sheep", n)
print(f"\n[Stage n_sheep={n}] training {args.steps_per_stage:,} steps")
model.learn(
total_timesteps=args.steps_per_stage,
reset_num_timesteps=(n == 1),
callback=ProgressCallback(f"{n} sheep", freq=100_000),
)
# Evaluate
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
r = evaluate(model, vn, n, args.eval_episodes, args.max_steps, rcfg)
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
f"mean_len={r['mean_len']:.0f} "
f"mean_min_pen={r['mean_min_pen']:.1f}m "
f"mean_act={r['mean_act']:.2f}")
# Failure-mode breakdown
if r["failure_modes"]:
modes = " ".join(
f"{k}={v}" for k, v in sorted(
r["failure_modes"].items(), key=lambda x: -x[1]))
print(f" failure modes: {modes}")
# Reward breakdown
if "reward_per_step" in r:
rps = r["reward_per_step"]
print(f" reward/step: " + " ".join(
f"{k}={v:+.4f}" for k, v in rps.items()))
# Episode visualization
hist = run_and_record(model, vn, n, args.max_steps, rcfg,
seed=1000 + n)
tag = "success" if hist["success"] else "fail"
plot_trajectory(
hist,
os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
plot_timeseries(
hist,
os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
r["n_sheep"] = n
stage_results.append(r)
# Save artefacts
model.save(os.path.join(run_dir, "final_model"))
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
json.dump(stage_results, f, indent=2)
finally:
try:
vn.close()
except Exception:
pass
# Summary
elapsed = (time.time() - t0) / 60
print("\n" + "=" * 70)
print(" TRAINING SUMMARY")
print("=" * 70)
for r in stage_results:
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
f"len={r['mean_len']:>5.0f} min_pen={r['mean_min_pen']:>5.1f}m "
f"act={r['mean_act']:.2f}")
print(f"\n Total time: {elapsed:.1f} min")
print(f" Artefacts: {run_dir}/")
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
if __name__ == "__main__":
main()