Sheep training flock of 10 fix?
This commit is contained in:
+4
-3
@@ -191,11 +191,12 @@ class DiagnosticCallback(BaseCallback):
|
|||||||
for m, c in sorted(failure_counts.items(), key=lambda x: -x[1]):
|
for m, c in sorted(failure_counts.items(), key=lambda x: -x[1]):
|
||||||
print(f" {m:<26} {c}/{self.n_episodes}")
|
print(f" {m:<26} {c}/{self.n_episodes}")
|
||||||
|
|
||||||
# Stall detection: same dominant failure at same n_sheep twice in a row
|
# Stall detection: same dominant failure at same n_sheep 5 checks in a row,
|
||||||
|
# and only after 3M total steps (give early stages time to warm up).
|
||||||
key = (n_sheep, dominant)
|
key = (n_sheep, dominant)
|
||||||
if key == self._prev_dominant and dominant != "SUCCESS":
|
if key == self._prev_dominant and dominant != "SUCCESS":
|
||||||
self._stall_count += 1
|
self._stall_count += 1
|
||||||
if self._stall_count >= 2:
|
if self._stall_count >= 5 and self.num_timesteps >= 3_000_000:
|
||||||
print(f"\n[Diag] STALL DETECTED — '{dominant}' on {n_sheep} sheep "
|
print(f"\n[Diag] STALL DETECTED — '{dominant}' on {n_sheep} sheep "
|
||||||
f"for {self._stall_count} consecutive checks. "
|
f"for {self._stall_count} consecutive checks. "
|
||||||
f"Aborting training early.")
|
f"Aborting training early.")
|
||||||
@@ -302,7 +303,7 @@ def main():
|
|||||||
verbose=1,
|
verbose=1,
|
||||||
)
|
)
|
||||||
diag_cb = DiagnosticCallback(
|
diag_cb = DiagnosticCallback(
|
||||||
diag_freq=max(args.diag_freq // args.n_envs, 1),
|
diag_freq=args.diag_freq,
|
||||||
n_episodes=20,
|
n_episodes=20,
|
||||||
max_steps=args.max_steps,
|
max_steps=args.max_steps,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user