Files
TIR_PROJ/training/sweep_reward.py
T
2026-04-25 13:39:49 +01:00

297 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Random-search sweep over reward-function hyperparameters.
Each trial trains a fresh PPO policy through a 1→2-sheep curriculum on a tight
budget, then evaluates at n=1,2,3 sheep. A composite score is computed and
written to a JSONL log. After all trials, a leaderboard is printed and the
best config is saved.
Sized to fit in ~4 hours wall-clock with default settings on 8 envs.
Usage
-----
python sweep_reward.py # 25 trials, default budget
python sweep_reward.py --n-trials 15
python sweep_reward.py --time-budget 6 # stop adding trials past 6h
python sweep_reward.py --resume runs/sweep_<timestamp> # continue logging
Per-trial budget (see TRAIN_*_STEPS below): ~1.0M training steps + 30 eval
episodes × 3 sheep counts. On this env that runs in ~812 min per trial.
"""
import argparse
import json
import os
import time
import traceback
from copy import deepcopy
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
class ProgressCallback(BaseCallback):
"""Print a one-line trial-progress summary every `freq` env steps.
Tracks per-env returns and success directly from rollout rewards/infos
(no Monitor wrapper needed)."""
def __init__(self, trial_id: int, stage_label: str, freq: int = 50_000):
super().__init__()
self.trial_id = trial_id
self.stage_label = stage_label
self.freq = freq
self._last = 0
self._ep_returns = []
self._ep_success = []
self._cur_ret = None # per-env running return
def _on_step(self) -> bool:
rewards = self.locals.get("rewards")
dones = self.locals.get("dones")
infos = self.locals.get("infos", [])
if rewards is None or dones is None:
return True
if self._cur_ret is None or len(self._cur_ret) != len(rewards):
self._cur_ret = np.zeros(len(rewards), dtype=np.float64)
self._cur_ret += np.asarray(rewards, dtype=np.float64)
for i, d in enumerate(dones):
if not d: continue
self._ep_returns.append(float(self._cur_ret[i]))
info = infos[i] if i < len(infos) else {}
self._ep_success.append(
int(info.get("n_penned", 0) == info.get("n_sheep", -1))
)
self._cur_ret[i] = 0.0
if len(self._ep_returns) > 50:
self._ep_returns.pop(0); self._ep_success.pop(0)
if self.num_timesteps - self._last >= self.freq:
self._last = self.num_timesteps
n_eps = len(self._ep_returns)
mean_r = float(np.mean(self._ep_returns)) if n_eps else float("nan")
sr = float(np.mean(self._ep_success)) if n_eps else float("nan")
print(f" ... [trial {self.trial_id+1} | {self.stage_label} | "
f"{self.num_timesteps:>7,} steps | "
f"ret(last {n_eps})={mean_r:+.2f} sr={sr*100:.0f}%]",
flush=True)
return True
# ---------------------------------------------------------------------------
# Search space — reward weights + a couple of hyperparams
# ---------------------------------------------------------------------------
SEARCH_SPACE = {
"W_PER_SHEEP": [1.0, 2.0, 4.0, 6.0],
"W_ALIGN": [0.0, 0.025, 0.05, 0.1],
"W_PEN_BONUS": [5.0, 10.0, 20.0],
"W_STEP_COST": [0.005, 0.02, 0.05],
"W_COMPLETE": [50.0, 100.0, 200.0],
"W_COMPACT": [0.0, 0.5, 1.5, 3.0],
"ALIGN_SHAPE": ["standoff", "near"],
"ALIGN_GATED": [True, False],
"ent_coef": [0.005, 0.01, 0.02, 0.05],
}
# Per-trial training budget — keep tight; total = sum + eval
TRAIN_STAGE1_STEPS = 400_000 # 1 sheep
TRAIN_STAGE2_STEPS = 600_000 # 2 sheep
EVAL_EPISODES = 10
EVAL_NSHEEP = (1, 2, 3)
MAX_STEPS = 1500
N_ENVS = 8
def sample_config(rng: np.random.Generator) -> dict:
cfg = {}
for k, v in SEARCH_SPACE.items():
choice = v[int(rng.integers(0, len(v)))]
cfg[k] = bool(choice) if isinstance(choice, np.bool_) else choice
return cfg
def reward_cfg(cfg: dict) -> dict:
"""Strip non-env keys (anything that isn't a HerdingEnv attribute)."""
return {k: v for k, v in cfg.items() if k != "ent_coef"}
def make_env(n_sheep, seed, max_steps, rcfg):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps, reward_cfg=rcfg)
env.reset(seed=seed)
return env
return _init
def evaluate(model, vn_template, n_sheep, n_episodes, max_steps, rcfg):
raw = DummyVecEnv([make_env(n_sheep, 9999, max_steps, rcfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
successes = 0
ep_lens, min_pen_list, action_mags = [], [], []
for _ in range(n_episodes):
obs = vn.reset()
done = False
steps, min_pen, mags = 0, float("inf"), []
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = vn.step(action)
done = dones[0]
inner = vn.envs[0]
com, _, _ = inner._flock_stats()
min_pen = min(min_pen, float(np.linalg.norm(com - inner.PEN_CENTER)))
mags.append(float(np.linalg.norm(action[0])))
steps += 1
successes += int(infos[0].get("n_penned") == n_sheep)
ep_lens.append(steps)
min_pen_list.append(min_pen)
action_mags.extend(mags)
vn.close()
return {
"sr": successes / n_episodes,
"mean_len": float(np.mean(ep_lens)),
"mean_min_pen": float(np.mean(min_pen_list)),
"mean_act": float(np.mean(action_mags)),
}
def run_trial(trial_id: int, cfg: dict, log_path: str) -> dict:
rcfg = reward_cfg(cfg)
train_env = SubprocVecEnv([
make_env(1, seed=trial_id * 100 + i, max_steps=MAX_STEPS, rcfg=rcfg)
for i in range(N_ENVS)
])
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
model = PPO(
"MlpPolicy", vn,
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
ent_coef=cfg["ent_coef"], vf_coef=0.5, max_grad_norm=0.5,
policy_kwargs=dict(net_arch=[256, 256]),
verbose=0,
)
try:
model.learn(total_timesteps=TRAIN_STAGE1_STEPS,
reset_num_timesteps=True,
callback=ProgressCallback(trial_id, "1 sheep"))
vn.env_method("set_n_sheep", 2)
model.learn(total_timesteps=TRAIN_STAGE2_STEPS,
reset_num_timesteps=False,
callback=ProgressCallback(trial_id, "2 sheep"))
per_sheep = {}
for n in EVAL_NSHEEP:
print(f" ... [trial {trial_id+1} | eval n={n}]", flush=True)
per_sheep[n] = evaluate(model, vn, n, EVAL_EPISODES, MAX_STEPS, rcfg)
finally:
try: vn.close()
except Exception: pass
sr = {n: per_sheep[n]["sr"] for n in EVAL_NSHEEP}
score = 0.2 * sr[1] + 0.5 * sr[2] + 0.3 * sr[3]
return {
"trial": trial_id,
"config": cfg,
"score": score,
"sr": sr,
"details": per_sheep,
}
def main():
p = argparse.ArgumentParser()
p.add_argument("--n-trials", type=int, default=25)
p.add_argument("--time-budget", type=float, default=7.5,
help="Stop launching new trials past this many hours.")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--run-dir", type=str, default=None,
help="If unset, creates runs/sweep_<timestamp>/")
p.add_argument("--resume", type=str, default=None,
help="Continue logging into an existing sweep dir")
args = p.parse_args()
run_dir = args.resume or args.run_dir or os.path.join(
"runs", "sweep_" + time.strftime("%Y%m%d_%H%M%S")
)
os.makedirs(run_dir, exist_ok=True)
log_path = os.path.join(run_dir, "results.jsonl")
rng = np.random.default_rng(args.seed)
start = time.time()
budget_s = args.time_budget * 3600
results = []
# If resuming, replay the existing log into memory
if args.resume and os.path.exists(log_path):
with open(log_path) as f:
for line in f:
try: results.append(json.loads(line))
except Exception: pass
print(f"Resumed sweep: {len(results)} prior trials loaded from {log_path}")
print(f"Sweep dir: {run_dir}")
print(f"Search space: {list(SEARCH_SPACE.keys())}")
print(f"Per-trial: {TRAIN_STAGE1_STEPS+TRAIN_STAGE2_STEPS:,} steps train + "
f"{EVAL_EPISODES * len(EVAL_NSHEEP)} eval eps")
print(f"Time budget: {args.time_budget}h\n")
n_done = sum(1 for r in results if "error" not in r)
trial_id = len(results)
while n_done < args.n_trials:
elapsed_h = (time.time() - start) / 3600
if elapsed_h >= args.time_budget:
print(f"\n[Sweep] time budget reached ({elapsed_h:.2f}h) — stopping.")
break
cfg = sample_config(rng)
t0 = time.time()
print(f"[Trial {trial_id+1:>3}] {cfg}")
try:
result = run_trial(trial_id, cfg, log_path)
result["elapsed_s"] = time.time() - t0
sr = result["sr"]
print(f" → score={result['score']:.3f} "
f"sr1={sr[1]:.2f} sr2={sr[2]:.2f} sr3={sr[3]:.2f} "
f"[{result['elapsed_s']:.0f}s]")
results.append(result)
n_done += 1
except Exception as e:
traceback.print_exc()
err = {"trial": trial_id, "config": cfg,
"error": f"{type(e).__name__}: {e}",
"elapsed_s": time.time() - t0}
results.append(err)
print(f" ! FAILED: {err['error']}")
with open(log_path, "a") as f:
f.write(json.dumps(results[-1]) + "\n")
trial_id += 1
# Leaderboard
succ = [r for r in results if "error" not in r]
succ.sort(key=lambda r: -r["score"])
print("\n" + "=" * 92)
print(" LEADERBOARD")
print("=" * 92)
hdr = f" {'rank':>4} {'score':>6} {'sr1':>5} {'sr2':>5} {'sr3':>5} config"
print(hdr); print(" " + "-" * 88)
for i, r in enumerate(succ[:15], 1):
sr = r["sr"]
cfg_short = " ".join(f"{k}={v}" for k, v in r["config"].items())
print(f" {i:>4d} {r['score']:>6.3f} {sr[1]:>5.2f} {sr[2]:>5.2f} {sr[3]:>5.2f} {cfg_short}")
if succ:
best = succ[0]
with open(os.path.join(run_dir, "best.json"), "w") as f:
json.dump(best, f, indent=2)
print(f"\n Best config saved to {run_dir}/best.json")
print(f" Total trials: {len(results)} ({len(succ)} successful, "
f"{len(results)-len(succ)} failed)")
print(f" Total time: {(time.time()-start)/3600:.2f}h\n")
if __name__ == "__main__":
main()