Sheep training flock _ improver
This commit is contained in:
+26
-10
@@ -34,30 +34,46 @@ from herding_env import HerdingEnv
|
|||||||
|
|
||||||
|
|
||||||
class ProgressCallback(BaseCallback):
|
class ProgressCallback(BaseCallback):
|
||||||
"""Print a one-line trial-progress summary every `freq` env steps."""
|
"""Print a one-line trial-progress summary every `freq` env steps.
|
||||||
|
Tracks per-env returns and success directly from rollout rewards/infos
|
||||||
|
(no Monitor wrapper needed)."""
|
||||||
def __init__(self, trial_id: int, stage_label: str, freq: int = 50_000):
|
def __init__(self, trial_id: int, stage_label: str, freq: int = 50_000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.trial_id = trial_id
|
self.trial_id = trial_id
|
||||||
self.stage_label = stage_label
|
self.stage_label = stage_label
|
||||||
self.freq = freq
|
self.freq = freq
|
||||||
self._last = 0
|
self._last = 0
|
||||||
self._ep_returns = [] # rolling list of completed-episode returns
|
self._ep_returns = []
|
||||||
|
self._ep_success = []
|
||||||
|
self._cur_ret = None # per-env running return
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
for info, done in zip(self.locals.get("infos", []),
|
rewards = self.locals.get("rewards")
|
||||||
self.locals.get("dones", [])):
|
dones = self.locals.get("dones")
|
||||||
if done and "episode" in info:
|
infos = self.locals.get("infos", [])
|
||||||
self._ep_returns.append(info["episode"]["r"])
|
if rewards is None or dones is None:
|
||||||
|
return True
|
||||||
|
if self._cur_ret is None or len(self._cur_ret) != len(rewards):
|
||||||
|
self._cur_ret = np.zeros(len(rewards), dtype=np.float64)
|
||||||
|
self._cur_ret += np.asarray(rewards, dtype=np.float64)
|
||||||
|
for i, d in enumerate(dones):
|
||||||
|
if not d: continue
|
||||||
|
self._ep_returns.append(float(self._cur_ret[i]))
|
||||||
|
info = infos[i] if i < len(infos) else {}
|
||||||
|
self._ep_success.append(
|
||||||
|
int(info.get("n_penned", 0) == info.get("n_sheep", -1))
|
||||||
|
)
|
||||||
|
self._cur_ret[i] = 0.0
|
||||||
if len(self._ep_returns) > 50:
|
if len(self._ep_returns) > 50:
|
||||||
self._ep_returns.pop(0)
|
self._ep_returns.pop(0); self._ep_success.pop(0)
|
||||||
if self.num_timesteps - self._last >= self.freq:
|
if self.num_timesteps - self._last >= self.freq:
|
||||||
self._last = self.num_timesteps
|
self._last = self.num_timesteps
|
||||||
mean_r = (float(np.mean(self._ep_returns))
|
|
||||||
if self._ep_returns else float("nan"))
|
|
||||||
n_eps = len(self._ep_returns)
|
n_eps = len(self._ep_returns)
|
||||||
|
mean_r = float(np.mean(self._ep_returns)) if n_eps else float("nan")
|
||||||
|
sr = float(np.mean(self._ep_success)) if n_eps else float("nan")
|
||||||
print(f" ... [trial {self.trial_id+1} | {self.stage_label} | "
|
print(f" ... [trial {self.trial_id+1} | {self.stage_label} | "
|
||||||
f"{self.num_timesteps:>7,} steps | "
|
f"{self.num_timesteps:>7,} steps | "
|
||||||
f"ep_return(last {n_eps})={mean_r:+.2f}]",
|
f"ret(last {n_eps})={mean_r:+.2f} sr={sr*100:.0f}%]",
|
||||||
flush=True)
|
flush=True)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user