""" Quick sanity check before committing to a full 15M-step training run. Trains 1 sheep for 500k steps (~5 min), then 3 sheep for 500k steps. If both pass, the obs/reward setup is sound and full training is worth running. If either fails, abort and fix before wasting 15M steps. Usage: python smoke_test.py # fresh run python smoke_test.py --render # watch episodes after each stage """ import argparse import os import sys import numpy as np from copy import deepcopy from stable_baselines3 import PPO from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize from herding_env import HerdingEnv COMPACT_RADIUS = 5.0 PASS_THRESHOLD = 0.60 # success rate required to pass each stage def make_env(n_sheep, seed, max_steps=2000): def _init(): env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps) env.reset(seed=seed) return env return _init def classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success): if success: return "SUCCESS" if min(ep_radius) > COMPACT_RADIUS: return "NEVER_COMPACT" first_compact = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS) if min(ep_com_dist[first_compact:]) > 3.0: return "COMPACT_CANT_DRIVE" if n_penned == 0: return "DROVE_NO_SHEEP" return f"PARTIAL_{n_penned}of{n_sheep}" def run_episodes(model, eval_env, n_episodes=30, max_steps=2000, render=False): """Run N deterministic episodes; return failure mode counts and success rate.""" failure_counts = {} successes = 0 for ep in range(n_episodes): obs = eval_env.reset() done = False ep_radius, ep_com_dist = [], [] n_penned = 0 n_sheep = 1 while not done: action, _ = model.predict(obs, deterministic=True) obs, _, dones, infos = eval_env.step(action) done = dones[0] inner = eval_env.envs[0] com, radius, _ = inner._flock_stats() com_dist = float(np.linalg.norm(com - inner.PEN_CENTER)) ep_radius.append(radius) ep_com_dist.append(com_dist) if render and ep == 0: inner.render() info = infos[0] n_penned = info.get("n_penned", 0) n_sheep = info.get("n_sheep", 1) success = n_penned == n_sheep successes += int(success) mode = classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success) failure_counts[mode] = failure_counts.get(mode, 0) + 1 success_rate = successes / n_episodes return success_rate, failure_counts def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None): """Train one stage; return (model, vecnorm).""" train_env = SubprocVecEnv([make_env(n_sheep, i) for i in range(n_envs)]) if prev_vecnorm is not None: vn = deepcopy(prev_vecnorm) vn.set_venv(train_env) vn.training = True vn.norm_reward = True else: vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0) if prev_model is not None: model = prev_model model.set_env(vn) else: model = PPO( "MlpPolicy", vn, learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10, gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.005, vf_coef=0.5, max_grad_norm=0.5, policy_kwargs=dict(net_arch=[256, 256]), verbose=1, ) model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None), tb_log_name="ppo_smoke") return model, vn def make_eval_env(model, vecnorm, n_sheep, max_steps=2000): raw = DummyVecEnv([make_env(n_sheep, seed=9999, max_steps=max_steps)]) vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False) vn.obs_rms = deepcopy(vecnorm.obs_rms) vn.ret_rms = deepcopy(vecnorm.ret_rms) return vn def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THRESHOLD): print(f"\n{'='*52}") print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})") print(f" {'─'*48}") for mode, cnt in sorted(failure_counts.items(), key=lambda x: -x[1]): bar = "█" * cnt print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}") print(f"{'='*52}") passed = success_rate >= threshold if passed: print(f" ✓ PASS (threshold {threshold*100:.0f}%)") else: dominant = max(failure_counts, key=failure_counts.get) print(f" ✗ FAIL — dominant: {dominant}") if dominant == "NEVER_COMPACT": print(" Dog can't compact flock. Check W_COLLECT, obs contains straggler positions?") elif dominant == "COMPACT_CANT_DRIVE": print(" Flock compacts but dog doesn't drive to pen. Check alignment reward / W_DRIVE.") elif dominant.startswith("PARTIAL"): print(" Flock splits near pen. Dog loses stragglers at the end.") print() return passed def main(): p = argparse.ArgumentParser() p.add_argument("--steps", type=int, default=500_000, help="Steps per smoke-test stage (default 500k)") p.add_argument("--n-envs", type=int, default=4) p.add_argument("--episodes", type=int, default=30, help="Validation episodes per stage") p.add_argument("--render", action="store_true") args = p.parse_args() # Stage 1 (1 sheep, 500k): fast sanity check — obs/reward structurally correct? # Stage 2 (3 sheep, 1.5M): real test at curriculum pace — if it fails here, # there is a genuine problem worth fixing before committing to 15M steps. stages = [(1, args.steps, 0.60), (3, args.steps * 3, 0.40)] model, vn = None, None all_passed = True for n_sheep, steps, threshold in stages: print(f"\n{'#'*52}") print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps") print(f"{'#'*52}") model, vn = train_stage(n_sheep, steps, args.n_envs, model, vn) eval_env = make_eval_env(model, vn, n_sheep) success_rate, failure_counts = run_episodes( model, eval_env, args.episodes, render=args.render ) eval_env.close() save_dir = f"runs/smoke_stage{n_sheep}" os.makedirs(save_dir, exist_ok=True) model.save(os.path.join(save_dir, "model")) vn.save(os.path.join(save_dir, "vecnorm.pkl")) print(f" Model saved to {save_dir}/") passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold) if not passed: all_passed = False print(" Aborting smoke test — fix the issue above before full training.") sys.exit(1) if all_passed: print("\n All smoke-test stages passed.") print(" Ready for full curriculum training:") print() print(" python train.py --curriculum --steps-per-stage 1500000 \\") print(" --total-steps 15000000 --n-sheep 1 --max-sheep 10 \\") print(" --n-envs 8 --run-dir runs/ppo_v2") print() if __name__ == "__main__": main()