Sheep training flock of 10 fix?
This commit is contained in:
@@ -123,7 +123,7 @@ def make_eval_env(model, vecnorm, n_sheep, max_steps=2000):
|
|||||||
return vn
|
return vn
|
||||||
|
|
||||||
|
|
||||||
def report(n_sheep, success_rate, failure_counts, n_episodes):
|
def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THRESHOLD):
|
||||||
print(f"\n{'='*52}")
|
print(f"\n{'='*52}")
|
||||||
print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})")
|
print(f" Stage n_sheep={n_sheep} | success={success_rate*100:.0f}% ({int(success_rate*n_episodes)}/{n_episodes})")
|
||||||
print(f" {'─'*48}")
|
print(f" {'─'*48}")
|
||||||
@@ -132,9 +132,9 @@ def report(n_sheep, success_rate, failure_counts, n_episodes):
|
|||||||
print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}")
|
print(f" {mode:<26} {cnt:>3}/{n_episodes} {bar}")
|
||||||
print(f"{'='*52}")
|
print(f"{'='*52}")
|
||||||
|
|
||||||
passed = success_rate >= PASS_THRESHOLD
|
passed = success_rate >= threshold
|
||||||
if passed:
|
if passed:
|
||||||
print(f" ✓ PASS (threshold {PASS_THRESHOLD*100:.0f}%)")
|
print(f" ✓ PASS (threshold {threshold*100:.0f}%)")
|
||||||
else:
|
else:
|
||||||
dominant = max(failure_counts, key=failure_counts.get)
|
dominant = max(failure_counts, key=failure_counts.get)
|
||||||
print(f" ✗ FAIL — dominant: {dominant}")
|
print(f" ✗ FAIL — dominant: {dominant}")
|
||||||
@@ -158,12 +158,15 @@ def main():
|
|||||||
p.add_argument("--render", action="store_true")
|
p.add_argument("--render", action="store_true")
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
stages = [(1, args.steps), (3, args.steps)]
|
# Stage 1 (1 sheep, 500k): fast sanity check — obs/reward structurally correct?
|
||||||
|
# Stage 2 (3 sheep, 1.5M): real test at curriculum pace — if it fails here,
|
||||||
|
# there is a genuine problem worth fixing before committing to 15M steps.
|
||||||
|
stages = [(1, args.steps, 0.60), (3, args.steps * 3, 0.40)]
|
||||||
|
|
||||||
model, vn = None, None
|
model, vn = None, None
|
||||||
all_passed = True
|
all_passed = True
|
||||||
|
|
||||||
for n_sheep, steps in stages:
|
for n_sheep, steps, threshold in stages:
|
||||||
print(f"\n{'#'*52}")
|
print(f"\n{'#'*52}")
|
||||||
print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps")
|
print(f"# Smoke-test stage: n_sheep={n_sheep}, {steps:,} steps")
|
||||||
print(f"{'#'*52}")
|
print(f"{'#'*52}")
|
||||||
@@ -176,7 +179,7 @@ def main():
|
|||||||
)
|
)
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
|
|
||||||
passed = report(n_sheep, success_rate, failure_counts, args.episodes)
|
passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
|
||||||
if not passed:
|
if not passed:
|
||||||
all_passed = False
|
all_passed = False
|
||||||
print(" Aborting smoke test — fix the issue above before full training.")
|
print(" Aborting smoke test — fix the issue above before full training.")
|
||||||
|
|||||||
Reference in New Issue
Block a user