Sheep training flock _ improver
This commit is contained in:
+100
-20
@@ -54,32 +54,58 @@ def classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success):
|
|||||||
|
|
||||||
|
|
||||||
def run_episodes(model, eval_env, n_episodes=30, max_steps=2000, render=False):
|
def run_episodes(model, eval_env, n_episodes=30, max_steps=2000, render=False):
|
||||||
"""Run N deterministic episodes; return failure mode counts and success rate."""
|
"""
|
||||||
|
Run N deterministic episodes.
|
||||||
|
Returns (success_rate, failure_counts, diagnostics_dict).
|
||||||
|
diagnostics_dict contains per-episode and aggregate stats useful for
|
||||||
|
understanding WHY the policy is failing without assuming the cause.
|
||||||
|
"""
|
||||||
failure_counts = {}
|
failure_counts = {}
|
||||||
successes = 0
|
successes = 0
|
||||||
|
|
||||||
|
all_action_mags = [] # action magnitude every step across all episodes
|
||||||
|
all_pen_progress = [] # per-episode: total pen-dist reduction (positive = good)
|
||||||
|
ep_steps_list = []
|
||||||
|
ep_min_pen_list = [] # min pen dist reached in each episode
|
||||||
|
|
||||||
for ep in range(n_episodes):
|
for ep in range(n_episodes):
|
||||||
obs = eval_env.reset()
|
obs = eval_env.reset()
|
||||||
done = False
|
done = False
|
||||||
ep_radius, ep_com_dist = [], []
|
ep_radius, ep_com_dist = [], []
|
||||||
|
ep_action_mags = []
|
||||||
n_penned = 0
|
n_penned = 0
|
||||||
n_sheep = 1
|
n_sheep = 1
|
||||||
|
prev_pen_dist = None
|
||||||
|
|
||||||
while not done:
|
while not done:
|
||||||
action, _ = model.predict(obs, deterministic=True)
|
action, _ = model.predict(obs, deterministic=True)
|
||||||
obs, _, dones, infos = eval_env.step(action)
|
obs, _, dones, infos = eval_env.step(action)
|
||||||
done = dones[0]
|
done = dones[0]
|
||||||
|
|
||||||
inner = eval_env.envs[0]
|
inner = eval_env.envs[0]
|
||||||
com, radius, _ = inner._flock_stats()
|
com, radius, _ = inner._flock_stats()
|
||||||
com_dist = float(np.linalg.norm(com - inner.PEN_CENTER))
|
com_dist = float(np.linalg.norm(com - inner.PEN_CENTER))
|
||||||
ep_radius.append(radius)
|
ep_radius.append(radius)
|
||||||
ep_com_dist.append(com_dist)
|
ep_com_dist.append(com_dist)
|
||||||
|
|
||||||
|
act_mag = float(np.linalg.norm(action[0]))
|
||||||
|
ep_action_mags.append(act_mag)
|
||||||
|
|
||||||
|
active = ~inner.penned[:inner.n_sheep]
|
||||||
|
if active.any():
|
||||||
|
pen_dist = float(np.linalg.norm(
|
||||||
|
inner.sheep_pos[:inner.n_sheep][active] - inner.PEN_CENTER, axis=1
|
||||||
|
).sum())
|
||||||
|
else:
|
||||||
|
pen_dist = 0.0
|
||||||
|
if prev_pen_dist is None:
|
||||||
|
prev_pen_dist = pen_dist
|
||||||
|
prev_pen_dist = pen_dist
|
||||||
|
|
||||||
if render and ep == 0:
|
if render and ep == 0:
|
||||||
inner.render()
|
inner.render()
|
||||||
|
|
||||||
info = infos[0]
|
info = infos[0]
|
||||||
n_penned = info.get("n_penned", 0)
|
n_penned = info.get("n_penned", 0)
|
||||||
n_sheep = info.get("n_sheep", 1)
|
n_sheep = info.get("n_sheep", 1)
|
||||||
success = n_penned == n_sheep
|
success = n_penned == n_sheep
|
||||||
@@ -87,8 +113,37 @@ def run_episodes(model, eval_env, n_episodes=30, max_steps=2000, render=False):
|
|||||||
mode = classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success)
|
mode = classify_failure(ep_radius, ep_com_dist, n_penned, n_sheep, success)
|
||||||
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
||||||
|
|
||||||
|
all_action_mags.extend(ep_action_mags)
|
||||||
|
ep_steps_list.append(len(ep_action_mags))
|
||||||
|
ep_min_pen_list.append(min(ep_com_dist))
|
||||||
|
|
||||||
|
# Per-episode one-liner for real-time feedback
|
||||||
|
mean_act = float(np.mean(ep_action_mags))
|
||||||
|
min_pen = min(ep_com_dist)
|
||||||
|
print(f" ep {ep+1:>3} steps={len(ep_action_mags):>5} "
|
||||||
|
f"penned={n_penned}/{n_sheep} "
|
||||||
|
f"act={mean_act:.2f} "
|
||||||
|
f"min_pen={min_pen:.1f}m [{mode}]")
|
||||||
|
|
||||||
success_rate = successes / n_episodes
|
success_rate = successes / n_episodes
|
||||||
return success_rate, failure_counts
|
|
||||||
|
diag = {
|
||||||
|
"mean_action_mag" : float(np.mean(all_action_mags)),
|
||||||
|
"p10_action_mag" : float(np.percentile(all_action_mags, 10)),
|
||||||
|
"p90_action_mag" : float(np.percentile(all_action_mags, 90)),
|
||||||
|
"mean_min_pen_dist": float(np.mean(ep_min_pen_list)),
|
||||||
|
"best_min_pen_dist": float(np.min(ep_min_pen_list)),
|
||||||
|
"mean_ep_steps" : float(np.mean(ep_steps_list)),
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\n Action magnitude mean={diag['mean_action_mag']:.3f} "
|
||||||
|
f"p10={diag['p10_action_mag']:.3f} p90={diag['p90_action_mag']:.3f}"
|
||||||
|
f" (0=stopped, 1=full speed)")
|
||||||
|
print(f" Pen distance mean_min={diag['mean_min_pen_dist']:.1f}m "
|
||||||
|
f"best_min={diag['best_min_pen_dist']:.1f}m "
|
||||||
|
f"(how close sheep got to pen center)")
|
||||||
|
|
||||||
|
return success_rate, failure_counts, diag
|
||||||
|
|
||||||
|
|
||||||
def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None):
|
def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None):
|
||||||
@@ -110,7 +165,7 @@ def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None):
|
|||||||
model = PPO(
|
model = PPO(
|
||||||
"MlpPolicy", vn,
|
"MlpPolicy", vn,
|
||||||
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
|
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
|
||||||
gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.005,
|
gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.02,
|
||||||
vf_coef=0.5, max_grad_norm=0.5,
|
vf_coef=0.5, max_grad_norm=0.5,
|
||||||
policy_kwargs=dict(net_arch=[256, 256]),
|
policy_kwargs=dict(net_arch=[256, 256]),
|
||||||
verbose=1,
|
verbose=1,
|
||||||
@@ -242,12 +297,12 @@ def main():
|
|||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
# 1 sheep (500k): hard check — obs/reward structurally correct?
|
# 1 sheep (500k): hard check — obs/reward structurally correct?
|
||||||
# 2 sheep (1M): soft check — proves multi-sheep learning has started
|
# Thresholds are MINIMUM bars — smoke test always runs ALL stages even on failure.
|
||||||
# 3 sheep (1.5M): directional check — not expected to fully converge here
|
# The per-episode diagnostics tell you WHY a stage failed.
|
||||||
stages = [(1, args.steps, 0.60), (2, args.steps * 2, 0.20), (3, args.steps * 3, 0.10)]
|
stages = [(1, args.steps, 0.10), (2, args.steps * 2, 0.20), (3, args.steps * 3, 0.10)]
|
||||||
|
|
||||||
model, vn = None, None
|
model, vn = None, None
|
||||||
all_passed = True
|
stage_results = []
|
||||||
|
|
||||||
for n_sheep, steps, threshold in stages:
|
for n_sheep, steps, threshold in stages:
|
||||||
print(f"\n{'#'*52}")
|
print(f"\n{'#'*52}")
|
||||||
@@ -257,7 +312,7 @@ def main():
|
|||||||
model, vn = train_stage(n_sheep, steps, args.n_envs, model, vn)
|
model, vn = train_stage(n_sheep, steps, args.n_envs, model, vn)
|
||||||
|
|
||||||
eval_env = make_eval_env(model, vn, n_sheep)
|
eval_env = make_eval_env(model, vn, n_sheep)
|
||||||
success_rate, failure_counts = run_episodes(
|
success_rate, failure_counts, diag = run_episodes(
|
||||||
model, eval_env, args.episodes, render=args.render
|
model, eval_env, args.episodes, render=args.render
|
||||||
)
|
)
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
@@ -270,19 +325,44 @@ def main():
|
|||||||
_save_smoke_vis(model, vn, n_sheep, save_dir)
|
_save_smoke_vis(model, vn, n_sheep, save_dir)
|
||||||
|
|
||||||
passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
|
passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
|
||||||
|
stage_results.append((n_sheep, success_rate, passed, diag))
|
||||||
|
|
||||||
|
if not passed:
|
||||||
|
print(f" ⚠ Stage {n_sheep} BELOW threshold — continuing to next stage.")
|
||||||
|
print(f" mean_action={diag['mean_action_mag']:.3f} "
|
||||||
|
f"best_pen_approach={diag['best_min_pen_dist']:.1f}m")
|
||||||
|
if diag['mean_action_mag'] < 0.05:
|
||||||
|
print(" !! Dog is NOT moving (sit-still). "
|
||||||
|
"Check ent_coef / step_cost / alignment.")
|
||||||
|
elif diag['best_min_pen_dist'] > 5.0:
|
||||||
|
print(" !! Dog never gets sheep near pen. "
|
||||||
|
"Check reward direction / initialization.")
|
||||||
|
else:
|
||||||
|
print(" !! Dog moves and approaches pen but low success rate. "
|
||||||
|
"Likely needs more training time.")
|
||||||
|
|
||||||
|
print(f"\n{'='*52}")
|
||||||
|
print(" SMOKE TEST SUMMARY")
|
||||||
|
print(f"{'='*52}")
|
||||||
|
all_passed = True
|
||||||
|
for n_sheep, sr, passed, diag in stage_results:
|
||||||
|
status = "PASS" if passed else "FAIL"
|
||||||
|
print(f" n_sheep={n_sheep} success={sr*100:.0f}% "
|
||||||
|
f"act={diag['mean_action_mag']:.2f} "
|
||||||
|
f"best_pen={diag['best_min_pen_dist']:.1f}m [{status}]")
|
||||||
if not passed:
|
if not passed:
|
||||||
all_passed = False
|
all_passed = False
|
||||||
print(" Aborting smoke test — fix the issue above before full training.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if all_passed:
|
if all_passed:
|
||||||
print("\n All smoke-test stages passed.")
|
print("\n All stages passed. Ready for full curriculum training:")
|
||||||
print(" Ready for full curriculum training:")
|
print(" python train.py --curriculum --steps-per-stage 1500000 "
|
||||||
print()
|
"--total-steps 15000000 --n-sheep 1 --max-sheep 10 "
|
||||||
print(" python train.py --curriculum --steps-per-stage 1500000 \\")
|
"--n-envs 8 --run-dir runs/ppo_v3")
|
||||||
print(" --total-steps 15000000 --n-sheep 1 --max-sheep 10 \\")
|
else:
|
||||||
print(" --n-envs 8 --run-dir runs/ppo_v2")
|
print("\n Some stages below threshold — check diagnostics above.")
|
||||||
print()
|
print(" Key signals: act<0.05=sit-still, best_pen>5=wrong direction, "
|
||||||
|
"else needs more training time.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
+1
-1
@@ -332,7 +332,7 @@ def main():
|
|||||||
gamma = 0.995,
|
gamma = 0.995,
|
||||||
gae_lambda = 0.95,
|
gae_lambda = 0.95,
|
||||||
clip_range = 0.2,
|
clip_range = 0.2,
|
||||||
ent_coef = 0.005,
|
ent_coef = 0.01,
|
||||||
vf_coef = 0.5,
|
vf_coef = 0.5,
|
||||||
max_grad_norm = 0.5,
|
max_grad_norm = 0.5,
|
||||||
policy_kwargs = dict(net_arch=[256, 256]),
|
policy_kwargs = dict(net_arch=[256, 256]),
|
||||||
|
|||||||
Reference in New Issue
Block a user