From 4d7f36535801c1383d7e7e2d82fb8a30d0f7d560 Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Fri, 24 Apr 2026 17:31:11 +0100 Subject: [PATCH] Sheep training flock of 10 fix? --- training/herding_env.py | 18 ++---------------- training/smoke_test.py | 2 +- training/train.py | 2 +- 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/training/herding_env.py b/training/herding_env.py index 8d43eb2..034bb17 100644 --- a/training/herding_env.py +++ b/training/herding_env.py @@ -54,10 +54,9 @@ class HerdingEnv(gym.Env): # Reward weights (simple per-sheep progress — no phases, no gating) # ----------------------------------------------------------------------- W_PER_SHEEP = 2.0 # progress: sum of per-sheep distance-to-pen reductions - W_ALIGN = 0.3 # position: dog on anti-pen side of COM (small, directional hint) W_PEN_BONUS = 10.0 # per sheep penned W_COMPLETE = 100.0 # all sheep penned - W_STEP_COST = 0.002 # time penalty + W_STEP_COST = 0.02 # time penalty — strong enough to punish doing nothing def __init__(self, n_sheep: int = 1, max_steps: int = 2000, render_mode: str = None, random_n_sheep: bool = False): @@ -309,20 +308,7 @@ class HerdingEnv(gym.Env): else: r_progress = 0.0 - # Small alignment hint: reward dog for being on anti-pen side of COM. - com, _, _ = self._flock_stats() - com_dist = float(np.linalg.norm(com - self.PEN_CENTER)) - d_dog_com = float(np.linalg.norm(self.dog_pos - com)) - if d_dog_com > 0.1 and com_dist > 0.1: - pen_dir = (self.PEN_CENTER - com) / com_dist - dog_dir = (self.dog_pos - com) / d_dog_com - cosine = -float(np.dot(pen_dir, dog_dir)) - proximity = max(0.0, 1.0 - d_dog_com / self.FLEE_DIST) - alignment = cosine * proximity * self.W_ALIGN - else: - alignment = 0.0 - - reward = r_progress + alignment + reward = r_progress reward += newly_penned * self.W_PEN_BONUS reward -= self.W_STEP_COST if n_penned == self.n_sheep: diff --git a/training/smoke_test.py b/training/smoke_test.py index 35ce0ba..99413fa 100644 --- a/training/smoke_test.py +++ b/training/smoke_test.py @@ -104,7 +104,7 @@ def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None): model = PPO( "MlpPolicy", vn, learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10, - gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.005, + gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=0.02, vf_coef=0.5, max_grad_norm=0.5, policy_kwargs=dict(net_arch=[256, 256]), verbose=1, diff --git a/training/train.py b/training/train.py index 9b6fd29..490c457 100644 --- a/training/train.py +++ b/training/train.py @@ -331,7 +331,7 @@ def main(): gamma = 0.995, gae_lambda = 0.95, clip_range = 0.2, - ent_coef = 0.005, + ent_coef = 0.02, vf_coef = 0.5, max_grad_norm = 0.5, policy_kwargs = dict(net_arch=[256, 256]),