Sheep training flock _ improver

This commit is contained in:
Johnny Fernandes
2026-04-25 11:31:39 +01:00
parent 062de676c9
commit fbe76a0d04
3 changed files with 190 additions and 24 deletions
+65 -15
View File
@@ -83,6 +83,13 @@ class CurriculumCallback(BaseCallback):
self._stage_start = 0
def _advance(self):
prev_sheep = self._cur_sheep
recent_sr = (np.mean(self._successes) if self._successes else float("nan"))
if self.verbose:
print(f"\n[Curriculum] leaving stage n_sheep={prev_sheep} "
f"after {self.num_timesteps - self._stage_start:,} steps "
f"| training success rate (last {len(self._successes)} eps) = "
f"{recent_sr*100:.0f}%")
self._cur_sheep += 1
self.training_env.env_method("set_n_sheep", self._cur_sheep)
if self.eval_env is not None:
@@ -90,26 +97,26 @@ class CurriculumCallback(BaseCallback):
self._stage_start = self.num_timesteps
self._successes.clear()
if self.verbose:
print(f"\n[Curriculum] → {self._cur_sheep} sheep "
print(f"[Curriculum] → {self._cur_sheep} sheep "
f"at step {self.num_timesteps:,}\n")
def _on_step(self) -> bool:
if self._cur_sheep >= self.max_sheep:
return True
# Always track training-side success (success = sheep all penned, not truncated)
for info, done in zip(self.locals["infos"], self.locals["dones"]):
if done:
npen = info.get("n_penned", 0)
nshp = info.get("n_sheep", self._cur_sheep)
self._successes.append(1 if npen == nshp else 0)
if len(self._successes) > self.window:
self._successes.pop(0)
if self.steps_per_stage is not None:
# Time-based: advance every steps_per_stage env steps
if self.num_timesteps - self._stage_start >= self.steps_per_stage:
self._advance()
else:
# Success-rate based
for info, done in zip(self.locals["infos"], self.locals["dones"]):
if done:
truncated = info.get("TimeLimit.truncated", False)
self._successes.append(0 if truncated else 1)
if len(self._successes) > self.window:
self._successes.pop(0)
if (len(self._successes) >= self.min_episodes
and np.mean(self._successes) >= self.threshold):
self._advance()
@@ -131,11 +138,13 @@ class DiagnosticCallback(BaseCallback):
"""
def __init__(self, diag_freq: int = 500_000, n_episodes: int = 20,
max_steps: int = 2000, verbose: int = 1):
max_steps: int = 2000, abort_on_stall: bool = True,
verbose: int = 1):
super().__init__(verbose)
self.diag_freq = diag_freq
self.n_episodes = n_episodes
self.max_steps = max_steps
self.abort_on_stall = abort_on_stall
self._last_diag = 0
self._prev_dominant = None # (n_sheep, mode) from last check
self._stall_count = 0
@@ -156,11 +165,19 @@ class DiagnosticCallback(BaseCallback):
failure_counts = {}
successes = 0
all_action_mags = []
ep_min_radii = []
ep_min_dog_com = [] # closest the dog ever got to flock COM
ep_min_pen_dists = [] # closest COM ever got to pen
rcomp_sums = {"progress":0.0,"alignment":0.0,"pen_bonus":0.0,
"step_cost":0.0,"complete":0.0}
rcomp_n = 0
for _ in range(self.n_episodes):
obs = vn.reset()
done = False
ep_radius, ep_com_dist = [], []
ep_radius, ep_com_dist, ep_dog_com = [], [], []
ep_actions = []
n_penned = 0
while not done:
@@ -173,12 +190,24 @@ class DiagnosticCallback(BaseCallback):
ep_com_dist.append(
float(np.linalg.norm(com - inner.PEN_CENTER))
)
ep_dog_com.append(
float(np.linalg.norm(inner.dog_pos - com))
)
ep_actions.append(float(np.linalg.norm(action[0])))
rc = infos[0].get("rcomps")
if rc is not None:
for k in rcomp_sums: rcomp_sums[k] += rc[k]
rcomp_n += 1
n_penned = infos[0].get("n_penned", 0)
success = n_penned == n_sheep
successes += int(success)
mode = _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
all_action_mags.extend(ep_actions)
ep_min_radii.append(min(ep_radius))
ep_min_dog_com.append(min(ep_dog_com))
ep_min_pen_dists.append(min(ep_com_dist))
vn.close()
@@ -190,13 +219,30 @@ class DiagnosticCallback(BaseCallback):
f"success={success_rate*100:.0f}%]")
for m, c in sorted(failure_counts.items(), key=lambda x: -x[1]):
print(f" {m:<26} {c}/{self.n_episodes}")
mean_act = float(np.mean(all_action_mags)) if all_action_mags else 0.0
p10 = float(np.percentile(all_action_mags, 10)) if all_action_mags else 0.0
p90 = float(np.percentile(all_action_mags, 90)) if all_action_mags else 0.0
print(f" action_mag mean={mean_act:.3f} p10={p10:.3f} p90={p90:.3f} "
f"(0=stopped, 1=full speed)")
print(f" min_flock_radius mean={np.mean(ep_min_radii):.2f}m "
f"best={np.min(ep_min_radii):.2f}m (target <5m to compact)")
print(f" min_dog_to_com mean={np.mean(ep_min_dog_com):.2f}m "
f"best={np.min(ep_min_dog_com):.2f}m (FLEE_DIST=7m)")
print(f" min_com_to_pen mean={np.mean(ep_min_pen_dists):.2f}m "
f"best={np.min(ep_min_pen_dists):.2f}m")
if rcomp_n > 0:
print(f" reward/step (mean): " + " ".join(
f"{k}={rcomp_sums[k]/rcomp_n:+.4f}" for k in
("progress","alignment","pen_bonus","step_cost","complete")
))
# Stall detection: same dominant failure at same n_sheep 5 checks in a row,
# and only after 3M total steps (give early stages time to warm up).
# Stall detection — disabled when --no-stall-abort or when we've never
# seen any stage succeed (we want full visibility into what's happening).
key = (n_sheep, dominant)
if key == self._prev_dominant and dominant != "SUCCESS":
self._stall_count += 1
if self._stall_count >= 5 and self.num_timesteps >= 3_000_000:
if (self.abort_on_stall and self._stall_count >= 5
and self.num_timesteps >= 3_000_000):
print(f"\n[Diag] STALL DETECTED — '{dominant}' on {n_sheep} sheep "
f"for {self._stall_count} consecutive checks. "
f"Aborting training early.")
@@ -250,6 +296,9 @@ def parse_args():
p.add_argument("--eval-eps", type=int, default=20)
p.add_argument("--diag-freq", type=int, default=500_000,
help="Run failure-mode diagnostics every N env steps")
p.add_argument("--no-stall-abort", action="store_true",
help="Disable early-abort on stall — run full --total-steps "
"for diagnostics")
p.add_argument("--mixed", action="store_true",
help="Randomise n_sheep each episode (consolidation pass, "
"use with --resume after curriculum training)")
@@ -306,6 +355,7 @@ def main():
diag_freq=args.diag_freq,
n_episodes=20,
max_steps=args.max_steps,
abort_on_stall=not args.no_stall_abort,
)
callbacks = [checkpoint_cb, eval_cb, diag_cb]