Sheep training flock of 10 fix?
This commit is contained in:
+115
-2
@@ -19,6 +19,7 @@ Usage examples
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
@@ -28,10 +29,25 @@ from stable_baselines3.common.callbacks import (
|
||||
CheckpointCallback,
|
||||
EvalCallback,
|
||||
)
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv, VecNormalize
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
|
||||
COMPACT_RADIUS = HerdingEnv.DRIVE_GATE_RADIUS
|
||||
|
||||
|
||||
def _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success):
|
||||
if success:
|
||||
return "SUCCESS"
|
||||
if min(ep_radius) > COMPACT_RADIUS:
|
||||
return "NEVER_COMPACT"
|
||||
first = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
|
||||
if min(ep_com_dist[first:]) > 3.0:
|
||||
return "COMPACT_CANT_DRIVE"
|
||||
if n_penned == 0:
|
||||
return "DROVE_NO_SHEEP"
|
||||
return f"PARTIAL_{n_penned}of{n_sheep}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Curriculum callback
|
||||
@@ -101,6 +117,96 @@ class CurriculumCallback(BaseCallback):
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Diagnostic callback — failure-mode breakdown every diag_freq steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DiagnosticCallback(BaseCallback):
|
||||
"""
|
||||
Every diag_freq env steps: spin up a temporary eval env, run n_episodes
|
||||
deterministic episodes, and print a failure-mode breakdown.
|
||||
Aborts training (returns False) if the dominant failure mode hasn't
|
||||
changed after two consecutive checks at the same n_sheep — a sign that
|
||||
training has stalled and further steps are wasted.
|
||||
"""
|
||||
|
||||
def __init__(self, diag_freq: int = 500_000, n_episodes: int = 20,
|
||||
max_steps: int = 2000, verbose: int = 1):
|
||||
super().__init__(verbose)
|
||||
self.diag_freq = diag_freq
|
||||
self.n_episodes = n_episodes
|
||||
self.max_steps = max_steps
|
||||
self._last_diag = 0
|
||||
self._prev_dominant = None # (n_sheep, mode) from last check
|
||||
self._stall_count = 0
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.num_timesteps - self._last_diag < self.diag_freq:
|
||||
return True
|
||||
self._last_diag = self.num_timesteps
|
||||
|
||||
n_sheep = self.training_env.get_attr("n_sheep")[0]
|
||||
|
||||
# Build a temporary single-env with copied VecNorm stats
|
||||
raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep,
|
||||
max_steps=self.max_steps)])
|
||||
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
vn.obs_rms = deepcopy(self.training_env.obs_rms)
|
||||
vn.ret_rms = deepcopy(self.training_env.ret_rms)
|
||||
|
||||
failure_counts = {}
|
||||
successes = 0
|
||||
|
||||
for _ in range(self.n_episodes):
|
||||
obs = vn.reset()
|
||||
done = False
|
||||
ep_radius, ep_com_dist = [], []
|
||||
n_penned = 0
|
||||
|
||||
while not done:
|
||||
action, _ = self.model.predict(obs, deterministic=True)
|
||||
obs, _, dones, infos = vn.step(action)
|
||||
done = dones[0]
|
||||
inner = vn.envs[0]
|
||||
com, radius, _ = inner._flock_stats()
|
||||
ep_radius.append(radius)
|
||||
ep_com_dist.append(
|
||||
float(np.linalg.norm(com - inner.PEN_CENTER))
|
||||
)
|
||||
|
||||
n_penned = infos[0].get("n_penned", 0)
|
||||
success = n_penned == n_sheep
|
||||
successes += int(success)
|
||||
mode = _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success)
|
||||
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
||||
|
||||
vn.close()
|
||||
|
||||
success_rate = successes / self.n_episodes
|
||||
dominant = max(failure_counts, key=failure_counts.get)
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n[Diag @ {self.num_timesteps:,} | n_sheep={n_sheep} | "
|
||||
f"success={success_rate*100:.0f}%]")
|
||||
for m, c in sorted(failure_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {m:<26} {c}/{self.n_episodes}")
|
||||
|
||||
# Stall detection: same dominant failure at same n_sheep twice in a row
|
||||
key = (n_sheep, dominant)
|
||||
if key == self._prev_dominant and dominant != "SUCCESS":
|
||||
self._stall_count += 1
|
||||
if self._stall_count >= 2:
|
||||
print(f"\n[Diag] STALL DETECTED — '{dominant}' on {n_sheep} sheep "
|
||||
f"for {self._stall_count} consecutive checks. "
|
||||
f"Aborting training early.")
|
||||
return False
|
||||
else:
|
||||
self._stall_count = 0
|
||||
self._prev_dominant = key
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment factory
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -141,6 +247,8 @@ def parse_args():
|
||||
p.add_argument("--save-freq", type=int, default=100_000)
|
||||
p.add_argument("--eval-freq", type=int, default=50_000)
|
||||
p.add_argument("--eval-eps", type=int, default=20)
|
||||
p.add_argument("--diag-freq", type=int, default=500_000,
|
||||
help="Run failure-mode diagnostics every N env steps")
|
||||
p.add_argument("--mixed", action="store_true",
|
||||
help="Randomise n_sheep each episode (consolidation pass, "
|
||||
"use with --resume after curriculum training)")
|
||||
@@ -193,7 +301,12 @@ def main():
|
||||
deterministic=True,
|
||||
verbose=1,
|
||||
)
|
||||
callbacks = [checkpoint_cb, eval_cb]
|
||||
diag_cb = DiagnosticCallback(
|
||||
diag_freq=max(args.diag_freq // args.n_envs, 1),
|
||||
n_episodes=20,
|
||||
max_steps=args.max_steps,
|
||||
)
|
||||
callbacks = [checkpoint_cb, eval_cb, diag_cb]
|
||||
|
||||
if args.curriculum:
|
||||
cur_cb = CurriculumCallback(
|
||||
|
||||
Reference in New Issue
Block a user