Sheep training flock of 10 fix?
This commit is contained in:
@@ -97,9 +97,9 @@ def build_obs(dog_pos: np.ndarray,
|
||||
return np.array([
|
||||
dog_pos[0] / FIELD, dog_pos[1] / FIELD,
|
||||
(com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D,
|
||||
(far1[0] - dog_pos[0]) / D, (far1[1] - dog_pos[1]) / D,
|
||||
(far2[0] - dog_pos[0]) / D, (far2[1] - dog_pos[1]) / D,
|
||||
(far3[0] - dog_pos[0]) / D, (far3[1] - dog_pos[1]) / D,
|
||||
(far1[0] - com[0]) / D, (far1[1] - com[1]) / D,
|
||||
(far2[0] - com[0]) / D, (far2[1] - com[1]) / D,
|
||||
(far3[0] - com[0]) / D, (far3[1] - com[1]) / D,
|
||||
(PEN_CENTER[0] - com[0]) / D, (PEN_CENTER[1] - com[1]) / D,
|
||||
(PEN_CENTER[0] - far1[0]) / D, (PEN_CENTER[1] - far1[1]) / D,
|
||||
radius / D,
|
||||
|
||||
@@ -279,12 +279,16 @@ class HerdingEnv(gym.Env):
|
||||
S = self.FIELD
|
||||
D = 2 * self.FIELD
|
||||
|
||||
# far1/far2/far3 expressed relative to COM, not dog.
|
||||
# For 1 sheep: far1-COM = far2-COM = far3-COM = [0,0] → cleanly ignorable.
|
||||
# For 3+ sheep: non-zero vectors tell the dog where each straggler is
|
||||
# within the group, without conflicting with weights trained on 1 sheep.
|
||||
return np.array([
|
||||
self.dog_pos[0] / S, self.dog_pos[1] / S,
|
||||
(com[0] - self.dog_pos[0]) / D, (com[1] - self.dog_pos[1]) / D,
|
||||
(far1[0] - self.dog_pos[0]) / D, (far1[1] - self.dog_pos[1]) / D,
|
||||
(far2[0] - self.dog_pos[0]) / D, (far2[1] - self.dog_pos[1]) / D,
|
||||
(far3[0] - self.dog_pos[0]) / D, (far3[1] - self.dog_pos[1]) / D,
|
||||
(far1[0] - com[0]) / D, (far1[1] - com[1]) / D,
|
||||
(far2[0] - com[0]) / D, (far2[1] - com[1]) / D,
|
||||
(far3[0] - com[0]) / D, (far3[1] - com[1]) / D,
|
||||
(self.PEN_CENTER[0] - com[0]) / D, (self.PEN_CENTER[1] - com[1]) / D,
|
||||
(self.PEN_CENTER[0] - far1[0]) / D, (self.PEN_CENTER[1] - far1[1]) / D,
|
||||
radius / D,
|
||||
|
||||
+88
-4
@@ -16,6 +16,12 @@ import sys
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.collections import LineCollection
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize
|
||||
|
||||
@@ -148,6 +154,83 @@ def report(n_sheep, success_rate, failure_counts, n_episodes, threshold=PASS_THR
|
||||
return passed
|
||||
|
||||
|
||||
SHEEP_COLORS = ["#e41a1c","#377eb8","#4daf4a","#984ea3","#ff7f00",
|
||||
"#a65628","#f781bf","#999999","#66c2a5","#fc8d62"]
|
||||
|
||||
def _save_smoke_vis(model, vn, n_sheep, save_dir, seed=42, max_steps=2000):
|
||||
"""Run one episode and save trajectory + timeseries PNGs."""
|
||||
from copy import deepcopy
|
||||
raw = DummyVecEnv([make_env(n_sheep, max_steps, seed)])
|
||||
env = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
env.obs_rms = deepcopy(vn.obs_rms)
|
||||
env.ret_rms = deepcopy(vn.ret_rms)
|
||||
|
||||
obs = env.reset()
|
||||
inner = env.envs[0]
|
||||
dog_xs, dog_ys = [], []
|
||||
sheep_xs = [[] for _ in range(n_sheep)]
|
||||
sheep_ys = [[] for _ in range(n_sheep)]
|
||||
radii, action_mags, rewards = [], [], []
|
||||
pen_dists = [[] for _ in range(n_sheep)]
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, reward, dones, _ = env.step(action)
|
||||
done = dones[0]
|
||||
dog_xs.append(float(inner.dog_pos[0])); dog_ys.append(float(inner.dog_pos[1]))
|
||||
com, radius, _ = inner._flock_stats()
|
||||
radii.append(radius)
|
||||
rewards.append(float(reward[0]))
|
||||
action_mags.append(float(np.linalg.norm(action[0])))
|
||||
for i in range(n_sheep):
|
||||
sheep_xs[i].append(float(inner.sheep_pos[i][0]))
|
||||
sheep_ys[i].append(float(inner.sheep_pos[i][1]))
|
||||
pen_dists[i].append(float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
|
||||
env.close()
|
||||
|
||||
steps = len(dog_xs)
|
||||
# Trajectory
|
||||
fig, ax = plt.subplots(figsize=(6,6))
|
||||
ax.set_xlim(-16,16); ax.set_ylim(-16,16); ax.set_aspect("equal")
|
||||
ax.set_facecolor("#dcedc8")
|
||||
ax.add_patch(mpatches.Rectangle((-15,-15),30,30,fill=False,edgecolor="#795548",lw=2))
|
||||
ax.add_patch(mpatches.Rectangle((10,-15),3,7,facecolor="#ffe082",edgecolor="#795548",lw=2))
|
||||
ax.text(11.5,-11.5,"pen",ha="center",va="center",fontsize=8,color="#795548")
|
||||
for i in range(n_sheep):
|
||||
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||
ax.plot(sheep_xs[i], sheep_ys[i], color=c, lw=1, alpha=0.6, label=f"sheep {i+1}")
|
||||
ax.plot(sheep_xs[i][0], sheep_ys[i][0], "o", color=c, ms=7)
|
||||
ax.plot(sheep_xs[i][-1], sheep_ys[i][-1], "*", color=c, ms=10)
|
||||
ax.plot(dog_xs, dog_ys, color="#4e342e", lw=1.5, label="dog", alpha=0.8)
|
||||
ax.plot(dog_xs[0], dog_ys[0], "s", color="#4e342e", ms=9)
|
||||
ax.plot(dog_xs[-1], dog_ys[-1], "D", color="#4e342e", ms=9)
|
||||
ax.set_title(f"n_sheep={n_sheep} {steps} steps min_r={min(radii):.1f}m")
|
||||
ax.legend(fontsize=7, loc="upper left")
|
||||
plt.tight_layout()
|
||||
fig.savefig(os.path.join(save_dir, "trajectory.png"), dpi=100)
|
||||
plt.close(fig)
|
||||
|
||||
# Timeseries
|
||||
t = np.arange(steps)
|
||||
fig, axes = plt.subplots(4,1,figsize=(10,8),sharex=True)
|
||||
axes[0].plot(t, radii, color="steelblue"); axes[0].axhline(5,color="orange",ls="--",lw=1)
|
||||
axes[0].set_ylabel("radius (m)"); axes[0].set_title("Flock radius (orange=5m threshold)")
|
||||
for i in range(n_sheep):
|
||||
axes[1].plot(t, pen_dists[i], color=SHEEP_COLORS[i%len(SHEEP_COLORS)], lw=1, label=f"sheep {i+1}")
|
||||
axes[1].set_ylabel("pen dist (m)"); axes[1].set_title("Per-sheep distance to pen"); axes[1].legend(fontsize=7)
|
||||
axes[2].plot(t, action_mags, color="tomato", lw=1, alpha=0.8)
|
||||
axes[2].axhline(1.0,color="gray",ls="--",lw=1); axes[2].set_ylim(0,1.5)
|
||||
axes[2].set_ylabel("action mag"); axes[2].set_title("Dog action magnitude (0=stopped)")
|
||||
axes[3].plot(t, rewards, color="purple", lw=1, alpha=0.7); axes[3].axhline(0,color="black",lw=0.5)
|
||||
axes[3].set_ylabel("reward"); axes[3].set_xlabel("step"); axes[3].set_title("Reward per step")
|
||||
fig.suptitle(f"Smoke stage n_sheep={n_sheep}", fontsize=12)
|
||||
plt.tight_layout()
|
||||
fig.savefig(os.path.join(save_dir, "timeseries.png"), dpi=100)
|
||||
plt.close(fig)
|
||||
print(f" Viz saved to {save_dir}/trajectory.png + timeseries.png")
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--steps", type=int, default=500_000,
|
||||
@@ -158,10 +241,10 @@ def main():
|
||||
p.add_argument("--render", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
# Stage 1 (1 sheep, 500k): fast sanity check — obs/reward structurally correct?
|
||||
# Stage 2 (3 sheep, 1.5M): real test at curriculum pace — if it fails here,
|
||||
# there is a genuine problem worth fixing before committing to 15M steps.
|
||||
stages = [(1, args.steps, 0.60), (3, args.steps * 3, 0.40)]
|
||||
# 1 sheep (500k): sanity check — obs/reward structurally correct?
|
||||
# 2 sheep (1M): first multi-agent step — gradual transfer
|
||||
# 3 sheep (1.5M): real multi-sheep test at curriculum pace
|
||||
stages = [(1, args.steps, 0.60), (2, args.steps * 2, 0.40), (3, args.steps * 3, 0.35)]
|
||||
|
||||
model, vn = None, None
|
||||
all_passed = True
|
||||
@@ -184,6 +267,7 @@ def main():
|
||||
model.save(os.path.join(save_dir, "model"))
|
||||
vn.save(os.path.join(save_dir, "vecnorm.pkl"))
|
||||
print(f" Model saved to {save_dir}/")
|
||||
_save_smoke_vis(model, vn, n_sheep, save_dir)
|
||||
|
||||
passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold)
|
||||
if not passed:
|
||||
|
||||
Reference in New Issue
Block a user