Files
TIR_PROJ/training/smoke_test.py
T
2026-04-24 23:51:47 +01:00

370 lines
15 KiB
Python

"""
Quick sanity check before committing to a full 15M-step training run.
Trains 1 sheep for 500k steps (~5 min), then 3 sheep for 500k steps.
If both pass, the obs/reward setup is sound and full training is worth running.
If either fails, abort and fix before wasting 15M steps.
Usage:
python smoke_test.py # fresh run
python smoke_test.py --render # watch episodes after each stage
"""
import argparse
import os
import sys
import numpy as np
from copy import deepcopy
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import LineCollection
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
COMPACT_RADIUS = 5.0
PASS_THRESHOLD = 0.60 # success rate required to pass each stage
def make_env(n_sheep, seed, max_steps=2000):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
env.reset(seed=seed)
return env
return _init
def classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success):
if success:
return "SUCCESS"
if min(ep_radius) > COMPACT_RADIUS:
return "NEVER_COMPACT"
first_compact = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
if min(ep_com_dist[first_compact:]) > 3.0:
return "COMPACT_CANT_DRIVE"
if n_penned == 0:
return "DROVE_NO_SHEEP"
return f"PARTIAL_{n_penned}of{n_sheep}"
def run_episodes(model, eval_env, n_episodes=30, max_steps=2000, render=False):
"""
Run N deterministic episodes.
Returns (success_rate, failure_counts, diagnostics_dict).
diagnostics_dict contains per-episode and aggregate stats useful for
understanding WHY the policy is failing without assuming the cause.
"""
failure_counts = {}
successes = 0
all_action_mags = [] # action magnitude every step across all episodes
all_pen_progress = [] # per-episode: total pen-dist reduction (positive = good)
ep_steps_list = []
ep_min_pen_list = [] # min pen dist reached in each episode
for ep in range(n_episodes):
obs = eval_env.reset()
done = False
ep_radius, ep_com_dist = [], []
ep_action_mags = []
n_penned = 0
n_sheep = 1
prev_pen_dist = None
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = eval_env.step(action)
done = dones[0]
inner = eval_env.envs[0]
com, radius, _ = inner._flock_stats()
com_dist = float(np.linalg.norm(com - inner.PEN_CENTER))
ep_radius.append(radius)
ep_com_dist.append(com_dist)
act_mag = float(np.linalg.norm(action[0]))
ep_action_mags.append(act_mag)
active = ~inner.penned[:inner.n_sheep]
if active.any():
pen_dist = float(np.linalg.norm(
inner.sheep_pos[:inner.n_sheep][active] - inner.PEN_CENTER, axis=1
).sum())
else:
pen_dist = 0.0
if prev_pen_dist is None:
prev_pen_dist = pen_dist
prev_pen_dist = pen_dist
if render and ep == 0:
inner.render()
info = infos[0]
n_penned = info.get("n_penned", 0)
n_sheep = info.get("n_sheep", 1)
success = n_penned == n_sheep
successes += int(success)
mode = classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
all_action_mags.extend(ep_action_mags)
ep_steps_list.append(len(ep_action_mags))
ep_min_pen_list.append(min(ep_com_dist))
# Per-episode one-liner for real-time feedback
mean_act = float(np.mean(ep_action_mags))
min_pen = min(ep_com_dist)
print(f" ep {ep+1:>3} steps={len(ep_action_mags):>5} "
f"penned={n_penned}/{n_sheep} "
f"act={mean_act:.2f} "
f"min_pen={min_pen:.1f}m [{mode}]")
success_rate = successes / n_episodes
diag = {
"mean_action_mag" : float(np.mean(all_action_mags)),
"p10_action_mag" : float(np.percentile(all_action_mags, 10)),
"p90_action_mag" : float(np.percentile(all_action_mags, 90)),
"mean_min_pen_dist": float(np.mean(ep_min_pen_list)),
"best_min_pen_dist": float(np.min(ep_min_pen_list)),
"mean_ep_steps" : float(np.mean(ep_steps_list)),
}
print(f"\n Action magnitude mean={diag['mean_action_mag']:.3f} "
f"p10={diag['p10_action_mag']:.3f} p90={diag['p90_action_mag']:.3f}"
f" (0=stopped, 1=full speed)")
print(f" Pen distance mean_min={diag['mean_min_pen_dist']:.1f}m "
f"best_min={diag['best_min_pen_dist']:.1f}m "
f"(how close sheep got to pen center)")
return success_rate, failure_counts, diag
def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None):
"""Train one stage; return (model, vecnorm)."""
train_env = SubprocVecEnv([make_env(n_sheep, i) for i in range(n_envs)])
if prev_vecnorm is not None:
vn = deepcopy(prev_vecnorm)
vn.set_venv(train_env)
vn.training = True
vn.norm_reward = True
else:
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
if prev_model is not None:
model = prev_model
model.set_env(vn)
else:
model = PPO(
"MlpPolicy", vn,
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.02,
vf_coef=0.5, max_grad_norm=0.5,
policy_kwargs=dict(net_arch=[256, 256]),
verbose=1,
)
model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None),
tb_log_name="ppo_smoke")
return model, vn
def make_eval_env(model, vecnorm, n_sheep, max_steps=2000):
raw = DummyVecEnv([make_env(n_sheep, seed=9999, max_steps=max_steps)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vecnorm.obs_rms)
vn.ret_rms = deepcopy(vecnorm.ret_rms)
return vn
def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THRESHOLD):
print(f"\n{'='*52}")
print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})")
print(f" {'─'*48}")
for mode, cnt in sorted(failure_counts.items(), key=lambda x: -x[1]):
bar = "█" * cnt
print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}")
print(f"{'='*52}")
passed = success_rate >= threshold
if passed:
print(f" ✓ PASS (threshold {threshold*100:.0f}%)")
else:
dominant = max(failure_counts, key=failure_counts.get)
print(f" ✗ FAIL — dominant: {dominant}")
if dominant == "NEVER_COMPACT":
print(" Dog can't compact flock. Check W_COLLECT, obs contains straggler positions?")
elif dominant == "COMPACT_CANT_DRIVE":
print(" Flock compacts but dog doesn't drive to pen. Check alignment reward / W_DRIVE.")
elif dominant.startswith("PARTIAL"):
print(" Flock splits near pen. Dog loses stragglers at the end.")
print()
return passed
SHEEP_COLORS = ["#e41a1c","#377eb8","#4daf4a","#984ea3","#ff7f00",
"#a65628","#f781bf","#999999","#66c2a5","#fc8d62"]
def _save_smoke_vis(model, vn, n_sheep, save_dir, seed=42, max_steps=2000):
"""Run one episode and save trajectory + timeseries PNGs."""
from copy import deepcopy
raw = DummyVecEnv([make_env(n_sheep, seed=seed, max_steps=max_steps)])
env = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
env.obs_rms = deepcopy(vn.obs_rms)
env.ret_rms = deepcopy(vn.ret_rms)
obs = env.reset()
inner = env.envs[0]
dog_xs, dog_ys = [], []
sheep_xs = [[] for _ in range(n_sheep)]
sheep_ys = [[] for _ in range(n_sheep)]
radii, action_mags, rewards = [], [], []
pen_dists = [[] for _ in range(n_sheep)]
done = False
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, dones, _ = env.step(action)
done = dones[0]
dog_xs.append(float(inner.dog_pos[0])); dog_ys.append(float(inner.dog_pos[1]))
com, radius, _ = inner._flock_stats()
radii.append(radius)
rewards.append(float(reward[0]))
action_mags.append(float(np.linalg.norm(action[0])))
for i in range(n_sheep):
sheep_xs[i].append(float(inner.sheep_pos[i][0]))
sheep_ys[i].append(float(inner.sheep_pos[i][1]))
pen_dists[i].append(float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
env.close()
steps = len(dog_xs)
# Trajectory
fig, ax = plt.subplots(figsize=(6,6))
ax.set_xlim(-16,16); ax.set_ylim(-16,16); ax.set_aspect("equal")
ax.set_facecolor("#dcedc8")
ax.add_patch(mpatches.Rectangle((-15,-15),30,30,fill=False,edgecolor="#795548",lw=2))
ax.add_patch(mpatches.Rectangle((10,-15),3,7,facecolor="#ffe082",edgecolor="#795548",lw=2))
ax.text(11.5,-11.5,"pen",ha="center",va="center",fontsize=8,color="#795548")
for i in range(n_sheep):
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
ax.plot(sheep_xs[i], sheep_ys[i], color=c, lw=1, alpha=0.6, label=f"sheep {i+1}")
ax.plot(sheep_xs[i][0], sheep_ys[i][0], "o", color=c, ms=7)
ax.plot(sheep_xs[i][-1], sheep_ys[i][-1], "*", color=c, ms=10)
ax.plot(dog_xs, dog_ys, color="#4e342e", lw=1.5, label="dog", alpha=0.8)
ax.plot(dog_xs[0], dog_ys[0], "s", color="#4e342e", ms=9)
ax.plot(dog_xs[-1], dog_ys[-1], "D", color="#4e342e", ms=9)
ax.set_title(f"n_sheep={n_sheep} {steps} steps min_r={min(radii):.1f}m")
ax.legend(fontsize=7, loc="upper left")
plt.tight_layout()
fig.savefig(os.path.join(save_dir, "trajectory.png"), dpi=100)
plt.close(fig)
# Timeseries
t = np.arange(steps)
fig, axes = plt.subplots(4,1,figsize=(10,8),sharex=True)
axes[0].plot(t, radii, color="steelblue"); axes[0].axhline(5,color="orange",ls="--",lw=1)
axes[0].set_ylabel("radius (m)"); axes[0].set_title("Flock radius (orange=5m threshold)")
for i in range(n_sheep):
axes[1].plot(t, pen_dists[i], color=SHEEP_COLORS[i%len(SHEEP_COLORS)], lw=1, label=f"sheep {i+1}")
axes[1].set_ylabel("pen dist (m)"); axes[1].set_title("Per-sheep distance to pen"); axes[1].legend(fontsize=7)
axes[2].plot(t, action_mags, color="tomato", lw=1, alpha=0.8)
axes[2].axhline(1.0,color="gray",ls="--",lw=1); axes[2].set_ylim(0,1.5)
axes[2].set_ylabel("action mag"); axes[2].set_title("Dog action magnitude (0=stopped)")
axes[3].plot(t, rewards, color="purple", lw=1, alpha=0.7); axes[3].axhline(0,color="black",lw=0.5)
axes[3].set_ylabel("reward"); axes[3].set_xlabel("step"); axes[3].set_title("Reward per step")
fig.suptitle(f"Smoke stage n_sheep={n_sheep}", fontsize=12)
plt.tight_layout()
fig.savefig(os.path.join(save_dir, "timeseries.png"), dpi=100)
plt.close(fig)
print(f" Viz saved to {save_dir}/trajectory.png + timeseries.png")
def main():
p = argparse.ArgumentParser()
p.add_argument("--steps", type=int, default=500_000,
help="Steps per smoke-test stage (default 500k)")
p.add_argument("--n-envs", type=int, default=4)
p.add_argument("--episodes", type=int, default=30,
help="Validation episodes per stage")
p.add_argument("--render", action="store_true")
args = p.parse_args()
# 1 sheep (500k): hard check — obs/reward structurally correct?
# Thresholds are MINIMUM bars — smoke test always runs ALL stages even on failure.
# The per-episode diagnostics tell you WHY a stage failed.
stages = [(1, args.steps, 0.10), (2, args.steps * 2, 0.20), (3, args.steps * 3, 0.10)]
model, vn = None, None
stage_results = []
for n_sheep, steps, threshold in stages:
print(f"\n{'#'*52}")
print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps")
print(f"{'#'*52}")
model, vn = train_stage(n_sheep, steps, args.n_envs, model, vn)
eval_env = make_eval_env(model, vn, n_sheep)
success_rate, failure_counts, diag = run_episodes(
model, eval_env, args.episodes, render=args.render
)
eval_env.close()
save_dir = f"runs/smoke_stage{n_sheep}"
os.makedirs(save_dir, exist_ok=True)
model.save(os.path.join(save_dir, "model"))
vn.save(os.path.join(save_dir, "vecnorm.pkl"))
print(f" Model saved to {save_dir}/")
_save_smoke_vis(model, vn, n_sheep, save_dir)
passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
stage_results.append((n_sheep, success_rate, passed, diag))
if not passed:
print(f" ⚠ Stage {n_sheep} BELOW threshold — continuing to next stage.")
print(f" mean_action={diag['mean_action_mag']:.3f} "
f"best_pen_approach={diag['best_min_pen_dist']:.1f}m")
if diag['mean_action_mag'] < 0.05:
print(" !! Dog is NOT moving (sit-still). "
"Check ent_coef / step_cost / alignment.")
elif diag['best_min_pen_dist'] > 5.0:
print(" !! Dog never gets sheep near pen. "
"Check reward direction / initialization.")
else:
print(" !! Dog moves and approaches pen but low success rate. "
"Likely needs more training time.")
print(f"\n{'='*52}")
print(" SMOKE TEST SUMMARY")
print(f"{'='*52}")
all_passed = True
for n_sheep, sr, passed, diag in stage_results:
status = "PASS" if passed else "FAIL"
print(f" n_sheep={n_sheep} success={sr*100:.0f}% "
f"act={diag['mean_action_mag']:.2f} "
f"best_pen={diag['best_min_pen_dist']:.1f}m [{status}]")
if not passed:
all_passed = False
if all_passed:
print("\n All stages passed. Ready for full curriculum training:")
print(" python train.py --curriculum --steps-per-stage 1500000 "
"--total-steps 15000000 --n-sheep 1 --max-sheep 10 "
"--n-envs 8 --run-dir runs/ppo_v3")
else:
print("\n Some stages below threshold — check diagnostics above.")
print(" Key signals: act<0.05=sit-still, best_pen>5=wrong direction, "
"else needs more training time.")
print()
if __name__ == "__main__":
main()