Sheep training flock _ improver
This commit is contained in:
@@ -155,8 +155,12 @@ def evaluate(model, vn_template, n_sheep, n_episodes, max_steps, rcfg):
|
||||
}
|
||||
|
||||
|
||||
def run_trial(trial_id: int, cfg: dict, log_path: str) -> dict:
|
||||
def run_trial(trial_id: int, cfg: dict, log_path: str, run_dir: str) -> dict:
|
||||
rcfg = reward_cfg(cfg)
|
||||
trial_dir = os.path.join(run_dir, f"trial_{trial_id:03d}")
|
||||
os.makedirs(trial_dir, exist_ok=True)
|
||||
with open(os.path.join(trial_dir, "config.json"), "w") as f:
|
||||
json.dump(cfg, f, indent=2)
|
||||
|
||||
train_env = SubprocVecEnv([
|
||||
make_env(1, seed=trial_id * 100 + i, max_steps=MAX_STEPS, rcfg=rcfg)
|
||||
@@ -186,6 +190,9 @@ def run_trial(trial_id: int, cfg: dict, log_path: str) -> dict:
|
||||
for n in EVAL_NSHEEP:
|
||||
print(f" ... [trial {trial_id+1} | eval n={n}]", flush=True)
|
||||
per_sheep[n] = evaluate(model, vn, n, EVAL_EPISODES, MAX_STEPS, rcfg)
|
||||
|
||||
model.save(os.path.join(trial_dir, "model"))
|
||||
vn.save(os.path.join(trial_dir, "vecnorm.pkl"))
|
||||
finally:
|
||||
try: vn.close()
|
||||
except Exception: pass
|
||||
@@ -250,7 +257,7 @@ def main():
|
||||
t0 = time.time()
|
||||
print(f"[Trial {trial_id+1:>3}] {cfg}")
|
||||
try:
|
||||
result = run_trial(trial_id, cfg, log_path)
|
||||
result = run_trial(trial_id, cfg, log_path, run_dir)
|
||||
result["elapsed_s"] = time.time() - t0
|
||||
sr = result["sr"]
|
||||
print(f" → score={result['score']:.3f} "
|
||||
|
||||
Reference in New Issue
Block a user