Sheep training flock _ improver

This commit is contained in:
Johnny Fernandes
2026-04-25 16:28:15 +01:00
parent 4350c7d320
commit 75c5b7c014
2 changed files with 127 additions and 2 deletions
+9 -2
View File
@@ -155,8 +155,12 @@ def evaluate(model, vn_template, n_sheep, n_episodes, max_steps, rcfg):
}
def run_trial(trial_id: int, cfg: dict, log_path: str) -> dict:
def run_trial(trial_id: int, cfg: dict, log_path: str, run_dir: str) -> dict:
rcfg = reward_cfg(cfg)
trial_dir = os.path.join(run_dir, f"trial_{trial_id:03d}")
os.makedirs(trial_dir, exist_ok=True)
with open(os.path.join(trial_dir, "config.json"), "w") as f:
json.dump(cfg, f, indent=2)
train_env = SubprocVecEnv([
make_env(1, seed=trial_id * 100 + i, max_steps=MAX_STEPS, rcfg=rcfg)
@@ -186,6 +190,9 @@ def run_trial(trial_id: int, cfg: dict, log_path: str) -> dict:
for n in EVAL_NSHEEP:
print(f" ... [trial {trial_id+1} | eval n={n}]", flush=True)
per_sheep[n] = evaluate(model, vn, n, EVAL_EPISODES, MAX_STEPS, rcfg)
model.save(os.path.join(trial_dir, "model"))
vn.save(os.path.join(trial_dir, "vecnorm.pkl"))
finally:
try: vn.close()
except Exception: pass
@@ -250,7 +257,7 @@ def main():
t0 = time.time()
print(f"[Trial {trial_id+1:>3}] {cfg}")
try:
result = run_trial(trial_id, cfg, log_path)
result = run_trial(trial_id, cfg, log_path, run_dir)
result["elapsed_s"] = time.time() - t0
sr = result["sr"]
print(f" → score={result['score']:.3f} "