Sheep training flock of 10 fix?
This commit is contained in:
@@ -123,7 +123,7 @@ def make_eval_env(model, vecnorm, n_sheep, max_steps=2000):
|
||||
return vn
|
||||
|
||||
|
||||
def report(n_sheep, success_rate, failure_counts, n_episodes):
|
||||
def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THRESHOLD):
|
||||
print(f"\n{'='*52}")
|
||||
print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})")
|
||||
print(f" {'─'*48}")
|
||||
@@ -132,9 +132,9 @@ def report(n_sheep, success_rate, failure_counts, n_episodes):
|
||||
print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}")
|
||||
print(f"{'='*52}")
|
||||
|
||||
passed = success_rate >= PASS_THRESHOLD
|
||||
passed = success_rate >= threshold
|
||||
if passed:
|
||||
print(f" ✓ PASS (threshold {PASS_THRESHOLD*100:.0f}%)")
|
||||
print(f" ✓ PASS (threshold {threshold*100:.0f}%)")
|
||||
else:
|
||||
dominant = max(failure_counts, key=failure_counts.get)
|
||||
print(f" ✗ FAIL — dominant: {dominant}")
|
||||
@@ -158,12 +158,15 @@ def main():
|
||||
p.add_argument("--render", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
stages = [(1, args.steps), (3, args.steps)]
|
||||
# Stage 1 (1 sheep, 500k): fast sanity check — obs/reward structurally correct?
|
||||
# Stage 2 (3 sheep, 1.5M): real test at curriculum pace — if it fails here,
|
||||
# there is a genuine problem worth fixing before committing to 15M steps.
|
||||
stages = [(1, args.steps, 0.60), (3, args.steps * 3, 0.40)]
|
||||
|
||||
model, vn = None, None
|
||||
all_passed = True
|
||||
|
||||
for n_sheep, steps in stages:
|
||||
for n_sheep, steps, threshold in stages:
|
||||
print(f"\n{'#'*52}")
|
||||
print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps")
|
||||
print(f"{'#'*52}")
|
||||
@@ -176,7 +179,7 @@ def main():
|
||||
)
|
||||
eval_env.close()
|
||||
|
||||
passed = report(n_sheep, success_rate, failure_counts, args.episodes)
|
||||
passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
|
||||
if not passed:
|
||||
all_passed = False
|
||||
print(" Aborting smoke test — fix the issue above before full training.")
|
||||
|
||||
Reference in New Issue
Block a user