diff --git a/training/smoke_test.py b/training/smoke_test.py index cd1ba94..11e582c 100644 --- a/training/smoke_test.py +++ b/training/smoke_test.py @@ -123,7 +123,7 @@ def make_eval_env(model, vecnorm, n_sheep, max_steps=2000): return vn -def report(n_sheep, success_rate, failure_counts, n_episodes): +def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THRESHOLD): print(f"\n{'='*52}") print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})") print(f" {'─'*48}") @@ -132,9 +132,9 @@ def report(n_sheep, success_rate, failure_counts, n_episodes): print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}") print(f"{'='*52}") - passed = success_rate >= PASS_THRESHOLD + passed = success_rate >= threshold if passed: - print(f" ✓ PASS (threshold {PASS_THRESHOLD*100:.0f}%)") + print(f" ✓ PASS (threshold {threshold*100:.0f}%)") else: dominant = max(failure_counts, key=failure_counts.get) print(f" ✗ FAIL — dominant: {dominant}") @@ -158,12 +158,15 @@ def main(): p.add_argument("--render", action="store_true") args = p.parse_args() - stages = [(1, args.steps), (3, args.steps)] + # Stage 1 (1 sheep, 500k): fast sanity check — obs/reward structurally correct? + # Stage 2 (3 sheep, 1.5M): real test at curriculum pace — if it fails here, + # there is a genuine problem worth fixing before committing to 15M steps. + stages = [(1, args.steps, 0.60), (3, args.steps * 3, 0.40)] model, vn = None, None all_passed = True - for n_sheep, steps in stages: + for n_sheep, steps, threshold in stages: print(f"\n{'#'*52}") print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps") print(f"{'#'*52}") @@ -176,7 +179,7 @@ def main(): ) eval_env.close() - passed = report(n_sheep, success_rate, failure_counts, args.episodes) + passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold) if not passed: all_passed = False print(" Aborting smoke test — fix the issue above before full training.")