370 lines
15 KiB
Python
370 lines
15 KiB
Python
"""
|
|
Quick sanity check before committing to a full 15M-step training run.
|
|
|
|
Trains 1 sheep for 500k steps (~5 min), then 3 sheep for 500k steps.
|
|
If both pass, the obs/reward setup is sound and full training is worth running.
|
|
If either fails, abort and fix before wasting 15M steps.
|
|
|
|
Usage:
|
|
python smoke_test.py # fresh run
|
|
python smoke_test.py --render # watch episodes after each stage
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as mpatches
|
|
from matplotlib.collections import LineCollection
|
|
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
|
|
|
|
from herding_env import HerdingEnv
|
|
|
|
|
|
COMPACT_RADIUS = 5.0
|
|
PASS_THRESHOLD = 0.60 # success rate required to pass each stage
|
|
|
|
|
|
def make_env(n_sheep, seed, max_steps=2000):
|
|
def _init():
|
|
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
|
|
env.reset(seed=seed)
|
|
return env
|
|
return _init
|
|
|
|
|
|
def classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success):
|
|
if success:
|
|
return "SUCCESS"
|
|
if min(ep_radius) > COMPACT_RADIUS:
|
|
return "NEVER_COMPACT"
|
|
first_compact = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
|
|
if min(ep_com_dist[first_compact:]) > 3.0:
|
|
return "COMPACT_CANT_DRIVE"
|
|
if n_penned == 0:
|
|
return "DROVE_NO_SHEEP"
|
|
return f"PARTIAL_{n_penned}of{n_sheep}"
|
|
|
|
|
|
def run_episodes(model, eval_env, n_episodes=30, max_steps=2000, render=False):
|
|
"""
|
|
Run N deterministic episodes.
|
|
Returns (success_rate, failure_counts, diagnostics_dict).
|
|
diagnostics_dict contains per-episode and aggregate stats useful for
|
|
understanding WHY the policy is failing without assuming the cause.
|
|
"""
|
|
failure_counts = {}
|
|
successes = 0
|
|
|
|
all_action_mags = [] # action magnitude every step across all episodes
|
|
all_pen_progress = [] # per-episode: total pen-dist reduction (positive = good)
|
|
ep_steps_list = []
|
|
ep_min_pen_list = [] # min pen dist reached in each episode
|
|
|
|
for ep in range(n_episodes):
|
|
obs = eval_env.reset()
|
|
done = False
|
|
ep_radius, ep_com_dist = [], []
|
|
ep_action_mags = []
|
|
n_penned = 0
|
|
n_sheep = 1
|
|
prev_pen_dist = None
|
|
|
|
while not done:
|
|
action, _ = model.predict(obs, deterministic=True)
|
|
obs, _, dones, infos = eval_env.step(action)
|
|
done = dones[0]
|
|
|
|
inner = eval_env.envs[0]
|
|
com, radius, _ = inner._flock_stats()
|
|
com_dist = float(np.linalg.norm(com - inner.PEN_CENTER))
|
|
ep_radius.append(radius)
|
|
ep_com_dist.append(com_dist)
|
|
|
|
act_mag = float(np.linalg.norm(action[0]))
|
|
ep_action_mags.append(act_mag)
|
|
|
|
active = ~inner.penned[:inner.n_sheep]
|
|
if active.any():
|
|
pen_dist = float(np.linalg.norm(
|
|
inner.sheep_pos[:inner.n_sheep][active] - inner.PEN_CENTER, axis=1
|
|
).sum())
|
|
else:
|
|
pen_dist = 0.0
|
|
if prev_pen_dist is None:
|
|
prev_pen_dist = pen_dist
|
|
prev_pen_dist = pen_dist
|
|
|
|
if render and ep == 0:
|
|
inner.render()
|
|
|
|
info = infos[0]
|
|
n_penned = info.get("n_penned", 0)
|
|
n_sheep = info.get("n_sheep", 1)
|
|
success = n_penned == n_sheep
|
|
successes += int(success)
|
|
mode = classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success)
|
|
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
|
|
|
all_action_mags.extend(ep_action_mags)
|
|
ep_steps_list.append(len(ep_action_mags))
|
|
ep_min_pen_list.append(min(ep_com_dist))
|
|
|
|
# Per-episode one-liner for real-time feedback
|
|
mean_act = float(np.mean(ep_action_mags))
|
|
min_pen = min(ep_com_dist)
|
|
print(f" ep {ep+1:>3} steps={len(ep_action_mags):>5} "
|
|
f"penned={n_penned}/{n_sheep} "
|
|
f"act={mean_act:.2f} "
|
|
f"min_pen={min_pen:.1f}m [{mode}]")
|
|
|
|
success_rate = successes / n_episodes
|
|
|
|
diag = {
|
|
"mean_action_mag" : float(np.mean(all_action_mags)),
|
|
"p10_action_mag" : float(np.percentile(all_action_mags, 10)),
|
|
"p90_action_mag" : float(np.percentile(all_action_mags, 90)),
|
|
"mean_min_pen_dist": float(np.mean(ep_min_pen_list)),
|
|
"best_min_pen_dist": float(np.min(ep_min_pen_list)),
|
|
"mean_ep_steps" : float(np.mean(ep_steps_list)),
|
|
}
|
|
|
|
print(f"\n Action magnitude mean={diag['mean_action_mag']:.3f} "
|
|
f"p10={diag['p10_action_mag']:.3f} p90={diag['p90_action_mag']:.3f}"
|
|
f" (0=stopped, 1=full speed)")
|
|
print(f" Pen distance mean_min={diag['mean_min_pen_dist']:.1f}m "
|
|
f"best_min={diag['best_min_pen_dist']:.1f}m "
|
|
f"(how close sheep got to pen center)")
|
|
|
|
return success_rate, failure_counts, diag
|
|
|
|
|
|
def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None):
|
|
"""Train one stage; return (model, vecnorm)."""
|
|
train_env = SubprocVecEnv([make_env(n_sheep, i) for i in range(n_envs)])
|
|
|
|
if prev_vecnorm is not None:
|
|
vn = deepcopy(prev_vecnorm)
|
|
vn.set_venv(train_env)
|
|
vn.training = True
|
|
vn.norm_reward = True
|
|
else:
|
|
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
|
|
|
|
if prev_model is not None:
|
|
model = prev_model
|
|
model.set_env(vn)
|
|
else:
|
|
model = PPO(
|
|
"MlpPolicy", vn,
|
|
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
|
|
gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.02,
|
|
vf_coef=0.5, max_grad_norm=0.5,
|
|
policy_kwargs=dict(net_arch=[256, 256]),
|
|
verbose=1,
|
|
)
|
|
|
|
model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None),
|
|
tb_log_name="ppo_smoke")
|
|
return model, vn
|
|
|
|
|
|
def make_eval_env(model, vecnorm, n_sheep, max_steps=2000):
|
|
raw = DummyVecEnv([make_env(n_sheep, seed=9999, max_steps=max_steps)])
|
|
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
|
vn.obs_rms = deepcopy(vecnorm.obs_rms)
|
|
vn.ret_rms = deepcopy(vecnorm.ret_rms)
|
|
return vn
|
|
|
|
|
|
def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THRESHOLD):
|
|
print(f"\n{'='*52}")
|
|
print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})")
|
|
print(f" {'─'*48}")
|
|
for mode, cnt in sorted(failure_counts.items(), key=lambda x: -x[1]):
|
|
bar = "█" * cnt
|
|
print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}")
|
|
print(f"{'='*52}")
|
|
|
|
passed = success_rate >= threshold
|
|
if passed:
|
|
print(f" ✓ PASS (threshold {threshold*100:.0f}%)")
|
|
else:
|
|
dominant = max(failure_counts, key=failure_counts.get)
|
|
print(f" ✗ FAIL — dominant: {dominant}")
|
|
if dominant == "NEVER_COMPACT":
|
|
print(" Dog can't compact flock. Check W_COLLECT, obs contains straggler positions?")
|
|
elif dominant == "COMPACT_CANT_DRIVE":
|
|
print(" Flock compacts but dog doesn't drive to pen. Check alignment reward / W_DRIVE.")
|
|
elif dominant.startswith("PARTIAL"):
|
|
print(" Flock splits near pen. Dog loses stragglers at the end.")
|
|
print()
|
|
return passed
|
|
|
|
|
|
SHEEP_COLORS = ["#e41a1c","#377eb8","#4daf4a","#984ea3","#ff7f00",
|
|
"#a65628","#f781bf","#999999","#66c2a5","#fc8d62"]
|
|
|
|
def _save_smoke_vis(model, vn, n_sheep, save_dir, seed=42, max_steps=2000):
|
|
"""Run one episode and save trajectory + timeseries PNGs."""
|
|
from copy import deepcopy
|
|
raw = DummyVecEnv([make_env(n_sheep, seed=seed, max_steps=max_steps)])
|
|
env = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
|
env.obs_rms = deepcopy(vn.obs_rms)
|
|
env.ret_rms = deepcopy(vn.ret_rms)
|
|
|
|
obs = env.reset()
|
|
inner = env.envs[0]
|
|
dog_xs, dog_ys = [], []
|
|
sheep_xs = [[] for _ in range(n_sheep)]
|
|
sheep_ys = [[] for _ in range(n_sheep)]
|
|
radii, action_mags, rewards = [], [], []
|
|
pen_dists = [[] for _ in range(n_sheep)]
|
|
done = False
|
|
|
|
while not done:
|
|
action, _ = model.predict(obs, deterministic=True)
|
|
obs, reward, dones, _ = env.step(action)
|
|
done = dones[0]
|
|
dog_xs.append(float(inner.dog_pos[0])); dog_ys.append(float(inner.dog_pos[1]))
|
|
com, radius, _ = inner._flock_stats()
|
|
radii.append(radius)
|
|
rewards.append(float(reward[0]))
|
|
action_mags.append(float(np.linalg.norm(action[0])))
|
|
for i in range(n_sheep):
|
|
sheep_xs[i].append(float(inner.sheep_pos[i][0]))
|
|
sheep_ys[i].append(float(inner.sheep_pos[i][1]))
|
|
pen_dists[i].append(float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
|
|
env.close()
|
|
|
|
steps = len(dog_xs)
|
|
# Trajectory
|
|
fig, ax = plt.subplots(figsize=(6,6))
|
|
ax.set_xlim(-16,16); ax.set_ylim(-16,16); ax.set_aspect("equal")
|
|
ax.set_facecolor("#dcedc8")
|
|
ax.add_patch(mpatches.Rectangle((-15,-15),30,30,fill=False,edgecolor="#795548",lw=2))
|
|
ax.add_patch(mpatches.Rectangle((10,-15),3,7,facecolor="#ffe082",edgecolor="#795548",lw=2))
|
|
ax.text(11.5,-11.5,"pen",ha="center",va="center",fontsize=8,color="#795548")
|
|
for i in range(n_sheep):
|
|
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
|
ax.plot(sheep_xs[i], sheep_ys[i], color=c, lw=1, alpha=0.6, label=f"sheep {i+1}")
|
|
ax.plot(sheep_xs[i][0], sheep_ys[i][0], "o", color=c, ms=7)
|
|
ax.plot(sheep_xs[i][-1], sheep_ys[i][-1], "*", color=c, ms=10)
|
|
ax.plot(dog_xs, dog_ys, color="#4e342e", lw=1.5, label="dog", alpha=0.8)
|
|
ax.plot(dog_xs[0], dog_ys[0], "s", color="#4e342e", ms=9)
|
|
ax.plot(dog_xs[-1], dog_ys[-1], "D", color="#4e342e", ms=9)
|
|
ax.set_title(f"n_sheep={n_sheep} {steps} steps min_r={min(radii):.1f}m")
|
|
ax.legend(fontsize=7, loc="upper left")
|
|
plt.tight_layout()
|
|
fig.savefig(os.path.join(save_dir, "trajectory.png"), dpi=100)
|
|
plt.close(fig)
|
|
|
|
# Timeseries
|
|
t = np.arange(steps)
|
|
fig, axes = plt.subplots(4,1,figsize=(10,8),sharex=True)
|
|
axes[0].plot(t, radii, color="steelblue"); axes[0].axhline(5,color="orange",ls="--",lw=1)
|
|
axes[0].set_ylabel("radius (m)"); axes[0].set_title("Flock radius (orange=5m threshold)")
|
|
for i in range(n_sheep):
|
|
axes[1].plot(t, pen_dists[i], color=SHEEP_COLORS[i%len(SHEEP_COLORS)], lw=1, label=f"sheep {i+1}")
|
|
axes[1].set_ylabel("pen dist (m)"); axes[1].set_title("Per-sheep distance to pen"); axes[1].legend(fontsize=7)
|
|
axes[2].plot(t, action_mags, color="tomato", lw=1, alpha=0.8)
|
|
axes[2].axhline(1.0,color="gray",ls="--",lw=1); axes[2].set_ylim(0,1.5)
|
|
axes[2].set_ylabel("action mag"); axes[2].set_title("Dog action magnitude (0=stopped)")
|
|
axes[3].plot(t, rewards, color="purple", lw=1, alpha=0.7); axes[3].axhline(0,color="black",lw=0.5)
|
|
axes[3].set_ylabel("reward"); axes[3].set_xlabel("step"); axes[3].set_title("Reward per step")
|
|
fig.suptitle(f"Smoke stage n_sheep={n_sheep}", fontsize=12)
|
|
plt.tight_layout()
|
|
fig.savefig(os.path.join(save_dir, "timeseries.png"), dpi=100)
|
|
plt.close(fig)
|
|
print(f" Viz saved to {save_dir}/trajectory.png + timeseries.png")
|
|
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--steps", type=int, default=500_000,
|
|
help="Steps per smoke-test stage (default 500k)")
|
|
p.add_argument("--n-envs", type=int, default=4)
|
|
p.add_argument("--episodes", type=int, default=30,
|
|
help="Validation episodes per stage")
|
|
p.add_argument("--render", action="store_true")
|
|
args = p.parse_args()
|
|
|
|
# 1 sheep (500k): hard check — obs/reward structurally correct?
|
|
# Thresholds are MINIMUM bars — smoke test always runs ALL stages even on failure.
|
|
# The per-episode diagnostics tell you WHY a stage failed.
|
|
stages = [(1, args.steps, 0.10), (2, args.steps * 2, 0.20), (3, args.steps * 3, 0.10)]
|
|
|
|
model, vn = None, None
|
|
stage_results = []
|
|
|
|
for n_sheep, steps, threshold in stages:
|
|
print(f"\n{'#'*52}")
|
|
print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps")
|
|
print(f"{'#'*52}")
|
|
|
|
model, vn = train_stage(n_sheep, steps, args.n_envs, model, vn)
|
|
|
|
eval_env = make_eval_env(model, vn, n_sheep)
|
|
success_rate, failure_counts, diag = run_episodes(
|
|
model, eval_env, args.episodes, render=args.render
|
|
)
|
|
eval_env.close()
|
|
|
|
save_dir = f"runs/smoke_stage{n_sheep}"
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
model.save(os.path.join(save_dir, "model"))
|
|
vn.save(os.path.join(save_dir, "vecnorm.pkl"))
|
|
print(f" Model saved to {save_dir}/")
|
|
_save_smoke_vis(model, vn, n_sheep, save_dir)
|
|
|
|
passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
|
|
stage_results.append((n_sheep, success_rate, passed, diag))
|
|
|
|
if not passed:
|
|
print(f" ⚠ Stage {n_sheep} BELOW threshold — continuing to next stage.")
|
|
print(f" mean_action={diag['mean_action_mag']:.3f} "
|
|
f"best_pen_approach={diag['best_min_pen_dist']:.1f}m")
|
|
if diag['mean_action_mag'] < 0.05:
|
|
print(" !! Dog is NOT moving (sit-still). "
|
|
"Check ent_coef / step_cost / alignment.")
|
|
elif diag['best_min_pen_dist'] > 5.0:
|
|
print(" !! Dog never gets sheep near pen. "
|
|
"Check reward direction / initialization.")
|
|
else:
|
|
print(" !! Dog moves and approaches pen but low success rate. "
|
|
"Likely needs more training time.")
|
|
|
|
print(f"\n{'='*52}")
|
|
print(" SMOKE TEST SUMMARY")
|
|
print(f"{'='*52}")
|
|
all_passed = True
|
|
for n_sheep, sr, passed, diag in stage_results:
|
|
status = "PASS" if passed else "FAIL"
|
|
print(f" n_sheep={n_sheep} success={sr*100:.0f}% "
|
|
f"act={diag['mean_action_mag']:.2f} "
|
|
f"best_pen={diag['best_min_pen_dist']:.1f}m [{status}]")
|
|
if not passed:
|
|
all_passed = False
|
|
|
|
if all_passed:
|
|
print("\n All stages passed. Ready for full curriculum training:")
|
|
print(" python train.py --curriculum --steps-per-stage 1500000 "
|
|
"--total-steps 15000000 --n-sheep 1 --max-sheep 10 "
|
|
"--n-envs 8 --run-dir runs/ppo_v3")
|
|
else:
|
|
print("\n Some stages below threshold — check diagnostics above.")
|
|
print(" Key signals: act<0.05=sit-still, best_pen>5=wrong direction, "
|
|
"else needs more training time.")
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|