Sheep training flock of 10 fix?

This commit is contained in:
Johnny Fernandes
2026-04-24 15:24:37 +01:00
parent 44b2788e78
commit 678d757fe8
+9 -6
View File
@@ -123,7 +123,7 @@ def make_eval_env(model, vecnorm, n_sheep, max_steps=2000):
return vn return vn
def report(n_sheep, success_rate, failure_counts, n_episodes): def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THRESHOLD):
print(f"\n{'='*52}") print(f"\n{'='*52}")
print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})") print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})")
print(f" {''*48}") print(f" {''*48}")
@@ -132,9 +132,9 @@ def report(n_sheep, success_rate, failure_counts, n_episodes):
print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}") print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}")
print(f"{'='*52}") print(f"{'='*52}")
passed = success_rate >= PASS_THRESHOLD passed = success_rate >= threshold
if passed: if passed:
print(f" ✓ PASS (threshold {PASS_THRESHOLD*100:.0f}%)") print(f" ✓ PASS (threshold {threshold*100:.0f}%)")
else: else:
dominant = max(failure_counts, key=failure_counts.get) dominant = max(failure_counts, key=failure_counts.get)
print(f" ✗ FAIL — dominant: {dominant}") print(f" ✗ FAIL — dominant: {dominant}")
@@ -158,12 +158,15 @@ def main():
p.add_argument("--render", action="store_true") p.add_argument("--render", action="store_true")
args = p.parse_args() args = p.parse_args()
stages = [(1, args.steps), (3, args.steps)] # Stage 1 (1 sheep, 500k): fast sanity check — obs/reward structurally correct?
# Stage 2 (3 sheep, 1.5M): real test at curriculum pace — if it fails here,
# there is a genuine problem worth fixing before committing to 15M steps.
stages = [(1, args.steps, 0.60), (3, args.steps * 3, 0.40)]
model, vn = None, None model, vn = None, None
all_passed = True all_passed = True
for n_sheep, steps in stages: for n_sheep, steps, threshold in stages:
print(f"\n{'#'*52}") print(f"\n{'#'*52}")
print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps") print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps")
print(f"{'#'*52}") print(f"{'#'*52}")
@@ -176,7 +179,7 @@ def main():
) )
eval_env.close() eval_env.close()
passed = report(n_sheep, success_rate, failure_counts, args.episodes) passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
if not passed: if not passed:
all_passed = False all_passed = False
print(" Aborting smoke test — fix the issue above before full training.") print(" Aborting smoke test — fix the issue above before full training.")