Sheep training flock _ improver

This commit is contained in:
Johnny Fernandes
2026-04-25 18:46:41 +01:00
parent 5005128c07
commit 7bfb7d3aae
2 changed files with 18 additions and 8 deletions
+1 -1
View File
@@ -44,7 +44,7 @@ def main():
help="Train with n_sheep randomized per episode (no curriculum). " help="Train with n_sheep randomized per episode (no curriculum). "
"Total train steps = steps-per-stage * max_sheep.") "Total train steps = steps-per-stage * max_sheep.")
p.add_argument("--n-envs", type=int, default=8) p.add_argument("--n-envs", type=int, default=8)
p.add_argument("--max-steps", type=int, default=1500) p.add_argument("--max-steps", type=int, default=2500)
p.add_argument("--eval-episodes", type=int, default=30) p.add_argument("--eval-episodes", type=int, default=30)
p.add_argument("--run-dir", type=str, default=None) p.add_argument("--run-dir", type=str, default=None)
args = p.parse_args() args = p.parse_args()
+17 -7
View File
@@ -36,7 +36,9 @@ from herding_env import HerdingEnv
class ProgressCallback(BaseCallback): class ProgressCallback(BaseCallback):
"""Print a one-line trial-progress summary every `freq` env steps. """Print a one-line trial-progress summary every `freq` env steps.
Tracks per-env returns and success directly from rollout rewards/infos Tracks per-env returns and success directly from rollout rewards/infos
(no Monitor wrapper needed).""" (no Monitor wrapper needed). The success window is COUNT-BASED, not
time-based, so successful episodes (which finish faster) don't oversample
the window vs truncated episodes (which take max_steps)."""
def __init__(self, trial_id: int, stage_label: str, freq: int = 50_000): def __init__(self, trial_id: int, stage_label: str, freq: int = 50_000):
super().__init__() super().__init__()
self.trial_id = trial_id self.trial_id = trial_id
@@ -45,6 +47,8 @@ class ProgressCallback(BaseCallback):
self._last = 0 self._last = 0
self._ep_returns = [] self._ep_returns = []
self._ep_success = [] self._ep_success = []
self._completed_count = 0 # total completed episodes since callback start
self._success_count = 0 # total successful episodes since callback start
self._cur_ret = None # per-env running return self._cur_ret = None # per-env running return
def _on_step(self) -> bool: def _on_step(self) -> bool:
@@ -60,20 +64,26 @@ class ProgressCallback(BaseCallback):
if not d: continue if not d: continue
self._ep_returns.append(float(self._cur_ret[i])) self._ep_returns.append(float(self._cur_ret[i]))
info = infos[i] if i < len(infos) else {} info = infos[i] if i < len(infos) else {}
self._ep_success.append( success = int(info.get("n_penned", 0) == info.get("n_sheep", -1))
int(info.get("n_penned", 0) == info.get("n_sheep", -1)) self._ep_success.append(success)
) self._completed_count += 1
self._success_count += success
self._cur_ret[i] = 0.0 self._cur_ret[i] = 0.0
if len(self._ep_returns) > 50: if len(self._ep_returns) > 50:
self._ep_returns.pop(0); self._ep_success.pop(0) self._ep_returns.pop(0); self._ep_success.pop(0)
if self.num_timesteps - self._last >= self.freq: if self.num_timesteps - self._last >= self.freq:
self._last = self.num_timesteps self._last = self.num_timesteps
n_eps = len(self._ep_returns) n_eps = len(self._ep_returns)
mean_r = float(np.mean(self._ep_returns)) if n_eps else float("nan") mean_r = float(np.mean(self._ep_returns)) if n_eps else float("nan")
sr = float(np.mean(self._ep_success)) if n_eps else float("nan") # Window sr (biased: short eps over-represented), and cumulative sr
# (unbiased over the whole stage).
win_sr = float(np.mean(self._ep_success)) if n_eps else float("nan")
cum_sr = (self._success_count / self._completed_count
if self._completed_count else float("nan"))
print(f" ... [trial {self.trial_id+1} | {self.stage_label} | " print(f" ... [trial {self.trial_id+1} | {self.stage_label} | "
f"{self.num_timesteps:>7,} steps | " f"{self.num_timesteps:>7,} steps | "
f"ret(last {n_eps})={mean_r:+.2f} sr={sr*100:.0f}%]", f"ret(last {n_eps})={mean_r:+.2f} "
f"win_sr={win_sr*100:.0f}% cum_sr={cum_sr*100:.0f}%]",
flush=True) flush=True)
return True return True