diff --git a/training/herding_env.py b/training/herding_env.py index f568a51..b0778d2 100644 --- a/training/herding_env.py +++ b/training/herding_env.py @@ -53,20 +53,21 @@ class HerdingEnv(gym.Env): # ----------------------------------------------------------------------- # Reward weights (progress-based potential shaping + sparse bonuses) # ----------------------------------------------------------------------- - W_DRIVE = 2.0 # progress: flock COM moved toward pen - W_COLLECT = 0.5 # progress: flock radius shrank + W_DRIVE = 2.0 # progress: flock COM moved toward pen + W_COLLECT = 2.0 # progress: flock radius shrank (was 0.5 — must match W_DRIVE) W_ALIGN = 0.5 # position: dog on anti-pen side of flock COM - W_PEN_BONUS = 5.0 # per sheep penned - W_COMPLETE = 20.0 # all sheep penned + W_PEN_BONUS = 10.0 # per sheep penned (was 5.0) + W_COMPLETE = 100.0 # all sheep penned (was 20.0 — must dominate dense rewards) W_STEP_COST = 0.002 # time penalty def __init__(self, n_sheep: int = 1, max_steps: int = 2000, - render_mode: str = None): + render_mode: str = None, random_n_sheep: bool = False): super().__init__() assert 1 <= n_sheep <= self.MAX_SHEEP - self.n_sheep = n_sheep - self.max_steps = max_steps - self.render_mode = render_mode + self.n_sheep = n_sheep + self.max_steps = max_steps + self.render_mode = render_mode + self.random_n_sheep = random_n_sheep # if True, randomise n_sheep each reset # Fixed 13-dim observation regardless of n_sheep: # dog_pos(2) + rel_com(2) + rel_far(2) + com_to_pen(2) @@ -110,6 +111,9 @@ class HerdingEnv(gym.Env): self._step_count = 0 self._prev_penned = 0 + if self.random_n_sheep: + self.n_sheep = int(self.np_random.integers(1, self.MAX_SHEEP + 1)) + # Active sheep (0 .. n_sheep-1): random non-pen positions self.sheep_pos[:] = self.PEN_CENTER self.penned[:] = True diff --git a/training/train.py b/training/train.py index c5b05ab..bd52050 100644 --- a/training/train.py +++ b/training/train.py @@ -105,9 +105,10 @@ class CurriculumCallback(BaseCallback): # Environment factory # --------------------------------------------------------------------------- -def make_env(n_sheep: int, seed: int, max_steps: int): +def make_env(n_sheep: int, seed: int, max_steps: int, random_n_sheep: bool = False): def _init(): - env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps) + env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps, + random_n_sheep=random_n_sheep) env.reset(seed=seed) return env return _init @@ -140,6 +141,9 @@ def parse_args(): p.add_argument("--save-freq", type=int, default=100_000) p.add_argument("--eval-freq", type=int, default=50_000) p.add_argument("--eval-eps", type=int, default=20) + p.add_argument("--mixed", action="store_true", + help="Randomise n_sheep each episode (consolidation pass, " + "use with --resume after curriculum training)") return p.parse_args() @@ -153,7 +157,8 @@ def main(): # Training envs train_env = SubprocVecEnv([ - make_env(args.n_sheep, seed=i, max_steps=args.max_steps) + make_env(args.n_sheep, seed=i, max_steps=args.max_steps, + random_n_sheep=args.mixed) for i in range(args.n_envs) ]) if args.resume and os.path.exists(norm_path):