Files
TIR_PROJ/training/train.py
T
2026-04-26 14:55:13 +01:00

369 lines
13 KiB
Python

"""
PPO training for the herding task with curriculum learning.
Trains from scratch through a 1→max_sheep curriculum, evaluates after each
stage, and auto-generates trajectory/timeseries plots plus a summary chart.
Usage
-----
python train.py # defaults from config.json
python train.py --config my_config.json --max-sheep 5
python train.py --max-sheep 3 --steps-per-stage 1000000
Outputs (in runs/<timestamp>/):
config.json resolved config
final_model.zip trained PPO model
vecnorm.pkl VecNormalize statistics
stage_results.json per-stage evaluation metrics
success_rate.png summary bar chart
eval/ trajectory & timeseries plots per sheep count
"""
import argparse
import json
import os
import time
from copy import deepcopy
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import (
DummyVecEnv,
SubprocVecEnv,
VecNormalize,
)
from herding_env import HerdingEnv
from viz import (
run_and_record,
plot_trajectory,
plot_timeseries,
plot_success_rate,
save_episode_gif,
)
# ── Callbacks ────────────────────────────────────────────────────────────────
class ProgressCallback(BaseCallback):
"""One-line progress summary every `freq` env steps."""
def __init__(self, stage_label: str, freq: int = 100_000):
super().__init__()
self.stage_label = stage_label
self.freq = freq
self._last = 0
self._ep_returns = []
self._ep_success = []
self._total_eps = 0
self._total_success = 0
self._cur_ret = None
def _on_step(self) -> bool:
rewards = self.locals.get("rewards")
dones = self.locals.get("dones")
infos = self.locals.get("infos", [])
if rewards is None or dones is None:
return True
if self._cur_ret is None or len(self._cur_ret) != len(rewards):
self._cur_ret = np.zeros(len(rewards), dtype=np.float64)
self._cur_ret += np.asarray(rewards, dtype=np.float64)
for i, d in enumerate(dones):
if not d:
continue
self._ep_returns.append(float(self._cur_ret[i]))
info = infos[i] if i < len(infos) else {}
success = int(info.get("n_penned", 0) == info.get("n_sheep", -1))
self._ep_success.append(success)
self._total_eps += 1
self._total_success += success
self._cur_ret[i] = 0.0
if len(self._ep_returns) > 50:
self._ep_returns.pop(0)
self._ep_success.pop(0)
if self.num_timesteps - self._last >= self.freq:
self._last = self.num_timesteps
n = len(self._ep_returns)
mean_r = float(np.mean(self._ep_returns)) if n else float("nan")
win_sr = float(np.mean(self._ep_success)) if n else float("nan")
cum_sr = (self._total_success / self._total_eps
if self._total_eps else float("nan"))
print(f" ... [{self.stage_label} | "
f"{self.num_timesteps:>7,} steps | "
f"ret(last {n})={mean_r:+.2f} "
f"win_sr={win_sr*100:.0f}% cum_sr={cum_sr*100:.0f}%]",
flush=True)
return True
# ── Environment factory ──────────────────────────────────────────────────────
def make_env(n_sheep, seed, max_steps, reward_cfg=None):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
reward_cfg=reward_cfg)
env.reset(seed=seed)
return env
return _init
# ── Failure-mode classification ──────────────────────────────────────────────
COMPACT_RADIUS = 5.0
def _classify(ep_radii, ep_com_dists, n_penned, n_sheep):
if n_penned == n_sheep:
return "SUCCESS"
if min(ep_radii) > COMPACT_RADIUS:
return "NEVER_COMPACT"
first = next(i for i, r in enumerate(ep_radii) if r <= COMPACT_RADIUS)
if min(ep_com_dists[first:]) > 3.0:
return "COMPACT_CANT_DRIVE"
if n_penned == 0:
return "DROVE_NO_SHEEP"
return f"PARTIAL_{n_penned}of{n_sheep}"
# ── Evaluation ───────────────────────────────────────────────────────────────
def evaluate(model, vn_template, n_sheep, n_episodes, max_steps,
reward_cfg=None):
"""Evaluate at a given sheep count; returns metrics dict."""
raw = DummyVecEnv([make_env(n_sheep, 9999, max_steps, reward_cfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
successes = 0
ep_lens = []
min_pen_list = []
action_mags = []
failure_counts = {}
rc_sums = {}
rc_n = 0
for _ in range(n_episodes):
obs = vn.reset()
done = False
steps = 0
min_pen = float("inf")
mags = []
ep_radii = []
ep_com_dists = []
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = vn.step(action)
done = dones[0]
inner = vn.envs[0]
com, radius, _ = inner._flock_stats()
min_pen = min(min_pen, float(np.linalg.norm(com - inner.PEN_CENTER)))
mags.append(float(np.linalg.norm(action[0])))
ep_radii.append(radius)
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
steps += 1
rc = infos[0].get("rcomps")
if rc:
for k, v in rc.items():
rc_sums[k] = rc_sums.get(k, 0.0) + v
rc_n += 1
n_penned = infos[0].get("n_penned", 0)
success = n_penned == n_sheep
successes += int(success)
ep_lens.append(steps)
min_pen_list.append(min_pen)
action_mags.extend(mags)
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
vn.close()
result = {
"sr": successes / n_episodes,
"mean_len": float(np.mean(ep_lens)),
"mean_min_pen": float(np.mean(min_pen_list)),
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
"failure_modes": failure_counts,
}
if rc_n > 0:
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
return result
# ── CLI ──────────────────────────────────────────────────────────────────────
DEFAULT_CONFIG = {
"W_PER_SHEEP": 2.0,
"W_ALIGN": 0.05,
"W_PEN_BONUS": 10.0,
"W_COMPLETE": 100.0,
"W_STEP_COST": 0.02,
"W_SOUTH": 0.01,
"W_COMPACT": 0.0,
"W_WALL_TOUCH": 0.04,
"WALL_TOUCH_BUFFER": 0.3,
"ALIGN_SHAPE": "standoff",
"ALIGN_GATED": True,
"ENTRY_AWARE": True,
"ent_coef": 0.02,
}
def parse_args():
p = argparse.ArgumentParser(
description="PPO training for herding task with curriculum learning")
p.add_argument("--config", type=str, default=None,
help="JSON config file (reward weights + ent_coef)")
p.add_argument("--max-sheep", type=int, default=10)
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
p.add_argument("--n-envs", type=int, default=8)
p.add_argument("--max-steps", type=int, default=2500)
p.add_argument("--eval-episodes", type=int, default=30)
p.add_argument("--run-dir", type=str, default=None)
p.add_argument("--no-gif", action="store_true",
help="Skip per-stage GIF rendering (PNGs still produced).")
p.add_argument("--gif-fps", type=int, default=20)
p.add_argument("--gif-skip", type=int, default=3,
help="Keep every Nth frame (smaller GIF; default 3).")
return p.parse_args()
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
# Load config: --config overrides, else auto-load config.json if present
cfg = dict(DEFAULT_CONFIG)
config_path = args.config
if config_path is None and os.path.exists("config.json"):
config_path = "config.json"
if config_path:
with open(config_path) as f:
cfg.update(json.load(f))
print(f"Config loaded from {config_path}")
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
# Run directory
run_dir = args.run_dir or os.path.join(
"runs", time.strftime("%Y%m%d_%H%M%S"))
eval_dir = os.path.join(run_dir, "eval")
os.makedirs(eval_dir, exist_ok=True)
with open(os.path.join(run_dir, "config.json"), "w") as f:
json.dump(cfg, f, indent=2)
print(f"Config: {cfg}")
print(f"Run dir: {run_dir}")
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
f"{args.steps_per_stage:,} steps/stage\n")
# Training envs
train_env = SubprocVecEnv([
make_env(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
for i in range(args.n_envs)
])
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True,
clip_obs=10.0)
# Model — force CPU (PPO with MLP runs faster on CPU than GPU; SB3 warns
# about this otherwise).
model = PPO(
"MlpPolicy", vn,
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
policy_kwargs=dict(net_arch=[256, 256]),
device="cpu",
verbose=0,
)
# Curriculum training
stage_results = []
t0 = time.time()
try:
for n in range(1, args.max_sheep + 1):
if n > 1:
vn.env_method("set_n_sheep", n)
print(f"\n[Stage n_sheep={n}] training {args.steps_per_stage:,} steps")
model.learn(
total_timesteps=args.steps_per_stage,
reset_num_timesteps=(n == 1),
callback=ProgressCallback(f"{n} sheep", freq=100_000),
)
# Evaluate
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
r = evaluate(model, vn, n, args.eval_episodes, args.max_steps, rcfg)
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
f"mean_len={r['mean_len']:.0f} "
f"mean_min_pen={r['mean_min_pen']:.1f}m "
f"mean_act={r['mean_act']:.2f}")
# Failure-mode breakdown
if r["failure_modes"]:
modes = " ".join(
f"{k}={v}" for k, v in sorted(
r["failure_modes"].items(), key=lambda x: -x[1]))
print(f" failure modes: {modes}")
# Reward breakdown
if "reward_per_step" in r:
rps = r["reward_per_step"]
print(f" reward/step: " + " ".join(
f"{k}={v:+.4f}" for k, v in rps.items()))
# Episode visualisation: trajectory + timeseries + animated GIF
hist = run_and_record(model, vn, n, args.max_steps, rcfg,
seed=1000 + n)
tag = "success" if hist["success"] else "fail"
plot_trajectory(
hist,
os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
plot_timeseries(
hist,
os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
if not args.no_gif:
save_episode_gif(
hist,
os.path.join(eval_dir, f"ep_{n}s_{tag}.gif"),
fps=args.gif_fps, skip=args.gif_skip)
r["n_sheep"] = n
stage_results.append(r)
# Save artefacts
model.save(os.path.join(run_dir, "final_model"))
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
json.dump(stage_results, f, indent=2)
finally:
try:
vn.close()
except Exception:
pass
# Summary
elapsed = (time.time() - t0) / 60
print("\n" + "=" * 70)
print(" TRAINING SUMMARY")
print("=" * 70)
for r in stage_results:
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
f"len={r['mean_len']:>5.0f} min_pen={r['mean_min_pen']:>5.1f}m "
f"act={r['mean_act']:.2f}")
print(f"\n Total time: {elapsed:.1f} min")
print(f" Artefacts: {run_dir}/")
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
if __name__ == "__main__":
main()