Sheep training flock of 10 fix?
This commit is contained in:
+8
-3
@@ -105,9 +105,10 @@ class CurriculumCallback(BaseCallback):
|
||||
# Environment factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_env(n_sheep: int, seed: int, max_steps: int):
|
||||
def make_env(n_sheep: int, seed: int, max_steps: int, random_n_sheep: bool = False):
|
||||
def _init():
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||
random_n_sheep=random_n_sheep)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
return _init
|
||||
@@ -140,6 +141,9 @@ def parse_args():
|
||||
p.add_argument("--save-freq", type=int, default=100_000)
|
||||
p.add_argument("--eval-freq", type=int, default=50_000)
|
||||
p.add_argument("--eval-eps", type=int, default=20)
|
||||
p.add_argument("--mixed", action="store_true",
|
||||
help="Randomise n_sheep each episode (consolidation pass, "
|
||||
"use with --resume after curriculum training)")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
@@ -153,7 +157,8 @@ def main():
|
||||
|
||||
# Training envs
|
||||
train_env = SubprocVecEnv([
|
||||
make_env(args.n_sheep, seed=i, max_steps=args.max_steps)
|
||||
make_env(args.n_sheep, seed=i, max_steps=args.max_steps,
|
||||
random_n_sheep=args.mixed)
|
||||
for i in range(args.n_envs)
|
||||
])
|
||||
if args.resume and os.path.exists(norm_path):
|
||||
|
||||
Reference in New Issue
Block a user