Sheep training flock of 10 fix?

This commit is contained in:
Johnny Fernandes
2026-04-24 01:59:15 +01:00
parent 1e3b67d194
commit 4189cc8dba
2 changed files with 20 additions and 11 deletions
+8 -3
View File
@@ -105,9 +105,10 @@ class CurriculumCallback(BaseCallback):
# Environment factory
# ---------------------------------------------------------------------------
def make_env(n_sheep: int, seed: int, max_steps: int):
def make_env(n_sheep: int, seed: int, max_steps: int, random_n_sheep: bool = False):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
random_n_sheep=random_n_sheep)
env.reset(seed=seed)
return env
return _init
@@ -140,6 +141,9 @@ def parse_args():
p.add_argument("--save-freq", type=int, default=100_000)
p.add_argument("--eval-freq", type=int, default=50_000)
p.add_argument("--eval-eps", type=int, default=20)
p.add_argument("--mixed", action="store_true",
help="Randomise n_sheep each episode (consolidation pass, "
"use with --resume after curriculum training)")
return p.parse_args()
@@ -153,7 +157,8 @@ def main():
# Training envs
train_env = SubprocVecEnv([
make_env(args.n_sheep, seed=i, max_steps=args.max_steps)
make_env(args.n_sheep, seed=i, max_steps=args.max_steps,
random_n_sheep=args.mixed)
for i in range(args.n_envs)
])
if args.resume and os.path.exists(norm_path):