Sheep training flock of 10 fix?

This commit is contained in:
Johnny Fernandes
2026-04-24 01:59:15 +01:00
parent 1e3b67d194
commit 4189cc8dba
2 changed files with 20 additions and 11 deletions
+8 -4
View File
@@ -54,19 +54,20 @@ class HerdingEnv(gym.Env):
# Reward weights (progress-based potential shaping + sparse bonuses)
# -----------------------------------------------------------------------
W_DRIVE = 2.0 # progress: flock COM moved toward pen
W_COLLECT = 0.5 # progress: flock radius shrank
W_COLLECT = 2.0 # progress: flock radius shrank (was 0.5 — must match W_DRIVE)
W_ALIGN = 0.5 # position: dog on anti-pen side of flock COM
W_PEN_BONUS = 5.0 # per sheep penned
W_COMPLETE = 20.0 # all sheep penned
W_PEN_BONUS = 10.0 # per sheep penned (was 5.0)
W_COMPLETE = 100.0 # all sheep penned (was 20.0 — must dominate dense rewards)
W_STEP_COST = 0.002 # time penalty
def __init__(self, n_sheep: int = 1, max_steps: int = 2000,
render_mode: str = None):
render_mode: str = None, random_n_sheep: bool = False):
super().__init__()
assert 1 <= n_sheep <= self.MAX_SHEEP
self.n_sheep = n_sheep
self.max_steps = max_steps
self.render_mode = render_mode
self.random_n_sheep = random_n_sheep # if True, randomise n_sheep each reset
# Fixed 13-dim observation regardless of n_sheep:
# dog_pos(2) + rel_com(2) + rel_far(2) + com_to_pen(2)
@@ -110,6 +111,9 @@ class HerdingEnv(gym.Env):
self._step_count = 0
self._prev_penned = 0
if self.random_n_sheep:
self.n_sheep = int(self.np_random.integers(1, self.MAX_SHEEP + 1))
# Active sheep (0 .. n_sheep-1): random non-pen positions
self.sheep_pos[:] = self.PEN_CENTER
self.penned[:] = True
+8 -3
View File
@@ -105,9 +105,10 @@ class CurriculumCallback(BaseCallback):
# Environment factory
# ---------------------------------------------------------------------------
def make_env(n_sheep: int, seed: int, max_steps: int):
def make_env(n_sheep: int, seed: int, max_steps: int, random_n_sheep: bool = False):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
random_n_sheep=random_n_sheep)
env.reset(seed=seed)
return env
return _init
@@ -140,6 +141,9 @@ def parse_args():
p.add_argument("--save-freq", type=int, default=100_000)
p.add_argument("--eval-freq", type=int, default=50_000)
p.add_argument("--eval-eps", type=int, default=20)
p.add_argument("--mixed", action="store_true",
help="Randomise n_sheep each episode (consolidation pass, "
"use with --resume after curriculum training)")
return p.parse_args()
@@ -153,7 +157,8 @@ def main():
# Training envs
train_env = SubprocVecEnv([
make_env(args.n_sheep, seed=i, max_steps=args.max_steps)
make_env(args.n_sheep, seed=i, max_steps=args.max_steps,
random_n_sheep=args.mixed)
for i in range(args.n_envs)
])
if args.resume and os.path.exists(norm_path):