366 lines
13 KiB
Python
366 lines
13 KiB
Python
"""
|
|
PPO training script for the herding task.
|
|
|
|
Usage examples
|
|
--------------
|
|
# Proper 5-sheep curriculum, 1 M steps per stage:
|
|
python train.py --curriculum --steps-per-stage 1000000 --total-steps 5000000
|
|
|
|
# Success-rate curriculum (advances when 70 % success over 100 episodes):
|
|
python train.py --curriculum --threshold 0.70
|
|
|
|
# Resume from checkpoint at stage 3:
|
|
python train.py --resume runs/ppo_herding/ckpt_3000000_steps.zip --n-sheep 3 \
|
|
--curriculum --steps-per-stage 1000000 --total-steps 5000000
|
|
|
|
# Quick smoke-test:
|
|
python train.py --n-envs 1 --total-steps 50000
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.callbacks import (
|
|
BaseCallback,
|
|
CallbackList,
|
|
CheckpointCallback,
|
|
EvalCallback,
|
|
)
|
|
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
|
|
|
|
from herding_env import HerdingEnv
|
|
|
|
COMPACT_RADIUS = 5.0
|
|
|
|
|
|
def _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success):
|
|
if success:
|
|
return "SUCCESS"
|
|
if min(ep_radius) > COMPACT_RADIUS:
|
|
return "NEVER_COMPACT"
|
|
first = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
|
|
if min(ep_com_dist[first:]) > 3.0:
|
|
return "COMPACT_CANT_DRIVE"
|
|
if n_penned == 0:
|
|
return "DROVE_NO_SHEEP"
|
|
return f"PARTIAL_{n_penned}of{n_sheep}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Curriculum callback
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class CurriculumCallback(BaseCallback):
|
|
"""
|
|
Advances n_sheep on both training and eval envs.
|
|
|
|
Two modes (mutually exclusive):
|
|
steps_per_stage — advance every N environment steps regardless of
|
|
success rate (recommended for reliability).
|
|
threshold — advance when rolling success rate exceeds this value
|
|
(requires the policy to actually reach the threshold).
|
|
"""
|
|
|
|
def __init__(self, start_sheep: int, max_sheep: int,
|
|
eval_env=None,
|
|
steps_per_stage: int = None,
|
|
threshold: float = 0.75,
|
|
window: int = 100,
|
|
min_episodes: int = 50,
|
|
verbose: int = 1):
|
|
super().__init__(verbose)
|
|
self.max_sheep = max_sheep
|
|
self.eval_env = eval_env
|
|
self.steps_per_stage = steps_per_stage
|
|
self.threshold = threshold
|
|
self.window = window
|
|
self.min_episodes = min_episodes
|
|
self._cur_sheep = start_sheep
|
|
self._successes = []
|
|
self._stage_start = 0
|
|
|
|
def _advance(self):
|
|
self._cur_sheep += 1
|
|
self.training_env.env_method("set_n_sheep", self._cur_sheep)
|
|
if self.eval_env is not None:
|
|
self.eval_env.env_method("set_n_sheep", self._cur_sheep)
|
|
self._stage_start = self.num_timesteps
|
|
self._successes.clear()
|
|
if self.verbose:
|
|
print(f"\n[Curriculum] → {self._cur_sheep} sheep "
|
|
f"at step {self.num_timesteps:,}\n")
|
|
|
|
def _on_step(self) -> bool:
|
|
if self._cur_sheep >= self.max_sheep:
|
|
return True
|
|
|
|
if self.steps_per_stage is not None:
|
|
# Time-based: advance every steps_per_stage env steps
|
|
if self.num_timesteps - self._stage_start >= self.steps_per_stage:
|
|
self._advance()
|
|
else:
|
|
# Success-rate based
|
|
for info, done in zip(self.locals["infos"], self.locals["dones"]):
|
|
if done:
|
|
truncated = info.get("TimeLimit.truncated", False)
|
|
self._successes.append(0 if truncated else 1)
|
|
if len(self._successes) > self.window:
|
|
self._successes.pop(0)
|
|
|
|
if (len(self._successes) >= self.min_episodes
|
|
and np.mean(self._successes) >= self.threshold):
|
|
self._advance()
|
|
|
|
return True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Diagnostic callback — failure-mode breakdown every diag_freq steps
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class DiagnosticCallback(BaseCallback):
|
|
"""
|
|
Every diag_freq env steps: spin up a temporary eval env, run n_episodes
|
|
deterministic episodes, and print a failure-mode breakdown.
|
|
Aborts training (returns False) if the dominant failure mode hasn't
|
|
changed after two consecutive checks at the same n_sheep — a sign that
|
|
training has stalled and further steps are wasted.
|
|
"""
|
|
|
|
def __init__(self, diag_freq: int = 500_000, n_episodes: int = 20,
|
|
max_steps: int = 2000, verbose: int = 1):
|
|
super().__init__(verbose)
|
|
self.diag_freq = diag_freq
|
|
self.n_episodes = n_episodes
|
|
self.max_steps = max_steps
|
|
self._last_diag = 0
|
|
self._prev_dominant = None # (n_sheep, mode) from last check
|
|
self._stall_count = 0
|
|
|
|
def _on_step(self) -> bool:
|
|
if self.num_timesteps - self._last_diag < self.diag_freq:
|
|
return True
|
|
self._last_diag = self.num_timesteps
|
|
|
|
n_sheep = self.training_env.get_attr("n_sheep")[0]
|
|
|
|
# Build a temporary single-env with copied VecNorm stats
|
|
raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep,
|
|
max_steps=self.max_steps)])
|
|
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
|
vn.obs_rms = deepcopy(self.training_env.obs_rms)
|
|
vn.ret_rms = deepcopy(self.training_env.ret_rms)
|
|
|
|
failure_counts = {}
|
|
successes = 0
|
|
|
|
for _ in range(self.n_episodes):
|
|
obs = vn.reset()
|
|
done = False
|
|
ep_radius, ep_com_dist = [], []
|
|
n_penned = 0
|
|
|
|
while not done:
|
|
action, _ = self.model.predict(obs, deterministic=True)
|
|
obs, _, dones, infos = vn.step(action)
|
|
done = dones[0]
|
|
inner = vn.envs[0]
|
|
com, radius, _ = inner._flock_stats()
|
|
ep_radius.append(radius)
|
|
ep_com_dist.append(
|
|
float(np.linalg.norm(com - inner.PEN_CENTER))
|
|
)
|
|
|
|
n_penned = infos[0].get("n_penned", 0)
|
|
success = n_penned == n_sheep
|
|
successes += int(success)
|
|
mode = _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success)
|
|
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
|
|
|
vn.close()
|
|
|
|
success_rate = successes / self.n_episodes
|
|
dominant = max(failure_counts, key=failure_counts.get)
|
|
|
|
if self.verbose:
|
|
print(f"\n[Diag @ {self.num_timesteps:,} | n_sheep={n_sheep} | "
|
|
f"success={success_rate*100:.0f}%]")
|
|
for m, c in sorted(failure_counts.items(), key=lambda x: -x[1]):
|
|
print(f" {m:<26} {c}/{self.n_episodes}")
|
|
|
|
# Stall detection: same dominant failure at same n_sheep 5 checks in a row,
|
|
# and only after 3M total steps (give early stages time to warm up).
|
|
key = (n_sheep, dominant)
|
|
if key == self._prev_dominant and dominant != "SUCCESS":
|
|
self._stall_count += 1
|
|
if self._stall_count >= 5 and self.num_timesteps >= 3_000_000:
|
|
print(f"\n[Diag] STALL DETECTED — '{dominant}' on {n_sheep} sheep "
|
|
f"for {self._stall_count} consecutive checks. "
|
|
f"Aborting training early.")
|
|
return False
|
|
else:
|
|
self._stall_count = 0
|
|
self._prev_dominant = key
|
|
|
|
return True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Environment factory
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def make_env(n_sheep: int, seed: int, max_steps: int, random_n_sheep: bool = False):
|
|
def _init():
|
|
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
|
random_n_sheep=random_n_sheep)
|
|
env.reset(seed=seed)
|
|
return env
|
|
return _init
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--n-sheep", type=int, default=1,
|
|
help="Starting sheep count")
|
|
p.add_argument("--max-sheep", type=int, default=5,
|
|
help="Final sheep count for curriculum")
|
|
p.add_argument("--n-envs", type=int, default=8,
|
|
help="Parallel training environments")
|
|
p.add_argument("--total-steps", type=int, default=5_000_000)
|
|
p.add_argument("--max-steps", type=int, default=2000,
|
|
help="Episode step limit")
|
|
p.add_argument("--curriculum", action="store_true",
|
|
help="Enable curriculum advancement")
|
|
p.add_argument("--steps-per-stage", type=int, default=None,
|
|
help="Advance curriculum every N steps (overrides --threshold)")
|
|
p.add_argument("--threshold", type=float, default=0.75,
|
|
help="Success-rate threshold to advance (used without --steps-per-stage)")
|
|
p.add_argument("--resume", type=str, default=None,
|
|
help="Checkpoint .zip to resume from")
|
|
p.add_argument("--run-dir", type=str, default="runs/ppo_herding")
|
|
p.add_argument("--save-freq", type=int, default=100_000)
|
|
p.add_argument("--eval-freq", type=int, default=50_000)
|
|
p.add_argument("--eval-eps", type=int, default=20)
|
|
p.add_argument("--diag-freq", type=int, default=500_000,
|
|
help="Run failure-mode diagnostics every N env steps")
|
|
p.add_argument("--mixed", action="store_true",
|
|
help="Randomise n_sheep each episode (consolidation pass, "
|
|
"use with --resume after curriculum training)")
|
|
return p.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
os.makedirs(args.run_dir, exist_ok=True)
|
|
ckpt_dir = os.path.join(args.run_dir, "checkpoints")
|
|
best_dir = os.path.join(args.run_dir, "best_model")
|
|
norm_path = os.path.join(args.run_dir, "vecnorm.pkl")
|
|
os.makedirs(ckpt_dir, exist_ok=True)
|
|
|
|
# Training envs
|
|
train_env = SubprocVecEnv([
|
|
make_env(args.n_sheep, seed=i, max_steps=args.max_steps,
|
|
random_n_sheep=args.mixed)
|
|
for i in range(args.n_envs)
|
|
])
|
|
if args.resume and os.path.exists(norm_path):
|
|
train_env = VecNormalize.load(norm_path, train_env)
|
|
train_env.training = True
|
|
train_env.norm_reward = True
|
|
else:
|
|
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True,
|
|
clip_obs=10.0)
|
|
|
|
# Eval env — starts at same difficulty, advances with curriculum callback
|
|
eval_env = SubprocVecEnv([
|
|
make_env(args.n_sheep, seed=1000 + i, max_steps=args.max_steps)
|
|
for i in range(2)
|
|
])
|
|
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False,
|
|
clip_obs=10.0, training=False)
|
|
|
|
# Callbacks
|
|
checkpoint_cb = CheckpointCallback(
|
|
save_freq=max(args.save_freq // args.n_envs, 1),
|
|
save_path=ckpt_dir,
|
|
name_prefix="ckpt",
|
|
save_vecnormalize=True,
|
|
)
|
|
eval_cb = EvalCallback(
|
|
eval_env,
|
|
best_model_save_path=best_dir,
|
|
log_path=args.run_dir,
|
|
eval_freq=max(args.eval_freq // args.n_envs, 1),
|
|
n_eval_episodes=args.eval_eps,
|
|
deterministic=True,
|
|
verbose=1,
|
|
)
|
|
diag_cb = DiagnosticCallback(
|
|
diag_freq=args.diag_freq,
|
|
n_episodes=20,
|
|
max_steps=args.max_steps,
|
|
)
|
|
callbacks = [checkpoint_cb, eval_cb, diag_cb]
|
|
|
|
if args.curriculum:
|
|
cur_cb = CurriculumCallback(
|
|
start_sheep=args.n_sheep,
|
|
max_sheep=args.max_sheep,
|
|
eval_env=eval_env,
|
|
steps_per_stage=args.steps_per_stage,
|
|
threshold=args.threshold,
|
|
)
|
|
callbacks.append(cur_cb)
|
|
|
|
callback_list = CallbackList(callbacks)
|
|
|
|
# Model
|
|
ppo_kwargs = dict(
|
|
policy = "MlpPolicy",
|
|
env = train_env,
|
|
learning_rate = 3e-4,
|
|
n_steps = 2048,
|
|
batch_size = 256,
|
|
n_epochs = 10,
|
|
gamma = 0.995,
|
|
gae_lambda = 0.95,
|
|
clip_range = 0.2,
|
|
ent_coef = 0.02,
|
|
vf_coef = 0.5,
|
|
max_grad_norm = 0.5,
|
|
policy_kwargs = dict(net_arch=[256, 256]),
|
|
tensorboard_log = args.run_dir,
|
|
verbose = 1,
|
|
)
|
|
|
|
if args.resume:
|
|
print(f"Resuming from {args.resume}")
|
|
model = PPO.load(args.resume, env=train_env, **{
|
|
k: v for k, v in ppo_kwargs.items()
|
|
if k not in ("policy", "env")
|
|
})
|
|
else:
|
|
model = PPO(**ppo_kwargs)
|
|
|
|
model.learn(
|
|
total_timesteps=args.total_steps,
|
|
callback=callback_list,
|
|
reset_num_timesteps=args.resume is None,
|
|
tb_log_name="ppo",
|
|
)
|
|
|
|
model.save(os.path.join(args.run_dir, "final_model"))
|
|
train_env.save(norm_path)
|
|
print(f"\nTraining complete. Artefacts saved to {args.run_dir}/")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|