From 9bbef28515d273abc99c75f1323d79dcaa707a2f Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Sat, 25 Apr 2026 13:30:37 +0100 Subject: [PATCH] Sheep training flock _ improver --- training/sweep_reward.py | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/training/sweep_reward.py b/training/sweep_reward.py index e11c2e0..b90c42a 100644 --- a/training/sweep_reward.py +++ b/training/sweep_reward.py @@ -27,10 +27,40 @@ from copy import deepcopy import numpy as np from stable_baselines3 import PPO +from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize from herding_env import HerdingEnv + +class ProgressCallback(BaseCallback): + """Print a one-line trial-progress summary every `freq` env steps.""" + def __init__(self, trial_id: int, stage_label: str, freq: int = 50_000): + super().__init__() + self.trial_id = trial_id + self.stage_label = stage_label + self.freq = freq + self._last = 0 + self._ep_returns = [] # rolling list of completed-episode returns + + def _on_step(self) -> bool: + for info, done in zip(self.locals.get("infos", []), + self.locals.get("dones", [])): + if done and "episode" in info: + self._ep_returns.append(info["episode"]["r"]) + if len(self._ep_returns) > 50: + self._ep_returns.pop(0) + if self.num_timesteps - self._last >= self.freq: + self._last = self.num_timesteps + mean_r = (float(np.mean(self._ep_returns)) + if self._ep_returns else float("nan")) + n_eps = len(self._ep_returns) + print(f" ... [trial {self.trial_id+1} | {self.stage_label} | " + f"{self.num_timesteps:>7,} steps | " + f"ep_return(last {n_eps})={mean_r:+.2f}]", + flush=True) + return True + # --------------------------------------------------------------------------- # Search space — reward weights + a couple of hyperparams # --------------------------------------------------------------------------- @@ -128,12 +158,17 @@ def run_trial(trial_id: int, cfg: dict, log_path: str) -> dict: ) try: - model.learn(total_timesteps=TRAIN_STAGE1_STEPS, reset_num_timesteps=True) + model.learn(total_timesteps=TRAIN_STAGE1_STEPS, + reset_num_timesteps=True, + callback=ProgressCallback(trial_id, "1 sheep")) vn.env_method("set_n_sheep", 2) - model.learn(total_timesteps=TRAIN_STAGE2_STEPS, reset_num_timesteps=False) + model.learn(total_timesteps=TRAIN_STAGE2_STEPS, + reset_num_timesteps=False, + callback=ProgressCallback(trial_id, "2 sheep")) per_sheep = {} for n in EVAL_NSHEEP: + print(f" ... [trial {trial_id+1} | eval n={n}]", flush=True) per_sheep[n] = evaluate(model, vn, n, EVAL_EPISODES, MAX_STEPS, rcfg) finally: try: vn.close()