diff --git a/training/replay_config.py b/training/replay_config.py index 7a79927..6903f7e 100644 --- a/training/replay_config.py +++ b/training/replay_config.py @@ -44,7 +44,7 @@ def main(): help="Train with n_sheep randomized per episode (no curriculum). " "Total train steps = steps-per-stage * max_sheep.") p.add_argument("--n-envs", type=int, default=8) - p.add_argument("--max-steps", type=int, default=1500) + p.add_argument("--max-steps", type=int, default=2500) p.add_argument("--eval-episodes", type=int, default=30) p.add_argument("--run-dir", type=str, default=None) args = p.parse_args() diff --git a/training/sweep_reward.py b/training/sweep_reward.py index db84412..03f318a 100644 --- a/training/sweep_reward.py +++ b/training/sweep_reward.py @@ -36,7 +36,9 @@ from herding_env import HerdingEnv class ProgressCallback(BaseCallback): """Print a one-line trial-progress summary every `freq` env steps. Tracks per-env returns and success directly from rollout rewards/infos - (no Monitor wrapper needed).""" + (no Monitor wrapper needed). The success window is COUNT-BASED, not + time-based, so successful episodes (which finish faster) don't oversample + the window vs truncated episodes (which take max_steps).""" def __init__(self, trial_id: int, stage_label: str, freq: int = 50_000): super().__init__() self.trial_id = trial_id @@ -45,6 +47,8 @@ class ProgressCallback(BaseCallback): self._last = 0 self._ep_returns = [] self._ep_success = [] + self._completed_count = 0 # total completed episodes since callback start + self._success_count = 0 # total successful episodes since callback start self._cur_ret = None # per-env running return def _on_step(self) -> bool: @@ -60,20 +64,26 @@ class ProgressCallback(BaseCallback): if not d: continue self._ep_returns.append(float(self._cur_ret[i])) info = infos[i] if i < len(infos) else {} - self._ep_success.append( - int(info.get("n_penned", 0) == info.get("n_sheep", -1)) - ) + success = int(info.get("n_penned", 0) == info.get("n_sheep", -1)) + self._ep_success.append(success) + self._completed_count += 1 + self._success_count += success self._cur_ret[i] = 0.0 if len(self._ep_returns) > 50: self._ep_returns.pop(0); self._ep_success.pop(0) if self.num_timesteps - self._last >= self.freq: self._last = self.num_timesteps - n_eps = len(self._ep_returns) + n_eps = len(self._ep_returns) mean_r = float(np.mean(self._ep_returns)) if n_eps else float("nan") - sr = float(np.mean(self._ep_success)) if n_eps else float("nan") + # Window sr (biased: short eps over-represented), and cumulative sr + # (unbiased over the whole stage). + win_sr = float(np.mean(self._ep_success)) if n_eps else float("nan") + cum_sr = (self._success_count / self._completed_count + if self._completed_count else float("nan")) print(f" ... [trial {self.trial_id+1} | {self.stage_label} | " f"{self.num_timesteps:>7,} steps | " - f"ret(last {n_eps})={mean_r:+.2f} sr={sr*100:.0f}%]", + f"ret(last {n_eps})={mean_r:+.2f} " + f"win_sr={win_sr*100:.0f}% cum_sr={cum_sr*100:.0f}%]", flush=True) return True