diff --git a/training/sweep_reward.py b/training/sweep_reward.py index b90c42a..0e0b819 100644 --- a/training/sweep_reward.py +++ b/training/sweep_reward.py @@ -34,30 +34,46 @@ from herding_env import HerdingEnv class ProgressCallback(BaseCallback): - """Print a one-line trial-progress summary every `freq` env steps.""" + """Print a one-line trial-progress summary every `freq` env steps. + Tracks per-env returns and success directly from rollout rewards/infos + (no Monitor wrapper needed).""" def __init__(self, trial_id: int, stage_label: str, freq: int = 50_000): super().__init__() self.trial_id = trial_id self.stage_label = stage_label self.freq = freq self._last = 0 - self._ep_returns = [] # rolling list of completed-episode returns + self._ep_returns = [] + self._ep_success = [] + self._cur_ret = None # per-env running return def _on_step(self) -> bool: - for info, done in zip(self.locals.get("infos", []), - self.locals.get("dones", [])): - if done and "episode" in info: - self._ep_returns.append(info["episode"]["r"]) - if len(self._ep_returns) > 50: - self._ep_returns.pop(0) + rewards = self.locals.get("rewards") + dones = self.locals.get("dones") + infos = self.locals.get("infos", []) + if rewards is None or dones is None: + return True + if self._cur_ret is None or len(self._cur_ret) != len(rewards): + self._cur_ret = np.zeros(len(rewards), dtype=np.float64) + self._cur_ret += np.asarray(rewards, dtype=np.float64) + for i, d in enumerate(dones): + if not d: continue + self._ep_returns.append(float(self._cur_ret[i])) + info = infos[i] if i < len(infos) else {} + self._ep_success.append( + int(info.get("n_penned", 0) == info.get("n_sheep", -1)) + ) + self._cur_ret[i] = 0.0 + if len(self._ep_returns) > 50: + self._ep_returns.pop(0); self._ep_success.pop(0) if self.num_timesteps - self._last >= self.freq: self._last = self.num_timesteps - mean_r = (float(np.mean(self._ep_returns)) - if self._ep_returns else float("nan")) n_eps = len(self._ep_returns) + mean_r = float(np.mean(self._ep_returns)) if n_eps else float("nan") + sr = float(np.mean(self._ep_success)) if n_eps else float("nan") print(f" ... [trial {self.trial_id+1} | {self.stage_label} | " f"{self.num_timesteps:>7,} steps | " - f"ep_return(last {n_eps})={mean_r:+.2f}]", + f"ret(last {n_eps})={mean_r:+.2f} sr={sr*100:.0f}%]", flush=True) return True