Checkpoint 2
This commit is contained in:
@@ -0,0 +1,115 @@
|
||||
# Shepherd Herding — Training & Inference
|
||||
|
||||
This directory holds the Gymnasium environment, PPO training script, and
|
||||
evaluation harness for the RL shepherd-dog policy. The Webots controller
|
||||
in `controllers/shepherd_dog/` loads the resulting policy at inference
|
||||
time when launched with `HERDING_MODE=rl`.
|
||||
|
||||
## Layout
|
||||
|
||||
```
|
||||
training/
|
||||
├── herding_env.py # gymnasium.Env — the dog is the agent
|
||||
├── train_ppo.py # SB3 PPO entry point (vec envs, eval, curriculum)
|
||||
├── eval.py # rollout success-rate / time-to-pen across flock sizes
|
||||
├── parity_test.py # smoke test: shapes, determinism, baseline rollout
|
||||
├── configs/ppo_default.yaml
|
||||
├── runs/ # tensorboard + checkpoints (gitignored)
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
python -m venv .venv && source .venv/bin/activate
|
||||
pip install -r training/requirements.txt
|
||||
```
|
||||
|
||||
CPU is the default and also the recommended device — SB3's PPO with an
|
||||
MLP policy of this size runs faster on CPU than on GPU because the
|
||||
bottleneck is rollout collection, not gradient compute. The 16 SubprocVecEnv
|
||||
workers saturate ~16 CPU cores. To force CUDA anyway, pass `--device cuda`.
|
||||
|
||||
## Train
|
||||
|
||||
```bash
|
||||
# Full curriculum (1 → 10 sheep), ~5M steps, ~2–3h on a single GPU.
|
||||
python -m training.train_ppo \
|
||||
--config training/configs/ppo_default.yaml \
|
||||
--out-dir training/runs/baseline
|
||||
```
|
||||
|
||||
Outputs:
|
||||
- `training/runs/baseline/best/best_model.zip` — best eval checkpoint
|
||||
- `training/runs/baseline/best/vecnormalize.pkl` — observation stats
|
||||
- `training/runs/baseline/checkpoints/ppo_*.zip` — periodic checkpoints
|
||||
- `training/runs/baseline/tb/` — TensorBoard logs (`tensorboard --logdir`)
|
||||
|
||||
To resume:
|
||||
|
||||
```bash
|
||||
python -m training.train_ppo --resume training/runs/baseline/checkpoints/ppo_500000_steps.zip
|
||||
```
|
||||
|
||||
## Evaluate
|
||||
|
||||
```bash
|
||||
# RL policy
|
||||
python -m training.eval --policy training/runs/baseline/best
|
||||
|
||||
# Strömbom baseline
|
||||
python -m training.eval --policy strombom
|
||||
```
|
||||
|
||||
Prints success rate, mean steps, and mean penned-count per flock size.
|
||||
Use the same `--n-seeds` for both to get a fair RL-vs-Strömbom A/B.
|
||||
|
||||
## Parity / smoke test
|
||||
|
||||
```bash
|
||||
python -m training.parity_test
|
||||
```
|
||||
|
||||
Checks observation/action shapes, deterministic seeding, the curriculum
|
||||
sampler, and a 400-step Strömbom rollout. Run this before every long
|
||||
training job — catches the boring class of bugs in seconds.
|
||||
|
||||
## Run the policy in Webots
|
||||
|
||||
1. Train (above) — produces `training/runs/<name>/best/`.
|
||||
2. In Webots, set the dog controller's environment variables:
|
||||
|
||||
```bash
|
||||
export HERDING_MODE=rl
|
||||
export HERDING_POLICY_DIR=$(pwd)/training/runs/baseline/best
|
||||
webots worlds/field.wbt
|
||||
```
|
||||
|
||||
Or set them via Webots' controller args / a `.wbproj` if you prefer.
|
||||
|
||||
3. To force the Strömbom baseline (same world, same controller):
|
||||
|
||||
```bash
|
||||
export HERDING_MODE=strombom
|
||||
webots worlds/field.wbt
|
||||
```
|
||||
|
||||
If `HERDING_MODE=rl` but the policy can't be loaded (SB3 not installed,
|
||||
zip missing, etc.), the controller logs the error and falls back to
|
||||
Strömbom automatically.
|
||||
|
||||
## Curriculum knobs
|
||||
|
||||
The default schedule in `configs/ppo_default.yaml` widens
|
||||
`max_n_sheep` over training. Each reset samples `n_sheep ~ U[1,
|
||||
max_n_sheep]`, so the final policy has seen every flock size from 1 to
|
||||
10 in proportion. To pin a specific size, instantiate the env with
|
||||
`HerdingEnv(n_sheep=N)` (see `eval.py`).
|
||||
|
||||
## Reward shaping
|
||||
|
||||
Weights live in class attributes on `HerdingEnv`. Tune from the 1-sheep
|
||||
curriculum first — if the dog can't herd a single sheep cleanly, raising
|
||||
`W_PROGRESS` or lowering `W_TIME` is usually the fix. For multi-sheep
|
||||
collapse modes (dog spins between sheep), increase `W_COMPACT` so
|
||||
tightening the flock pays.
|
||||
@@ -0,0 +1,218 @@
|
||||
"""Behavior cloning of the sequential teacher into an SB3-compatible policy.
|
||||
|
||||
Trains the policy network (mean-action head) of an SB3 ``MlpPolicy`` to
|
||||
mimic the demonstrations collected by ``tools.collect_demos``. The
|
||||
saved zip is loadable via ``PPO.load(...)`` and can be passed to
|
||||
``train_ppo.py --resume`` for fine-tuning.
|
||||
|
||||
Why this works: the teacher (sequential single-target driving) solves
|
||||
n=10 at 80%+ in our env. BC gives the RL a competent starting policy,
|
||||
so PPO doesn't have to discover behavior from scratch — it only has to
|
||||
*refine* the teacher's strategy via the sparse pen reward.
|
||||
|
||||
Usage::
|
||||
|
||||
python -m training.bc_pretrain \\
|
||||
--demos training/demos.npz \\
|
||||
--out training/runs/bc_pretrained
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def build_model(net_arch_pi, net_arch_vf, log_std_init: float):
|
||||
"""Build a fresh SB3 PPO with the same architecture as train_ppo.
|
||||
|
||||
We only need the policy to load weights into; PPO's training-loop
|
||||
plumbing isn't used during BC.
|
||||
"""
|
||||
env = DummyVecEnv([lambda: HerdingEnv()])
|
||||
model = PPO(
|
||||
"MlpPolicy", env,
|
||||
policy_kwargs=dict(
|
||||
net_arch=dict(pi=net_arch_pi, vf=net_arch_vf),
|
||||
log_std_init=log_std_init,
|
||||
),
|
||||
verbose=0,
|
||||
)
|
||||
return model, env
|
||||
|
||||
|
||||
def policy_forward_mean(policy, obs_batch):
|
||||
"""Return the policy's deterministic mean action for a batch.
|
||||
|
||||
SB3's ActorCriticPolicy doesn't expose this directly — it goes
|
||||
through a Distribution wrapper. We replicate the forward path:
|
||||
extract_features → mlp_extractor → action_net.
|
||||
"""
|
||||
features = policy.extract_features(obs_batch)
|
||||
if isinstance(features, tuple):
|
||||
# SB3 ≥ 2.0 sometimes returns (pi_features, vf_features)
|
||||
pi_features = features[0]
|
||||
else:
|
||||
pi_features = features
|
||||
latent_pi, _latent_vf = policy.mlp_extractor(pi_features)
|
||||
return policy.action_net(latent_pi)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--demos", default="training/demos.npz")
|
||||
parser.add_argument("--out", default="training/runs/bc_pretrained")
|
||||
parser.add_argument("--epochs", type=int, default=60)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--lr", type=float, default=1e-3)
|
||||
parser.add_argument("--val-split", type=float, default=0.1)
|
||||
parser.add_argument("--net-arch", default="256,256",
|
||||
help="Comma-separated hidden layer widths.")
|
||||
parser.add_argument("--log-std-init", type=float, default=0.5)
|
||||
parser.add_argument("--cos-weight", type=float, default=1.0,
|
||||
help="Weight on (1 - cosine similarity) loss term. "
|
||||
"MSE alone shrinks policy output toward zero "
|
||||
"(zero-magnitude action minimises mean squared "
|
||||
"error against ±1 targets); cos loss keeps "
|
||||
"the action pointed correctly even at small "
|
||||
"magnitudes.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--device", default="cpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# --- Load demos ---
|
||||
print(f"[bc] loading demos from {args.demos}")
|
||||
data = np.load(args.demos)
|
||||
obs = data["obs"].astype(np.float32)
|
||||
actions = data["actions"].astype(np.float32)
|
||||
meta = data["meta"]
|
||||
print(f"[bc] obs={obs.shape} actions={actions.shape} trajectories={len(meta)}")
|
||||
if obs.size == 0:
|
||||
raise RuntimeError("Empty demo file.")
|
||||
|
||||
# Action sanity check — sequential outputs unit vectors.
|
||||
a_norms = np.linalg.norm(actions, axis=1)
|
||||
print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} "
|
||||
f"min={a_norms.min():.3f} max={a_norms.max():.3f}")
|
||||
|
||||
# --- Train/val split ---
|
||||
n = len(obs)
|
||||
perm = np.random.permutation(n)
|
||||
n_val = int(n * args.val_split)
|
||||
val_idx, train_idx = perm[:n_val], perm[n_val:]
|
||||
print(f"[bc] train={len(train_idx)} val={len(val_idx)}")
|
||||
|
||||
obs_t = torch.from_numpy(obs)
|
||||
act_t = torch.from_numpy(actions)
|
||||
train_loader = DataLoader(
|
||||
TensorDataset(obs_t[train_idx], act_t[train_idx]),
|
||||
batch_size=args.batch_size, shuffle=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
TensorDataset(obs_t[val_idx], act_t[val_idx]),
|
||||
batch_size=args.batch_size, shuffle=False,
|
||||
)
|
||||
|
||||
# --- Build model ---
|
||||
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
|
||||
net_arch_vf = net_arch_pi[:]
|
||||
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init)
|
||||
policy = model.policy.to(args.device)
|
||||
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
|
||||
|
||||
# --- Train ---
|
||||
print(f"[bc] training: epochs={args.epochs} batch={args.batch_size} "
|
||||
f"lr={args.lr} device={args.device}")
|
||||
t_start = time.time()
|
||||
best_val = float("inf")
|
||||
|
||||
def combined_loss(pred, target):
|
||||
mse = nn.functional.mse_loss(pred, target)
|
||||
p_norm = pred.norm(dim=1).clamp_min(1e-6)
|
||||
t_norm = target.norm(dim=1).clamp_min(1e-6)
|
||||
cos_sim = (pred * target).sum(dim=1) / (p_norm * t_norm)
|
||||
cos_loss = (1.0 - cos_sim).mean()
|
||||
return mse + args.cos_weight * cos_loss, mse.item(), cos_sim.mean().item()
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
policy.train()
|
||||
train_loss_total, train_mse_total, train_cos_total, train_count = 0.0, 0.0, 0.0, 0
|
||||
for ob_batch, act_batch in train_loader:
|
||||
ob_batch = ob_batch.to(args.device)
|
||||
act_batch = act_batch.to(args.device)
|
||||
optimizer.zero_grad()
|
||||
mean_action = policy_forward_mean(policy, ob_batch)
|
||||
loss, mse_val, cos_val = combined_loss(mean_action, act_batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
bs = ob_batch.size(0)
|
||||
train_loss_total += loss.item() * bs
|
||||
train_mse_total += mse_val * bs
|
||||
train_cos_total += cos_val * bs
|
||||
train_count += bs
|
||||
train_mse = train_mse_total / max(1, train_count)
|
||||
train_cos = train_cos_total / max(1, train_count)
|
||||
|
||||
policy.eval()
|
||||
val_total, val_count = 0.0, 0
|
||||
cos_sim_total = 0.0
|
||||
with torch.no_grad():
|
||||
for ob_batch, act_batch in val_loader:
|
||||
ob_batch = ob_batch.to(args.device)
|
||||
act_batch = act_batch.to(args.device)
|
||||
mean_action = policy_forward_mean(policy, ob_batch)
|
||||
bs = ob_batch.size(0)
|
||||
val_total += nn.functional.mse_loss(
|
||||
mean_action, act_batch, reduction="sum",
|
||||
).item()
|
||||
# Cosine similarity in action space — useful sanity for
|
||||
# "is the policy pointing the same way as the teacher?".
|
||||
m_norm = mean_action.norm(dim=1).clamp_min(1e-6)
|
||||
a_norm = act_batch.norm(dim=1).clamp_min(1e-6)
|
||||
cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm)
|
||||
cos_sim_total += cos.sum().item()
|
||||
val_count += bs
|
||||
val_mse = val_total / max(1, val_count) / actions.shape[1]
|
||||
cos_sim = cos_sim_total / max(1, val_count)
|
||||
print(f" epoch {epoch+1:>2d}/{args.epochs} "
|
||||
f"train_mse={train_mse:.4f} train_cos={train_cos:+.3f} "
|
||||
f"val_mse={val_mse:.4f} val_cos={cos_sim:+.3f}")
|
||||
if val_mse < best_val:
|
||||
best_val = val_mse
|
||||
|
||||
elapsed = time.time() - t_start
|
||||
print(f"[bc] done in {elapsed:.0f}s best_val_mse={best_val:.4f}")
|
||||
|
||||
# --- Save ---
|
||||
out_dir = Path(args.out)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
model.save(out_dir / "policy.zip")
|
||||
print(f"[bc] saved policy to {out_dir / 'policy.zip'}")
|
||||
print(f"\n[bc] verify with: "
|
||||
f"python -m training.eval --policy {out_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,14 +0,0 @@
|
||||
{
|
||||
"W_PER_SHEEP": 2.0,
|
||||
"W_ALIGN": 0.05,
|
||||
"W_PEN_BONUS": 10.0,
|
||||
"W_COMPLETE": 100.0,
|
||||
"W_STEP_COST": 0.02,
|
||||
"W_COMPACT": 0.0,
|
||||
"W_WALL_TOUCH": 0.0,
|
||||
"WALL_TOUCH_BUFFER": 0.4,
|
||||
"ALIGN_SHAPE": "standoff",
|
||||
"ALIGN_GATED": true,
|
||||
"ENTRY_AWARE": true,
|
||||
"ent_coef": 0.02
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
# PPO hyperparameters for the herding env. Tuned for a 28-D obs / 2-D
|
||||
# continuous action space with 16 parallel envs on GPU. These are SB3
|
||||
# defaults nudged toward longer credit assignment (gamma=0.995) and a
|
||||
# slightly higher entropy bonus to keep exploration alive while curriculum
|
||||
# expands the flock size.
|
||||
|
||||
# --- PPO ---
|
||||
learning_rate: 3.0e-4
|
||||
n_steps: 2048 # rollout length per env before each update
|
||||
batch_size: 256
|
||||
n_epochs: 10
|
||||
gamma: 0.995
|
||||
gae_lambda: 0.95
|
||||
clip_range: 0.2
|
||||
ent_coef: 0.05 # was 0.01 — earlier runs collapsed to ~0 actions
|
||||
vf_coef: 0.5
|
||||
max_grad_norm: 0.5
|
||||
target_kl: null # disable early-stop on KL
|
||||
|
||||
# --- Network ---
|
||||
policy: MlpPolicy
|
||||
net_arch_pi: [128, 128]
|
||||
net_arch_vf: [128, 128]
|
||||
log_std_init: 0.5 # std≈1.6 instead of default 1.0 — more exploration
|
||||
|
||||
# --- Training schedule ---
|
||||
total_timesteps: 10_000_000
|
||||
n_envs: 16
|
||||
checkpoint_freq: 500_000 # in env steps
|
||||
eval_freq: 100_000 # in env steps
|
||||
n_eval_episodes: 20
|
||||
|
||||
# --- Curriculum (max-n_sheep schedule, in env steps) ---
|
||||
# Each entry: at step s, raise the env's max_n_sheep to k. The env samples
|
||||
# uniformly from [1, max_n_sheep] each reset, so this widens the
|
||||
# distribution gradually rather than swapping fixed sizes.
|
||||
#
|
||||
# State-space curriculum: difficulty controls sheep spawn area
|
||||
# (0 = sheep spawn just north of gate, 1 = sheep spawn anywhere in field).
|
||||
# Plus the existing flock-size curriculum.
|
||||
#
|
||||
# The two together let the policy first learn "what penning looks like"
|
||||
# in a regime where random exploration reliably triggers it, then
|
||||
# gradually generalise to the deployment distribution.
|
||||
curriculum:
|
||||
- { step: 0, max_n_sheep: 1, difficulty: 0.0 }
|
||||
- { step: 1_000_000, max_n_sheep: 1, difficulty: 0.3 }
|
||||
- { step: 2_000_000, max_n_sheep: 2, difficulty: 0.5 }
|
||||
- { step: 4_000_000, max_n_sheep: 3, difficulty: 0.8 }
|
||||
- { step: 6_000_000, max_n_sheep: 5, difficulty: 1.0 }
|
||||
- { step: 8_000_000, max_n_sheep: 8, difficulty: 1.0 }
|
||||
- { step: 9_000_000, max_n_sheep: 10, difficulty: 1.0 }
|
||||
Binary file not shown.
@@ -0,0 +1,136 @@
|
||||
"""Evaluate a trained PPO policy (or the Strömbom baseline) on the env.
|
||||
|
||||
Reports success rate and time-to-pen across a fixed seed grid for each
|
||||
flock size 1..MAX_SHEEP. Used to produce the M5 quantitative comparison
|
||||
table mentioned in plan.md.
|
||||
|
||||
Usage::
|
||||
|
||||
python -m training.eval --policy training/runs/latest/best
|
||||
python -m training.eval --policy strombom
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from statistics import mean, stdev
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from herding.geometry import MAX_SHEEP, PEN_ENTRY
|
||||
from herding.strombom import compute_action as strombom_action
|
||||
from herding.sequential import compute_action as sequential_action
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
|
||||
obs, _ = env.reset()
|
||||
success = False
|
||||
for t in range(max_steps):
|
||||
action = predict_fn(env, obs)
|
||||
obs, _r, terminated, truncated, info = env.step(action)
|
||||
if terminated or truncated:
|
||||
success = bool(info.get("is_success", False))
|
||||
return {"success": success, "steps": info.get("steps", t + 1),
|
||||
"n_penned": info.get("n_penned", 0)}
|
||||
return {"success": False, "steps": max_steps, "n_penned": int(env.sheep_penned.sum())}
|
||||
|
||||
|
||||
def make_analytic_predictor(action_fn):
|
||||
def _predict(env, _obs):
|
||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||
for i in range(env.n_sheep)
|
||||
if not env.sheep_penned[i]}
|
||||
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
||||
return np.array([vx, vy], dtype=np.float32)
|
||||
return _predict
|
||||
|
||||
|
||||
# Backwards-compat alias.
|
||||
def make_strombom_predictor():
|
||||
return make_analytic_predictor(strombom_action)
|
||||
|
||||
|
||||
def make_policy_predictor(model, vecnorm):
|
||||
def _predict(_env, obs):
|
||||
if vecnorm is not None:
|
||||
obs_b = vecnorm.normalize_obs(np.asarray(obs, dtype=np.float32).reshape(1, -1))
|
||||
else:
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
action, _ = model.predict(obs_b, deterministic=True)
|
||||
return action[0]
|
||||
return _predict
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--policy", required=True,
|
||||
help="Either 'strombom' or path to an SB3 run directory.")
|
||||
parser.add_argument("--n-seeds", type=int, default=10)
|
||||
parser.add_argument("--max-steps", type=int, default=5000)
|
||||
parser.add_argument("--max-flock", type=int, default=MAX_SHEEP)
|
||||
# 1.0 = deployment distribution (sheep anywhere in field).
|
||||
# Lower values use the training-curriculum spawn band (sheep near gate).
|
||||
parser.add_argument("--difficulty", type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.policy == "strombom":
|
||||
predict = make_analytic_predictor(strombom_action)
|
||||
elif args.policy == "sequential":
|
||||
predict = make_analytic_predictor(sequential_action)
|
||||
else:
|
||||
from stable_baselines3 import PPO
|
||||
run = Path(args.policy)
|
||||
# Resolve to a zip: directory of checkpoints, or a direct zip path.
|
||||
if run.is_file():
|
||||
zip_path = run
|
||||
else:
|
||||
for name in ("best_model.zip", "policy.zip", "final.zip"):
|
||||
if (run / name).exists():
|
||||
zip_path = run / name
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"No checkpoint found in {run} (tried best_model.zip, "
|
||||
f"policy.zip, final.zip)"
|
||||
)
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
vecnorm = None
|
||||
vn_path = run / "vecnormalize.pkl"
|
||||
if not vn_path.exists() and run.parent.name != "best":
|
||||
vn_path = run.parent / "vecnormalize.pkl"
|
||||
if vn_path.exists():
|
||||
import pickle
|
||||
with open(vn_path, "rb") as f:
|
||||
vecnorm = pickle.load(f)
|
||||
vecnorm.training = False
|
||||
vecnorm.norm_reward = False
|
||||
predict = make_policy_predictor(model, vecnorm)
|
||||
|
||||
print(f"{'n_sheep':>8} {'success%':>10} {'mean_steps':>12} {'mean_penned':>12}")
|
||||
print("-" * 46)
|
||||
for n in range(1, args.max_flock + 1):
|
||||
successes, steps, penned = [], [], []
|
||||
for seed in range(args.n_seeds):
|
||||
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
|
||||
difficulty=args.difficulty, seed=seed)
|
||||
r = rollout(env, predict, args.max_steps)
|
||||
successes.append(int(r["success"]))
|
||||
steps.append(r["steps"])
|
||||
penned.append(r["n_penned"])
|
||||
sr = 100.0 * mean(successes)
|
||||
ms = mean(steps)
|
||||
mp = mean(penned)
|
||||
print(f"{n:>8d} {sr:>9.1f}% {ms:>12.0f} {mp:>12.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+355
-707
File diff suppressed because it is too large
Load Diff
+75
-297
@@ -1,318 +1,96 @@
|
||||
"""
|
||||
Parity test: verify 2D training env matches Webots controller implementations.
|
||||
"""Parity smoke-test for the herding env.
|
||||
|
||||
Tests:
|
||||
1. Observation building: HerdingEnv._obs() vs shepherd_dog_rl.build_obs()
|
||||
2. Dog drive: HerdingEnv._step_dog_substep() vs shepherd_dog_rl.drive() math
|
||||
3. Sheep drive: HerdingEnv._sheep_drive() vs sheep.py drive() math
|
||||
Verifies (a) all imports resolve, (b) the env's reset/step contract is
|
||||
correct, (c) deterministic seeds give deterministic trajectories, and
|
||||
(d) the Strömbom baseline can drive the env without crashing.
|
||||
|
||||
Run::
|
||||
|
||||
python -m training.parity_test
|
||||
"""
|
||||
|
||||
import sys
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import math
|
||||
import sys
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Make imports work from project root
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "controllers", "shepherd_dog_rl"))
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
|
||||
# Re-implement the Webots functions standalone (no Webots dependency)
|
||||
FIELD = 15.0
|
||||
PEN_CENTER = np.array([11.5, -11.5], dtype=np.float32)
|
||||
PEN_ENTRY = np.array([11.5, -8.0], dtype=np.float32)
|
||||
PEN_X = (10.0, 13.0)
|
||||
PEN_Y = (-15.0, -8.0)
|
||||
ENTRY_AWARE = True
|
||||
from herding.geometry import MAX_SHEEP, PEN_ENTRY
|
||||
from herding.obs import OBS_DIM
|
||||
from herding.strombom import compute_action
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def webots_build_obs(dog_pos, sheep_positions, n_sheep, dog_heading):
|
||||
"""Standalone version of shepherd_dog_rl.py build_obs()."""
|
||||
D = 2 * FIELD
|
||||
active_pos = np.array(
|
||||
[p for p in sheep_positions
|
||||
if not (PEN_X[0] < p[0] < PEN_X[1] and PEN_Y[0] < p[1] < PEN_Y[1])],
|
||||
dtype=np.float32
|
||||
)
|
||||
n_active = len(active_pos)
|
||||
if n_active > 0:
|
||||
com = active_pos.mean(axis=0)
|
||||
d_from_com = np.linalg.norm(active_pos - com, axis=1)
|
||||
sorted_idx = np.argsort(d_from_com)[::-1]
|
||||
radius = float(d_from_com[sorted_idx[0]])
|
||||
def nth(n):
|
||||
return active_pos[sorted_idx[n]] if len(sorted_idx) > n else com
|
||||
far1, far2, far3 = nth(0), nth(1), nth(2)
|
||||
else:
|
||||
com = PEN_CENTER.copy()
|
||||
radius = 0.0
|
||||
far1 = far2 = far3 = PEN_CENTER.copy()
|
||||
frac_active = n_active / max(n_sheep, 1)
|
||||
pen_ref = PEN_ENTRY if ENTRY_AWARE else PEN_CENTER
|
||||
return np.array([
|
||||
dog_pos[0] / FIELD, dog_pos[1] / FIELD,
|
||||
(com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D,
|
||||
(far1[0] - com[0]) / D, (far1[1] - com[1]) / D,
|
||||
(far2[0] - com[0]) / D, (far2[1] - com[1]) / D,
|
||||
(far3[0] - com[0]) / D, (far3[1] - com[1]) / D,
|
||||
(pen_ref[0] - com[0]) / D, (pen_ref[1] - com[1]) / D,
|
||||
(pen_ref[0] - far1[0]) / D, (pen_ref[1] - far1[1]) / D,
|
||||
radius / D,
|
||||
frac_active,
|
||||
math.cos(dog_heading), math.sin(dog_heading),
|
||||
], dtype=np.float32)
|
||||
def test_obs_action_shapes():
|
||||
env = HerdingEnv(n_sheep=3, seed=0)
|
||||
obs, info = env.reset()
|
||||
assert obs.shape == (OBS_DIM,), obs.shape
|
||||
assert obs.dtype == np.float32
|
||||
obs2, r, term, trunc, info = env.step(np.array([0.5, 0.0], dtype=np.float32))
|
||||
assert obs2.shape == (OBS_DIM,)
|
||||
assert isinstance(r, float)
|
||||
assert isinstance(term, bool) and isinstance(trunc, bool)
|
||||
print("[ok] shapes")
|
||||
|
||||
|
||||
def webots_dog_drive(heading, speed_ms, wheel_r=0.038, k_turn=4.0,
|
||||
motor_max=70.0, axle_track=0.28):
|
||||
"""Standalone version of shepherd_dog_rl.py drive() kinematics.
|
||||
def test_reset_determinism():
|
||||
"""Reset with the same seed should give the same initial observation.
|
||||
|
||||
Returns (v_linear, omega, left_w, right_w).
|
||||
We don't require step-determinism — PPO doesn't need it, and chasing
|
||||
bit-exactness through the flocking jitter isn't worth the complexity.
|
||||
"""
|
||||
err = math.atan2(math.sin(heading), math.cos(heading))
|
||||
fwd_ms = speed_ms * max(0.0, math.cos(err))
|
||||
fwd_rad = fwd_ms / wheel_r
|
||||
turn = k_turn * err
|
||||
l = max(-motor_max, min(motor_max, fwd_rad - turn))
|
||||
r = max(-motor_max, min(motor_max, fwd_rad + turn))
|
||||
v = wheel_r * 0.5 * (r + l)
|
||||
w = (wheel_r / axle_track) * (r - l)
|
||||
return v, w, l, r
|
||||
env_a = HerdingEnv(n_sheep=3, seed=42)
|
||||
env_b = HerdingEnv(n_sheep=3, seed=42)
|
||||
obs_a, _ = env_a.reset(seed=42)
|
||||
obs_b, _ = env_b.reset(seed=42)
|
||||
assert np.allclose(obs_a, obs_b), "Reset is non-deterministic for same seed"
|
||||
print("[ok] reset determinism")
|
||||
|
||||
|
||||
def webots_sheep_drive(heading, speed_rad, wheel_r=0.031, k_turn=4.0,
|
||||
motor_max=22.0, axle_track=0.20):
|
||||
"""Standalone version of sheep.py drive() kinematics."""
|
||||
err = math.atan2(math.sin(heading), math.cos(heading))
|
||||
fwd = speed_rad * max(0.0, math.cos(err))
|
||||
k = 4.0
|
||||
l = max(-motor_max, min(motor_max, fwd - k * err))
|
||||
r = max(-motor_max, min(motor_max, fwd + k * err))
|
||||
v = wheel_r * 0.5 * (r + l)
|
||||
w = (wheel_r / axle_track) * (r - l)
|
||||
return v, w, l, r
|
||||
def test_curriculum_n_sheep_varies():
|
||||
env = HerdingEnv(seed=0)
|
||||
sizes = set()
|
||||
for _ in range(40):
|
||||
_, info = env.reset()
|
||||
sizes.add(info["n_sheep"])
|
||||
assert 1 in sizes
|
||||
assert max(sizes) <= MAX_SHEEP
|
||||
print(f"[ok] curriculum sampling — saw n_sheep in {sorted(sizes)}")
|
||||
|
||||
|
||||
def test_obs_parity():
|
||||
"""Test that build_obs matches between 2D env and Webots controller."""
|
||||
print("=== Test 1: Observation Parity ===")
|
||||
env = HerdingEnv(n_sheep=3)
|
||||
# Set ENTRY_AWARE to match our webots constant
|
||||
env.ENTRY_AWARE = ENTRY_AWARE
|
||||
env.reset(seed=42)
|
||||
|
||||
# Manually set positions for a controlled test
|
||||
env.dog_pos = np.array([5.0, 3.0], dtype=np.float32)
|
||||
env.dog_heading = 1.2
|
||||
env.sheep_pos[0] = np.array([0.0, 0.0], dtype=np.float32)
|
||||
env.sheep_pos[1] = np.array([2.0, -1.0], dtype=np.float32)
|
||||
env.sheep_pos[2] = np.array([11.5, -11.5], dtype=np.float32) # penned
|
||||
env.penned[0] = False
|
||||
env.penned[1] = False
|
||||
env.penned[2] = True
|
||||
|
||||
obs_2d = env._obs()
|
||||
|
||||
# Build equivalent Webots observation
|
||||
sheep_positions = [
|
||||
env.sheep_pos[0].tolist(),
|
||||
env.sheep_pos[1].tolist(),
|
||||
env.sheep_pos[2].tolist(),
|
||||
]
|
||||
obs_webots = webots_build_obs(
|
||||
env.dog_pos, sheep_positions, 3, env.dog_heading
|
||||
)
|
||||
|
||||
max_diff = float(np.max(np.abs(obs_2d - obs_webots)))
|
||||
print(f" Max element-wise diff: {max_diff:.2e}")
|
||||
if max_diff < 1e-6:
|
||||
print(" PASS: Observations match")
|
||||
else:
|
||||
print(" FAIL: Observations differ!")
|
||||
for i in range(18):
|
||||
if abs(obs_2d[i] - obs_webots[i]) > 1e-6:
|
||||
print(f" dim {i}: 2d={obs_2d[i]:.6f} webots={obs_webots[i]:.6f}")
|
||||
return max_diff < 1e-6
|
||||
def test_strombom_drives_env():
|
||||
"""Quick functional check that the analytic baseline can play the env
|
||||
without exploding. Not a success-rate test — just no errors / NaNs."""
|
||||
env = HerdingEnv(n_sheep=2, max_steps=400, seed=1)
|
||||
obs, _ = env.reset()
|
||||
for t in range(400):
|
||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||
for i in range(env.n_sheep)
|
||||
if not env.sheep_penned[i]}
|
||||
if not positions:
|
||||
break
|
||||
vx, vy, _mode = compute_action((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
||||
obs, r, term, trunc, info = env.step(np.array([vx, vy], dtype=np.float32))
|
||||
assert np.isfinite(obs).all(), f"NaN/Inf in obs at step {t}"
|
||||
assert np.isfinite(r), f"NaN reward at step {t}"
|
||||
if term or trunc:
|
||||
break
|
||||
print(f"[ok] strombom rollout — final n_penned={int(env.sheep_penned.sum())}/{env.n_sheep} after {env.steps} steps")
|
||||
|
||||
|
||||
def test_dog_drive_parity():
|
||||
"""Test that dog diff-drive matches Webots controller."""
|
||||
print("\n=== Test 2: Dog Drive Parity ===")
|
||||
env = HerdingEnv(n_sheep=1)
|
||||
env.reset(seed=42)
|
||||
|
||||
all_pass = True
|
||||
test_cases = [
|
||||
# (heading_error, speed_ms) — target_heading relative to current heading
|
||||
(0.0, 2.5), # aligned, full speed
|
||||
(0.5, 2.5), # 30deg error
|
||||
(1.5, 2.5), # ~86deg error
|
||||
(3.14, 2.5), # ~180deg error — should spin in place
|
||||
(0.0, 0.5), # aligned, slow
|
||||
(0.3, 1.0), # small error, medium speed
|
||||
]
|
||||
|
||||
for heading_err, speed_ms in test_cases:
|
||||
env.dog_heading = 0.0
|
||||
target_heading = heading_err
|
||||
action = np.array([
|
||||
math.cos(target_heading), math.sin(target_heading)
|
||||
], dtype=np.float32) * (speed_ms / env.DOG_SPEED)
|
||||
|
||||
# 2D env step
|
||||
dbg = env._step_dog_substep(action, 0.016)
|
||||
v_2d = dbg["v"]
|
||||
w_2d = dbg["w"]
|
||||
l_2d = dbg["left_w"]
|
||||
r_2d = dbg["right_w"]
|
||||
|
||||
# Webots equivalent
|
||||
v_w, w_w, l_w, r_w = webots_dog_drive(heading_err, speed_ms)
|
||||
|
||||
diffs = {
|
||||
"v": abs(v_2d - v_w),
|
||||
"w": abs(w_2d - w_w),
|
||||
"left": abs(l_2d - l_w),
|
||||
"right": abs(r_2d - r_w),
|
||||
}
|
||||
max_diff = max(diffs.values())
|
||||
ok = max_diff < 1e-6
|
||||
status = "PASS" if ok else "FAIL"
|
||||
print(f" err={heading_err:.2f} spd={speed_ms:.1f}: {status} (max_diff={max_diff:.2e})")
|
||||
if not ok:
|
||||
for k, d in diffs.items():
|
||||
if d > 1e-6:
|
||||
print(f" {k}: 2d={eval(k+'_2d'):.6f} webots={eval(k+'_w'):.6f}")
|
||||
all_pass = False
|
||||
|
||||
return all_pass
|
||||
|
||||
|
||||
def test_sheep_drive_parity():
|
||||
"""Test that sheep diff-drive matches Webots sheep controller."""
|
||||
print("\n=== Test 3: Sheep Drive Parity ===")
|
||||
env = HerdingEnv(n_sheep=1)
|
||||
env.reset(seed=42)
|
||||
|
||||
all_pass = True
|
||||
test_cases = [
|
||||
# (heading_error, speed_rad)
|
||||
(0.0, 20.0), # aligned, flee speed
|
||||
(0.0, 3.0), # aligned, wander speed
|
||||
(0.5, 15.0), # moderate error
|
||||
(1.57, 10.0), # 90deg — should spin in place
|
||||
(3.14, 20.0), # 180deg — should spin in place fast
|
||||
(0.2, 8.0), # small error, medium speed
|
||||
]
|
||||
|
||||
for heading_err, speed_rad in test_cases:
|
||||
env.sheep_heading[0] = 0.0
|
||||
env.sheep_pos[0] = np.array([0.0, 0.0], dtype=np.float32)
|
||||
target_heading = heading_err
|
||||
|
||||
# 2D env
|
||||
new_pos = env._sheep_drive(0, target_heading, speed_rad, 0.016)
|
||||
v_2d_raw = float(np.linalg.norm(new_pos - np.array([0.0, 0.0]))) / 0.016
|
||||
# Re-derive v, w from the internal state
|
||||
heading_2d = env.sheep_heading[0]
|
||||
|
||||
# Webots equivalent
|
||||
v_w, w_w, l_w, r_w = webots_sheep_drive(heading_err, speed_rad)
|
||||
|
||||
# For 2D, compute the same intermediate values
|
||||
err_2d = (target_heading - 0.0 + np.pi) % (2 * np.pi) - np.pi
|
||||
fwd_2d = speed_rad * max(0.0, math.cos(err_2d))
|
||||
turn_2d = 4.0 * err_2d
|
||||
l_2d = max(-22.0, min(22.0, fwd_2d - turn_2d))
|
||||
r_2d = max(-22.0, min(22.0, fwd_2d + turn_2d))
|
||||
|
||||
diffs = {
|
||||
"left": abs(l_2d - l_w),
|
||||
"right": abs(r_2d - r_w),
|
||||
}
|
||||
max_diff = max(diffs.values())
|
||||
ok = max_diff < 1e-6
|
||||
status = "PASS" if ok else "FAIL"
|
||||
print(f" err={heading_err:.2f} spd={speed_rad:.1f}: {status} (max_diff={max_diff:.2e})")
|
||||
if not ok:
|
||||
for k, d in diffs.items():
|
||||
if d > 1e-6:
|
||||
print(f" {k}: 2d={l_2d if k=='left' else r_2d:.6f} webots={l_w if k=='left' else r_w:.6f}")
|
||||
all_pass = False
|
||||
|
||||
return all_pass
|
||||
|
||||
|
||||
def test_full_trajectory_parity():
|
||||
"""Test that running identical actions produces matching trajectories."""
|
||||
print("\n=== Test 4: Full Trajectory Parity (dog only) ===")
|
||||
# Run 50 steps with a fixed action, compare dog heading/position
|
||||
# at each step between 2D env kinematics and pure Webots kinematics.
|
||||
env = HerdingEnv(n_sheep=1)
|
||||
env.reset(seed=42)
|
||||
env.dog_pos = np.array([0.0, 0.0], dtype=np.float32)
|
||||
env.dog_heading = 0.0
|
||||
env.ENTRY_AWARE = ENTRY_AWARE
|
||||
|
||||
action = np.array([0.8, -0.6], dtype=np.float32) # magnitude 1.0
|
||||
dt = 0.016667 # sub_dt
|
||||
|
||||
# Webots-side tracking
|
||||
wb_heading = 0.0
|
||||
wb_x, wb_y = 0.0, 0.0
|
||||
|
||||
max_heading_diff = 0.0
|
||||
max_pos_diff = 0.0
|
||||
|
||||
for step in range(50):
|
||||
# 2D env sub-step
|
||||
env._step_dog_substep(action, dt)
|
||||
|
||||
# Webots-side computation
|
||||
speed_ms = 1.0 * 2.5
|
||||
target_heading = math.atan2(-0.6, 0.8)
|
||||
err = math.atan2(math.sin(target_heading - wb_heading),
|
||||
math.cos(target_heading - wb_heading))
|
||||
fwd_ms = speed_ms * max(0.0, math.cos(err))
|
||||
fwd_rad = fwd_ms / 0.038
|
||||
turn = 4.0 * err
|
||||
l = max(-70.0, min(70.0, fwd_rad - turn))
|
||||
r = max(-70.0, min(70.0, fwd_rad + turn))
|
||||
v = 0.038 * 0.5 * (r + l)
|
||||
w = (0.038 / 0.28) * (r - l)
|
||||
wb_heading = math.atan2(math.sin(wb_heading + w * dt),
|
||||
math.cos(wb_heading + w * dt))
|
||||
wb_x += math.cos(wb_heading) * v * dt
|
||||
wb_y += math.sin(wb_heading) * v * dt
|
||||
|
||||
heading_diff = abs(env.dog_heading - wb_heading)
|
||||
pos_diff = math.hypot(env.dog_pos[0] - wb_x, env.dog_pos[1] - wb_y)
|
||||
max_heading_diff = max(max_heading_diff, heading_diff)
|
||||
max_pos_diff = max(max_pos_diff, pos_diff)
|
||||
|
||||
print(f" Max heading diff over 50 steps: {max_heading_diff:.2e} rad")
|
||||
print(f" Max position diff over 50 steps: {max_pos_diff:.2e} m")
|
||||
ok = max_pos_diff < 1e-4
|
||||
print(f" {'PASS' if ok else 'FAIL'}: Trajectories match")
|
||||
return ok
|
||||
def main():
|
||||
test_obs_action_shapes()
|
||||
test_reset_determinism()
|
||||
test_curriculum_n_sheep_varies()
|
||||
test_strombom_drives_env()
|
||||
print("\nAll parity checks passed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = []
|
||||
results.append(("Obs parity", test_obs_parity()))
|
||||
results.append(("Dog drive parity", test_dog_drive_parity()))
|
||||
results.append(("Sheep drive parity", test_sheep_drive_parity()))
|
||||
results.append(("Trajectory parity", test_full_trajectory_parity()))
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("RESULTS")
|
||||
print("=" * 50)
|
||||
all_pass = True
|
||||
for name, passed in results:
|
||||
print(f" {name}: {'PASS' if passed else 'FAIL'}")
|
||||
if not passed:
|
||||
all_pass = False
|
||||
print(f"\nOverall: {'ALL PASS' if all_pass else 'SOME FAILURES'}")
|
||||
env.close()
|
||||
main()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
gymnasium>=0.29
|
||||
stable-baselines3>=2.3
|
||||
torch>=2.2
|
||||
numpy>=1.26
|
||||
matplotlib>=3.8
|
||||
tensorboard>=2.16
|
||||
# Pin major versions; SB3 2.x requires gymnasium and torch >= 1.13.
|
||||
gymnasium>=0.29,<2.0
|
||||
stable-baselines3[extra]>=2.3,<3.0
|
||||
torch>=2.1
|
||||
numpy>=1.24
|
||||
pyyaml>=6.0
|
||||
tensorboard>=2.14
|
||||
tqdm>=4.66
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
|
||||
@@ -1,392 +0,0 @@
|
||||
"""
|
||||
PPO training for the herding task with curriculum learning.
|
||||
|
||||
Trains from scratch through a 1→max_sheep curriculum, evaluates after each
|
||||
stage, and auto-generates trajectory/timeseries plots plus a summary chart.
|
||||
|
||||
Usage
|
||||
-----
|
||||
python train.py # defaults from config.json
|
||||
python train.py --config my_config.json --max-sheep 5
|
||||
python train.py --max-sheep 3 --steps-per-stage 1000000
|
||||
|
||||
Outputs (in runs/<timestamp>/):
|
||||
config.json resolved config
|
||||
final_model.zip trained PPO model
|
||||
vecnorm.pkl VecNormalize statistics
|
||||
stage_results.json per-stage evaluation metrics
|
||||
success_rate.png summary bar chart
|
||||
eval/ trajectory & timeseries plots per sheep count
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.vec_env import (
|
||||
DummyVecEnv,
|
||||
SubprocVecEnv,
|
||||
VecNormalize,
|
||||
)
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
from viz import (
|
||||
run_and_record,
|
||||
plot_trajectory,
|
||||
plot_timeseries,
|
||||
plot_success_rate,
|
||||
save_episode_gif,
|
||||
)
|
||||
|
||||
|
||||
# ── Callbacks ────────────────────────────────────────────────────────────────
|
||||
|
||||
class ProgressCallback(BaseCallback):
|
||||
"""One-line progress summary every `freq` env steps."""
|
||||
|
||||
def __init__(self, stage_label: str, freq: int = 100_000):
|
||||
super().__init__()
|
||||
self.stage_label = stage_label
|
||||
self.freq = freq
|
||||
self._last = 0
|
||||
self._ep_returns = []
|
||||
self._ep_success = []
|
||||
self._total_eps = 0
|
||||
self._total_success = 0
|
||||
self._cur_ret = None
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
rewards = self.locals.get("rewards")
|
||||
dones = self.locals.get("dones")
|
||||
infos = self.locals.get("infos", [])
|
||||
if rewards is None or dones is None:
|
||||
return True
|
||||
if self._cur_ret is None or len(self._cur_ret) != len(rewards):
|
||||
self._cur_ret = np.zeros(len(rewards), dtype=np.float64)
|
||||
self._cur_ret += np.asarray(rewards, dtype=np.float64)
|
||||
for i, d in enumerate(dones):
|
||||
if not d:
|
||||
continue
|
||||
self._ep_returns.append(float(self._cur_ret[i]))
|
||||
info = infos[i] if i < len(infos) else {}
|
||||
success = int(info.get("n_penned", 0) == info.get("n_sheep", -1))
|
||||
self._ep_success.append(success)
|
||||
self._total_eps += 1
|
||||
self._total_success += success
|
||||
self._cur_ret[i] = 0.0
|
||||
if len(self._ep_returns) > 50:
|
||||
self._ep_returns.pop(0)
|
||||
self._ep_success.pop(0)
|
||||
if self.num_timesteps - self._last >= self.freq:
|
||||
self._last = self.num_timesteps
|
||||
n = len(self._ep_returns)
|
||||
mean_r = float(np.mean(self._ep_returns)) if n else float("nan")
|
||||
win_sr = float(np.mean(self._ep_success)) if n else float("nan")
|
||||
cum_sr = (self._total_success / self._total_eps
|
||||
if self._total_eps else float("nan"))
|
||||
print(f" ... [{self.stage_label} | "
|
||||
f"{self.num_timesteps:>7,} steps | "
|
||||
f"ret(last {n})={mean_r:+.2f} "
|
||||
f"win_sr={win_sr*100:.0f}% cum_sr={cum_sr*100:.0f}%]",
|
||||
flush=True)
|
||||
return True
|
||||
|
||||
|
||||
# ── Environment factory ──────────────────────────────────────────────────────
|
||||
|
||||
def make_env(n_sheep, seed, max_steps, reward_cfg=None):
|
||||
def _init():
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||
reward_cfg=reward_cfg)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
|
||||
# ── Failure-mode classification ──────────────────────────────────────────────
|
||||
|
||||
COMPACT_RADIUS = 5.0
|
||||
|
||||
|
||||
def _classify(ep_radii, ep_com_dists, n_penned, n_sheep):
|
||||
if n_penned == n_sheep:
|
||||
return "SUCCESS"
|
||||
if min(ep_radii) > COMPACT_RADIUS:
|
||||
return "NEVER_COMPACT"
|
||||
first = next(i for i, r in enumerate(ep_radii) if r <= COMPACT_RADIUS)
|
||||
if min(ep_com_dists[first:]) > 3.0:
|
||||
return "COMPACT_CANT_DRIVE"
|
||||
if n_penned == 0:
|
||||
return "DROVE_NO_SHEEP"
|
||||
return f"PARTIAL_{n_penned}of{n_sheep}"
|
||||
|
||||
|
||||
# ── Evaluation ───────────────────────────────────────────────────────────────
|
||||
|
||||
def evaluate(model, vn_template, n_sheep, n_episodes, max_steps,
|
||||
reward_cfg=None):
|
||||
"""Evaluate at a given sheep count; returns metrics dict."""
|
||||
raw = DummyVecEnv([make_env(n_sheep, 9999, max_steps, reward_cfg)])
|
||||
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
vn.obs_rms = deepcopy(vn_template.obs_rms)
|
||||
vn.ret_rms = deepcopy(vn_template.ret_rms)
|
||||
|
||||
successes = 0
|
||||
ep_lens = []
|
||||
min_pen_list = []
|
||||
action_mags = []
|
||||
failure_counts = {}
|
||||
rc_sums = {}
|
||||
rc_n = 0
|
||||
|
||||
for _ in range(n_episodes):
|
||||
obs = vn.reset()
|
||||
done = False
|
||||
steps = 0
|
||||
min_pen = float("inf")
|
||||
mags = []
|
||||
ep_radii = []
|
||||
ep_com_dists = []
|
||||
while not done:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, _, dones, infos = vn.step(action)
|
||||
done = dones[0]
|
||||
inner = vn.envs[0]
|
||||
com, radius, _ = inner._flock_stats()
|
||||
min_pen = min(min_pen, float(np.linalg.norm(com - inner.PEN_CENTER)))
|
||||
mags.append(float(np.linalg.norm(action[0])))
|
||||
ep_radii.append(radius)
|
||||
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
|
||||
steps += 1
|
||||
rc = infos[0].get("rcomps")
|
||||
if rc:
|
||||
for k, v in rc.items():
|
||||
rc_sums[k] = rc_sums.get(k, 0.0) + v
|
||||
rc_n += 1
|
||||
n_penned = infos[0].get("n_penned", 0)
|
||||
success = n_penned == n_sheep
|
||||
successes += int(success)
|
||||
ep_lens.append(steps)
|
||||
min_pen_list.append(min_pen)
|
||||
action_mags.extend(mags)
|
||||
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
|
||||
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
||||
|
||||
vn.close()
|
||||
|
||||
result = {
|
||||
"sr": successes / n_episodes,
|
||||
"mean_len": float(np.mean(ep_lens)),
|
||||
"mean_min_pen": float(np.mean(min_pen_list)),
|
||||
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
|
||||
"failure_modes": failure_counts,
|
||||
}
|
||||
if rc_n > 0:
|
||||
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
|
||||
return result
|
||||
|
||||
|
||||
|
||||
# ── CLI ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"W_PER_SHEEP": 2.0,
|
||||
"W_ALIGN": 0.05,
|
||||
"W_PEN_BONUS": 10.0,
|
||||
"W_COMPLETE": 100.0,
|
||||
"W_STEP_COST": 0.02,
|
||||
"W_SOUTH": 0.01,
|
||||
"W_COMPACT": 0.0,
|
||||
"W_WALL_TOUCH": 0.04,
|
||||
"WALL_TOUCH_BUFFER": 0.3,
|
||||
"ALIGN_SHAPE": "standoff",
|
||||
"ALIGN_GATED": True,
|
||||
"ENTRY_AWARE": True,
|
||||
"ent_coef": 0.02,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(
|
||||
description="PPO training for herding task with curriculum learning")
|
||||
p.add_argument("--config", type=str, default=None,
|
||||
help="JSON config file (reward weights + ent_coef)")
|
||||
p.add_argument("--max-sheep", type=int, default=10)
|
||||
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
|
||||
p.add_argument("--n-envs", type=int, default=8)
|
||||
p.add_argument("--max-steps", type=int, default=2500)
|
||||
p.add_argument("--eval-episodes", type=int, default=30)
|
||||
p.add_argument("--run-dir", type=str, default=None)
|
||||
p.add_argument("--no-gif", action="store_true",
|
||||
help="Skip per-stage GIF rendering (PNGs still produced).")
|
||||
p.add_argument("--gif-fps", type=int, default=20)
|
||||
p.add_argument("--gif-skip", type=int, default=3,
|
||||
help="Keep every Nth frame (smaller GIF; default 3).")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load config: --config overrides, else auto-load config.json if present
|
||||
cfg = dict(DEFAULT_CONFIG)
|
||||
config_path = args.config
|
||||
if config_path is None and os.path.exists("config.json"):
|
||||
config_path = "config.json"
|
||||
if config_path:
|
||||
with open(config_path) as f:
|
||||
cfg.update(json.load(f))
|
||||
print(f"Config loaded from {config_path}")
|
||||
|
||||
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
|
||||
|
||||
# Run directory
|
||||
run_dir = args.run_dir or os.path.join(
|
||||
"runs", time.strftime("%Y%m%d_%H%M%S"))
|
||||
eval_dir = os.path.join(run_dir, "eval")
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
with open(os.path.join(run_dir, "config.json"), "w") as f:
|
||||
json.dump(cfg, f, indent=2)
|
||||
|
||||
print(f"Config: {cfg}")
|
||||
print(f"Run dir: {run_dir}")
|
||||
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
|
||||
f"{args.steps_per_stage:,} steps/stage\n")
|
||||
|
||||
# Training envs
|
||||
train_env = SubprocVecEnv([
|
||||
make_env(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
|
||||
for i in range(args.n_envs)
|
||||
])
|
||||
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True,
|
||||
clip_obs=10.0)
|
||||
|
||||
# Model — force CPU (PPO with MLP runs faster on CPU than GPU; SB3 warns
|
||||
# about this otherwise).
|
||||
model = PPO(
|
||||
"MlpPolicy", vn,
|
||||
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
|
||||
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
|
||||
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
|
||||
policy_kwargs=dict(net_arch=[256, 256]),
|
||||
device="cpu",
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
# Curriculum training
|
||||
stage_results = []
|
||||
t0 = time.time()
|
||||
|
||||
try:
|
||||
for n in range(1, args.max_sheep + 1):
|
||||
if n == 1:
|
||||
print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=args.steps_per_stage,
|
||||
reset_num_timesteps=True,
|
||||
callback=ProgressCallback("1 sheep", freq=100_000),
|
||||
)
|
||||
else:
|
||||
# Mixed transition: half envs stay at n-1, half advance to n,
|
||||
# for the first half of the stage budget. This prevents the
|
||||
# n+1 task's noisy early gradients from destroying the n policy
|
||||
# (catastrophic forgetting) before it has a chance to adapt.
|
||||
half = max(1, args.n_envs // 2)
|
||||
for i in range(half):
|
||||
vn.env_method("set_n_sheep", n - 1, indices=[i])
|
||||
for i in range(half, args.n_envs):
|
||||
vn.env_method("set_n_sheep", n, indices=[i])
|
||||
mix_steps = args.steps_per_stage // 2
|
||||
full_steps = args.steps_per_stage - mix_steps
|
||||
print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) "
|
||||
f"{mix_steps:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=mix_steps,
|
||||
reset_num_timesteps=False,
|
||||
callback=ProgressCallback(f"{n-1}→{n} mix", freq=100_000),
|
||||
)
|
||||
vn.env_method("set_n_sheep", n)
|
||||
print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=full_steps,
|
||||
reset_num_timesteps=False,
|
||||
callback=ProgressCallback(f"{n} sheep", freq=100_000),
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
|
||||
r = evaluate(model, vn, n, args.eval_episodes, args.max_steps, rcfg)
|
||||
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
|
||||
f"mean_len={r['mean_len']:.0f} "
|
||||
f"mean_min_pen={r['mean_min_pen']:.1f}m "
|
||||
f"mean_act={r['mean_act']:.2f}")
|
||||
|
||||
# Failure-mode breakdown
|
||||
if r["failure_modes"]:
|
||||
modes = " ".join(
|
||||
f"{k}={v}" for k, v in sorted(
|
||||
r["failure_modes"].items(), key=lambda x: -x[1]))
|
||||
print(f" failure modes: {modes}")
|
||||
|
||||
# Reward breakdown
|
||||
if "reward_per_step" in r:
|
||||
rps = r["reward_per_step"]
|
||||
print(f" reward/step: " + " ".join(
|
||||
f"{k}={v:+.4f}" for k, v in rps.items()))
|
||||
|
||||
# Episode visualisation: trajectory + timeseries + animated GIF
|
||||
hist = run_and_record(model, vn, n, args.max_steps, rcfg,
|
||||
seed=1000 + n)
|
||||
tag = "success" if hist["success"] else "fail"
|
||||
plot_trajectory(
|
||||
hist,
|
||||
os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
|
||||
plot_timeseries(
|
||||
hist,
|
||||
os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
|
||||
if not args.no_gif:
|
||||
save_episode_gif(
|
||||
hist,
|
||||
os.path.join(eval_dir, f"ep_{n}s_{tag}.gif"),
|
||||
fps=args.gif_fps, skip=args.gif_skip)
|
||||
|
||||
r["n_sheep"] = n
|
||||
stage_results.append(r)
|
||||
|
||||
# Save artefacts
|
||||
model.save(os.path.join(run_dir, "final_model"))
|
||||
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
|
||||
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
|
||||
json.dump(stage_results, f, indent=2)
|
||||
|
||||
finally:
|
||||
try:
|
||||
vn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Summary
|
||||
elapsed = (time.time() - t0) / 60
|
||||
print("\n" + "=" * 70)
|
||||
print(" TRAINING SUMMARY")
|
||||
print("=" * 70)
|
||||
for r in stage_results:
|
||||
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
|
||||
f"len={r['mean_len']:>5.0f} min_pen={r['mean_min_pen']:>5.1f}m "
|
||||
f"act={r['mean_act']:.2f}")
|
||||
print(f"\n Total time: {elapsed:.1f} min")
|
||||
print(f" Artefacts: {run_dir}/")
|
||||
|
||||
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
|
||||
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,412 +0,0 @@
|
||||
"""
|
||||
PPO training with attention-based policy (train_at.py).
|
||||
|
||||
Key difference from train.py
|
||||
-----------------------------
|
||||
- Observation exposes ALL sheep as individual per-sheep tokens rather than
|
||||
only the top-3 farthest. The policy therefore has complete flock visibility
|
||||
at any sheep count — no hidden sheep even at n=10.
|
||||
- A TransformerFeaturesExtractor processes the sheep tokens with multi-head
|
||||
self-attention (permutation-invariant), then mean-pools over valid tokens
|
||||
and concatenates the result with global dog/pen features.
|
||||
- Curriculum transition uses the same mixed-env approach as train.py: half
|
||||
the envs stay at n-1 for the first half of each new stage to suppress
|
||||
catastrophic forgetting.
|
||||
|
||||
Observation layout (7 + MAX_SHEEP*6 = 67 dims, fixed)
|
||||
-------------------------------------------------------
|
||||
Global (7):
|
||||
dog_x / FIELD, dog_y / FIELD,
|
||||
cos(heading), sin(heading),
|
||||
(pen_x - dog_x) / D, (pen_y - dog_y) / D,
|
||||
n_active / n_sheep
|
||||
|
||||
Per sheep i (6):
|
||||
(sheep_x - dog_x) / D, (sheep_y - dog_y) / D, ← pos rel to dog
|
||||
(pen_x - sheep_x) / D, (pen_y - sheep_y) / D, ← sheep-to-pen
|
||||
is_active 1.0 if not penned, else 0.0
|
||||
is_valid 1.0 if i < n_sheep, else 0.0 (padding sentinel)
|
||||
|
||||
After VecNormalize, is_valid for real sheep normalises > 0 and for
|
||||
padding tokens < 0 (because mean ∈ (0,1)), so a threshold of 0 cleanly
|
||||
separates real from padded without any extra bookkeeping.
|
||||
|
||||
Usage
|
||||
-----
|
||||
python train_at.py # defaults from config.json
|
||||
python train_at.py --max-sheep 10 --steps-per-stage 2000000
|
||||
python train_at.py --embed-dim 128 --n-heads 4 --n-layers 3
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
from train import ProgressCallback, _classify, COMPACT_RADIUS, DEFAULT_CONFIG
|
||||
from viz import (
|
||||
run_and_record, plot_trajectory, plot_timeseries,
|
||||
plot_success_rate, save_episode_gif,
|
||||
)
|
||||
|
||||
|
||||
# ── Per-sheep token observation environment ───────────────────────────────────
|
||||
|
||||
class HerdingEnvAt(HerdingEnv):
|
||||
"""
|
||||
HerdingEnv with a per-sheep token observation for the attention policy.
|
||||
Everything else (dynamics, reward, curriculum interface) is inherited.
|
||||
"""
|
||||
|
||||
OBS_GLOBAL = 7
|
||||
OBS_SHEEP = 6
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
obs_dim = self.OBS_GLOBAL + self.MAX_SHEEP * self.OBS_SHEEP
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32
|
||||
)
|
||||
|
||||
def _obs(self) -> np.ndarray:
|
||||
S = self.FIELD
|
||||
D = 2.0 * self.FIELD
|
||||
pen_ref = self.PEN_ENTRY if self.ENTRY_AWARE else self.PEN_CENTER
|
||||
active_mask = ~self.penned[:self.n_sheep]
|
||||
n_active = int(active_mask.sum())
|
||||
|
||||
global_feats = np.array([
|
||||
self.dog_pos[0] / S,
|
||||
self.dog_pos[1] / S,
|
||||
float(np.cos(self.dog_heading)),
|
||||
float(np.sin(self.dog_heading)),
|
||||
(pen_ref[0] - self.dog_pos[0]) / D,
|
||||
(pen_ref[1] - self.dog_pos[1]) / D,
|
||||
n_active / max(self.n_sheep, 1),
|
||||
], dtype=np.float32)
|
||||
|
||||
sheep_feats = np.zeros((self.MAX_SHEEP, self.OBS_SHEEP), dtype=np.float32)
|
||||
for i in range(self.n_sheep):
|
||||
pos = self.sheep_pos[i]
|
||||
sheep_feats[i] = [
|
||||
(pos[0] - self.dog_pos[0]) / D,
|
||||
(pos[1] - self.dog_pos[1]) / D,
|
||||
(pen_ref[0] - pos[0]) / D,
|
||||
(pen_ref[1] - pos[1]) / D,
|
||||
float(not self.penned[i]),
|
||||
1.0, # is_valid: this sheep exists
|
||||
]
|
||||
# i >= n_sheep: all zeros, is_valid=0 → masked out in attention
|
||||
|
||||
return np.concatenate([global_feats, sheep_feats.ravel()])
|
||||
|
||||
|
||||
# ── Attention features extractor ──────────────────────────────────────────────
|
||||
|
||||
class ShepherdAttentionExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
Multi-head self-attention over per-sheep tokens, mean-pooled over valid
|
||||
(non-padding) tokens and concatenated with global dog/pen features.
|
||||
|
||||
After VecNormalize:
|
||||
real sheep → is_valid_norm > 0 (normalised from 1.0)
|
||||
padding → is_valid_norm ≤ 0 (normalised from 0.0)
|
||||
so threshold at 0 is always correct regardless of curriculum stage.
|
||||
"""
|
||||
|
||||
GLOBAL_DIM = HerdingEnvAt.OBS_GLOBAL # 7
|
||||
SHEEP_DIM = HerdingEnvAt.OBS_SHEEP # 6
|
||||
MAX_SHEEP = HerdingEnv.MAX_SHEEP # 10
|
||||
VALID_IDX = 5 # index of is_valid within each token
|
||||
|
||||
def __init__(self, observation_space, embed_dim: int = 64,
|
||||
n_heads: int = 4, n_layers: int = 2, ff_dim: int = 128):
|
||||
super().__init__(observation_space,
|
||||
features_dim=self.GLOBAL_DIM + embed_dim)
|
||||
self.sheep_embed = nn.Linear(self.SHEEP_DIM, embed_dim)
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=embed_dim, nhead=n_heads, dim_feedforward=ff_dim,
|
||||
dropout=0.0, batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer,
|
||||
num_layers=n_layers,
|
||||
enable_nested_tensor=False)
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
B = obs.shape[0]
|
||||
global_feats = obs[:, :self.GLOBAL_DIM] # (B, 7)
|
||||
tokens = obs[:, self.GLOBAL_DIM:].view(
|
||||
B, self.MAX_SHEEP, self.SHEEP_DIM) # (B, 10, 6)
|
||||
|
||||
# is_valid after VecNorm: real > 0, padding ≤ 0
|
||||
is_valid_norm = tokens[:, :, self.VALID_IDX] # (B, 10)
|
||||
key_padding_mask = is_valid_norm <= 0.0 # True → ignore
|
||||
|
||||
x = self.sheep_embed(tokens) # (B, 10, E)
|
||||
x = self.transformer(x, src_key_padding_mask=key_padding_mask)
|
||||
|
||||
valid_w = (is_valid_norm > 0.0).float().unsqueeze(-1) # (B, 10, 1)
|
||||
pooled = (x * valid_w).sum(1) / valid_w.sum(1).clamp(min=1.0)
|
||||
|
||||
return torch.cat([global_feats, pooled], dim=1) # (B, 7+E)
|
||||
|
||||
|
||||
# ── Environment factory ───────────────────────────────────────────────────────
|
||||
|
||||
def make_env_at(n_sheep, seed, max_steps, reward_cfg=None):
|
||||
def _init():
|
||||
env = HerdingEnvAt(n_sheep=n_sheep, max_steps=max_steps,
|
||||
reward_cfg=reward_cfg)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
|
||||
# ── Evaluation ────────────────────────────────────────────────────────────────
|
||||
|
||||
def evaluate_at(model, vn_template, n_sheep, n_episodes, max_steps,
|
||||
reward_cfg=None):
|
||||
raw = DummyVecEnv([make_env_at(n_sheep, 9999, max_steps, reward_cfg)])
|
||||
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
vn.obs_rms = deepcopy(vn_template.obs_rms)
|
||||
vn.ret_rms = deepcopy(vn_template.ret_rms)
|
||||
|
||||
successes = 0
|
||||
ep_lens, min_pen_list, action_mags = [], [], []
|
||||
failure_counts, rc_sums = {}, {}
|
||||
rc_n = 0
|
||||
|
||||
for _ in range(n_episodes):
|
||||
obs = vn.reset()
|
||||
done = False
|
||||
steps, min_pen = 0, float("inf")
|
||||
mags, ep_radii, ep_com_dists = [], [], []
|
||||
while not done:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, _, dones, infos = vn.step(action)
|
||||
done = dones[0]
|
||||
inner = vn.envs[0]
|
||||
com, radius, _ = inner._flock_stats()
|
||||
min_pen = min(min_pen,
|
||||
float(np.linalg.norm(com - inner.PEN_CENTER)))
|
||||
mags.append(float(np.linalg.norm(action[0])))
|
||||
ep_radii.append(radius)
|
||||
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
|
||||
steps += 1
|
||||
rc = infos[0].get("rcomps")
|
||||
if rc:
|
||||
for k, v in rc.items():
|
||||
rc_sums[k] = rc_sums.get(k, 0.0) + v
|
||||
rc_n += 1
|
||||
n_penned = infos[0].get("n_penned", 0)
|
||||
successes += int(n_penned == n_sheep)
|
||||
ep_lens.append(steps)
|
||||
min_pen_list.append(min_pen)
|
||||
action_mags.extend(mags)
|
||||
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
|
||||
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
||||
|
||||
vn.close()
|
||||
result = {
|
||||
"sr": successes / n_episodes,
|
||||
"mean_len": float(np.mean(ep_lens)),
|
||||
"mean_min_pen": float(np.mean(min_pen_list)),
|
||||
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
|
||||
"failure_modes": failure_counts,
|
||||
}
|
||||
if rc_n > 0:
|
||||
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
|
||||
return result
|
||||
|
||||
|
||||
# ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(
|
||||
description="PPO + attention training for herding task")
|
||||
p.add_argument("--config", type=str, default=None)
|
||||
p.add_argument("--max-sheep", type=int, default=10)
|
||||
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
|
||||
p.add_argument("--n-envs", type=int, default=8)
|
||||
p.add_argument("--max-steps", type=int, default=2500)
|
||||
p.add_argument("--eval-episodes", type=int, default=30)
|
||||
p.add_argument("--run-dir", type=str, default=None)
|
||||
p.add_argument("--no-gif", action="store_true")
|
||||
p.add_argument("--gif-fps", type=int, default=20)
|
||||
p.add_argument("--gif-skip", type=int, default=3)
|
||||
# Attention architecture
|
||||
p.add_argument("--embed-dim", type=int, default=64,
|
||||
help="Transformer embedding dimension (default 64)")
|
||||
p.add_argument("--n-heads", type=int, default=4,
|
||||
help="Number of attention heads (default 4)")
|
||||
p.add_argument("--n-layers", type=int, default=2,
|
||||
help="Number of transformer encoder layers (default 2)")
|
||||
p.add_argument("--ff-dim", type=int, default=128,
|
||||
help="Transformer feed-forward dim (default 128)")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
cfg = dict(DEFAULT_CONFIG)
|
||||
config_path = args.config
|
||||
if config_path is None and os.path.exists("config.json"):
|
||||
config_path = "config.json"
|
||||
if config_path:
|
||||
with open(config_path) as f:
|
||||
cfg.update(json.load(f))
|
||||
print(f"Config loaded from {config_path}")
|
||||
|
||||
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
|
||||
|
||||
run_dir = args.run_dir or os.path.join(
|
||||
"runs", "at_" + time.strftime("%Y%m%d_%H%M%S"))
|
||||
eval_dir = os.path.join(run_dir, "eval")
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
with open(os.path.join(run_dir, "config.json"), "w") as f:
|
||||
json.dump(cfg, f, indent=2)
|
||||
|
||||
print(f"Config: {cfg}")
|
||||
print(f"Run dir: {run_dir}")
|
||||
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
|
||||
f"{args.steps_per_stage:,} steps/stage")
|
||||
print(f"Transformer: embed={args.embed_dim} heads={args.n_heads} "
|
||||
f"layers={args.n_layers} ff={args.ff_dim}\n")
|
||||
|
||||
train_env = SubprocVecEnv([
|
||||
make_env_at(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
|
||||
for i in range(args.n_envs)
|
||||
])
|
||||
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
|
||||
|
||||
model = PPO(
|
||||
"MlpPolicy", vn,
|
||||
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
|
||||
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
|
||||
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
|
||||
policy_kwargs=dict(
|
||||
features_extractor_class=ShepherdAttentionExtractor,
|
||||
features_extractor_kwargs=dict(
|
||||
embed_dim=args.embed_dim,
|
||||
n_heads=args.n_heads,
|
||||
n_layers=args.n_layers,
|
||||
ff_dim=args.ff_dim,
|
||||
),
|
||||
net_arch=[256, 256],
|
||||
),
|
||||
device="cpu",
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
stage_results = []
|
||||
t0 = time.time()
|
||||
|
||||
try:
|
||||
for n in range(1, args.max_sheep + 1):
|
||||
if n == 1:
|
||||
print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=args.steps_per_stage,
|
||||
reset_num_timesteps=True,
|
||||
callback=ProgressCallback("1 sheep", freq=100_000),
|
||||
)
|
||||
else:
|
||||
half = max(1, args.n_envs // 2)
|
||||
mix_steps = args.steps_per_stage // 2
|
||||
full_steps = args.steps_per_stage - mix_steps
|
||||
|
||||
for i in range(half):
|
||||
vn.env_method("set_n_sheep", n - 1, indices=[i])
|
||||
for i in range(half, args.n_envs):
|
||||
vn.env_method("set_n_sheep", n, indices=[i])
|
||||
|
||||
print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) "
|
||||
f"{mix_steps:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=mix_steps,
|
||||
reset_num_timesteps=False,
|
||||
callback=ProgressCallback(f"{n-1}→{n} mix", freq=100_000),
|
||||
)
|
||||
|
||||
vn.env_method("set_n_sheep", n)
|
||||
print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=full_steps,
|
||||
reset_num_timesteps=False,
|
||||
callback=ProgressCallback(f"{n} sheep", freq=100_000),
|
||||
)
|
||||
|
||||
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
|
||||
r = evaluate_at(model, vn, n, args.eval_episodes,
|
||||
args.max_steps, rcfg)
|
||||
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
|
||||
f"mean_len={r['mean_len']:.0f} "
|
||||
f"mean_min_pen={r['mean_min_pen']:.1f}m "
|
||||
f"mean_act={r['mean_act']:.2f}")
|
||||
if r["failure_modes"]:
|
||||
modes = " ".join(
|
||||
f"{k}={v}" for k, v in sorted(
|
||||
r["failure_modes"].items(), key=lambda x: -x[1]))
|
||||
print(f" failure modes: {modes}")
|
||||
if "reward_per_step" in r:
|
||||
rps = r["reward_per_step"]
|
||||
print(" reward/step: " + " ".join(
|
||||
f"{k}={v:+.4f}" for k, v in rps.items()))
|
||||
|
||||
hist = run_and_record(
|
||||
model, vn, n, args.max_steps, rcfg,
|
||||
seed=1000 + n, make_env_fn=make_env_at,
|
||||
)
|
||||
tag = "success" if hist["success"] else "fail"
|
||||
plot_trajectory(hist, os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
|
||||
plot_timeseries(hist, os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
|
||||
if not args.no_gif:
|
||||
save_episode_gif(
|
||||
hist,
|
||||
os.path.join(eval_dir, f"ep_{n}s_{tag}.gif"),
|
||||
fps=args.gif_fps, skip=args.gif_skip)
|
||||
|
||||
r["n_sheep"] = n
|
||||
stage_results.append(r)
|
||||
|
||||
model.save(os.path.join(run_dir, "final_model"))
|
||||
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
|
||||
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
|
||||
json.dump(stage_results, f, indent=2)
|
||||
|
||||
finally:
|
||||
try:
|
||||
vn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elapsed = (time.time() - t0) / 60
|
||||
print("\n" + "=" * 70)
|
||||
print(" TRAINING SUMMARY (attention policy)")
|
||||
print("=" * 70)
|
||||
for r in stage_results:
|
||||
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
|
||||
f"len={r['mean_len']:>5.0f} "
|
||||
f"min_pen={r['mean_min_pen']:>5.1f}m "
|
||||
f"act={r['mean_act']:.2f}")
|
||||
print(f"\n Total time: {elapsed:.1f} min")
|
||||
print(f" Artefacts: {run_dir}/")
|
||||
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
|
||||
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,267 @@
|
||||
"""Train a PPO shepherd-dog policy on ``HerdingEnv`` with curriculum.
|
||||
|
||||
Defaults to 16 parallel ``SubprocVecEnv`` workers feeding a GPU policy.
|
||||
Saves checkpoints, the best-eval model, and the VecNormalize stats —
|
||||
all three are needed at inference time by the Webots controller.
|
||||
|
||||
Usage::
|
||||
|
||||
python -m training.train_ppo \
|
||||
--config training/configs/ppo_default.yaml \
|
||||
--out-dir training/runs/baseline
|
||||
|
||||
To resume from a checkpoint::
|
||||
|
||||
python -m training.train_ppo --resume training/runs/baseline/checkpoints/ppo_500000_steps.zip
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import (
|
||||
BaseCallback, CheckpointCallback, EvalCallback,
|
||||
)
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import (
|
||||
DummyVecEnv, SubprocVecEnv, VecNormalize,
|
||||
)
|
||||
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Env factories
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
def _make_env(rank: int, seed: int = 0):
|
||||
def _thunk():
|
||||
env = HerdingEnv(seed=seed + rank)
|
||||
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
|
||||
return env
|
||||
return _thunk
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Curriculum callback
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
class CurriculumCallback(BaseCallback):
|
||||
"""Drive the env's flock-size + state-space difficulty curriculum.
|
||||
|
||||
Schedule entries: {step, max_n_sheep, difficulty}. The largest entry
|
||||
whose step <= num_timesteps wins; both knobs update together.
|
||||
"""
|
||||
|
||||
def __init__(self, schedule, vec_envs, verbose: int = 0):
|
||||
super().__init__(verbose)
|
||||
self.schedule = sorted(schedule, key=lambda d: d["step"])
|
||||
# Accept a list of envs so the eval env tracks training difficulty.
|
||||
self.vec_envs = vec_envs if isinstance(vec_envs, (list, tuple)) else [vec_envs]
|
||||
self._last_n = None
|
||||
self._last_d = None
|
||||
|
||||
def _call(self, method, value):
|
||||
for v in self.vec_envs:
|
||||
try:
|
||||
v.env_method(method, value)
|
||||
except AttributeError:
|
||||
v.venv.env_method(method, value)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
t = self.num_timesteps
|
||||
n = self.schedule[0]["max_n_sheep"]
|
||||
d = self.schedule[0].get("difficulty", 1.0)
|
||||
for entry in self.schedule:
|
||||
if t >= entry["step"]:
|
||||
n = entry["max_n_sheep"]
|
||||
d = entry.get("difficulty", 1.0)
|
||||
if n != self._last_n:
|
||||
self._call("set_max_n_sheep", n)
|
||||
self._last_n = n
|
||||
if d != self._last_d:
|
||||
self._call("set_difficulty", d)
|
||||
self._last_d = d
|
||||
if self.verbose:
|
||||
print(f"[curriculum] t={t} → max_n_sheep={n} difficulty={d}")
|
||||
return True
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Main
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", default=os.path.join(_HERE, "configs", "ppo_default.yaml"))
|
||||
parser.add_argument("--out-dir", default=os.path.join(_HERE, "runs", "latest"))
|
||||
parser.add_argument("--n-envs", type=int, default=None,
|
||||
help="Override config n_envs.")
|
||||
parser.add_argument("--total-timesteps", type=int, default=None,
|
||||
help="Override config total_timesteps.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--resume", type=str, default=None,
|
||||
help="Path to a SB3 zip to resume from.")
|
||||
# SB3 recommends CPU for MlpPolicy — GPU helps CNN policies, not MLPs
|
||||
# of this size. Override with --device cuda if you really want it.
|
||||
parser.add_argument("--device", default="cpu")
|
||||
parser.add_argument("--no-vecnorm", action="store_true",
|
||||
help="Disable VecNormalize wrapper. Required when "
|
||||
"resuming from a BC-pretrained policy that "
|
||||
"wasn't trained under it.")
|
||||
parser.add_argument("--no-curriculum", action="store_true",
|
||||
help="Skip curriculum callback (resumed policy is "
|
||||
"already competent across the distribution).")
|
||||
parser.add_argument("--imitate-weight", type=float, default=None,
|
||||
help="Override env W_IMITATE. Set to 0 to disable "
|
||||
"Strömbom imitation reward.")
|
||||
parser.add_argument("--difficulty", type=float, default=None,
|
||||
help="Override env difficulty (0=easy, 1=hard). "
|
||||
"Used in BC fine-tune to skip easy curriculum.")
|
||||
parser.add_argument("--log-std", type=float, default=None,
|
||||
help="Override the policy's log_std after load. "
|
||||
"BC trained with std≈1.6 (log_std=0.5) which "
|
||||
"is too noisy for fine-tune. Use -1.5 (std≈0.22) "
|
||||
"to keep PPO close to the BC mean while still "
|
||||
"exploring locally.")
|
||||
parser.add_argument("--learning-rate", type=float, default=None,
|
||||
help="Override config learning rate. For BC "
|
||||
"fine-tune, 5e-5 is much safer than the 3e-4 "
|
||||
"default.")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
n_envs = args.n_envs or cfg["n_envs"]
|
||||
total_timesteps = args.total_timesteps or cfg["total_timesteps"]
|
||||
|
||||
out = Path(args.out_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
(out / "checkpoints").mkdir(exist_ok=True)
|
||||
(out / "best").mkdir(exist_ok=True)
|
||||
(out / "evals").mkdir(exist_ok=True)
|
||||
|
||||
print(f"[train] out={out} n_envs={n_envs} total={total_timesteps} device={args.device}")
|
||||
|
||||
# --- Train env (vectorised, optionally normalised) ---
|
||||
env_fns = [_make_env(i, seed=args.seed) for i in range(n_envs)]
|
||||
venv = SubprocVecEnv(env_fns) if n_envs > 1 else DummyVecEnv(env_fns)
|
||||
eval_venv = DummyVecEnv([_make_env(99, seed=args.seed + 999)])
|
||||
if not args.no_vecnorm:
|
||||
venv = VecNormalize(venv, norm_obs=True, norm_reward=False, clip_obs=10.0)
|
||||
eval_venv = VecNormalize(eval_venv, norm_obs=True, norm_reward=False,
|
||||
clip_obs=10.0, training=False)
|
||||
eval_venv.obs_rms = venv.obs_rms
|
||||
else:
|
||||
print("[train] VecNormalize disabled (resumed policy was trained without it).")
|
||||
|
||||
# Apply env-level overrides (used by BC fine-tune to disable Strömbom
|
||||
# imitation and start at full deployment difficulty).
|
||||
def _env_call(method, value):
|
||||
for v in (venv, eval_venv):
|
||||
try:
|
||||
v.env_method(method, value)
|
||||
except AttributeError:
|
||||
v.venv.env_method(method, value)
|
||||
|
||||
if args.imitate_weight is not None:
|
||||
_env_call("set_imitate_weight", args.imitate_weight)
|
||||
print(f"[train] W_IMITATE overridden to {args.imitate_weight}")
|
||||
if args.difficulty is not None:
|
||||
_env_call("set_difficulty", args.difficulty)
|
||||
print(f"[train] difficulty pinned to {args.difficulty}")
|
||||
|
||||
# --- Model ---
|
||||
policy_kwargs = dict(
|
||||
net_arch=dict(pi=cfg["net_arch_pi"], vf=cfg["net_arch_vf"]),
|
||||
log_std_init=cfg.get("log_std_init", 0.0),
|
||||
)
|
||||
|
||||
if args.resume:
|
||||
print(f"[train] resuming from {args.resume}")
|
||||
custom_objects = {}
|
||||
if args.learning_rate is not None:
|
||||
custom_objects["learning_rate"] = args.learning_rate
|
||||
model = PPO.load(args.resume, env=venv, device=args.device,
|
||||
tensorboard_log=str(out / "tb"),
|
||||
custom_objects=custom_objects or None)
|
||||
if args.log_std is not None:
|
||||
import torch as _th
|
||||
with _th.no_grad():
|
||||
model.policy.log_std.fill_(args.log_std)
|
||||
print(f"[train] log_std overridden to {args.log_std} "
|
||||
f"(std≈{2.71828 ** args.log_std:.2f})")
|
||||
if args.learning_rate is not None:
|
||||
print(f"[train] learning_rate overridden to {args.learning_rate}")
|
||||
else:
|
||||
model = PPO(
|
||||
cfg["policy"], venv,
|
||||
learning_rate=cfg["learning_rate"],
|
||||
n_steps=cfg["n_steps"],
|
||||
batch_size=cfg["batch_size"],
|
||||
n_epochs=cfg["n_epochs"],
|
||||
gamma=cfg["gamma"],
|
||||
gae_lambda=cfg["gae_lambda"],
|
||||
clip_range=cfg["clip_range"],
|
||||
ent_coef=cfg["ent_coef"],
|
||||
vf_coef=cfg["vf_coef"],
|
||||
max_grad_norm=cfg["max_grad_norm"],
|
||||
target_kl=cfg.get("target_kl"),
|
||||
policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=str(out / "tb"),
|
||||
seed=args.seed,
|
||||
device=args.device,
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
# --- Callbacks ---
|
||||
ckpt_cb = CheckpointCallback(
|
||||
save_freq=max(1, cfg["checkpoint_freq"] // n_envs),
|
||||
save_path=str(out / "checkpoints"), name_prefix="ppo",
|
||||
save_vecnormalize=True,
|
||||
)
|
||||
eval_cb = EvalCallback(
|
||||
eval_venv,
|
||||
best_model_save_path=str(out / "best"),
|
||||
log_path=str(out / "evals"),
|
||||
eval_freq=max(1, cfg["eval_freq"] // n_envs),
|
||||
n_eval_episodes=cfg["n_eval_episodes"],
|
||||
deterministic=True,
|
||||
)
|
||||
callbacks = [ckpt_cb, eval_cb]
|
||||
if not args.no_curriculum and "curriculum" in cfg and cfg["curriculum"]:
|
||||
callbacks.append(CurriculumCallback(
|
||||
cfg["curriculum"], [venv, eval_venv], verbose=1,
|
||||
))
|
||||
elif args.no_curriculum:
|
||||
print("[train] curriculum disabled — env knobs left at their current values.")
|
||||
|
||||
# --- Train ---
|
||||
model.learn(total_timesteps=total_timesteps, callback=callbacks,
|
||||
progress_bar=True)
|
||||
|
||||
# --- Save final model + VecNormalize stats ---
|
||||
model.save(out / "final.zip")
|
||||
venv.save(str(out / "vecnormalize.pkl"))
|
||||
# The EvalCallback already wrote best_model.zip into out/best/ — drop the
|
||||
# VecNormalize stats next to it for the controller to pick up.
|
||||
venv.save(str(out / "best" / "vecnormalize.pkl"))
|
||||
print(f"[train] done. saved to {out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
-342
@@ -1,342 +0,0 @@
|
||||
"""
|
||||
All visualization for the herding policy: trajectory plots, timeseries plots,
|
||||
success-rate bar chart, and animated GIFs.
|
||||
|
||||
Used both by train.py (auto-rendered after each curriculum stage) and as a CLI
|
||||
to render a fresh episode against a saved model.
|
||||
|
||||
CLI usage:
|
||||
python viz.py --run-dir runs/v1 --n-sheep 5
|
||||
python viz.py --run-dir runs/v1 --n-sheep 10 --no-gif
|
||||
python viz.py --model runs/v1/final_model.zip --vecnorm runs/v1/vecnorm.pkl \\
|
||||
--n-sheep 3 --out-dir vis_v1_3sheep
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
from copy import deepcopy
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.animation as animation
|
||||
from matplotlib.collections import LineCollection
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
|
||||
|
||||
# ── Palette ──────────────────────────────────────────────────────────────────
|
||||
|
||||
SHEEP_COLORS = [
|
||||
"#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00",
|
||||
"#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62",
|
||||
]
|
||||
DOG_COLOR = "#4e342e"
|
||||
|
||||
|
||||
# ── Common drawing primitives ────────────────────────────────────────────────
|
||||
|
||||
def draw_field(ax):
|
||||
ax.set_xlim(-16, 16)
|
||||
ax.set_ylim(-16, 16)
|
||||
ax.set_aspect("equal")
|
||||
ax.set_facecolor("#dcedc8")
|
||||
ax.add_patch(mpatches.Rectangle(
|
||||
(-15, -15), 30, 30, fill=False, edgecolor="#795548", lw=2))
|
||||
ax.add_patch(mpatches.Rectangle(
|
||||
(10, -15), 3, 7, facecolor="#ffe082", edgecolor="#795548", lw=2))
|
||||
ax.text(11.5, -11.5, "pen", ha="center", va="center",
|
||||
fontsize=8, color="#795548")
|
||||
|
||||
|
||||
def faded_path(ax, xs, ys, color, lw=1.5, label=None):
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return
|
||||
points = np.array([xs, ys]).T.reshape(-1, 1, 2)
|
||||
segs = np.concatenate([points[:-1], points[1:]], axis=1)
|
||||
alphas = np.linspace(0.15, 1.0, len(segs))
|
||||
colors = [(*matplotlib.colors.to_rgb(color), a) for a in alphas]
|
||||
ax.add_collection(LineCollection(segs, colors=colors, linewidth=lw))
|
||||
if label:
|
||||
ax.plot([], [], color=color, lw=lw, label=label)
|
||||
|
||||
|
||||
# ── Episode rollout ──────────────────────────────────────────────────────────
|
||||
|
||||
def make_eval_env(n_sheep, seed, max_steps, reward_cfg=None):
|
||||
def _init():
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||
reward_cfg=reward_cfg)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
|
||||
def run_and_record(model, vn_template, n_sheep, max_steps,
|
||||
reward_cfg=None, seed=42, make_env_fn=None):
|
||||
"""Run one deterministic episode and return full trajectory history."""
|
||||
_factory = make_env_fn or make_eval_env
|
||||
raw = DummyVecEnv([_factory(n_sheep, seed, max_steps, reward_cfg)])
|
||||
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
vn.obs_rms = deepcopy(vn_template.obs_rms)
|
||||
vn.ret_rms = deepcopy(vn_template.ret_rms)
|
||||
|
||||
obs = vn.reset()
|
||||
inner = vn.envs[0]
|
||||
done = False
|
||||
|
||||
dog_xs, dog_ys = [], []
|
||||
sheep_xs = [[] for _ in range(n_sheep)]
|
||||
sheep_ys = [[] for _ in range(n_sheep)]
|
||||
sheep_penned = [[] for _ in range(n_sheep)]
|
||||
radii = []
|
||||
pen_dists = [[] for _ in range(n_sheep)]
|
||||
action_mags = []
|
||||
rewards = []
|
||||
penned_at = [None] * n_sheep
|
||||
step = 0
|
||||
|
||||
while not done:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, reward, dones, infos = vn.step(action)
|
||||
done = dones[0]
|
||||
step += 1
|
||||
|
||||
dog_xs.append(float(inner.dog_pos[0]))
|
||||
dog_ys.append(float(inner.dog_pos[1]))
|
||||
com, radius, _ = inner._flock_stats()
|
||||
radii.append(radius)
|
||||
rewards.append(float(reward[0]))
|
||||
action_mags.append(float(np.linalg.norm(action[0])))
|
||||
for i in range(n_sheep):
|
||||
sheep_xs[i].append(float(inner.sheep_pos[i][0]))
|
||||
sheep_ys[i].append(float(inner.sheep_pos[i][1]))
|
||||
sheep_penned[i].append(bool(inner.penned[i]))
|
||||
pen_dists[i].append(
|
||||
float(np.linalg.norm(inner.sheep_pos[i] - inner.PEN_CENTER)))
|
||||
if inner.penned[i] and penned_at[i] is None:
|
||||
penned_at[i] = step
|
||||
|
||||
n_penned = infos[0].get("n_penned", 0)
|
||||
vn.close()
|
||||
|
||||
return dict(
|
||||
dog_xs=dog_xs, dog_ys=dog_ys,
|
||||
sheep_xs=sheep_xs, sheep_ys=sheep_ys,
|
||||
sheep_penned=sheep_penned,
|
||||
radii=radii, pen_dists=pen_dists,
|
||||
action_mags=action_mags, rewards=rewards,
|
||||
penned_at=penned_at,
|
||||
n_penned=n_penned, n_sheep=n_sheep,
|
||||
success=n_penned == n_sheep, steps=step,
|
||||
)
|
||||
|
||||
|
||||
# ── Static plots ─────────────────────────────────────────────────────────────
|
||||
|
||||
def plot_trajectory(hist, out_path):
|
||||
fig, ax = plt.subplots(figsize=(7, 7))
|
||||
draw_field(ax)
|
||||
for i in range(hist["n_sheep"]):
|
||||
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||
xs, ys = hist["sheep_xs"][i], hist["sheep_ys"][i]
|
||||
faded_path(ax, xs, ys, c, lw=1.2, label=f"sheep {i+1}")
|
||||
ax.plot(xs[0], ys[0], "o", color=c, ms=7, zorder=4)
|
||||
end = hist["penned_at"][i] if hist["penned_at"][i] is not None else -1
|
||||
ax.plot(xs[end], ys[end], "*", color=c, ms=11, zorder=5)
|
||||
faded_path(ax, hist["dog_xs"], hist["dog_ys"], DOG_COLOR, lw=2.0,
|
||||
label="dog")
|
||||
ax.plot(hist["dog_xs"][0], hist["dog_ys"][0], "s", color=DOG_COLOR,
|
||||
ms=10, zorder=5)
|
||||
ax.plot(hist["dog_xs"][-1], hist["dog_ys"][-1], "D", color=DOG_COLOR,
|
||||
ms=10, zorder=5)
|
||||
result = ("SUCCESS" if hist["success"]
|
||||
else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})")
|
||||
ax.set_title(f"n={hist['n_sheep']} {result} {hist['steps']} steps",
|
||||
fontsize=12)
|
||||
ax.legend(loc="upper left", fontsize=8)
|
||||
plt.tight_layout()
|
||||
fig.savefig(out_path, dpi=120)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_timeseries(hist, out_path):
|
||||
t = np.arange(hist["steps"])
|
||||
fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
|
||||
|
||||
axes[0].plot(t, hist["radii"], color="steelblue")
|
||||
axes[0].axhline(5.0, color="orange", ls="--", lw=1, label="compact (5m)")
|
||||
axes[0].set_ylabel("flock radius (m)")
|
||||
axes[0].legend(fontsize=8)
|
||||
axes[0].set_title("Flock radius")
|
||||
|
||||
for i in range(hist["n_sheep"]):
|
||||
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||
axes[1].plot(t, hist["pen_dists"][i], color=c, lw=1,
|
||||
label=f"sheep {i+1}")
|
||||
if hist["penned_at"][i] is not None:
|
||||
axes[1].axvline(hist["penned_at"][i], color=c, ls=":", lw=1)
|
||||
axes[1].set_ylabel("dist to pen (m)")
|
||||
axes[1].legend(fontsize=7, ncol=min(hist["n_sheep"], 5))
|
||||
axes[1].set_title("Per-sheep distance to pen")
|
||||
|
||||
axes[2].plot(t, hist["action_mags"], color="tomato", lw=1)
|
||||
axes[2].axhline(1.0, color="gray", ls="--", lw=1, label="max")
|
||||
axes[2].set_ylabel("action ||(vx,vy)||")
|
||||
axes[2].set_ylim(0, 1.5)
|
||||
axes[2].set_title("Dog action magnitude")
|
||||
axes[2].legend(fontsize=8)
|
||||
|
||||
axes[3].plot(t, hist["rewards"], color="purple", lw=1, alpha=0.7)
|
||||
axes[3].axhline(0, color="black", lw=0.5)
|
||||
axes[3].set_ylabel("reward")
|
||||
axes[3].set_xlabel("step")
|
||||
axes[3].set_title("Reward per step")
|
||||
|
||||
result = ("SUCCESS" if hist["success"]
|
||||
else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})")
|
||||
fig.suptitle(f"n_sheep={hist['n_sheep']} {result} {hist['steps']} steps",
|
||||
fontsize=13)
|
||||
plt.tight_layout()
|
||||
fig.savefig(out_path, dpi=120)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_success_rate(stage_results, out_path):
|
||||
fig, ax = plt.subplots(figsize=(8, 4))
|
||||
ns = [r["n_sheep"] for r in stage_results]
|
||||
srs = [r["sr"] * 100 for r in stage_results]
|
||||
bars = ax.bar(ns, srs, color="steelblue", edgecolor="white")
|
||||
ax.set_xlabel("Sheep count")
|
||||
ax.set_ylabel("Success rate (%)")
|
||||
ax.set_ylim(0, 105)
|
||||
ax.axhline(90, color="orange", ls="--", lw=1, label="90% target")
|
||||
for bar, sr in zip(bars, srs):
|
||||
ax.text(bar.get_x() + bar.get_width() / 2,
|
||||
bar.get_height() + 1, f"{sr:.0f}%",
|
||||
ha="center", fontsize=9)
|
||||
ax.legend()
|
||||
ax.set_title("Evaluation success rate per sheep count")
|
||||
plt.tight_layout()
|
||||
fig.savefig(out_path, dpi=120)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# ── Animated GIF ─────────────────────────────────────────────────────────────
|
||||
|
||||
def save_episode_gif(hist, out_path, fps=20, skip=3):
|
||||
"""Render hist as an animated GIF. `skip` keeps every Nth frame (smaller file)."""
|
||||
n_sheep = hist["n_sheep"]
|
||||
frames = list(range(0, hist["steps"], max(1, skip)))
|
||||
if frames[-1] != hist["steps"] - 1:
|
||||
frames.append(hist["steps"] - 1)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 6))
|
||||
draw_field(ax)
|
||||
title = ax.text(0, 16.5, "", ha="center", fontsize=11)
|
||||
dog_marker, = ax.plot([], [], "s", color=DOG_COLOR, ms=12,
|
||||
markeredgecolor="black", markeredgewidth=1.5,
|
||||
zorder=5)
|
||||
sheep_markers = []
|
||||
for i in range(n_sheep):
|
||||
c = SHEEP_COLORS[i % len(SHEEP_COLORS)]
|
||||
m, = ax.plot([], [], "o", color=c, ms=10,
|
||||
markeredgecolor="#333", markeredgewidth=1, zorder=4)
|
||||
sheep_markers.append(m)
|
||||
dog_trail, = ax.plot([], [], color=DOG_COLOR, lw=1.0, alpha=0.5)
|
||||
|
||||
def update(k):
|
||||
title.set_text(
|
||||
f"n={n_sheep} step {k+1}/{hist['steps']} "
|
||||
f"penned {sum(hist['sheep_penned'][i][k] for i in range(n_sheep))}/{n_sheep}")
|
||||
dog_marker.set_data([hist["dog_xs"][k]], [hist["dog_ys"][k]])
|
||||
dog_trail.set_data(hist["dog_xs"][:k+1], hist["dog_ys"][:k+1])
|
||||
for i, m in enumerate(sheep_markers):
|
||||
m.set_data([hist["sheep_xs"][i][k]], [hist["sheep_ys"][i][k]])
|
||||
penned = hist["sheep_penned"][i][k]
|
||||
m.set_color("deeppink" if penned else SHEEP_COLORS[i % len(SHEEP_COLORS)])
|
||||
return [title, dog_marker, dog_trail, *sheep_markers]
|
||||
|
||||
anim = animation.FuncAnimation(
|
||||
fig, update, frames=frames, interval=1000 / fps, blit=False)
|
||||
anim.save(out_path, writer=animation.PillowWriter(fps=fps), dpi=80)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# ── CLI ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _resolve_paths(args):
|
||||
if args.run_dir:
|
||||
model_path = os.path.join(args.run_dir, "final_model.zip")
|
||||
vn_path = os.path.join(args.run_dir, "vecnorm.pkl")
|
||||
cfg_path = os.path.join(args.run_dir, "config.json")
|
||||
else:
|
||||
model_path = args.model
|
||||
vn_path = args.vecnorm
|
||||
cfg_path = args.config
|
||||
return model_path, vn_path, cfg_path
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(
|
||||
description="Render trajectory + timeseries + GIF for a saved policy.")
|
||||
p.add_argument("--run-dir", type=str, default=None,
|
||||
help="Run directory containing final_model.zip + vecnorm.pkl + config.json")
|
||||
p.add_argument("--model", type=str, default=None)
|
||||
p.add_argument("--vecnorm", type=str, default=None)
|
||||
p.add_argument("--config", type=str, default=None)
|
||||
p.add_argument("--n-sheep", type=int, default=3)
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
p.add_argument("--max-steps", type=int, default=2500)
|
||||
p.add_argument("--out-dir", type=str, default=None)
|
||||
p.add_argument("--no-gif", action="store_true",
|
||||
help="Skip the animated GIF (PNG-only is faster).")
|
||||
p.add_argument("--gif-fps", type=int, default=20)
|
||||
p.add_argument("--gif-skip", type=int, default=3)
|
||||
args = p.parse_args()
|
||||
|
||||
model_path, vn_path, cfg_path = _resolve_paths(args)
|
||||
if not (model_path and vn_path):
|
||||
p.error("either --run-dir or both --model and --vecnorm are required")
|
||||
|
||||
rcfg = None
|
||||
if cfg_path and os.path.exists(cfg_path):
|
||||
with open(cfg_path) as f:
|
||||
cfg = json.load(f)
|
||||
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
|
||||
|
||||
out_dir = args.out_dir or os.path.join(
|
||||
os.path.dirname(os.path.abspath(model_path)),
|
||||
f"vis_{args.n_sheep}s")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
print(f"Loading model: {model_path}")
|
||||
print(f"Loading vecnorm: {vn_path}")
|
||||
model = PPO.load(model_path, device="cpu")
|
||||
|
||||
raw = DummyVecEnv([make_eval_env(args.n_sheep, args.seed, args.max_steps, rcfg)])
|
||||
vn = VecNormalize.load(vn_path, raw)
|
||||
|
||||
print(f"Rolling out n_sheep={args.n_sheep} (seed={args.seed})...")
|
||||
hist = run_and_record(model, vn, args.n_sheep, args.max_steps,
|
||||
reward_cfg=rcfg, seed=args.seed)
|
||||
result = "SUCCESS" if hist["success"] else f"FAIL ({hist['n_penned']}/{hist['n_sheep']})"
|
||||
print(f" {result} in {hist['steps']} steps")
|
||||
|
||||
plot_trajectory(hist, os.path.join(out_dir, "trajectory.png"))
|
||||
plot_timeseries(hist, os.path.join(out_dir, "timeseries.png"))
|
||||
print(f" saved trajectory.png + timeseries.png to {out_dir}/")
|
||||
if not args.no_gif:
|
||||
gif_path = os.path.join(out_dir, "episode.gif")
|
||||
print(f" rendering GIF (fps={args.gif_fps}, skip={args.gif_skip})...")
|
||||
save_episode_gif(hist, gif_path, fps=args.gif_fps, skip=args.gif_skip)
|
||||
print(f" saved {gif_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user