Files
TIR_PROJ/training/smoke_test.py
T
2026-04-24 17:49:42 +01:00

290 lines
11 KiB
Python

"""
Quick sanity check before committing to a full 15M-step training run.
Trains 1 sheep for 500k steps (~5 min), then 3 sheep for 500k steps.
If both pass, the obs/reward setup is sound and full training is worth running.
If either fails, abort and fix before wasting 15M steps.
Usage:
python smoke_test.py # fresh run
python smoke_test.py --render # watch episodes after each stage
"""
import argparse
import os
import sys
import numpy as np
from copy import deepcopy
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import LineCollection
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
COMPACT_RADIUS = 5.0
PASS_THRESHOLD = 0.60 # success rate required to pass each stage
def make_env(n_sheep, seed, max_steps=2000):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
env.reset(seed=seed)
return env
return _init
def classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success):
if success:
return "SUCCESS"
if min(ep_radius) > COMPACT_RADIUS:
return "NEVER_COMPACT"
first_compact = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
if min(ep_com_dist[first_compact:]) > 3.0:
return "COMPACT_CANT_DRIVE"
if n_penned == 0:
return "DROVE_NO_SHEEP"
return f"PARTIAL_{n_penned}of{n_sheep}"
def run_episodes(model, eval_env, n_episodes=30, max_steps=2000, render=False):
"""Run N deterministic episodes; return failure mode counts and success rate."""
failure_counts = {}
successes = 0
for ep in range(n_episodes):
obs = eval_env.reset()
done = False
ep_radius, ep_com_dist = [], []
n_penned = 0
n_sheep = 1
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = eval_env.step(action)
done = dones[0]
inner = eval_env.envs[0]
com, radius, _ = inner._flock_stats()
com_dist = float(np.linalg.norm(com - inner.PEN_CENTER))
ep_radius.append(radius)
ep_com_dist.append(com_dist)
if render and ep == 0:
inner.render()
info = infos[0]
n_penned = info.get("n_penned", 0)
n_sheep = info.get("n_sheep", 1)
success = n_penned == n_sheep
successes += int(success)
mode = classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
success_rate = successes / n_episodes
return success_rate, failure_counts
def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None):
"""Train one stage; return (model, vecnorm)."""
train_env = SubprocVecEnv([make_env(n_sheep, i) for i in range(n_envs)])
if prev_vecnorm is not None:
vn = deepcopy(prev_vecnorm)
vn.set_venv(train_env)
vn.training = True
vn.norm_reward = True
else:
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
if prev_model is not None:
model = prev_model
model.set_env(vn)
else:
model = PPO(
"MlpPolicy", vn,
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.02,
vf_coef=0.5, max_grad_norm=0.5,
policy_kwargs=dict(net_arch=[256, 256]),
verbose=1,
)
model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None),
tb_log_name="ppo_smoke")
return model, vn
def make_eval_env(model, vecnorm, n_sheep, max_steps=2000):
raw = DummyVecEnv([make_env(n_sheep, seed=9999, max_steps=max_steps)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vecnorm.obs_rms)
vn.ret_rms = deepcopy(vecnorm.ret_rms)
return vn
def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THRESHOLD):
print(f"\n{'='*52}")
print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})")
print(f" {'─'*48}")
for mode, cnt in sorted(failure_counts.items(), key=lambda x: -x[1]):
bar = "█" * cnt
print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}")
print(f"{'='*52}")
passed = success_rate >= threshold
if passed:
print(f" ✓ PASS (threshold {threshold*100:.0f}%)")
else:
dominant = max(failure_counts, key=failure_counts.get)
print(f" ✗ FAIL — dominant: {dominant}")
if dominant == "NEVER_COMPACT":
print(" Dog can't compact flock. Check W_COLLECT, obs contains straggler positions?")
elif dominant == "COMPACT_CANT_DRIVE":
print(" Flock compacts but dog doesn't drive to pen. Check alignment reward / W_DRIVE.")
elif dominant.startswith("PARTIAL"):
print(" Flock splits near pen. Dog loses stragglers at the end.")
print()
return passed
SHEEP_COLORS = ["#e41a1c","#377eb8","#4daf4a","#984ea3","#ff7f00",
"#a65628","#f781bf","#999999","#66c2a5","#fc8d62"]
def _save_smoke_vis(model, vn, n_sheep, save_dir, seed=42, max_steps=2000):
"""Run one episode and save trajectory + timeseries PNGs."""
from copy import deepcopy
raw = DummyVecEnv([make_env(n_sheep, max_steps, seed)])
env = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
env.obs_rms = deepcopy(vn.obs_rms)
env.ret_rms = deepcopy(vn.ret_rms)
obs = env.reset()
inner = env.envs[0]
dog_xs, dog_ys = [], []
sheep_xs = [[] for _ in range(n_sheep)]
sheep_ys = [[] for _ in range(n_sheep)]
radii, action_mags, rewards = [], [], []
pen_dists = [[] for _ in range(n_sheep)]
done = False
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, dones, _ = env.step(action)
done = dones[0]
dog_xs.append(float(inner.dog_pos[0])); dog_ys.append(float(inner.dog_pos[1]))
com, radius, _ = inner._flock_stats()
radii.append(radius)
rewards.append(float(reward[0]))
action_mags.append(float(np.linalg.norm(action[0])))
for i in range(n_sheep):
sheep_xs[i].append(float(inner.sheep_pos[i][0]))
sheep_ys[i].append(float(inner.sheep_pos[i][1]))
pen_dists[i].append(float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
env.close()
steps = len(dog_xs)
# Trajectory
fig, ax = plt.subplots(figsize=(6,6))
ax.set_xlim(-16,16); ax.set_ylim(-16,16); ax.set_aspect("equal")
ax.set_facecolor("#dcedc8")
ax.add_patch(mpatches.Rectangle((-15,-15),30,30,fill=False,edgecolor="#795548",lw=2))
ax.add_patch(mpatches.Rectangle((10,-15),3,7,facecolor="#ffe082",edgecolor="#795548",lw=2))
ax.text(11.5,-11.5,"pen",ha="center",va="center",fontsize=8,color="#795548")
for i in range(n_sheep):
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
ax.plot(sheep_xs[i], sheep_ys[i], color=c, lw=1, alpha=0.6, label=f"sheep {i+1}")
ax.plot(sheep_xs[i][0], sheep_ys[i][0], "o", color=c, ms=7)
ax.plot(sheep_xs[i][-1], sheep_ys[i][-1], "*", color=c, ms=10)
ax.plot(dog_xs, dog_ys, color="#4e342e", lw=1.5, label="dog", alpha=0.8)
ax.plot(dog_xs[0], dog_ys[0], "s", color="#4e342e", ms=9)
ax.plot(dog_xs[-1], dog_ys[-1], "D", color="#4e342e", ms=9)
ax.set_title(f"n_sheep={n_sheep} {steps} steps min_r={min(radii):.1f}m")
ax.legend(fontsize=7, loc="upper left")
plt.tight_layout()
fig.savefig(os.path.join(save_dir, "trajectory.png"), dpi=100)
plt.close(fig)
# Timeseries
t = np.arange(steps)
fig, axes = plt.subplots(4,1,figsize=(10,8),sharex=True)
axes[0].plot(t, radii, color="steelblue"); axes[0].axhline(5,color="orange",ls="--",lw=1)
axes[0].set_ylabel("radius (m)"); axes[0].set_title("Flock radius (orange=5m threshold)")
for i in range(n_sheep):
axes[1].plot(t, pen_dists[i], color=SHEEP_COLORS[i%len(SHEEP_COLORS)], lw=1, label=f"sheep {i+1}")
axes[1].set_ylabel("pen dist (m)"); axes[1].set_title("Per-sheep distance to pen"); axes[1].legend(fontsize=7)
axes[2].plot(t, action_mags, color="tomato", lw=1, alpha=0.8)
axes[2].axhline(1.0,color="gray",ls="--",lw=1); axes[2].set_ylim(0,1.5)
axes[2].set_ylabel("action mag"); axes[2].set_title("Dog action magnitude (0=stopped)")
axes[3].plot(t, rewards, color="purple", lw=1, alpha=0.7); axes[3].axhline(0,color="black",lw=0.5)
axes[3].set_ylabel("reward"); axes[3].set_xlabel("step"); axes[3].set_title("Reward per step")
fig.suptitle(f"Smoke stage n_sheep={n_sheep}", fontsize=12)
plt.tight_layout()
fig.savefig(os.path.join(save_dir, "timeseries.png"), dpi=100)
plt.close(fig)
print(f" Viz saved to {save_dir}/trajectory.png + timeseries.png")
def main():
p = argparse.ArgumentParser()
p.add_argument("--steps", type=int, default=500_000,
help="Steps per smoke-test stage (default 500k)")
p.add_argument("--n-envs", type=int, default=4)
p.add_argument("--episodes", type=int, default=30,
help="Validation episodes per stage")
p.add_argument("--render", action="store_true")
args = p.parse_args()
# 1 sheep (500k): sanity check — obs/reward structurally correct?
# 2 sheep (1M): first multi-agent step — gradual transfer
# 3 sheep (1.5M): real multi-sheep test at curriculum pace
stages = [(1, args.steps, 0.60), (2, args.steps * 2, 0.40), (3, args.steps * 3, 0.35)]
model, vn = None, None
all_passed = True
for n_sheep, steps, threshold in stages:
print(f"\n{'#'*52}")
print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps")
print(f"{'#'*52}")
model, vn = train_stage(n_sheep, steps, args.n_envs, model, vn)
eval_env = make_eval_env(model, vn, n_sheep)
success_rate, failure_counts = run_episodes(
model, eval_env, args.episodes, render=args.render
)
eval_env.close()
save_dir = f"runs/smoke_stage{n_sheep}"
os.makedirs(save_dir, exist_ok=True)
model.save(os.path.join(save_dir, "model"))
vn.save(os.path.join(save_dir, "vecnorm.pkl"))
print(f" Model saved to {save_dir}/")
_save_smoke_vis(model, vn, n_sheep, save_dir)
passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
if not passed:
all_passed = False
print(" Aborting smoke test — fix the issue above before full training.")
sys.exit(1)
if all_passed:
print("\n All smoke-test stages passed.")
print(" Ready for full curriculum training:")
print()
print(" python train.py --curriculum --steps-per-stage 1500000 \\")
print(" --total-steps 15000000 --n-sheep 1 --max-sheep 10 \\")
print(" --n-envs 8 --run-dir runs/ppo_v2")
print()
if __name__ == "__main__":
main()