Sheep training flock of 10 fix?

This commit is contained in:
Johnny Fernandes
2026-04-24 17:49:42 +01:00
parent 4d7f365358
commit bf9fe902d9
3 changed files with 98 additions and 10 deletions
+7 -3
View File
@@ -279,12 +279,16 @@ class HerdingEnv(gym.Env):
S = self.FIELD
D = 2 * self.FIELD
# far1/far2/far3 expressed relative to COM, not dog.
# For 1 sheep: far1-COM = far2-COM = far3-COM = [0,0] → cleanly ignorable.
# For 3+ sheep: non-zero vectors tell the dog where each straggler is
# within the group, without conflicting with weights trained on 1 sheep.
return np.array([
self.dog_pos[0] / S, self.dog_pos[1] / S,
(com[0] - self.dog_pos[0]) / D, (com[1] - self.dog_pos[1]) / D,
(far1[0] - self.dog_pos[0]) / D, (far1[1] - self.dog_pos[1]) / D,
(far2[0] - self.dog_pos[0]) / D, (far2[1] - self.dog_pos[1]) / D,
(far3[0] - self.dog_pos[0]) / D, (far3[1] - self.dog_pos[1]) / D,
(far1[0] - com[0]) / D, (far1[1] - com[1]) / D,
(far2[0] - com[0]) / D, (far2[1] - com[1]) / D,
(far3[0] - com[0]) / D, (far3[1] - com[1]) / D,
(self.PEN_CENTER[0] - com[0]) / D, (self.PEN_CENTER[1] - com[1]) / D,
(self.PEN_CENTER[0] - far1[0]) / D, (self.PEN_CENTER[1] - far1[1]) / D,
radius / D,
+88 -4
View File
@@ -16,6 +16,12 @@ import sys
import numpy as np
from copy import deepcopy
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import LineCollection
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
@@ -148,6 +154,83 @@ def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THR
return passed
SHEEP_COLORS = ["#e41a1c","#377eb8","#4daf4a","#984ea3","#ff7f00",
"#a65628","#f781bf","#999999","#66c2a5","#fc8d62"]
def _save_smoke_vis(model, vn, n_sheep, save_dir, seed=42, max_steps=2000):
"""Run one episode and save trajectory + timeseries PNGs."""
from copy import deepcopy
raw = DummyVecEnv([make_env(n_sheep, max_steps, seed)])
env = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
env.obs_rms = deepcopy(vn.obs_rms)
env.ret_rms = deepcopy(vn.ret_rms)
obs = env.reset()
inner = env.envs[0]
dog_xs, dog_ys = [], []
sheep_xs = [[] for _ in range(n_sheep)]
sheep_ys = [[] for _ in range(n_sheep)]
radii, action_mags, rewards = [], [], []
pen_dists = [[] for _ in range(n_sheep)]
done = False
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, dones, _ = env.step(action)
done = dones[0]
dog_xs.append(float(inner.dog_pos[0])); dog_ys.append(float(inner.dog_pos[1]))
com, radius, _ = inner._flock_stats()
radii.append(radius)
rewards.append(float(reward[0]))
action_mags.append(float(np.linalg.norm(action[0])))
for i in range(n_sheep):
sheep_xs[i].append(float(inner.sheep_pos[i][0]))
sheep_ys[i].append(float(inner.sheep_pos[i][1]))
pen_dists[i].append(float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
env.close()
steps = len(dog_xs)
# Trajectory
fig, ax = plt.subplots(figsize=(6,6))
ax.set_xlim(-16,16); ax.set_ylim(-16,16); ax.set_aspect("equal")
ax.set_facecolor("#dcedc8")
ax.add_patch(mpatches.Rectangle((-15,-15),30,30,fill=False,edgecolor="#795548",lw=2))
ax.add_patch(mpatches.Rectangle((10,-15),3,7,facecolor="#ffe082",edgecolor="#795548",lw=2))
ax.text(11.5,-11.5,"pen",ha="center",va="center",fontsize=8,color="#795548")
for i in range(n_sheep):
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
ax.plot(sheep_xs[i], sheep_ys[i], color=c, lw=1, alpha=0.6, label=f"sheep {i+1}")
ax.plot(sheep_xs[i][0], sheep_ys[i][0], "o", color=c, ms=7)
ax.plot(sheep_xs[i][-1], sheep_ys[i][-1], "*", color=c, ms=10)
ax.plot(dog_xs, dog_ys, color="#4e342e", lw=1.5, label="dog", alpha=0.8)
ax.plot(dog_xs[0], dog_ys[0], "s", color="#4e342e", ms=9)
ax.plot(dog_xs[-1], dog_ys[-1], "D", color="#4e342e", ms=9)
ax.set_title(f"n_sheep={n_sheep} {steps} steps min_r={min(radii):.1f}m")
ax.legend(fontsize=7, loc="upper left")
plt.tight_layout()
fig.savefig(os.path.join(save_dir, "trajectory.png"), dpi=100)
plt.close(fig)
# Timeseries
t = np.arange(steps)
fig, axes = plt.subplots(4,1,figsize=(10,8),sharex=True)
axes[0].plot(t, radii, color="steelblue"); axes[0].axhline(5,color="orange",ls="--",lw=1)
axes[0].set_ylabel("radius (m)"); axes[0].set_title("Flock radius (orange=5m threshold)")
for i in range(n_sheep):
axes[1].plot(t, pen_dists[i], color=SHEEP_COLORS[i%len(SHEEP_COLORS)], lw=1, label=f"sheep {i+1}")
axes[1].set_ylabel("pen dist (m)"); axes[1].set_title("Per-sheep distance to pen"); axes[1].legend(fontsize=7)
axes[2].plot(t, action_mags, color="tomato", lw=1, alpha=0.8)
axes[2].axhline(1.0,color="gray",ls="--",lw=1); axes[2].set_ylim(0,1.5)
axes[2].set_ylabel("action mag"); axes[2].set_title("Dog action magnitude (0=stopped)")
axes[3].plot(t, rewards, color="purple", lw=1, alpha=0.7); axes[3].axhline(0,color="black",lw=0.5)
axes[3].set_ylabel("reward"); axes[3].set_xlabel("step"); axes[3].set_title("Reward per step")
fig.suptitle(f"Smoke stage n_sheep={n_sheep}", fontsize=12)
plt.tight_layout()
fig.savefig(os.path.join(save_dir, "timeseries.png"), dpi=100)
plt.close(fig)
print(f" Viz saved to {save_dir}/trajectory.png + timeseries.png")
def main():
p = argparse.ArgumentParser()
p.add_argument("--steps", type=int, default=500_000,
@@ -158,10 +241,10 @@ def main():
p.add_argument("--render", action="store_true")
args = p.parse_args()
# Stage 1 (1 sheep, 500k): fast sanity check — obs/reward structurally correct?
# Stage 2 (3 sheep, 1.5M): real test at curriculum pace — if it fails here,
# there is a genuine problem worth fixing before committing to 15M steps.
stages = [(1, args.steps, 0.60), (3, args.steps * 3, 0.40)]
# 1 sheep (500k): sanity check — obs/reward structurally correct?
# 2 sheep (1M): first multi-agent step — gradual transfer
# 3 sheep (1.5M): real multi-sheep test at curriculum pace
stages = [(1, args.steps, 0.60), (2, args.steps * 2, 0.40), (3, args.steps * 3, 0.35)]
model, vn = None, None
all_passed = True
@@ -184,6 +267,7 @@ def main():
model.save(os.path.join(save_dir, "model"))
vn.save(os.path.join(save_dir, "vecnorm.pkl"))
print(f" Model saved to {save_dir}/")
_save_smoke_vis(model, vn, n_sheep, save_dir)
passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
if not passed: