Sheep training flock of 10 fix?
This commit is contained in:
@@ -84,23 +84,25 @@ def build_obs(dog_pos: np.ndarray,
|
|||||||
d_from_com = np.linalg.norm(active_pos - com, axis=1)
|
d_from_com = np.linalg.norm(active_pos - com, axis=1)
|
||||||
sorted_idx = np.argsort(d_from_com)[::-1]
|
sorted_idx = np.argsort(d_from_com)[::-1]
|
||||||
radius = float(d_from_com[sorted_idx[0]])
|
radius = float(d_from_com[sorted_idx[0]])
|
||||||
far = active_pos[sorted_idx[0]]
|
def nth(n):
|
||||||
second_far_dist = float(d_from_com[sorted_idx[1]]) if len(sorted_idx) > 1 else 0.0
|
return active_pos[sorted_idx[n]] if len(sorted_idx) > n else com
|
||||||
|
far1, far2, far3 = nth(0), nth(1), nth(2)
|
||||||
else:
|
else:
|
||||||
com = PEN_CENTER.copy()
|
com = PEN_CENTER.copy()
|
||||||
radius = second_far_dist = 0.0
|
radius = 0.0
|
||||||
far = PEN_CENTER.copy()
|
far1 = far2 = far3 = PEN_CENTER.copy()
|
||||||
|
|
||||||
frac_active = n_active / max(n_sheep, 1)
|
frac_active = n_active / max(n_sheep, 1)
|
||||||
|
|
||||||
return np.array([
|
return np.array([
|
||||||
dog_pos[0] / FIELD, dog_pos[1] / FIELD,
|
dog_pos[0] / FIELD, dog_pos[1] / FIELD,
|
||||||
(com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D,
|
(com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D,
|
||||||
(far[0] - dog_pos[0]) / D, (far[1] - dog_pos[1]) / D,
|
(far1[0] - dog_pos[0]) / D, (far1[1] - dog_pos[1]) / D,
|
||||||
(PEN_CENTER[0] - com[0]) / D, (PEN_CENTER[1] - com[1]) / D,
|
(far2[0] - dog_pos[0]) / D, (far2[1] - dog_pos[1]) / D,
|
||||||
(PEN_CENTER[0] - far[0]) / D, (PEN_CENTER[1] - far[1]) / D,
|
(far3[0] - dog_pos[0]) / D, (far3[1] - dog_pos[1]) / D,
|
||||||
radius / D,
|
(PEN_CENTER[0] - com[0]) / D, (PEN_CENTER[1] - com[1]) / D,
|
||||||
second_far_dist / D,
|
(PEN_CENTER[0] - far1[0]) / D, (PEN_CENTER[1] - far1[1]) / D,
|
||||||
|
radius / D,
|
||||||
frac_active,
|
frac_active,
|
||||||
], dtype=np.float32)
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
|||||||
+17
-21
@@ -56,7 +56,7 @@ class HerdingEnv(gym.Env):
|
|||||||
W_DRIVE = 2.0 # progress: COM moved toward pen (only when compact)
|
W_DRIVE = 2.0 # progress: COM moved toward pen (only when compact)
|
||||||
W_COLLECT = 4.0 # progress: radius shrank (2× stronger when scattered)
|
W_COLLECT = 4.0 # progress: radius shrank (2× stronger when scattered)
|
||||||
W_ALIGN = 0.5 # position: dog on anti-pen side of COM
|
W_ALIGN = 0.5 # position: dog on anti-pen side of COM
|
||||||
W_COMPACT_BONUS = 0.1 # per-step bonus for staying compact (sustained signal)
|
W_COMPACT_BONUS = 0.0 # disabled: 0.1/step over 4000 steps = 400 >> W_COMPLETE=100
|
||||||
W_PEN_BONUS = 10.0 # per sheep penned
|
W_PEN_BONUS = 10.0 # per sheep penned
|
||||||
W_COMPLETE = 100.0 # all sheep penned
|
W_COMPLETE = 100.0 # all sheep penned
|
||||||
W_STEP_COST = 0.002 # time penalty
|
W_STEP_COST = 0.002 # time penalty
|
||||||
@@ -72,11 +72,11 @@ class HerdingEnv(gym.Env):
|
|||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
self.random_n_sheep = random_n_sheep # if True, randomise n_sheep each reset
|
self.random_n_sheep = random_n_sheep # if True, randomise n_sheep each reset
|
||||||
|
|
||||||
# Fixed 13-dim observation regardless of n_sheep:
|
# Fixed 17-dim observation regardless of n_sheep:
|
||||||
# dog_pos(2) + rel_com(2) + rel_far(2) + com_to_pen(2)
|
# dog_pos(2) + rel_com(2) + rel_far1(2) + rel_far2(2) + rel_far3(2)
|
||||||
# + far_to_pen(2) + radius(1) + second_far_dist(1) + frac_penned(1)
|
# + com_to_pen(2) + far1_to_pen(2) + radius(1) + frac_penned(1)
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float32
|
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float32
|
||||||
)
|
)
|
||||||
|
|
||||||
# Action: desired velocity (vx, vy) ∈ [-1, 1]², scaled by DOG_SPEED
|
# Action: desired velocity (vx, vy) ∈ [-1, 1]², scaled by DOG_SPEED
|
||||||
@@ -269,29 +269,25 @@ class HerdingEnv(gym.Env):
|
|||||||
pts = self.sheep_pos[:self.n_sheep][active_mask]
|
pts = self.sheep_pos[:self.n_sheep][active_mask]
|
||||||
dists = np.linalg.norm(pts - com, axis=1)
|
dists = np.linalg.norm(pts - com, axis=1)
|
||||||
sorted_idx = np.argsort(dists)[::-1] # farthest first
|
sorted_idx = np.argsort(dists)[::-1] # farthest first
|
||||||
far = pts[sorted_idx[0]]
|
# Top-3 stragglers; pad with COM when fewer active sheep exist
|
||||||
# 2nd farthest — if only 1 active sheep, reuse the same position
|
def nth(n):
|
||||||
far2 = pts[sorted_idx[1]] if len(sorted_idx) > 1 else far
|
return pts[sorted_idx[n]] if len(sorted_idx) > n else com
|
||||||
second_far_dist = float(dists[sorted_idx[1]]) if len(sorted_idx) > 1 else 0.0
|
far1, far2, far3 = nth(0), nth(1), nth(2)
|
||||||
else:
|
else:
|
||||||
far = far2 = self.PEN_CENTER.copy()
|
far1 = far2 = far3 = self.PEN_CENTER.copy()
|
||||||
second_far_dist = 0.0
|
|
||||||
|
|
||||||
S = self.FIELD
|
S = self.FIELD
|
||||||
D = 2 * self.FIELD
|
D = 2 * self.FIELD
|
||||||
|
|
||||||
return np.array([
|
return np.array([
|
||||||
self.dog_pos[0] / S, self.dog_pos[1] / S,
|
self.dog_pos[0] / S, self.dog_pos[1] / S,
|
||||||
(com[0] - self.dog_pos[0]) / D,
|
(com[0] - self.dog_pos[0]) / D, (com[1] - self.dog_pos[1]) / D,
|
||||||
(com[1] - self.dog_pos[1]) / D,
|
(far1[0] - self.dog_pos[0]) / D, (far1[1] - self.dog_pos[1]) / D,
|
||||||
(far[0] - self.dog_pos[0]) / D,
|
(far2[0] - self.dog_pos[0]) / D, (far2[1] - self.dog_pos[1]) / D,
|
||||||
(far[1] - self.dog_pos[1]) / D,
|
(far3[0] - self.dog_pos[0]) / D, (far3[1] - self.dog_pos[1]) / D,
|
||||||
(self.PEN_CENTER[0] - com[0]) / D,
|
(self.PEN_CENTER[0] - com[0]) / D, (self.PEN_CENTER[1] - com[1]) / D,
|
||||||
(self.PEN_CENTER[1] - com[1]) / D,
|
(self.PEN_CENTER[0] - far1[0]) / D, (self.PEN_CENTER[1] - far1[1]) / D,
|
||||||
(self.PEN_CENTER[0] - far[0]) / D,
|
radius / D,
|
||||||
(self.PEN_CENTER[1] - far[1]) / D,
|
|
||||||
radius / D,
|
|
||||||
second_far_dist / D, # replaced mean_disp: 2nd farthest sheep from COM
|
|
||||||
active_mask.sum() / self.n_sheep,
|
active_mask.sum() / self.n_sheep,
|
||||||
], dtype=np.float32)
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
Quick sanity check before committing to a full 15M-step training run.
|
||||||
|
|
||||||
|
Trains 1 sheep for 500k steps (~5 min), then 3 sheep for 500k steps.
|
||||||
|
If both pass, the obs/reward setup is sound and full training is worth running.
|
||||||
|
If either fails, abort and fix before wasting 15M steps.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python smoke_test.py # fresh run
|
||||||
|
python smoke_test.py --render # watch episodes after each stage
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
|
||||||
|
|
||||||
|
from herding_env import HerdingEnv
|
||||||
|
|
||||||
|
|
||||||
|
COMPACT_RADIUS = 5.0
|
||||||
|
PASS_THRESHOLD = 0.60 # success rate required to pass each stage
|
||||||
|
|
||||||
|
|
||||||
|
def make_env(n_sheep, seed, max_steps=2000):
|
||||||
|
def _init():
|
||||||
|
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
|
||||||
|
env.reset(seed=seed)
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success):
|
||||||
|
if success:
|
||||||
|
return "SUCCESS"
|
||||||
|
if min(ep_radius) > COMPACT_RADIUS:
|
||||||
|
return "NEVER_COMPACT"
|
||||||
|
first_compact = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
|
||||||
|
if min(ep_com_dist[first_compact:]) > 3.0:
|
||||||
|
return "COMPACT_CANT_DRIVE"
|
||||||
|
if n_penned == 0:
|
||||||
|
return "DROVE_NO_SHEEP"
|
||||||
|
return f"PARTIAL_{n_penned}of{n_sheep}"
|
||||||
|
|
||||||
|
|
||||||
|
def run_episodes(model, eval_env, n_episodes=30, max_steps=2000, render=False):
|
||||||
|
"""Run N deterministic episodes; return failure mode counts and success rate."""
|
||||||
|
failure_counts = {}
|
||||||
|
successes = 0
|
||||||
|
|
||||||
|
for ep in range(n_episodes):
|
||||||
|
obs = eval_env.reset()
|
||||||
|
done = False
|
||||||
|
ep_radius, ep_com_dist = [], []
|
||||||
|
n_penned = 0
|
||||||
|
n_sheep = 1
|
||||||
|
|
||||||
|
while not done:
|
||||||
|
action, _ = model.predict(obs, deterministic=True)
|
||||||
|
obs, _, dones, infos = eval_env.step(action)
|
||||||
|
done = dones[0]
|
||||||
|
|
||||||
|
inner = eval_env.envs[0]
|
||||||
|
com, radius, _ = inner._flock_stats()
|
||||||
|
com_dist = float(np.linalg.norm(com - inner.PEN_CENTER))
|
||||||
|
ep_radius.append(radius)
|
||||||
|
ep_com_dist.append(com_dist)
|
||||||
|
|
||||||
|
if render and ep == 0:
|
||||||
|
inner.render()
|
||||||
|
|
||||||
|
info = infos[0]
|
||||||
|
n_penned = info.get("n_penned", 0)
|
||||||
|
n_sheep = info.get("n_sheep", 1)
|
||||||
|
success = n_penned == n_sheep
|
||||||
|
successes += int(success)
|
||||||
|
mode = classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success)
|
||||||
|
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
||||||
|
|
||||||
|
success_rate = successes / n_episodes
|
||||||
|
return success_rate, failure_counts
|
||||||
|
|
||||||
|
|
||||||
|
def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None):
|
||||||
|
"""Train one stage; return (model, vecnorm)."""
|
||||||
|
train_env = SubprocVecEnv([make_env(n_sheep, i) for i in range(n_envs)])
|
||||||
|
|
||||||
|
if prev_vecnorm is not None:
|
||||||
|
vn = deepcopy(prev_vecnorm)
|
||||||
|
vn.set_venv(train_env)
|
||||||
|
vn.training = True
|
||||||
|
vn.norm_reward = True
|
||||||
|
else:
|
||||||
|
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
|
||||||
|
|
||||||
|
if prev_model is not None:
|
||||||
|
model = PPO.load(prev_model, env=vn,
|
||||||
|
learning_rate=3e-4, n_steps=2048, batch_size=256,
|
||||||
|
n_epochs=10, gamma=0.995, gae_lambda=0.95,
|
||||||
|
clip_range=0.2, ent_coef=0.005, vf_coef=0.5,
|
||||||
|
max_grad_norm=0.5)
|
||||||
|
else:
|
||||||
|
model = PPO(
|
||||||
|
"MlpPolicy", vn,
|
||||||
|
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
|
||||||
|
gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.005,
|
||||||
|
vf_coef=0.5, max_grad_norm=0.5,
|
||||||
|
policy_kwargs=dict(net_arch=[256, 256]),
|
||||||
|
verbose=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None))
|
||||||
|
return model, vn
|
||||||
|
|
||||||
|
|
||||||
|
def make_eval_env(model, vecnorm, n_sheep, max_steps=2000):
|
||||||
|
raw = DummyVecEnv([make_env(n_sheep, seed=9999, max_steps=max_steps)])
|
||||||
|
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||||
|
vn.obs_rms = deepcopy(vecnorm.obs_rms)
|
||||||
|
vn.ret_rms = deepcopy(vecnorm.ret_rms)
|
||||||
|
return vn
|
||||||
|
|
||||||
|
|
||||||
|
def report(n_sheep, success_rate, failure_counts, n_episodes):
|
||||||
|
print(f"\n{'='*52}")
|
||||||
|
print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})")
|
||||||
|
print(f" {'─'*48}")
|
||||||
|
for mode, cnt in sorted(failure_counts.items(), key=lambda x: -x[1]):
|
||||||
|
bar = "█" * cnt
|
||||||
|
print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}")
|
||||||
|
print(f"{'='*52}")
|
||||||
|
|
||||||
|
passed = success_rate >= PASS_THRESHOLD
|
||||||
|
if passed:
|
||||||
|
print(f" ✓ PASS (threshold {PASS_THRESHOLD*100:.0f}%)")
|
||||||
|
else:
|
||||||
|
dominant = max(failure_counts, key=failure_counts.get)
|
||||||
|
print(f" ✗ FAIL — dominant: {dominant}")
|
||||||
|
if dominant == "NEVER_COMPACT":
|
||||||
|
print(" Dog can't compact flock. Check W_COLLECT, obs contains straggler positions?")
|
||||||
|
elif dominant == "COMPACT_CANT_DRIVE":
|
||||||
|
print(" Flock compacts but dog doesn't drive to pen. Check alignment reward / W_DRIVE.")
|
||||||
|
elif dominant.startswith("PARTIAL"):
|
||||||
|
print(" Flock splits near pen. Dog loses stragglers at the end.")
|
||||||
|
print()
|
||||||
|
return passed
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--steps", type=int, default=500_000,
|
||||||
|
help="Steps per smoke-test stage (default 500k)")
|
||||||
|
p.add_argument("--n-envs", type=int, default=4)
|
||||||
|
p.add_argument("--episodes", type=int, default=30,
|
||||||
|
help="Validation episodes per stage")
|
||||||
|
p.add_argument("--render", action="store_true")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
stages = [(1, args.steps), (3, args.steps)]
|
||||||
|
|
||||||
|
model, vn = None, None
|
||||||
|
all_passed = True
|
||||||
|
|
||||||
|
for n_sheep, steps in stages:
|
||||||
|
print(f"\n{'#'*52}")
|
||||||
|
print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps")
|
||||||
|
print(f"{'#'*52}")
|
||||||
|
|
||||||
|
model, vn = train_stage(n_sheep, steps, args.n_envs, model, vn)
|
||||||
|
|
||||||
|
eval_env = make_eval_env(model, vn, n_sheep)
|
||||||
|
success_rate, failure_counts = run_episodes(
|
||||||
|
model, eval_env, args.episodes, render=args.render
|
||||||
|
)
|
||||||
|
eval_env.close()
|
||||||
|
|
||||||
|
passed = report(n_sheep, success_rate, failure_counts, args.episodes)
|
||||||
|
if not passed:
|
||||||
|
all_passed = False
|
||||||
|
print(" Aborting smoke test — fix the issue above before full training.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if all_passed:
|
||||||
|
print("\n All smoke-test stages passed.")
|
||||||
|
print(" Ready for full curriculum training:")
|
||||||
|
print()
|
||||||
|
print(" python train.py --curriculum --steps-per-stage 1500000 \\")
|
||||||
|
print(" --total-steps 15000000 --n-sheep 1 --max-sheep 10 \\")
|
||||||
|
print(" --n-envs 8 --run-dir runs/ppo_v2")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
+115
-2
@@ -19,6 +19,7 @@ Usage examples
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
@@ -28,10 +29,25 @@ from stable_baselines3.common.callbacks import (
|
|||||||
CheckpointCallback,
|
CheckpointCallback,
|
||||||
EvalCallback,
|
EvalCallback,
|
||||||
)
|
)
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv, VecNormalize
|
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
|
||||||
|
|
||||||
from herding_env import HerdingEnv
|
from herding_env import HerdingEnv
|
||||||
|
|
||||||
|
COMPACT_RADIUS = HerdingEnv.DRIVE_GATE_RADIUS
|
||||||
|
|
||||||
|
|
||||||
|
def _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success):
|
||||||
|
if success:
|
||||||
|
return "SUCCESS"
|
||||||
|
if min(ep_radius) > COMPACT_RADIUS:
|
||||||
|
return "NEVER_COMPACT"
|
||||||
|
first = next(i for i, r in enumerate(ep_radius) if r <= COMPACT_RADIUS)
|
||||||
|
if min(ep_com_dist[first:]) > 3.0:
|
||||||
|
return "COMPACT_CANT_DRIVE"
|
||||||
|
if n_penned == 0:
|
||||||
|
return "DROVE_NO_SHEEP"
|
||||||
|
return f"PARTIAL_{n_penned}of{n_sheep}"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Curriculum callback
|
# Curriculum callback
|
||||||
@@ -101,6 +117,96 @@ class CurriculumCallback(BaseCallback):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Diagnostic callback — failure-mode breakdown every diag_freq steps
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class DiagnosticCallback(BaseCallback):
|
||||||
|
"""
|
||||||
|
Every diag_freq env steps: spin up a temporary eval env, run n_episodes
|
||||||
|
deterministic episodes, and print a failure-mode breakdown.
|
||||||
|
Aborts training (returns False) if the dominant failure mode hasn't
|
||||||
|
changed after two consecutive checks at the same n_sheep — a sign that
|
||||||
|
training has stalled and further steps are wasted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, diag_freq: int = 500_000, n_episodes: int = 20,
|
||||||
|
max_steps: int = 2000, verbose: int = 1):
|
||||||
|
super().__init__(verbose)
|
||||||
|
self.diag_freq = diag_freq
|
||||||
|
self.n_episodes = n_episodes
|
||||||
|
self.max_steps = max_steps
|
||||||
|
self._last_diag = 0
|
||||||
|
self._prev_dominant = None # (n_sheep, mode) from last check
|
||||||
|
self._stall_count = 0
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
if self.num_timesteps - self._last_diag < self.diag_freq:
|
||||||
|
return True
|
||||||
|
self._last_diag = self.num_timesteps
|
||||||
|
|
||||||
|
n_sheep = self.training_env.get_attr("n_sheep")[0]
|
||||||
|
|
||||||
|
# Build a temporary single-env with copied VecNorm stats
|
||||||
|
raw = DummyVecEnv([lambda: HerdingEnv(n_sheep=n_sheep,
|
||||||
|
max_steps=self.max_steps)])
|
||||||
|
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||||
|
vn.obs_rms = deepcopy(self.training_env.obs_rms)
|
||||||
|
vn.ret_rms = deepcopy(self.training_env.ret_rms)
|
||||||
|
|
||||||
|
failure_counts = {}
|
||||||
|
successes = 0
|
||||||
|
|
||||||
|
for _ in range(self.n_episodes):
|
||||||
|
obs = vn.reset()
|
||||||
|
done = False
|
||||||
|
ep_radius, ep_com_dist = [], []
|
||||||
|
n_penned = 0
|
||||||
|
|
||||||
|
while not done:
|
||||||
|
action, _ = self.model.predict(obs, deterministic=True)
|
||||||
|
obs, _, dones, infos = vn.step(action)
|
||||||
|
done = dones[0]
|
||||||
|
inner = vn.envs[0]
|
||||||
|
com, radius, _ = inner._flock_stats()
|
||||||
|
ep_radius.append(radius)
|
||||||
|
ep_com_dist.append(
|
||||||
|
float(np.linalg.norm(com - inner.PEN_CENTER))
|
||||||
|
)
|
||||||
|
|
||||||
|
n_penned = infos[0].get("n_penned", 0)
|
||||||
|
success = n_penned == n_sheep
|
||||||
|
successes += int(success)
|
||||||
|
mode = _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success)
|
||||||
|
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
||||||
|
|
||||||
|
vn.close()
|
||||||
|
|
||||||
|
success_rate = successes / self.n_episodes
|
||||||
|
dominant = max(failure_counts, key=failure_counts.get)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"\n[Diag @ {self.num_timesteps:,} | n_sheep={n_sheep} | "
|
||||||
|
f"success={success_rate*100:.0f}%]")
|
||||||
|
for m, c in sorted(failure_counts.items(), key=lambda x: -x[1]):
|
||||||
|
print(f" {m:<26} {c}/{self.n_episodes}")
|
||||||
|
|
||||||
|
# Stall detection: same dominant failure at same n_sheep twice in a row
|
||||||
|
key = (n_sheep, dominant)
|
||||||
|
if key == self._prev_dominant and dominant != "SUCCESS":
|
||||||
|
self._stall_count += 1
|
||||||
|
if self._stall_count >= 2:
|
||||||
|
print(f"\n[Diag] STALL DETECTED — '{dominant}' on {n_sheep} sheep "
|
||||||
|
f"for {self._stall_count} consecutive checks. "
|
||||||
|
f"Aborting training early.")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
self._stall_count = 0
|
||||||
|
self._prev_dominant = key
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Environment factory
|
# Environment factory
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -141,6 +247,8 @@ def parse_args():
|
|||||||
p.add_argument("--save-freq", type=int, default=100_000)
|
p.add_argument("--save-freq", type=int, default=100_000)
|
||||||
p.add_argument("--eval-freq", type=int, default=50_000)
|
p.add_argument("--eval-freq", type=int, default=50_000)
|
||||||
p.add_argument("--eval-eps", type=int, default=20)
|
p.add_argument("--eval-eps", type=int, default=20)
|
||||||
|
p.add_argument("--diag-freq", type=int, default=500_000,
|
||||||
|
help="Run failure-mode diagnostics every N env steps")
|
||||||
p.add_argument("--mixed", action="store_true",
|
p.add_argument("--mixed", action="store_true",
|
||||||
help="Randomise n_sheep each episode (consolidation pass, "
|
help="Randomise n_sheep each episode (consolidation pass, "
|
||||||
"use with --resume after curriculum training)")
|
"use with --resume after curriculum training)")
|
||||||
@@ -193,7 +301,12 @@ def main():
|
|||||||
deterministic=True,
|
deterministic=True,
|
||||||
verbose=1,
|
verbose=1,
|
||||||
)
|
)
|
||||||
callbacks = [checkpoint_cb, eval_cb]
|
diag_cb = DiagnosticCallback(
|
||||||
|
diag_freq=max(args.diag_freq // args.n_envs, 1),
|
||||||
|
n_episodes=20,
|
||||||
|
max_steps=args.max_steps,
|
||||||
|
)
|
||||||
|
callbacks = [checkpoint_cb, eval_cb, diag_cb]
|
||||||
|
|
||||||
if args.curriculum:
|
if args.curriculum:
|
||||||
cur_cb = CurriculumCallback(
|
cur_cb = CurriculumCallback(
|
||||||
|
|||||||
Reference in New Issue
Block a user