Files
TIR_PROJ/training/smoke_test.py
T
2026-04-24 17:31:11 +01:00

206 lines
7.1 KiB
Python

"""
Quick sanity check before committing to a full 15M-step training run.
Trains 1 sheep for 500k steps (~5 min), then 3 sheep for 500k steps.
If both pass, the obs/reward setup is sound and full training is worth running.
If either fails, abort and fix before wasting 15M steps.
Usage:
python smoke_test.py # fresh run
python smoke_test.py --render # watch episodes after each stage
"""
import argparse
import os
import sys
import numpy as np
from copy import deepcopy
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
COMPACT_RADIUS = 5.0
PASS_THRESHOLD = 0.60 # success rate required to pass each stage
def make_env(n_sheep, seed, max_steps=2000):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
env.reset(seed=seed)
return env
return _init
def classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success):
if success:
return "SUCCESS"
if min(ep_radius) > COMPACT_RADIUS:
return "NEVER_COMPACT"
first_compact = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
if min(ep_com_dist[first_compact:]) > 3.0:
return "COMPACT_CANT_DRIVE"
if n_penned == 0:
return "DROVE_NO_SHEEP"
return f"PARTIAL_{n_penned}of{n_sheep}"
def run_episodes(model, eval_env, n_episodes=30, max_steps=2000, render=False):
"""Run N deterministic episodes; return failure mode counts and success rate."""
failure_counts = {}
successes = 0
for ep in range(n_episodes):
obs = eval_env.reset()
done = False
ep_radius, ep_com_dist = [], []
n_penned = 0
n_sheep = 1
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = eval_env.step(action)
done = dones[0]
inner = eval_env.envs[0]
com, radius, _ = inner._flock_stats()
com_dist = float(np.linalg.norm(com - inner.PEN_CENTER))
ep_radius.append(radius)
ep_com_dist.append(com_dist)
if render and ep == 0:
inner.render()
info = infos[0]
n_penned = info.get("n_penned", 0)
n_sheep = info.get("n_sheep", 1)
success = n_penned == n_sheep
successes += int(success)
mode = classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
success_rate = successes / n_episodes
return success_rate, failure_counts
def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None):
"""Train one stage; return (model, vecnorm)."""
train_env = SubprocVecEnv([make_env(n_sheep, i) for i in range(n_envs)])
if prev_vecnorm is not None:
vn = deepcopy(prev_vecnorm)
vn.set_venv(train_env)
vn.training = True
vn.norm_reward = True
else:
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
if prev_model is not None:
model = prev_model
model.set_env(vn)
else:
model = PPO(
"MlpPolicy", vn,
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.02,
vf_coef=0.5, max_grad_norm=0.5,
policy_kwargs=dict(net_arch=[256, 256]),
verbose=1,
)
model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None),
tb_log_name="ppo_smoke")
return model, vn
def make_eval_env(model, vecnorm, n_sheep, max_steps=2000):
raw = DummyVecEnv([make_env(n_sheep, seed=9999, max_steps=max_steps)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vecnorm.obs_rms)
vn.ret_rms = deepcopy(vecnorm.ret_rms)
return vn
def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THRESHOLD):
print(f"\n{'='*52}")
print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})")
print(f" {'─'*48}")
for mode, cnt in sorted(failure_counts.items(), key=lambda x: -x[1]):
bar = "█" * cnt
print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}")
print(f"{'='*52}")
passed = success_rate >= threshold
if passed:
print(f" ✓ PASS (threshold {threshold*100:.0f}%)")
else:
dominant = max(failure_counts, key=failure_counts.get)
print(f" ✗ FAIL — dominant: {dominant}")
if dominant == "NEVER_COMPACT":
print(" Dog can't compact flock. Check W_COLLECT, obs contains straggler positions?")
elif dominant == "COMPACT_CANT_DRIVE":
print(" Flock compacts but dog doesn't drive to pen. Check alignment reward / W_DRIVE.")
elif dominant.startswith("PARTIAL"):
print(" Flock splits near pen. Dog loses stragglers at the end.")
print()
return passed
def main():
p = argparse.ArgumentParser()
p.add_argument("--steps", type=int, default=500_000,
help="Steps per smoke-test stage (default 500k)")
p.add_argument("--n-envs", type=int, default=4)
p.add_argument("--episodes", type=int, default=30,
help="Validation episodes per stage")
p.add_argument("--render", action="store_true")
args = p.parse_args()
# Stage 1 (1 sheep, 500k): fast sanity check — obs/reward structurally correct?
# Stage 2 (3 sheep, 1.5M): real test at curriculum pace — if it fails here,
# there is a genuine problem worth fixing before committing to 15M steps.
stages = [(1, args.steps, 0.60), (3, args.steps * 3, 0.40)]
model, vn = None, None
all_passed = True
for n_sheep, steps, threshold in stages:
print(f"\n{'#'*52}")
print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps")
print(f"{'#'*52}")
model, vn = train_stage(n_sheep, steps, args.n_envs, model, vn)
eval_env = make_eval_env(model, vn, n_sheep)
success_rate, failure_counts = run_episodes(
model, eval_env, args.episodes, render=args.render
)
eval_env.close()
save_dir = f"runs/smoke_stage{n_sheep}"
os.makedirs(save_dir, exist_ok=True)
model.save(os.path.join(save_dir, "model"))
vn.save(os.path.join(save_dir, "vecnorm.pkl"))
print(f" Model saved to {save_dir}/")
passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
if not passed:
all_passed = False
print(" Aborting smoke test — fix the issue above before full training.")
sys.exit(1)
if all_passed:
print("\n All smoke-test stages passed.")
print(" Ready for full curriculum training:")
print()
print(" python train.py --curriculum --steps-per-stage 1500000 \\")
print(" --total-steps 15000000 --n-sheep 1 --max-sheep 10 \\")
print(" --n-envs 8 --run-dir runs/ppo_v2")
print()
if __name__ == "__main__":
main()