Behaviour refinement
This commit is contained in:
@@ -43,6 +43,10 @@ def main():
|
||||
p.add_argument("--mixed", action="store_true",
|
||||
help="Train with n_sheep randomized per episode (no curriculum). "
|
||||
"Total train steps = steps-per-stage * max_sheep.")
|
||||
p.add_argument("--final-mixed-steps", type=int, default=0,
|
||||
help="After the curriculum, train this many extra steps with "
|
||||
"random_n_sheep ∈ [1, max_sheep] to consolidate the policy "
|
||||
"across all flock sizes. Re-evaluates all n_sheep at the end.")
|
||||
p.add_argument("--n-envs", type=int, default=8)
|
||||
p.add_argument("--max-steps", type=int, default=2500)
|
||||
p.add_argument("--eval-episodes", type=int, default=30)
|
||||
@@ -123,6 +127,28 @@ def main():
|
||||
f"mean_act={r['mean_act']:.2f}")
|
||||
stage_results.append({"n_sheep": n, **r})
|
||||
|
||||
# Optional consolidation pass with mixed n_sheep — fixes specialization
|
||||
# imbalance from curriculum order (e.g. n=1 weakness after long n=10
|
||||
# training). Replaces stage_results with the post-consolidation eval.
|
||||
if args.final_mixed_steps > 0 and not args.mixed:
|
||||
print(f"\n[Consolidation] mixed n_sheep ∈ [1, {args.max_sheep}], "
|
||||
f"{args.final_mixed_steps:,} steps")
|
||||
vn.env_method("__setattr__", "random_n_sheep", True)
|
||||
model.learn(
|
||||
total_timesteps=args.final_mixed_steps,
|
||||
reset_num_timesteps=False,
|
||||
callback=ProgressCallback(0, "consolidate", freq=100_000),
|
||||
)
|
||||
print("[Consolidation] re-evaluating all sheep counts")
|
||||
stage_results = []
|
||||
for n in range(1, args.max_sheep + 1):
|
||||
r = evaluate(model, vn, n, args.eval_episodes, args.max_steps, rcfg)
|
||||
print(f"[Consolidation] n_sheep={n} sr={r['sr']*100:.0f}% "
|
||||
f"mean_len={r['mean_len']:.0f} "
|
||||
f"mean_min_pen={r['mean_min_pen']:.1f}m "
|
||||
f"mean_act={r['mean_act']:.2f}")
|
||||
stage_results.append({"n_sheep": n, **r})
|
||||
|
||||
model.save(os.path.join(run_dir, "final_model"))
|
||||
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
|
||||
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
|
||||
|
||||
Reference in New Issue
Block a user