diff --git a/training/herding_env.py b/training/herding_env.py index 5fe5744..4f53df6 100644 --- a/training/herding_env.py +++ b/training/herding_env.py @@ -54,9 +54,8 @@ class HerdingEnv(gym.Env): # Reward weights (simple per-sheep progress — no phases, no gating) # ----------------------------------------------------------------------- W_PER_SHEEP = 2.0 # progress: sum of per-sheep distance-to-pen reductions - W_ALIGN = 0.0 # disabled: created a sit-still trap from n_sheep≥2. - # Progress reward already encodes "be on anti-pen side" - # implicitly (sheep flee toward pen → positive progress). + W_ALIGN = 0.05 # gated on action magnitude — dog only earns it when moving. + # Without gating this created a sit-still trap from n_sheep≥2. W_PEN_BONUS = 10.0 # per sheep penned W_COMPLETE = 100.0 # all sheep penned W_STEP_COST = 0.02 # time penalty — strong enough to punish doing nothing @@ -180,7 +179,7 @@ class HerdingEnv(gym.Env): newly_penned = n_penned - self._prev_penned self._prev_penned = n_penned - reward, rcomps = self._reward(n_penned, newly_penned) + reward, rcomps = self._reward(n_penned, newly_penned, act) terminated = n_penned == self.n_sheep truncated = self._step_count >= self.max_steps info = {"n_penned": n_penned, "n_sheep": self.n_sheep, @@ -299,7 +298,7 @@ class HerdingEnv(gym.Env): active_mask.sum() / self.n_sheep, ], dtype=np.float32) - def _reward(self, n_penned: int, newly_penned: int): + def _reward(self, n_penned: int, newly_penned: int, action: np.ndarray): active = ~self.penned[:self.n_sheep] # Per-sheep progress toward pen: fires whenever any sheep moves closer. @@ -324,7 +323,11 @@ class HerdingEnv(gym.Env): dog_dir = (self.dog_pos - com) / d_dog_com cosine = -float(np.dot(pen_dir, dog_dir)) proximity = max(0.0, 1.0 - d_dog_com / self.FLEE_DIST) - alignment = cosine * proximity * self.W_ALIGN + # Gate on action magnitude: only paid when the dog is actually moving. + # Without this, parking on the anti-pen side farms +0.03/step against + # the -0.02 step_cost and the policy collapses to sit-still. + move_gate = min(1.0, float(np.linalg.norm(action))) + alignment = cosine * proximity * move_gate * self.W_ALIGN else: alignment = 0.0