diff --git a/training/herding_env.py b/training/herding_env.py index 4488fa2..1aa7356 100644 --- a/training/herding_env.py +++ b/training/herding_env.py @@ -55,8 +55,8 @@ class HerdingEnv(gym.Env): # ----------------------------------------------------------------------- W_DRIVE = 2.0 # progress: COM moved toward pen (only when compact) W_COLLECT = 4.0 # progress: radius shrank (2× stronger when scattered) - W_ALIGN = 0.5 # position: dog on anti-pen side of COM - W_COMPACT_BONUS = 0.0 # disabled: 0.1/step over 4000 steps = 400 >> W_COMPLETE=100 + W_APPROACH_FAR = 1.0 # progress: dog moved toward farthest straggler (scatter only) + W_ALIGN = 0.5 # position: dog on anti-pen side of COM (compact only) W_PEN_BONUS = 10.0 # per sheep penned W_COMPLETE = 100.0 # all sheep penned W_STEP_COST = 0.002 # time penalty @@ -85,10 +85,11 @@ class HerdingEnv(gym.Env): ) # Runtime state (populated by reset) - self._step_count = 0 - self._prev_penned = 0 - self._prev_com_dist = 0.0 # COM-to-pen distance at previous step - self._prev_radius = 0.0 # flock radius at previous step + self._step_count = 0 + self._prev_penned = 0 + self._prev_com_dist = 0.0 + self._prev_radius = 0.0 + self._prev_dog_to_far1 = 0.0 self.dog_pos = np.zeros(2, dtype=np.float32) self.sheep_pos = np.zeros((self.MAX_SHEEP, 2), dtype=np.float32) self.penned = np.ones(self.MAX_SHEEP, dtype=bool) @@ -155,6 +156,14 @@ class HerdingEnv(gym.Env): self._prev_com_dist = float(np.linalg.norm(com - self.PEN_CENTER)) self._prev_radius = radius + active_mask = ~self.penned[:self.n_sheep] + if active_mask.any(): + pts = self.sheep_pos[:self.n_sheep][active_mask] + far1 = pts[int(np.argmax(np.linalg.norm(pts - com, axis=1)))] + self._prev_dog_to_far1 = float(np.linalg.norm(self.dog_pos - far1)) + else: + self._prev_dog_to_far1 = 0.0 + return self._obs(), {} def step(self, action): @@ -293,17 +302,42 @@ class HerdingEnv(gym.Env): def _reward(self, n_penned: int, newly_penned: int) -> float: com, radius, _ = self._flock_stats() - com_dist = float(np.linalg.norm(com - self.PEN_CENTER)) + com_dist = float(np.linalg.norm(com - self.PEN_CENTER)) + scattered = radius > self.DRIVE_GATE_RADIUS drive_delta = self._prev_com_dist - com_dist collect_delta = self._prev_radius - radius - self._prev_com_dist = com_dist self._prev_radius = radius - # Alignment: dog on anti-pen side of COM, gated by proximity. + # Collect: always active, 2× stronger when scattered. + r_collect = collect_delta * self.W_COLLECT * (2.0 if scattered else 1.0) + + # Drive: only when compact — prevents rewarding COM movement while scattered. + r_drive = 0.0 if scattered else drive_delta * self.W_DRIVE + + # Approach-to-straggler: reward dog for closing on farthest sheep. + # Only in scatter phase so it doesn't override drive positioning. + # Gated on there being active sheep. + active_mask = ~self.penned[:self.n_sheep] + if scattered and active_mask.any(): + pts = self.sheep_pos[:self.n_sheep][active_mask] + far1 = pts[int(np.argmax(np.linalg.norm(pts - com, axis=1)))] + cur_dog_to_far1 = float(np.linalg.norm(self.dog_pos - far1)) + r_approach = (self._prev_dog_to_far1 - cur_dog_to_far1) * self.W_APPROACH_FAR + self._prev_dog_to_far1 = cur_dog_to_far1 + else: + r_approach = 0.0 + if active_mask.any(): + pts = self.sheep_pos[:self.n_sheep][active_mask] + far1 = pts[int(np.argmax(np.linalg.norm(pts - com, axis=1)))] + self._prev_dog_to_far1 = float(np.linalg.norm(self.dog_pos - far1)) + + # Alignment: dog on anti-pen side of COM — only in drive phase. + # Disabled when scattered: chasing a straggler on the pen side would be + # wrongly penalised otherwise. d_dog_com = float(np.linalg.norm(self.dog_pos - com)) - if d_dog_com > 0.1 and com_dist > 0.1: + if not scattered and d_dog_com > 0.1 and com_dist > 0.1: pen_dir = (self.PEN_CENTER - com) / com_dist dog_dir = (self.dog_pos - com) / d_dog_com cosine = -float(np.dot(pen_dir, dog_dir)) @@ -312,19 +346,7 @@ class HerdingEnv(gym.Env): else: alignment = 0.0 - scattered = radius > self.DRIVE_GATE_RADIUS - - # Collect always on; 2× scale when scattered to force collect-first. - r_collect = collect_delta * self.W_COLLECT * (2.0 if scattered else 1.0) - - # Drive only fires when flock is compact — prevents rewarding COM movement - # while sheep are spread across the field. - r_drive = 0.0 if scattered else drive_delta * self.W_DRIVE - - # Small sustained reward for maintaining a compact flock. - r_compact = 0.0 if scattered else self.W_COMPACT_BONUS - - reward = r_drive + r_collect + r_compact + alignment + reward = r_drive + r_collect + r_approach + alignment reward += newly_penned * self.W_PEN_BONUS reward -= self.W_STEP_COST if n_penned == self.n_sheep: