Checkpoint 2

This commit is contained in:
Johnny Fernandes
2026-05-07 22:00:10 +01:00
parent 90aa3bbcb4
commit 1bb9415414
37 changed files with 3068 additions and 2912 deletions
+115
View File
@@ -0,0 +1,115 @@
# Shepherd Herding — Training & Inference
This directory holds the Gymnasium environment, PPO training script, and
evaluation harness for the RL shepherd-dog policy. The Webots controller
in `controllers/shepherd_dog/` loads the resulting policy at inference
time when launched with `HERDING_MODE=rl`.
## Layout
```
training/
├── herding_env.py # gymnasium.Env — the dog is the agent
├── train_ppo.py # SB3 PPO entry point (vec envs, eval, curriculum)
├── eval.py # rollout success-rate / time-to-pen across flock sizes
├── parity_test.py # smoke test: shapes, determinism, baseline rollout
├── configs/ppo_default.yaml
├── runs/ # tensorboard + checkpoints (gitignored)
└── requirements.txt
```
## Setup
```bash
python -m venv .venv && source .venv/bin/activate
pip install -r training/requirements.txt
```
CPU is the default and also the recommended device — SB3's PPO with an
MLP policy of this size runs faster on CPU than on GPU because the
bottleneck is rollout collection, not gradient compute. The 16 SubprocVecEnv
workers saturate ~16 CPU cores. To force CUDA anyway, pass `--device cuda`.
## Train
```bash
# Full curriculum (1 → 10 sheep), ~5M steps, ~23h on a single GPU.
python -m training.train_ppo \
--config training/configs/ppo_default.yaml \
--out-dir training/runs/baseline
```
Outputs:
- `training/runs/baseline/best/best_model.zip` — best eval checkpoint
- `training/runs/baseline/best/vecnormalize.pkl` — observation stats
- `training/runs/baseline/checkpoints/ppo_*.zip` — periodic checkpoints
- `training/runs/baseline/tb/` — TensorBoard logs (`tensorboard --logdir`)
To resume:
```bash
python -m training.train_ppo --resume training/runs/baseline/checkpoints/ppo_500000_steps.zip
```
## Evaluate
```bash
# RL policy
python -m training.eval --policy training/runs/baseline/best
# Strömbom baseline
python -m training.eval --policy strombom
```
Prints success rate, mean steps, and mean penned-count per flock size.
Use the same `--n-seeds` for both to get a fair RL-vs-Strömbom A/B.
## Parity / smoke test
```bash
python -m training.parity_test
```
Checks observation/action shapes, deterministic seeding, the curriculum
sampler, and a 400-step Strömbom rollout. Run this before every long
training job — catches the boring class of bugs in seconds.
## Run the policy in Webots
1. Train (above) — produces `training/runs/<name>/best/`.
2. In Webots, set the dog controller's environment variables:
```bash
export HERDING_MODE=rl
export HERDING_POLICY_DIR=$(pwd)/training/runs/baseline/best
webots worlds/field.wbt
```
Or set them via Webots' controller args / a `.wbproj` if you prefer.
3. To force the Strömbom baseline (same world, same controller):
```bash
export HERDING_MODE=strombom
webots worlds/field.wbt
```
If `HERDING_MODE=rl` but the policy can't be loaded (SB3 not installed,
zip missing, etc.), the controller logs the error and falls back to
Strömbom automatically.
## Curriculum knobs
The default schedule in `configs/ppo_default.yaml` widens
`max_n_sheep` over training. Each reset samples `n_sheep ~ U[1,
max_n_sheep]`, so the final policy has seen every flock size from 1 to
10 in proportion. To pin a specific size, instantiate the env with
`HerdingEnv(n_sheep=N)` (see `eval.py`).
## Reward shaping
Weights live in class attributes on `HerdingEnv`. Tune from the 1-sheep
curriculum first — if the dog can't herd a single sheep cleanly, raising
`W_PROGRESS` or lowering `W_TIME` is usually the fix. For multi-sheep
collapse modes (dog spins between sheep), increase `W_COMPACT` so
tightening the flock pays.
View File
+218
View File
@@ -0,0 +1,218 @@
"""Behavior cloning of the sequential teacher into an SB3-compatible policy.
Trains the policy network (mean-action head) of an SB3 ``MlpPolicy`` to
mimic the demonstrations collected by ``tools.collect_demos``. The
saved zip is loadable via ``PPO.load(...)`` and can be passed to
``train_ppo.py --resume`` for fine-tuning.
Why this works: the teacher (sequential single-target driving) solves
n=10 at 80%+ in our env. BC gives the RL a competent starting policy,
so PPO doesn't have to discover behavior from scratch — it only has to
*refine* the teacher's strategy via the sparse pen reward.
Usage::
python -m training.bc_pretrain \\
--demos training/demos.npz \\
--out training/runs/bc_pretrained
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from training.herding_env import HerdingEnv
def build_model(net_arch_pi, net_arch_vf, log_std_init: float):
"""Build a fresh SB3 PPO with the same architecture as train_ppo.
We only need the policy to load weights into; PPO's training-loop
plumbing isn't used during BC.
"""
env = DummyVecEnv([lambda: HerdingEnv()])
model = PPO(
"MlpPolicy", env,
policy_kwargs=dict(
net_arch=dict(pi=net_arch_pi, vf=net_arch_vf),
log_std_init=log_std_init,
),
verbose=0,
)
return model, env
def policy_forward_mean(policy, obs_batch):
"""Return the policy's deterministic mean action for a batch.
SB3's ActorCriticPolicy doesn't expose this directly — it goes
through a Distribution wrapper. We replicate the forward path:
extract_features → mlp_extractor → action_net.
"""
features = policy.extract_features(obs_batch)
if isinstance(features, tuple):
# SB3 ≥ 2.0 sometimes returns (pi_features, vf_features)
pi_features = features[0]
else:
pi_features = features
latent_pi, _latent_vf = policy.mlp_extractor(pi_features)
return policy.action_net(latent_pi)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--demos", default="training/demos.npz")
parser.add_argument("--out", default="training/runs/bc_pretrained")
parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--val-split", type=float, default=0.1)
parser.add_argument("--net-arch", default="256,256",
help="Comma-separated hidden layer widths.")
parser.add_argument("--log-std-init", type=float, default=0.5)
parser.add_argument("--cos-weight", type=float, default=1.0,
help="Weight on (1 - cosine similarity) loss term. "
"MSE alone shrinks policy output toward zero "
"(zero-magnitude action minimises mean squared "
"error against ±1 targets); cos loss keeps "
"the action pointed correctly even at small "
"magnitudes.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# --- Load demos ---
print(f"[bc] loading demos from {args.demos}")
data = np.load(args.demos)
obs = data["obs"].astype(np.float32)
actions = data["actions"].astype(np.float32)
meta = data["meta"]
print(f"[bc] obs={obs.shape} actions={actions.shape} trajectories={len(meta)}")
if obs.size == 0:
raise RuntimeError("Empty demo file.")
# Action sanity check — sequential outputs unit vectors.
a_norms = np.linalg.norm(actions, axis=1)
print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} "
f"min={a_norms.min():.3f} max={a_norms.max():.3f}")
# --- Train/val split ---
n = len(obs)
perm = np.random.permutation(n)
n_val = int(n * args.val_split)
val_idx, train_idx = perm[:n_val], perm[n_val:]
print(f"[bc] train={len(train_idx)} val={len(val_idx)}")
obs_t = torch.from_numpy(obs)
act_t = torch.from_numpy(actions)
train_loader = DataLoader(
TensorDataset(obs_t[train_idx], act_t[train_idx]),
batch_size=args.batch_size, shuffle=True,
)
val_loader = DataLoader(
TensorDataset(obs_t[val_idx], act_t[val_idx]),
batch_size=args.batch_size, shuffle=False,
)
# --- Build model ---
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
net_arch_vf = net_arch_pi[:]
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init)
policy = model.policy.to(args.device)
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
# --- Train ---
print(f"[bc] training: epochs={args.epochs} batch={args.batch_size} "
f"lr={args.lr} device={args.device}")
t_start = time.time()
best_val = float("inf")
def combined_loss(pred, target):
mse = nn.functional.mse_loss(pred, target)
p_norm = pred.norm(dim=1).clamp_min(1e-6)
t_norm = target.norm(dim=1).clamp_min(1e-6)
cos_sim = (pred * target).sum(dim=1) / (p_norm * t_norm)
cos_loss = (1.0 - cos_sim).mean()
return mse + args.cos_weight * cos_loss, mse.item(), cos_sim.mean().item()
for epoch in range(args.epochs):
policy.train()
train_loss_total, train_mse_total, train_cos_total, train_count = 0.0, 0.0, 0.0, 0
for ob_batch, act_batch in train_loader:
ob_batch = ob_batch.to(args.device)
act_batch = act_batch.to(args.device)
optimizer.zero_grad()
mean_action = policy_forward_mean(policy, ob_batch)
loss, mse_val, cos_val = combined_loss(mean_action, act_batch)
loss.backward()
optimizer.step()
bs = ob_batch.size(0)
train_loss_total += loss.item() * bs
train_mse_total += mse_val * bs
train_cos_total += cos_val * bs
train_count += bs
train_mse = train_mse_total / max(1, train_count)
train_cos = train_cos_total / max(1, train_count)
policy.eval()
val_total, val_count = 0.0, 0
cos_sim_total = 0.0
with torch.no_grad():
for ob_batch, act_batch in val_loader:
ob_batch = ob_batch.to(args.device)
act_batch = act_batch.to(args.device)
mean_action = policy_forward_mean(policy, ob_batch)
bs = ob_batch.size(0)
val_total += nn.functional.mse_loss(
mean_action, act_batch, reduction="sum",
).item()
# Cosine similarity in action space — useful sanity for
# "is the policy pointing the same way as the teacher?".
m_norm = mean_action.norm(dim=1).clamp_min(1e-6)
a_norm = act_batch.norm(dim=1).clamp_min(1e-6)
cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm)
cos_sim_total += cos.sum().item()
val_count += bs
val_mse = val_total / max(1, val_count) / actions.shape[1]
cos_sim = cos_sim_total / max(1, val_count)
print(f" epoch {epoch+1:>2d}/{args.epochs} "
f"train_mse={train_mse:.4f} train_cos={train_cos:+.3f} "
f"val_mse={val_mse:.4f} val_cos={cos_sim:+.3f}")
if val_mse < best_val:
best_val = val_mse
elapsed = time.time() - t_start
print(f"[bc] done in {elapsed:.0f}s best_val_mse={best_val:.4f}")
# --- Save ---
out_dir = Path(args.out)
out_dir.mkdir(parents=True, exist_ok=True)
model.save(out_dir / "policy.zip")
print(f"[bc] saved policy to {out_dir / 'policy.zip'}")
print(f"\n[bc] verify with: "
f"python -m training.eval --policy {out_dir}")
if __name__ == "__main__":
main()
-14
View File
@@ -1,14 +0,0 @@
{
"W_PER_SHEEP": 2.0,
"W_ALIGN": 0.05,
"W_PEN_BONUS": 10.0,
"W_COMPLETE": 100.0,
"W_STEP_COST": 0.02,
"W_COMPACT": 0.0,
"W_WALL_TOUCH": 0.0,
"WALL_TOUCH_BUFFER": 0.4,
"ALIGN_SHAPE": "standoff",
"ALIGN_GATED": true,
"ENTRY_AWARE": true,
"ent_coef": 0.02
}
View File
+52
View File
@@ -0,0 +1,52 @@
# PPO hyperparameters for the herding env. Tuned for a 28-D obs / 2-D
# continuous action space with 16 parallel envs on GPU. These are SB3
# defaults nudged toward longer credit assignment (gamma=0.995) and a
# slightly higher entropy bonus to keep exploration alive while curriculum
# expands the flock size.
# --- PPO ---
learning_rate: 3.0e-4
n_steps: 2048 # rollout length per env before each update
batch_size: 256
n_epochs: 10
gamma: 0.995
gae_lambda: 0.95
clip_range: 0.2
ent_coef: 0.05 # was 0.01 — earlier runs collapsed to ~0 actions
vf_coef: 0.5
max_grad_norm: 0.5
target_kl: null # disable early-stop on KL
# --- Network ---
policy: MlpPolicy
net_arch_pi: [128, 128]
net_arch_vf: [128, 128]
log_std_init: 0.5 # std≈1.6 instead of default 1.0 — more exploration
# --- Training schedule ---
total_timesteps: 10_000_000
n_envs: 16
checkpoint_freq: 500_000 # in env steps
eval_freq: 100_000 # in env steps
n_eval_episodes: 20
# --- Curriculum (max-n_sheep schedule, in env steps) ---
# Each entry: at step s, raise the env's max_n_sheep to k. The env samples
# uniformly from [1, max_n_sheep] each reset, so this widens the
# distribution gradually rather than swapping fixed sizes.
#
# State-space curriculum: difficulty controls sheep spawn area
# (0 = sheep spawn just north of gate, 1 = sheep spawn anywhere in field).
# Plus the existing flock-size curriculum.
#
# The two together let the policy first learn "what penning looks like"
# in a regime where random exploration reliably triggers it, then
# gradually generalise to the deployment distribution.
curriculum:
- { step: 0, max_n_sheep: 1, difficulty: 0.0 }
- { step: 1_000_000, max_n_sheep: 1, difficulty: 0.3 }
- { step: 2_000_000, max_n_sheep: 2, difficulty: 0.5 }
- { step: 4_000_000, max_n_sheep: 3, difficulty: 0.8 }
- { step: 6_000_000, max_n_sheep: 5, difficulty: 1.0 }
- { step: 8_000_000, max_n_sheep: 8, difficulty: 1.0 }
- { step: 9_000_000, max_n_sheep: 10, difficulty: 1.0 }
Binary file not shown.
+136
View File
@@ -0,0 +1,136 @@
"""Evaluate a trained PPO policy (or the Strömbom baseline) on the env.
Reports success rate and time-to-pen across a fixed seed grid for each
flock size 1..MAX_SHEEP. Used to produce the M5 quantitative comparison
table mentioned in plan.md.
Usage::
python -m training.eval --policy training/runs/latest/best
python -m training.eval --policy strombom
"""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
from statistics import mean, stdev
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
from herding.geometry import MAX_SHEEP, PEN_ENTRY
from herding.strombom import compute_action as strombom_action
from herding.sequential import compute_action as sequential_action
from training.herding_env import HerdingEnv
def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
obs, _ = env.reset()
success = False
for t in range(max_steps):
action = predict_fn(env, obs)
obs, _r, terminated, truncated, info = env.step(action)
if terminated or truncated:
success = bool(info.get("is_success", False))
return {"success": success, "steps": info.get("steps", t + 1),
"n_penned": info.get("n_penned", 0)}
return {"success": False, "steps": max_steps, "n_penned": int(env.sheep_penned.sum())}
def make_analytic_predictor(action_fn):
def _predict(env, _obs):
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
for i in range(env.n_sheep)
if not env.sheep_penned[i]}
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
return np.array([vx, vy], dtype=np.float32)
return _predict
# Backwards-compat alias.
def make_strombom_predictor():
return make_analytic_predictor(strombom_action)
def make_policy_predictor(model, vecnorm):
def _predict(_env, obs):
if vecnorm is not None:
obs_b = vecnorm.normalize_obs(np.asarray(obs, dtype=np.float32).reshape(1, -1))
else:
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
action, _ = model.predict(obs_b, deterministic=True)
return action[0]
return _predict
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--policy", required=True,
help="Either 'strombom' or path to an SB3 run directory.")
parser.add_argument("--n-seeds", type=int, default=10)
parser.add_argument("--max-steps", type=int, default=5000)
parser.add_argument("--max-flock", type=int, default=MAX_SHEEP)
# 1.0 = deployment distribution (sheep anywhere in field).
# Lower values use the training-curriculum spawn band (sheep near gate).
parser.add_argument("--difficulty", type=float, default=1.0)
args = parser.parse_args()
if args.policy == "strombom":
predict = make_analytic_predictor(strombom_action)
elif args.policy == "sequential":
predict = make_analytic_predictor(sequential_action)
else:
from stable_baselines3 import PPO
run = Path(args.policy)
# Resolve to a zip: directory of checkpoints, or a direct zip path.
if run.is_file():
zip_path = run
else:
for name in ("best_model.zip", "policy.zip", "final.zip"):
if (run / name).exists():
zip_path = run / name
break
else:
raise FileNotFoundError(
f"No checkpoint found in {run} (tried best_model.zip, "
f"policy.zip, final.zip)"
)
model = PPO.load(str(zip_path), device="auto")
vecnorm = None
vn_path = run / "vecnormalize.pkl"
if not vn_path.exists() and run.parent.name != "best":
vn_path = run.parent / "vecnormalize.pkl"
if vn_path.exists():
import pickle
with open(vn_path, "rb") as f:
vecnorm = pickle.load(f)
vecnorm.training = False
vecnorm.norm_reward = False
predict = make_policy_predictor(model, vecnorm)
print(f"{'n_sheep':>8} {'success%':>10} {'mean_steps':>12} {'mean_penned':>12}")
print("-" * 46)
for n in range(1, args.max_flock + 1):
successes, steps, penned = [], [], []
for seed in range(args.n_seeds):
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
difficulty=args.difficulty, seed=seed)
r = rollout(env, predict, args.max_steps)
successes.append(int(r["success"]))
steps.append(r["steps"])
penned.append(r["n_penned"])
sr = 100.0 * mean(successes)
ms = mean(steps)
mp = mean(penned)
print(f"{n:>8d} {sr:>9.1f}% {ms:>12.0f} {mp:>12.2f}")
if __name__ == "__main__":
main()
+355 -707
View File
File diff suppressed because it is too large Load Diff
+75 -297
View File
@@ -1,318 +1,96 @@
"""
Parity test: verify 2D training env matches Webots controller implementations.
"""Parity smoke-test for the herding env.
Tests:
1. Observation building: HerdingEnv._obs() vs shepherd_dog_rl.build_obs()
2. Dog drive: HerdingEnv._step_dog_substep() vs shepherd_dog_rl.drive() math
3. Sheep drive: HerdingEnv._sheep_drive() vs sheep.py drive() math
Verifies (a) all imports resolve, (b) the env's reset/step contract is
correct, (c) deterministic seeds give deterministic trajectories, and
(d) the Strömbom baseline can drive the env without crashing.
Run::
python -m training.parity_test
"""
import sys
from __future__ import annotations
import os
import math
import sys
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
# Make imports work from project root
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "controllers", "shepherd_dog_rl"))
from herding_env import HerdingEnv
# Re-implement the Webots functions standalone (no Webots dependency)
FIELD = 15.0
PEN_CENTER = np.array([11.5, -11.5], dtype=np.float32)
PEN_ENTRY = np.array([11.5, -8.0], dtype=np.float32)
PEN_X = (10.0, 13.0)
PEN_Y = (-15.0, -8.0)
ENTRY_AWARE = True
from herding.geometry import MAX_SHEEP, PEN_ENTRY
from herding.obs import OBS_DIM
from herding.strombom import compute_action
from training.herding_env import HerdingEnv
def webots_build_obs(dog_pos, sheep_positions, n_sheep, dog_heading):
"""Standalone version of shepherd_dog_rl.py build_obs()."""
D = 2 * FIELD
active_pos = np.array(
[p for p in sheep_positions
if not (PEN_X[0] < p[0] < PEN_X[1] and PEN_Y[0] < p[1] < PEN_Y[1])],
dtype=np.float32
)
n_active = len(active_pos)
if n_active > 0:
com = active_pos.mean(axis=0)
d_from_com = np.linalg.norm(active_pos - com, axis=1)
sorted_idx = np.argsort(d_from_com)[::-1]
radius = float(d_from_com[sorted_idx[0]])
def nth(n):
return active_pos[sorted_idx[n]] if len(sorted_idx) > n else com
far1, far2, far3 = nth(0), nth(1), nth(2)
else:
com = PEN_CENTER.copy()
radius = 0.0
far1 = far2 = far3 = PEN_CENTER.copy()
frac_active = n_active / max(n_sheep, 1)
pen_ref = PEN_ENTRY if ENTRY_AWARE else PEN_CENTER
return np.array([
dog_pos[0] / FIELD, dog_pos[1] / FIELD,
(com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D,
(far1[0] - com[0]) / D, (far1[1] - com[1]) / D,
(far2[0] - com[0]) / D, (far2[1] - com[1]) / D,
(far3[0] - com[0]) / D, (far3[1] - com[1]) / D,
(pen_ref[0] - com[0]) / D, (pen_ref[1] - com[1]) / D,
(pen_ref[0] - far1[0]) / D, (pen_ref[1] - far1[1]) / D,
radius / D,
frac_active,
math.cos(dog_heading), math.sin(dog_heading),
], dtype=np.float32)
def test_obs_action_shapes():
env = HerdingEnv(n_sheep=3, seed=0)
obs, info = env.reset()
assert obs.shape == (OBS_DIM,), obs.shape
assert obs.dtype == np.float32
obs2, r, term, trunc, info = env.step(np.array([0.5, 0.0], dtype=np.float32))
assert obs2.shape == (OBS_DIM,)
assert isinstance(r, float)
assert isinstance(term, bool) and isinstance(trunc, bool)
print("[ok] shapes")
def webots_dog_drive(heading, speed_ms, wheel_r=0.038, k_turn=4.0,
motor_max=70.0, axle_track=0.28):
"""Standalone version of shepherd_dog_rl.py drive() kinematics.
def test_reset_determinism():
"""Reset with the same seed should give the same initial observation.
Returns (v_linear, omega, left_w, right_w).
We don't require step-determinism — PPO doesn't need it, and chasing
bit-exactness through the flocking jitter isn't worth the complexity.
"""
err = math.atan2(math.sin(heading), math.cos(heading))
fwd_ms = speed_ms * max(0.0, math.cos(err))
fwd_rad = fwd_ms / wheel_r
turn = k_turn * err
l = max(-motor_max, min(motor_max, fwd_rad - turn))
r = max(-motor_max, min(motor_max, fwd_rad + turn))
v = wheel_r * 0.5 * (r + l)
w = (wheel_r / axle_track) * (r - l)
return v, w, l, r
env_a = HerdingEnv(n_sheep=3, seed=42)
env_b = HerdingEnv(n_sheep=3, seed=42)
obs_a, _ = env_a.reset(seed=42)
obs_b, _ = env_b.reset(seed=42)
assert np.allclose(obs_a, obs_b), "Reset is non-deterministic for same seed"
print("[ok] reset determinism")
def webots_sheep_drive(heading, speed_rad, wheel_r=0.031, k_turn=4.0,
motor_max=22.0, axle_track=0.20):
"""Standalone version of sheep.py drive() kinematics."""
err = math.atan2(math.sin(heading), math.cos(heading))
fwd = speed_rad * max(0.0, math.cos(err))
k = 4.0
l = max(-motor_max, min(motor_max, fwd - k * err))
r = max(-motor_max, min(motor_max, fwd + k * err))
v = wheel_r * 0.5 * (r + l)
w = (wheel_r / axle_track) * (r - l)
return v, w, l, r
def test_curriculum_n_sheep_varies():
env = HerdingEnv(seed=0)
sizes = set()
for _ in range(40):
_, info = env.reset()
sizes.add(info["n_sheep"])
assert 1 in sizes
assert max(sizes) <= MAX_SHEEP
print(f"[ok] curriculum sampling — saw n_sheep in {sorted(sizes)}")
def test_obs_parity():
"""Test that build_obs matches between 2D env and Webots controller."""
print("=== Test 1: Observation Parity ===")
env = HerdingEnv(n_sheep=3)
# Set ENTRY_AWARE to match our webots constant
env.ENTRY_AWARE = ENTRY_AWARE
env.reset(seed=42)
# Manually set positions for a controlled test
env.dog_pos = np.array([5.0, 3.0], dtype=np.float32)
env.dog_heading = 1.2
env.sheep_pos[0] = np.array([0.0, 0.0], dtype=np.float32)
env.sheep_pos[1] = np.array([2.0, -1.0], dtype=np.float32)
env.sheep_pos[2] = np.array([11.5, -11.5], dtype=np.float32) # penned
env.penned[0] = False
env.penned[1] = False
env.penned[2] = True
obs_2d = env._obs()
# Build equivalent Webots observation
sheep_positions = [
env.sheep_pos[0].tolist(),
env.sheep_pos[1].tolist(),
env.sheep_pos[2].tolist(),
]
obs_webots = webots_build_obs(
env.dog_pos, sheep_positions, 3, env.dog_heading
)
max_diff = float(np.max(np.abs(obs_2d - obs_webots)))
print(f" Max element-wise diff: {max_diff:.2e}")
if max_diff < 1e-6:
print(" PASS: Observations match")
else:
print(" FAIL: Observations differ!")
for i in range(18):
if abs(obs_2d[i] - obs_webots[i]) > 1e-6:
print(f" dim {i}: 2d={obs_2d[i]:.6f} webots={obs_webots[i]:.6f}")
return max_diff < 1e-6
def test_strombom_drives_env():
"""Quick functional check that the analytic baseline can play the env
without exploding. Not a success-rate test — just no errors / NaNs."""
env = HerdingEnv(n_sheep=2, max_steps=400, seed=1)
obs, _ = env.reset()
for t in range(400):
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
for i in range(env.n_sheep)
if not env.sheep_penned[i]}
if not positions:
break
vx, vy, _mode = compute_action((env.dog_x, env.dog_y), positions, PEN_ENTRY)
obs, r, term, trunc, info = env.step(np.array([vx, vy], dtype=np.float32))
assert np.isfinite(obs).all(), f"NaN/Inf in obs at step {t}"
assert np.isfinite(r), f"NaN reward at step {t}"
if term or trunc:
break
print(f"[ok] strombom rollout — final n_penned={int(env.sheep_penned.sum())}/{env.n_sheep} after {env.steps} steps")
def test_dog_drive_parity():
"""Test that dog diff-drive matches Webots controller."""
print("\n=== Test 2: Dog Drive Parity ===")
env = HerdingEnv(n_sheep=1)
env.reset(seed=42)
all_pass = True
test_cases = [
# (heading_error, speed_ms) — target_heading relative to current heading
(0.0, 2.5), # aligned, full speed
(0.5, 2.5), # 30deg error
(1.5, 2.5), # ~86deg error
(3.14, 2.5), # ~180deg error — should spin in place
(0.0, 0.5), # aligned, slow
(0.3, 1.0), # small error, medium speed
]
for heading_err, speed_ms in test_cases:
env.dog_heading = 0.0
target_heading = heading_err
action = np.array([
math.cos(target_heading), math.sin(target_heading)
], dtype=np.float32) * (speed_ms / env.DOG_SPEED)
# 2D env step
dbg = env._step_dog_substep(action, 0.016)
v_2d = dbg["v"]
w_2d = dbg["w"]
l_2d = dbg["left_w"]
r_2d = dbg["right_w"]
# Webots equivalent
v_w, w_w, l_w, r_w = webots_dog_drive(heading_err, speed_ms)
diffs = {
"v": abs(v_2d - v_w),
"w": abs(w_2d - w_w),
"left": abs(l_2d - l_w),
"right": abs(r_2d - r_w),
}
max_diff = max(diffs.values())
ok = max_diff < 1e-6
status = "PASS" if ok else "FAIL"
print(f" err={heading_err:.2f} spd={speed_ms:.1f}: {status} (max_diff={max_diff:.2e})")
if not ok:
for k, d in diffs.items():
if d > 1e-6:
print(f" {k}: 2d={eval(k+'_2d'):.6f} webots={eval(k+'_w'):.6f}")
all_pass = False
return all_pass
def test_sheep_drive_parity():
"""Test that sheep diff-drive matches Webots sheep controller."""
print("\n=== Test 3: Sheep Drive Parity ===")
env = HerdingEnv(n_sheep=1)
env.reset(seed=42)
all_pass = True
test_cases = [
# (heading_error, speed_rad)
(0.0, 20.0), # aligned, flee speed
(0.0, 3.0), # aligned, wander speed
(0.5, 15.0), # moderate error
(1.57, 10.0), # 90deg — should spin in place
(3.14, 20.0), # 180deg — should spin in place fast
(0.2, 8.0), # small error, medium speed
]
for heading_err, speed_rad in test_cases:
env.sheep_heading[0] = 0.0
env.sheep_pos[0] = np.array([0.0, 0.0], dtype=np.float32)
target_heading = heading_err
# 2D env
new_pos = env._sheep_drive(0, target_heading, speed_rad, 0.016)
v_2d_raw = float(np.linalg.norm(new_pos - np.array([0.0, 0.0]))) / 0.016
# Re-derive v, w from the internal state
heading_2d = env.sheep_heading[0]
# Webots equivalent
v_w, w_w, l_w, r_w = webots_sheep_drive(heading_err, speed_rad)
# For 2D, compute the same intermediate values
err_2d = (target_heading - 0.0 + np.pi) % (2 * np.pi) - np.pi
fwd_2d = speed_rad * max(0.0, math.cos(err_2d))
turn_2d = 4.0 * err_2d
l_2d = max(-22.0, min(22.0, fwd_2d - turn_2d))
r_2d = max(-22.0, min(22.0, fwd_2d + turn_2d))
diffs = {
"left": abs(l_2d - l_w),
"right": abs(r_2d - r_w),
}
max_diff = max(diffs.values())
ok = max_diff < 1e-6
status = "PASS" if ok else "FAIL"
print(f" err={heading_err:.2f} spd={speed_rad:.1f}: {status} (max_diff={max_diff:.2e})")
if not ok:
for k, d in diffs.items():
if d > 1e-6:
print(f" {k}: 2d={l_2d if k=='left' else r_2d:.6f} webots={l_w if k=='left' else r_w:.6f}")
all_pass = False
return all_pass
def test_full_trajectory_parity():
"""Test that running identical actions produces matching trajectories."""
print("\n=== Test 4: Full Trajectory Parity (dog only) ===")
# Run 50 steps with a fixed action, compare dog heading/position
# at each step between 2D env kinematics and pure Webots kinematics.
env = HerdingEnv(n_sheep=1)
env.reset(seed=42)
env.dog_pos = np.array([0.0, 0.0], dtype=np.float32)
env.dog_heading = 0.0
env.ENTRY_AWARE = ENTRY_AWARE
action = np.array([0.8, -0.6], dtype=np.float32) # magnitude 1.0
dt = 0.016667 # sub_dt
# Webots-side tracking
wb_heading = 0.0
wb_x, wb_y = 0.0, 0.0
max_heading_diff = 0.0
max_pos_diff = 0.0
for step in range(50):
# 2D env sub-step
env._step_dog_substep(action, dt)
# Webots-side computation
speed_ms = 1.0 * 2.5
target_heading = math.atan2(-0.6, 0.8)
err = math.atan2(math.sin(target_heading - wb_heading),
math.cos(target_heading - wb_heading))
fwd_ms = speed_ms * max(0.0, math.cos(err))
fwd_rad = fwd_ms / 0.038
turn = 4.0 * err
l = max(-70.0, min(70.0, fwd_rad - turn))
r = max(-70.0, min(70.0, fwd_rad + turn))
v = 0.038 * 0.5 * (r + l)
w = (0.038 / 0.28) * (r - l)
wb_heading = math.atan2(math.sin(wb_heading + w * dt),
math.cos(wb_heading + w * dt))
wb_x += math.cos(wb_heading) * v * dt
wb_y += math.sin(wb_heading) * v * dt
heading_diff = abs(env.dog_heading - wb_heading)
pos_diff = math.hypot(env.dog_pos[0] - wb_x, env.dog_pos[1] - wb_y)
max_heading_diff = max(max_heading_diff, heading_diff)
max_pos_diff = max(max_pos_diff, pos_diff)
print(f" Max heading diff over 50 steps: {max_heading_diff:.2e} rad")
print(f" Max position diff over 50 steps: {max_pos_diff:.2e} m")
ok = max_pos_diff < 1e-4
print(f" {'PASS' if ok else 'FAIL'}: Trajectories match")
return ok
def main():
test_obs_action_shapes()
test_reset_determinism()
test_curriculum_n_sheep_varies()
test_strombom_drives_env()
print("\nAll parity checks passed.")
if __name__ == "__main__":
results = []
results.append(("Obs parity", test_obs_parity()))
results.append(("Dog drive parity", test_dog_drive_parity()))
results.append(("Sheep drive parity", test_sheep_drive_parity()))
results.append(("Trajectory parity", test_full_trajectory_parity()))
print("\n" + "=" * 50)
print("RESULTS")
print("=" * 50)
all_pass = True
for name, passed in results:
print(f" {name}: {'PASS' if passed else 'FAIL'}")
if not passed:
all_pass = False
print(f"\nOverall: {'ALL PASS' if all_pass else 'SOME FAILURES'}")
env.close()
main()
+8 -6
View File
@@ -1,6 +1,8 @@
gymnasium>=0.29
stable-baselines3>=2.3
torch>=2.2
numpy>=1.26
matplotlib>=3.8
tensorboard>=2.16
# Pin major versions; SB3 2.x requires gymnasium and torch >= 1.13.
gymnasium>=0.29,<2.0
stable-baselines3[extra]>=2.3,<3.0
torch>=2.1
numpy>=1.24
pyyaml>=6.0
tensorboard>=2.14
tqdm>=4.66
-1
View File
@@ -1 +0,0 @@
-392
View File
@@ -1,392 +0,0 @@
"""
PPO training for the herding task with curriculum learning.
Trains from scratch through a 1→max_sheep curriculum, evaluates after each
stage, and auto-generates trajectory/timeseries plots plus a summary chart.
Usage
-----
python train.py # defaults from config.json
python train.py --config my_config.json --max-sheep 5
python train.py --max-sheep 3 --steps-per-stage 1000000
Outputs (in runs/<timestamp>/):
config.json resolved config
final_model.zip trained PPO model
vecnorm.pkl VecNormalize statistics
stage_results.json per-stage evaluation metrics
success_rate.png summary bar chart
eval/ trajectory & timeseries plots per sheep count
"""
import argparse
import json
import os
import time
from copy import deepcopy
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import (
DummyVecEnv,
SubprocVecEnv,
VecNormalize,
)
from herding_env import HerdingEnv
from viz import (
run_and_record,
plot_trajectory,
plot_timeseries,
plot_success_rate,
save_episode_gif,
)
# ── Callbacks ────────────────────────────────────────────────────────────────
class ProgressCallback(BaseCallback):
"""One-line progress summary every `freq` env steps."""
def __init__(self, stage_label: str, freq: int = 100_000):
super().__init__()
self.stage_label = stage_label
self.freq = freq
self._last = 0
self._ep_returns = []
self._ep_success = []
self._total_eps = 0
self._total_success = 0
self._cur_ret = None
def _on_step(self) -> bool:
rewards = self.locals.get("rewards")
dones = self.locals.get("dones")
infos = self.locals.get("infos", [])
if rewards is None or dones is None:
return True
if self._cur_ret is None or len(self._cur_ret) != len(rewards):
self._cur_ret = np.zeros(len(rewards), dtype=np.float64)
self._cur_ret += np.asarray(rewards, dtype=np.float64)
for i, d in enumerate(dones):
if not d:
continue
self._ep_returns.append(float(self._cur_ret[i]))
info = infos[i] if i < len(infos) else {}
success = int(info.get("n_penned", 0) == info.get("n_sheep", -1))
self._ep_success.append(success)
self._total_eps += 1
self._total_success += success
self._cur_ret[i] = 0.0
if len(self._ep_returns) > 50:
self._ep_returns.pop(0)
self._ep_success.pop(0)
if self.num_timesteps - self._last >= self.freq:
self._last = self.num_timesteps
n = len(self._ep_returns)
mean_r = float(np.mean(self._ep_returns)) if n else float("nan")
win_sr = float(np.mean(self._ep_success)) if n else float("nan")
cum_sr = (self._total_success / self._total_eps
if self._total_eps else float("nan"))
print(f" ... [{self.stage_label} | "
f"{self.num_timesteps:>7,} steps | "
f"ret(last {n})={mean_r:+.2f} "
f"win_sr={win_sr*100:.0f}% cum_sr={cum_sr*100:.0f}%]",
flush=True)
return True
# ── Environment factory ──────────────────────────────────────────────────────
def make_env(n_sheep, seed, max_steps, reward_cfg=None):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
reward_cfg=reward_cfg)
env.reset(seed=seed)
return env
return _init
# ── Failure-mode classification ──────────────────────────────────────────────
COMPACT_RADIUS = 5.0
def _classify(ep_radii, ep_com_dists, n_penned, n_sheep):
if n_penned == n_sheep:
return "SUCCESS"
if min(ep_radii) > COMPACT_RADIUS:
return "NEVER_COMPACT"
first = next(i for i, r in enumerate(ep_radii) if r <= COMPACT_RADIUS)
if min(ep_com_dists[first:]) > 3.0:
return "COMPACT_CANT_DRIVE"
if n_penned == 0:
return "DROVE_NO_SHEEP"
return f"PARTIAL_{n_penned}of{n_sheep}"
# ── Evaluation ───────────────────────────────────────────────────────────────
def evaluate(model, vn_template, n_sheep, n_episodes, max_steps,
reward_cfg=None):
"""Evaluate at a given sheep count; returns metrics dict."""
raw = DummyVecEnv([make_env(n_sheep, 9999, max_steps, reward_cfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
successes = 0
ep_lens = []
min_pen_list = []
action_mags = []
failure_counts = {}
rc_sums = {}
rc_n = 0
for _ in range(n_episodes):
obs = vn.reset()
done = False
steps = 0
min_pen = float("inf")
mags = []
ep_radii = []
ep_com_dists = []
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = vn.step(action)
done = dones[0]
inner = vn.envs[0]
com, radius, _ = inner._flock_stats()
min_pen = min(min_pen, float(np.linalg.norm(com - inner.PEN_CENTER)))
mags.append(float(np.linalg.norm(action[0])))
ep_radii.append(radius)
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
steps += 1
rc = infos[0].get("rcomps")
if rc:
for k, v in rc.items():
rc_sums[k] = rc_sums.get(k, 0.0) + v
rc_n += 1
n_penned = infos[0].get("n_penned", 0)
success = n_penned == n_sheep
successes += int(success)
ep_lens.append(steps)
min_pen_list.append(min_pen)
action_mags.extend(mags)
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
vn.close()
result = {
"sr": successes / n_episodes,
"mean_len": float(np.mean(ep_lens)),
"mean_min_pen": float(np.mean(min_pen_list)),
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
"failure_modes": failure_counts,
}
if rc_n > 0:
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
return result
# ── CLI ──────────────────────────────────────────────────────────────────────
DEFAULT_CONFIG = {
"W_PER_SHEEP": 2.0,
"W_ALIGN": 0.05,
"W_PEN_BONUS": 10.0,
"W_COMPLETE": 100.0,
"W_STEP_COST": 0.02,
"W_SOUTH": 0.01,
"W_COMPACT": 0.0,
"W_WALL_TOUCH": 0.04,
"WALL_TOUCH_BUFFER": 0.3,
"ALIGN_SHAPE": "standoff",
"ALIGN_GATED": True,
"ENTRY_AWARE": True,
"ent_coef": 0.02,
}
def parse_args():
p = argparse.ArgumentParser(
description="PPO training for herding task with curriculum learning")
p.add_argument("--config", type=str, default=None,
help="JSON config file (reward weights + ent_coef)")
p.add_argument("--max-sheep", type=int, default=10)
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
p.add_argument("--n-envs", type=int, default=8)
p.add_argument("--max-steps", type=int, default=2500)
p.add_argument("--eval-episodes", type=int, default=30)
p.add_argument("--run-dir", type=str, default=None)
p.add_argument("--no-gif", action="store_true",
help="Skip per-stage GIF rendering (PNGs still produced).")
p.add_argument("--gif-fps", type=int, default=20)
p.add_argument("--gif-skip", type=int, default=3,
help="Keep every Nth frame (smaller GIF; default 3).")
return p.parse_args()
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
# Load config: --config overrides, else auto-load config.json if present
cfg = dict(DEFAULT_CONFIG)
config_path = args.config
if config_path is None and os.path.exists("config.json"):
config_path = "config.json"
if config_path:
with open(config_path) as f:
cfg.update(json.load(f))
print(f"Config loaded from {config_path}")
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
# Run directory
run_dir = args.run_dir or os.path.join(
"runs", time.strftime("%Y%m%d_%H%M%S"))
eval_dir = os.path.join(run_dir, "eval")
os.makedirs(eval_dir, exist_ok=True)
with open(os.path.join(run_dir, "config.json"), "w") as f:
json.dump(cfg, f, indent=2)
print(f"Config: {cfg}")
print(f"Run dir: {run_dir}")
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
f"{args.steps_per_stage:,} steps/stage\n")
# Training envs
train_env = SubprocVecEnv([
make_env(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
for i in range(args.n_envs)
])
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True,
clip_obs=10.0)
# Model — force CPU (PPO with MLP runs faster on CPU than GPU; SB3 warns
# about this otherwise).
model = PPO(
"MlpPolicy", vn,
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
policy_kwargs=dict(net_arch=[256, 256]),
device="cpu",
verbose=0,
)
# Curriculum training
stage_results = []
t0 = time.time()
try:
for n in range(1, args.max_sheep + 1):
if n == 1:
print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps")
model.learn(
total_timesteps=args.steps_per_stage,
reset_num_timesteps=True,
callback=ProgressCallback("1 sheep", freq=100_000),
)
else:
# Mixed transition: half envs stay at n-1, half advance to n,
# for the first half of the stage budget. This prevents the
# n+1 task's noisy early gradients from destroying the n policy
# (catastrophic forgetting) before it has a chance to adapt.
half = max(1, args.n_envs // 2)
for i in range(half):
vn.env_method("set_n_sheep", n - 1, indices=[i])
for i in range(half, args.n_envs):
vn.env_method("set_n_sheep", n, indices=[i])
mix_steps = args.steps_per_stage // 2
full_steps = args.steps_per_stage - mix_steps
print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) "
f"{mix_steps:,} steps")
model.learn(
total_timesteps=mix_steps,
reset_num_timesteps=False,
callback=ProgressCallback(f"{n-1}{n} mix", freq=100_000),
)
vn.env_method("set_n_sheep", n)
print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps")
model.learn(
total_timesteps=full_steps,
reset_num_timesteps=False,
callback=ProgressCallback(f"{n} sheep", freq=100_000),
)
# Evaluate
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
r = evaluate(model, vn, n, args.eval_episodes, args.max_steps, rcfg)
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
f"mean_len={r['mean_len']:.0f} "
f"mean_min_pen={r['mean_min_pen']:.1f}m "
f"mean_act={r['mean_act']:.2f}")
# Failure-mode breakdown
if r["failure_modes"]:
modes = " ".join(
f"{k}={v}" for k, v in sorted(
r["failure_modes"].items(), key=lambda x: -x[1]))
print(f" failure modes: {modes}")
# Reward breakdown
if "reward_per_step" in r:
rps = r["reward_per_step"]
print(f" reward/step: " + " ".join(
f"{k}={v:+.4f}" for k, v in rps.items()))
# Episode visualisation: trajectory + timeseries + animated GIF
hist = run_and_record(model, vn, n, args.max_steps, rcfg,
seed=1000 + n)
tag = "success" if hist["success"] else "fail"
plot_trajectory(
hist,
os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
plot_timeseries(
hist,
os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
if not args.no_gif:
save_episode_gif(
hist,
os.path.join(eval_dir, f"ep_{n}s_{tag}.gif"),
fps=args.gif_fps, skip=args.gif_skip)
r["n_sheep"] = n
stage_results.append(r)
# Save artefacts
model.save(os.path.join(run_dir, "final_model"))
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
json.dump(stage_results, f, indent=2)
finally:
try:
vn.close()
except Exception:
pass
# Summary
elapsed = (time.time() - t0) / 60
print("\n" + "=" * 70)
print(" TRAINING SUMMARY")
print("=" * 70)
for r in stage_results:
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
f"len={r['mean_len']:>5.0f} min_pen={r['mean_min_pen']:>5.1f}m "
f"act={r['mean_act']:.2f}")
print(f"\n Total time: {elapsed:.1f} min")
print(f" Artefacts: {run_dir}/")
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
if __name__ == "__main__":
main()
-412
View File
@@ -1,412 +0,0 @@
"""
PPO training with attention-based policy (train_at.py).
Key difference from train.py
-----------------------------
- Observation exposes ALL sheep as individual per-sheep tokens rather than
only the top-3 farthest. The policy therefore has complete flock visibility
at any sheep count — no hidden sheep even at n=10.
- A TransformerFeaturesExtractor processes the sheep tokens with multi-head
self-attention (permutation-invariant), then mean-pools over valid tokens
and concatenates the result with global dog/pen features.
- Curriculum transition uses the same mixed-env approach as train.py: half
the envs stay at n-1 for the first half of each new stage to suppress
catastrophic forgetting.
Observation layout (7 + MAX_SHEEP*6 = 67 dims, fixed)
-------------------------------------------------------
Global (7):
dog_x / FIELD, dog_y / FIELD,
cos(heading), sin(heading),
(pen_x - dog_x) / D, (pen_y - dog_y) / D,
n_active / n_sheep
Per sheep i (6):
(sheep_x - dog_x) / D, (sheep_y - dog_y) / D, ← pos rel to dog
(pen_x - sheep_x) / D, (pen_y - sheep_y) / D, ← sheep-to-pen
is_active 1.0 if not penned, else 0.0
is_valid 1.0 if i < n_sheep, else 0.0 (padding sentinel)
After VecNormalize, is_valid for real sheep normalises > 0 and for
padding tokens < 0 (because mean ∈ (0,1)), so a threshold of 0 cleanly
separates real from padded without any extra bookkeeping.
Usage
-----
python train_at.py # defaults from config.json
python train_at.py --max-sheep 10 --steps-per-stage 2000000
python train_at.py --embed-dim 128 --n-heads 4 --n-layers 3
"""
import argparse
import json
import os
import time
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize
from herding_env import HerdingEnv
from train import ProgressCallback, _classify, COMPACT_RADIUS, DEFAULT_CONFIG
from viz import (
run_and_record, plot_trajectory, plot_timeseries,
plot_success_rate, save_episode_gif,
)
# ── Per-sheep token observation environment ───────────────────────────────────
class HerdingEnvAt(HerdingEnv):
"""
HerdingEnv with a per-sheep token observation for the attention policy.
Everything else (dynamics, reward, curriculum interface) is inherited.
"""
OBS_GLOBAL = 7
OBS_SHEEP = 6
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
obs_dim = self.OBS_GLOBAL + self.MAX_SHEEP * self.OBS_SHEEP
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32
)
def _obs(self) -> np.ndarray:
S = self.FIELD
D = 2.0 * self.FIELD
pen_ref = self.PEN_ENTRY if self.ENTRY_AWARE else self.PEN_CENTER
active_mask = ~self.penned[:self.n_sheep]
n_active = int(active_mask.sum())
global_feats = np.array([
self.dog_pos[0] / S,
self.dog_pos[1] / S,
float(np.cos(self.dog_heading)),
float(np.sin(self.dog_heading)),
(pen_ref[0] - self.dog_pos[0]) / D,
(pen_ref[1] - self.dog_pos[1]) / D,
n_active / max(self.n_sheep, 1),
], dtype=np.float32)
sheep_feats = np.zeros((self.MAX_SHEEP, self.OBS_SHEEP), dtype=np.float32)
for i in range(self.n_sheep):
pos = self.sheep_pos[i]
sheep_feats[i] = [
(pos[0] - self.dog_pos[0]) / D,
(pos[1] - self.dog_pos[1]) / D,
(pen_ref[0] - pos[0]) / D,
(pen_ref[1] - pos[1]) / D,
float(not self.penned[i]),
1.0, # is_valid: this sheep exists
]
# i >= n_sheep: all zeros, is_valid=0 → masked out in attention
return np.concatenate([global_feats, sheep_feats.ravel()])
# ── Attention features extractor ──────────────────────────────────────────────
class ShepherdAttentionExtractor(BaseFeaturesExtractor):
"""
Multi-head self-attention over per-sheep tokens, mean-pooled over valid
(non-padding) tokens and concatenated with global dog/pen features.
After VecNormalize:
real sheep → is_valid_norm > 0 (normalised from 1.0)
padding → is_valid_norm ≤ 0 (normalised from 0.0)
so threshold at 0 is always correct regardless of curriculum stage.
"""
GLOBAL_DIM = HerdingEnvAt.OBS_GLOBAL # 7
SHEEP_DIM = HerdingEnvAt.OBS_SHEEP # 6
MAX_SHEEP = HerdingEnv.MAX_SHEEP # 10
VALID_IDX = 5 # index of is_valid within each token
def __init__(self, observation_space, embed_dim: int = 64,
n_heads: int = 4, n_layers: int = 2, ff_dim: int = 128):
super().__init__(observation_space,
features_dim=self.GLOBAL_DIM + embed_dim)
self.sheep_embed = nn.Linear(self.SHEEP_DIM, embed_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=n_heads, dim_feedforward=ff_dim,
dropout=0.0, batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer,
num_layers=n_layers,
enable_nested_tensor=False)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
B = obs.shape[0]
global_feats = obs[:, :self.GLOBAL_DIM] # (B, 7)
tokens = obs[:, self.GLOBAL_DIM:].view(
B, self.MAX_SHEEP, self.SHEEP_DIM) # (B, 10, 6)
# is_valid after VecNorm: real > 0, padding ≤ 0
is_valid_norm = tokens[:, :, self.VALID_IDX] # (B, 10)
key_padding_mask = is_valid_norm <= 0.0 # True → ignore
x = self.sheep_embed(tokens) # (B, 10, E)
x = self.transformer(x, src_key_padding_mask=key_padding_mask)
valid_w = (is_valid_norm > 0.0).float().unsqueeze(-1) # (B, 10, 1)
pooled = (x * valid_w).sum(1) / valid_w.sum(1).clamp(min=1.0)
return torch.cat([global_feats, pooled], dim=1) # (B, 7+E)
# ── Environment factory ───────────────────────────────────────────────────────
def make_env_at(n_sheep, seed, max_steps, reward_cfg=None):
def _init():
env = HerdingEnvAt(n_sheep=n_sheep, max_steps=max_steps,
reward_cfg=reward_cfg)
env.reset(seed=seed)
return env
return _init
# ── Evaluation ────────────────────────────────────────────────────────────────
def evaluate_at(model, vn_template, n_sheep, n_episodes, max_steps,
reward_cfg=None):
raw = DummyVecEnv([make_env_at(n_sheep, 9999, max_steps, reward_cfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
successes = 0
ep_lens, min_pen_list, action_mags = [], [], []
failure_counts, rc_sums = {}, {}
rc_n = 0
for _ in range(n_episodes):
obs = vn.reset()
done = False
steps, min_pen = 0, float("inf")
mags, ep_radii, ep_com_dists = [], [], []
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = vn.step(action)
done = dones[0]
inner = vn.envs[0]
com, radius, _ = inner._flock_stats()
min_pen = min(min_pen,
float(np.linalg.norm(com - inner.PEN_CENTER)))
mags.append(float(np.linalg.norm(action[0])))
ep_radii.append(radius)
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
steps += 1
rc = infos[0].get("rcomps")
if rc:
for k, v in rc.items():
rc_sums[k] = rc_sums.get(k, 0.0) + v
rc_n += 1
n_penned = infos[0].get("n_penned", 0)
successes += int(n_penned == n_sheep)
ep_lens.append(steps)
min_pen_list.append(min_pen)
action_mags.extend(mags)
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
vn.close()
result = {
"sr": successes / n_episodes,
"mean_len": float(np.mean(ep_lens)),
"mean_min_pen": float(np.mean(min_pen_list)),
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
"failure_modes": failure_counts,
}
if rc_n > 0:
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
return result
# ── CLI ───────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser(
description="PPO + attention training for herding task")
p.add_argument("--config", type=str, default=None)
p.add_argument("--max-sheep", type=int, default=10)
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
p.add_argument("--n-envs", type=int, default=8)
p.add_argument("--max-steps", type=int, default=2500)
p.add_argument("--eval-episodes", type=int, default=30)
p.add_argument("--run-dir", type=str, default=None)
p.add_argument("--no-gif", action="store_true")
p.add_argument("--gif-fps", type=int, default=20)
p.add_argument("--gif-skip", type=int, default=3)
# Attention architecture
p.add_argument("--embed-dim", type=int, default=64,
help="Transformer embedding dimension (default 64)")
p.add_argument("--n-heads", type=int, default=4,
help="Number of attention heads (default 4)")
p.add_argument("--n-layers", type=int, default=2,
help="Number of transformer encoder layers (default 2)")
p.add_argument("--ff-dim", type=int, default=128,
help="Transformer feed-forward dim (default 128)")
return p.parse_args()
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
cfg = dict(DEFAULT_CONFIG)
config_path = args.config
if config_path is None and os.path.exists("config.json"):
config_path = "config.json"
if config_path:
with open(config_path) as f:
cfg.update(json.load(f))
print(f"Config loaded from {config_path}")
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
run_dir = args.run_dir or os.path.join(
"runs", "at_" + time.strftime("%Y%m%d_%H%M%S"))
eval_dir = os.path.join(run_dir, "eval")
os.makedirs(eval_dir, exist_ok=True)
with open(os.path.join(run_dir, "config.json"), "w") as f:
json.dump(cfg, f, indent=2)
print(f"Config: {cfg}")
print(f"Run dir: {run_dir}")
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
f"{args.steps_per_stage:,} steps/stage")
print(f"Transformer: embed={args.embed_dim} heads={args.n_heads} "
f"layers={args.n_layers} ff={args.ff_dim}\n")
train_env = SubprocVecEnv([
make_env_at(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
for i in range(args.n_envs)
])
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
model = PPO(
"MlpPolicy", vn,
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
policy_kwargs=dict(
features_extractor_class=ShepherdAttentionExtractor,
features_extractor_kwargs=dict(
embed_dim=args.embed_dim,
n_heads=args.n_heads,
n_layers=args.n_layers,
ff_dim=args.ff_dim,
),
net_arch=[256, 256],
),
device="cpu",
verbose=0,
)
stage_results = []
t0 = time.time()
try:
for n in range(1, args.max_sheep + 1):
if n == 1:
print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps")
model.learn(
total_timesteps=args.steps_per_stage,
reset_num_timesteps=True,
callback=ProgressCallback("1 sheep", freq=100_000),
)
else:
half = max(1, args.n_envs // 2)
mix_steps = args.steps_per_stage // 2
full_steps = args.steps_per_stage - mix_steps
for i in range(half):
vn.env_method("set_n_sheep", n - 1, indices=[i])
for i in range(half, args.n_envs):
vn.env_method("set_n_sheep", n, indices=[i])
print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) "
f"{mix_steps:,} steps")
model.learn(
total_timesteps=mix_steps,
reset_num_timesteps=False,
callback=ProgressCallback(f"{n-1}{n} mix", freq=100_000),
)
vn.env_method("set_n_sheep", n)
print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps")
model.learn(
total_timesteps=full_steps,
reset_num_timesteps=False,
callback=ProgressCallback(f"{n} sheep", freq=100_000),
)
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
r = evaluate_at(model, vn, n, args.eval_episodes,
args.max_steps, rcfg)
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
f"mean_len={r['mean_len']:.0f} "
f"mean_min_pen={r['mean_min_pen']:.1f}m "
f"mean_act={r['mean_act']:.2f}")
if r["failure_modes"]:
modes = " ".join(
f"{k}={v}" for k, v in sorted(
r["failure_modes"].items(), key=lambda x: -x[1]))
print(f" failure modes: {modes}")
if "reward_per_step" in r:
rps = r["reward_per_step"]
print(" reward/step: " + " ".join(
f"{k}={v:+.4f}" for k, v in rps.items()))
hist = run_and_record(
model, vn, n, args.max_steps, rcfg,
seed=1000 + n, make_env_fn=make_env_at,
)
tag = "success" if hist["success"] else "fail"
plot_trajectory(hist, os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
plot_timeseries(hist, os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
if not args.no_gif:
save_episode_gif(
hist,
os.path.join(eval_dir, f"ep_{n}s_{tag}.gif"),
fps=args.gif_fps, skip=args.gif_skip)
r["n_sheep"] = n
stage_results.append(r)
model.save(os.path.join(run_dir, "final_model"))
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
json.dump(stage_results, f, indent=2)
finally:
try:
vn.close()
except Exception:
pass
elapsed = (time.time() - t0) / 60
print("\n" + "=" * 70)
print(" TRAINING SUMMARY (attention policy)")
print("=" * 70)
for r in stage_results:
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
f"len={r['mean_len']:>5.0f} "
f"min_pen={r['mean_min_pen']:>5.1f}m "
f"act={r['mean_act']:.2f}")
print(f"\n Total time: {elapsed:.1f} min")
print(f" Artefacts: {run_dir}/")
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
if __name__ == "__main__":
main()
+267
View File
@@ -0,0 +1,267 @@
"""Train a PPO shepherd-dog policy on ``HerdingEnv`` with curriculum.
Defaults to 16 parallel ``SubprocVecEnv`` workers feeding a GPU policy.
Saves checkpoints, the best-eval model, and the VecNormalize stats —
all three are needed at inference time by the Webots controller.
Usage::
python -m training.train_ppo \
--config training/configs/ppo_default.yaml \
--out-dir training/runs/baseline
To resume from a checkpoint::
python -m training.train_ppo --resume training/runs/baseline/checkpoints/ppo_500000_steps.zip
"""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
import yaml
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
import torch as th
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import (
BaseCallback, CheckpointCallback, EvalCallback,
)
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import (
DummyVecEnv, SubprocVecEnv, VecNormalize,
)
from training.herding_env import HerdingEnv
# --------------------------------------------------------------------------
# Env factories
# --------------------------------------------------------------------------
def _make_env(rank: int, seed: int = 0):
def _thunk():
env = HerdingEnv(seed=seed + rank)
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
return env
return _thunk
# --------------------------------------------------------------------------
# Curriculum callback
# --------------------------------------------------------------------------
class CurriculumCallback(BaseCallback):
"""Drive the env's flock-size + state-space difficulty curriculum.
Schedule entries: {step, max_n_sheep, difficulty}. The largest entry
whose step <= num_timesteps wins; both knobs update together.
"""
def __init__(self, schedule, vec_envs, verbose: int = 0):
super().__init__(verbose)
self.schedule = sorted(schedule, key=lambda d: d["step"])
# Accept a list of envs so the eval env tracks training difficulty.
self.vec_envs = vec_envs if isinstance(vec_envs, (list, tuple)) else [vec_envs]
self._last_n = None
self._last_d = None
def _call(self, method, value):
for v in self.vec_envs:
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
def _on_step(self) -> bool:
t = self.num_timesteps
n = self.schedule[0]["max_n_sheep"]
d = self.schedule[0].get("difficulty", 1.0)
for entry in self.schedule:
if t >= entry["step"]:
n = entry["max_n_sheep"]
d = entry.get("difficulty", 1.0)
if n != self._last_n:
self._call("set_max_n_sheep", n)
self._last_n = n
if d != self._last_d:
self._call("set_difficulty", d)
self._last_d = d
if self.verbose:
print(f"[curriculum] t={t} → max_n_sheep={n} difficulty={d}")
return True
# --------------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=os.path.join(_HERE, "configs", "ppo_default.yaml"))
parser.add_argument("--out-dir", default=os.path.join(_HERE, "runs", "latest"))
parser.add_argument("--n-envs", type=int, default=None,
help="Override config n_envs.")
parser.add_argument("--total-timesteps", type=int, default=None,
help="Override config total_timesteps.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--resume", type=str, default=None,
help="Path to a SB3 zip to resume from.")
# SB3 recommends CPU for MlpPolicy — GPU helps CNN policies, not MLPs
# of this size. Override with --device cuda if you really want it.
parser.add_argument("--device", default="cpu")
parser.add_argument("--no-vecnorm", action="store_true",
help="Disable VecNormalize wrapper. Required when "
"resuming from a BC-pretrained policy that "
"wasn't trained under it.")
parser.add_argument("--no-curriculum", action="store_true",
help="Skip curriculum callback (resumed policy is "
"already competent across the distribution).")
parser.add_argument("--imitate-weight", type=float, default=None,
help="Override env W_IMITATE. Set to 0 to disable "
"Strömbom imitation reward.")
parser.add_argument("--difficulty", type=float, default=None,
help="Override env difficulty (0=easy, 1=hard). "
"Used in BC fine-tune to skip easy curriculum.")
parser.add_argument("--log-std", type=float, default=None,
help="Override the policy's log_std after load. "
"BC trained with std≈1.6 (log_std=0.5) which "
"is too noisy for fine-tune. Use -1.5 (std≈0.22) "
"to keep PPO close to the BC mean while still "
"exploring locally.")
parser.add_argument("--learning-rate", type=float, default=None,
help="Override config learning rate. For BC "
"fine-tune, 5e-5 is much safer than the 3e-4 "
"default.")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
n_envs = args.n_envs or cfg["n_envs"]
total_timesteps = args.total_timesteps or cfg["total_timesteps"]
out = Path(args.out_dir)
out.mkdir(parents=True, exist_ok=True)
(out / "checkpoints").mkdir(exist_ok=True)
(out / "best").mkdir(exist_ok=True)
(out / "evals").mkdir(exist_ok=True)
print(f"[train] out={out} n_envs={n_envs} total={total_timesteps} device={args.device}")
# --- Train env (vectorised, optionally normalised) ---
env_fns = [_make_env(i, seed=args.seed) for i in range(n_envs)]
venv = SubprocVecEnv(env_fns) if n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, seed=args.seed + 999)])
if not args.no_vecnorm:
venv = VecNormalize(venv, norm_obs=True, norm_reward=False, clip_obs=10.0)
eval_venv = VecNormalize(eval_venv, norm_obs=True, norm_reward=False,
clip_obs=10.0, training=False)
eval_venv.obs_rms = venv.obs_rms
else:
print("[train] VecNormalize disabled (resumed policy was trained without it).")
# Apply env-level overrides (used by BC fine-tune to disable Strömbom
# imitation and start at full deployment difficulty).
def _env_call(method, value):
for v in (venv, eval_venv):
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
if args.imitate_weight is not None:
_env_call("set_imitate_weight", args.imitate_weight)
print(f"[train] W_IMITATE overridden to {args.imitate_weight}")
if args.difficulty is not None:
_env_call("set_difficulty", args.difficulty)
print(f"[train] difficulty pinned to {args.difficulty}")
# --- Model ---
policy_kwargs = dict(
net_arch=dict(pi=cfg["net_arch_pi"], vf=cfg["net_arch_vf"]),
log_std_init=cfg.get("log_std_init", 0.0),
)
if args.resume:
print(f"[train] resuming from {args.resume}")
custom_objects = {}
if args.learning_rate is not None:
custom_objects["learning_rate"] = args.learning_rate
model = PPO.load(args.resume, env=venv, device=args.device,
tensorboard_log=str(out / "tb"),
custom_objects=custom_objects or None)
if args.log_std is not None:
import torch as _th
with _th.no_grad():
model.policy.log_std.fill_(args.log_std)
print(f"[train] log_std overridden to {args.log_std} "
f"(std≈{2.71828 ** args.log_std:.2f})")
if args.learning_rate is not None:
print(f"[train] learning_rate overridden to {args.learning_rate}")
else:
model = PPO(
cfg["policy"], venv,
learning_rate=cfg["learning_rate"],
n_steps=cfg["n_steps"],
batch_size=cfg["batch_size"],
n_epochs=cfg["n_epochs"],
gamma=cfg["gamma"],
gae_lambda=cfg["gae_lambda"],
clip_range=cfg["clip_range"],
ent_coef=cfg["ent_coef"],
vf_coef=cfg["vf_coef"],
max_grad_norm=cfg["max_grad_norm"],
target_kl=cfg.get("target_kl"),
policy_kwargs=policy_kwargs,
tensorboard_log=str(out / "tb"),
seed=args.seed,
device=args.device,
verbose=1,
)
# --- Callbacks ---
ckpt_cb = CheckpointCallback(
save_freq=max(1, cfg["checkpoint_freq"] // n_envs),
save_path=str(out / "checkpoints"), name_prefix="ppo",
save_vecnormalize=True,
)
eval_cb = EvalCallback(
eval_venv,
best_model_save_path=str(out / "best"),
log_path=str(out / "evals"),
eval_freq=max(1, cfg["eval_freq"] // n_envs),
n_eval_episodes=cfg["n_eval_episodes"],
deterministic=True,
)
callbacks = [ckpt_cb, eval_cb]
if not args.no_curriculum and "curriculum" in cfg and cfg["curriculum"]:
callbacks.append(CurriculumCallback(
cfg["curriculum"], [venv, eval_venv], verbose=1,
))
elif args.no_curriculum:
print("[train] curriculum disabled — env knobs left at their current values.")
# --- Train ---
model.learn(total_timesteps=total_timesteps, callback=callbacks,
progress_bar=True)
# --- Save final model + VecNormalize stats ---
model.save(out / "final.zip")
venv.save(str(out / "vecnormalize.pkl"))
# The EvalCallback already wrote best_model.zip into out/best/ — drop the
# VecNormalize stats next to it for the controller to pick up.
venv.save(str(out / "best" / "vecnormalize.pkl"))
print(f"[train] done. saved to {out}")
if __name__ == "__main__":
main()
-342
View File
@@ -1,342 +0,0 @@
"""
All visualization for the herding policy: trajectory plots, timeseries plots,
success-rate bar chart, and animated GIFs.
Used both by train.py (auto-rendered after each curriculum stage) and as a CLI
to render a fresh episode against a saved model.
CLI usage:
python viz.py --run-dir runs/v1 --n-sheep 5
python viz.py --run-dir runs/v1 --n-sheep 10 --no-gif
python viz.py --model runs/v1/final_model.zip --vecnorm runs/v1/vecnorm.pkl \\
--n-sheep 3 --out-dir vis_v1_3sheep
"""
import argparse
import os
import json
from copy import deepcopy
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.animation as animation
from matplotlib.collections import LineCollection
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from herding_env import HerdingEnv
# ── Palette ──────────────────────────────────────────────────────────────────
SHEEP_COLORS = [
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00",
"#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62",
]
DOG_COLOR = "#4e342e"
# ── Common drawing primitives ────────────────────────────────────────────────
def draw_field(ax):
ax.set_xlim(-16, 16)
ax.set_ylim(-16, 16)
ax.set_aspect("equal")
ax.set_facecolor("#dcedc8")
ax.add_patch(mpatches.Rectangle(
(-15, -15), 30, 30, fill=False, edgecolor="#795548", lw=2))
ax.add_patch(mpatches.Rectangle(
(10, -15), 3, 7, facecolor="#ffe082", edgecolor="#795548", lw=2))
ax.text(11.5, -11.5, "pen", ha="center", va="center",
fontsize=8, color="#795548")
def faded_path(ax, xs, ys, color, lw=1.5, label=None):
n = len(xs)
if n < 2:
return
points = np.array([xs, ys]).T.reshape(-1, 1, 2)
segs = np.concatenate([points[:-1], points[1:]], axis=1)
alphas = np.linspace(0.15, 1.0, len(segs))
colors = [(*matplotlib.colors.to_rgb(color), a) for a in alphas]
ax.add_collection(LineCollection(segs, colors=colors, linewidth=lw))
if label:
ax.plot([], [], color=color, lw=lw, label=label)
# ── Episode rollout ──────────────────────────────────────────────────────────
def make_eval_env(n_sheep, seed, max_steps, reward_cfg=None):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
reward_cfg=reward_cfg)
env.reset(seed=seed)
return env
return _init
def run_and_record(model, vn_template, n_sheep, max_steps,
reward_cfg=None, seed=42, make_env_fn=None):
"""Run one deterministic episode and return full trajectory history."""
_factory = make_env_fn or make_eval_env
raw = DummyVecEnv([_factory(n_sheep, seed, max_steps, reward_cfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
obs = vn.reset()
inner = vn.envs[0]
done = False
dog_xs, dog_ys = [], []
sheep_xs = [[] for _ in range(n_sheep)]
sheep_ys = [[] for _ in range(n_sheep)]
sheep_penned = [[] for _ in range(n_sheep)]
radii = []
pen_dists = [[] for _ in range(n_sheep)]
action_mags = []
rewards = []
penned_at = [None] * n_sheep
step = 0
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, dones, infos = vn.step(action)
done = dones[0]
step += 1
dog_xs.append(float(inner.dog_pos[0]))
dog_ys.append(float(inner.dog_pos[1]))
com, radius, _ = inner._flock_stats()
radii.append(radius)
rewards.append(float(reward[0]))
action_mags.append(float(np.linalg.norm(action[0])))
for i in range(n_sheep):
sheep_xs[i].append(float(inner.sheep_pos[i][0]))
sheep_ys[i].append(float(inner.sheep_pos[i][1]))
sheep_penned[i].append(bool(inner.penned[i]))
pen_dists[i].append(
float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
if inner.penned[i] and penned_at[i] is None:
penned_at[i] = step
n_penned = infos[0].get("n_penned", 0)
vn.close()
return dict(
dog_xs=dog_xs, dog_ys=dog_ys,
sheep_xs=sheep_xs, sheep_ys=sheep_ys,
sheep_penned=sheep_penned,
radii=radii, pen_dists=pen_dists,
action_mags=action_mags, rewards=rewards,
penned_at=penned_at,
n_penned=n_penned, n_sheep=n_sheep,
success=n_penned == n_sheep, steps=step,
)
# ── Static plots ─────────────────────────────────────────────────────────────
def plot_trajectory(hist, out_path):
fig, ax = plt.subplots(figsize=(7, 7))
draw_field(ax)
for i in range(hist["n_sheep"]):
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
xs, ys = hist["sheep_xs"][i], hist["sheep_ys"][i]
faded_path(ax, xs, ys, c, lw=1.2, label=f"sheep {i+1}")
ax.plot(xs[0], ys[0], "o", color=c, ms=7, zorder=4)
end = hist["penned_at"][i] if hist["penned_at"][i] is not None else -1
ax.plot(xs[end], ys[end], "*", color=c, ms=11, zorder=5)
faded_path(ax, hist["dog_xs"], hist["dog_ys"], DOG_COLOR, lw=2.0,
label="dog")
ax.plot(hist["dog_xs"][0], hist["dog_ys"][0], "s", color=DOG_COLOR,
ms=10, zorder=5)
ax.plot(hist["dog_xs"][-1], hist["dog_ys"][-1], "D", color=DOG_COLOR,
ms=10, zorder=5)
result = ("SUCCESS" if hist["success"]
else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})")
ax.set_title(f"n={hist['n_sheep']} {result} {hist['steps']} steps",
fontsize=12)
ax.legend(loc="upper left", fontsize=8)
plt.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
def plot_timeseries(hist, out_path):
t = np.arange(hist["steps"])
fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
axes[0].plot(t, hist["radii"], color="steelblue")
axes[0].axhline(5.0, color="orange", ls="--", lw=1, label="compact (5m)")
axes[0].set_ylabel("flock radius (m)")
axes[0].legend(fontsize=8)
axes[0].set_title("Flock radius")
for i in range(hist["n_sheep"]):
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
axes[1].plot(t, hist["pen_dists"][i], color=c, lw=1,
label=f"sheep {i+1}")
if hist["penned_at"][i] is not None:
axes[1].axvline(hist["penned_at"][i], color=c, ls=":", lw=1)
axes[1].set_ylabel("dist to pen (m)")
axes[1].legend(fontsize=7, ncol=min(hist["n_sheep"], 5))
axes[1].set_title("Per-sheep distance to pen")
axes[2].plot(t, hist["action_mags"], color="tomato", lw=1)
axes[2].axhline(1.0, color="gray", ls="--", lw=1, label="max")
axes[2].set_ylabel("action ||(vx,vy)||")
axes[2].set_ylim(0, 1.5)
axes[2].set_title("Dog action magnitude")
axes[2].legend(fontsize=8)
axes[3].plot(t, hist["rewards"], color="purple", lw=1, alpha=0.7)
axes[3].axhline(0, color="black", lw=0.5)
axes[3].set_ylabel("reward")
axes[3].set_xlabel("step")
axes[3].set_title("Reward per step")
result = ("SUCCESS" if hist["success"]
else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})")
fig.suptitle(f"n_sheep={hist['n_sheep']} {result} {hist['steps']} steps",
fontsize=13)
plt.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
def plot_success_rate(stage_results, out_path):
fig, ax = plt.subplots(figsize=(8, 4))
ns = [r["n_sheep"] for r in stage_results]
srs = [r["sr"] * 100 for r in stage_results]
bars = ax.bar(ns, srs, color="steelblue", edgecolor="white")
ax.set_xlabel("Sheep count")
ax.set_ylabel("Success rate (%)")
ax.set_ylim(0, 105)
ax.axhline(90, color="orange", ls="--", lw=1, label="90% target")
for bar, sr in zip(bars, srs):
ax.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() + 1, f"{sr:.0f}%",
ha="center", fontsize=9)
ax.legend()
ax.set_title("Evaluation success rate per sheep count")
plt.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
# ── Animated GIF ─────────────────────────────────────────────────────────────
def save_episode_gif(hist, out_path, fps=20, skip=3):
"""Render hist as an animated GIF. `skip` keeps every Nth frame (smaller file)."""
n_sheep = hist["n_sheep"]
frames = list(range(0, hist["steps"], max(1, skip)))
if frames[-1] != hist["steps"] - 1:
frames.append(hist["steps"] - 1)
fig, ax = plt.subplots(figsize=(6, 6))
draw_field(ax)
title = ax.text(0, 16.5, "", ha="center", fontsize=11)
dog_marker, = ax.plot([], [], "s", color=DOG_COLOR, ms=12,
markeredgecolor="black", markeredgewidth=1.5,
zorder=5)
sheep_markers = []
for i in range(n_sheep):
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
m, = ax.plot([], [], "o", color=c, ms=10,
markeredgecolor="#333", markeredgewidth=1, zorder=4)
sheep_markers.append(m)
dog_trail, = ax.plot([], [], color=DOG_COLOR, lw=1.0, alpha=0.5)
def update(k):
title.set_text(
f"n={n_sheep} step {k+1}/{hist['steps']} "
f"penned {sum(hist['sheep_penned'][i][k] for i in range(n_sheep))}/{n_sheep}")
dog_marker.set_data([hist["dog_xs"][k]], [hist["dog_ys"][k]])
dog_trail.set_data(hist["dog_xs"][:k+1], hist["dog_ys"][:k+1])
for i, m in enumerate(sheep_markers):
m.set_data([hist["sheep_xs"][i][k]], [hist["sheep_ys"][i][k]])
penned = hist["sheep_penned"][i][k]
m.set_color("deeppink" if penned else SHEEP_COLORS[i % len(SHEEP_COLORS)])
return [title, dog_marker, dog_trail, *sheep_markers]
anim = animation.FuncAnimation(
fig, update, frames=frames, interval=1000 / fps, blit=False)
anim.save(out_path, writer=animation.PillowWriter(fps=fps), dpi=80)
plt.close(fig)
# ── CLI ──────────────────────────────────────────────────────────────────────
def _resolve_paths(args):
if args.run_dir:
model_path = os.path.join(args.run_dir, "final_model.zip")
vn_path = os.path.join(args.run_dir, "vecnorm.pkl")
cfg_path = os.path.join(args.run_dir, "config.json")
else:
model_path = args.model
vn_path = args.vecnorm
cfg_path = args.config
return model_path, vn_path, cfg_path
def main():
p = argparse.ArgumentParser(
description="Render trajectory + timeseries + GIF for a saved policy.")
p.add_argument("--run-dir", type=str, default=None,
help="Run directory containing final_model.zip + vecnorm.pkl + config.json")
p.add_argument("--model", type=str, default=None)
p.add_argument("--vecnorm", type=str, default=None)
p.add_argument("--config", type=str, default=None)
p.add_argument("--n-sheep", type=int, default=3)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--max-steps", type=int, default=2500)
p.add_argument("--out-dir", type=str, default=None)
p.add_argument("--no-gif", action="store_true",
help="Skip the animated GIF (PNG-only is faster).")
p.add_argument("--gif-fps", type=int, default=20)
p.add_argument("--gif-skip", type=int, default=3)
args = p.parse_args()
model_path, vn_path, cfg_path = _resolve_paths(args)
if not (model_path and vn_path):
p.error("either --run-dir or both --model and --vecnorm are required")
rcfg = None
if cfg_path and os.path.exists(cfg_path):
with open(cfg_path) as f:
cfg = json.load(f)
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
out_dir = args.out_dir or os.path.join(
os.path.dirname(os.path.abspath(model_path)),
f"vis_{args.n_sheep}s")
os.makedirs(out_dir, exist_ok=True)
print(f"Loading model: {model_path}")
print(f"Loading vecnorm: {vn_path}")
model = PPO.load(model_path, device="cpu")
raw = DummyVecEnv([make_eval_env(args.n_sheep, args.seed, args.max_steps, rcfg)])
vn = VecNormalize.load(vn_path, raw)
print(f"Rolling out n_sheep={args.n_sheep} (seed={args.seed})...")
hist = run_and_record(model, vn, args.n_sheep, args.max_steps,
reward_cfg=rcfg, seed=args.seed)
result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})"
print(f" {result} in {hist['steps']} steps")
plot_trajectory(hist, os.path.join(out_dir, "trajectory.png"))
plot_timeseries(hist, os.path.join(out_dir, "timeseries.png"))
print(f" saved trajectory.png + timeseries.png to {out_dir}/")
if not args.no_gif:
gif_path = os.path.join(out_dir, "episode.gif")
print(f" rendering GIF (fps={args.gif_fps}, skip={args.gif_skip})...")
save_episode_gif(hist, gif_path, fps=args.gif_fps, skip=args.gif_skip)
print(f" saved {gif_path}")
if __name__ == "__main__":
main()