413 lines
17 KiB
Python
413 lines
17 KiB
Python
"""
|
|
PPO training with attention-based policy (train_at.py).
|
|
|
|
Key difference from train.py
|
|
-----------------------------
|
|
- Observation exposes ALL sheep as individual per-sheep tokens rather than
|
|
only the top-3 farthest. The policy therefore has complete flock visibility
|
|
at any sheep count — no hidden sheep even at n=10.
|
|
- A TransformerFeaturesExtractor processes the sheep tokens with multi-head
|
|
self-attention (permutation-invariant), then mean-pools over valid tokens
|
|
and concatenates the result with global dog/pen features.
|
|
- Curriculum transition uses the same mixed-env approach as train.py: half
|
|
the envs stay at n-1 for the first half of each new stage to suppress
|
|
catastrophic forgetting.
|
|
|
|
Observation layout (7 + MAX_SHEEP*6 = 67 dims, fixed)
|
|
-------------------------------------------------------
|
|
Global (7):
|
|
dog_x / FIELD, dog_y / FIELD,
|
|
cos(heading), sin(heading),
|
|
(pen_x - dog_x) / D, (pen_y - dog_y) / D,
|
|
n_active / n_sheep
|
|
|
|
Per sheep i (6):
|
|
(sheep_x - dog_x) / D, (sheep_y - dog_y) / D, ← pos rel to dog
|
|
(pen_x - sheep_x) / D, (pen_y - sheep_y) / D, ← sheep-to-pen
|
|
is_active 1.0 if not penned, else 0.0
|
|
is_valid 1.0 if i < n_sheep, else 0.0 (padding sentinel)
|
|
|
|
After VecNormalize, is_valid for real sheep normalises > 0 and for
|
|
padding tokens < 0 (because mean ∈ (0,1)), so a threshold of 0 cleanly
|
|
separates real from padded without any extra bookkeeping.
|
|
|
|
Usage
|
|
-----
|
|
python train_at.py # defaults from config.json
|
|
python train_at.py --max-sheep 10 --steps-per-stage 2000000
|
|
python train_at.py --embed-dim 128 --n-heads 4 --n-layers 3
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from gymnasium import spaces
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
|
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize
|
|
|
|
from herding_env import HerdingEnv
|
|
from train import ProgressCallback, _classify, COMPACT_RADIUS, DEFAULT_CONFIG
|
|
from viz import (
|
|
run_and_record, plot_trajectory, plot_timeseries,
|
|
plot_success_rate, save_episode_gif,
|
|
)
|
|
|
|
|
|
# ── Per-sheep token observation environment ───────────────────────────────────
|
|
|
|
class HerdingEnvAt(HerdingEnv):
|
|
"""
|
|
HerdingEnv with a per-sheep token observation for the attention policy.
|
|
Everything else (dynamics, reward, curriculum interface) is inherited.
|
|
"""
|
|
|
|
OBS_GLOBAL = 7
|
|
OBS_SHEEP = 6
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
obs_dim = self.OBS_GLOBAL + self.MAX_SHEEP * self.OBS_SHEEP
|
|
self.observation_space = spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32
|
|
)
|
|
|
|
def _obs(self) -> np.ndarray:
|
|
S = self.FIELD
|
|
D = 2.0 * self.FIELD
|
|
pen_ref = self.PEN_ENTRY if self.ENTRY_AWARE else self.PEN_CENTER
|
|
active_mask = ~self.penned[:self.n_sheep]
|
|
n_active = int(active_mask.sum())
|
|
|
|
global_feats = np.array([
|
|
self.dog_pos[0] / S,
|
|
self.dog_pos[1] / S,
|
|
float(np.cos(self.dog_heading)),
|
|
float(np.sin(self.dog_heading)),
|
|
(pen_ref[0] - self.dog_pos[0]) / D,
|
|
(pen_ref[1] - self.dog_pos[1]) / D,
|
|
n_active / max(self.n_sheep, 1),
|
|
], dtype=np.float32)
|
|
|
|
sheep_feats = np.zeros((self.MAX_SHEEP, self.OBS_SHEEP), dtype=np.float32)
|
|
for i in range(self.n_sheep):
|
|
pos = self.sheep_pos[i]
|
|
sheep_feats[i] = [
|
|
(pos[0] - self.dog_pos[0]) / D,
|
|
(pos[1] - self.dog_pos[1]) / D,
|
|
(pen_ref[0] - pos[0]) / D,
|
|
(pen_ref[1] - pos[1]) / D,
|
|
float(not self.penned[i]),
|
|
1.0, # is_valid: this sheep exists
|
|
]
|
|
# i >= n_sheep: all zeros, is_valid=0 → masked out in attention
|
|
|
|
return np.concatenate([global_feats, sheep_feats.ravel()])
|
|
|
|
|
|
# ── Attention features extractor ──────────────────────────────────────────────
|
|
|
|
class ShepherdAttentionExtractor(BaseFeaturesExtractor):
|
|
"""
|
|
Multi-head self-attention over per-sheep tokens, mean-pooled over valid
|
|
(non-padding) tokens and concatenated with global dog/pen features.
|
|
|
|
After VecNormalize:
|
|
real sheep → is_valid_norm > 0 (normalised from 1.0)
|
|
padding → is_valid_norm ≤ 0 (normalised from 0.0)
|
|
so threshold at 0 is always correct regardless of curriculum stage.
|
|
"""
|
|
|
|
GLOBAL_DIM = HerdingEnvAt.OBS_GLOBAL # 7
|
|
SHEEP_DIM = HerdingEnvAt.OBS_SHEEP # 6
|
|
MAX_SHEEP = HerdingEnv.MAX_SHEEP # 10
|
|
VALID_IDX = 5 # index of is_valid within each token
|
|
|
|
def __init__(self, observation_space, embed_dim: int = 64,
|
|
n_heads: int = 4, n_layers: int = 2, ff_dim: int = 128):
|
|
super().__init__(observation_space,
|
|
features_dim=self.GLOBAL_DIM + embed_dim)
|
|
self.sheep_embed = nn.Linear(self.SHEEP_DIM, embed_dim)
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
d_model=embed_dim, nhead=n_heads, dim_feedforward=ff_dim,
|
|
dropout=0.0, batch_first=True,
|
|
)
|
|
self.transformer = nn.TransformerEncoder(encoder_layer,
|
|
num_layers=n_layers,
|
|
enable_nested_tensor=False)
|
|
|
|
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
|
B = obs.shape[0]
|
|
global_feats = obs[:, :self.GLOBAL_DIM] # (B, 7)
|
|
tokens = obs[:, self.GLOBAL_DIM:].view(
|
|
B, self.MAX_SHEEP, self.SHEEP_DIM) # (B, 10, 6)
|
|
|
|
# is_valid after VecNorm: real > 0, padding ≤ 0
|
|
is_valid_norm = tokens[:, :, self.VALID_IDX] # (B, 10)
|
|
key_padding_mask = is_valid_norm <= 0.0 # True → ignore
|
|
|
|
x = self.sheep_embed(tokens) # (B, 10, E)
|
|
x = self.transformer(x, src_key_padding_mask=key_padding_mask)
|
|
|
|
valid_w = (is_valid_norm > 0.0).float().unsqueeze(-1) # (B, 10, 1)
|
|
pooled = (x * valid_w).sum(1) / valid_w.sum(1).clamp(min=1.0)
|
|
|
|
return torch.cat([global_feats, pooled], dim=1) # (B, 7+E)
|
|
|
|
|
|
# ── Environment factory ───────────────────────────────────────────────────────
|
|
|
|
def make_env_at(n_sheep, seed, max_steps, reward_cfg=None):
|
|
def _init():
|
|
env = HerdingEnvAt(n_sheep=n_sheep, max_steps=max_steps,
|
|
reward_cfg=reward_cfg)
|
|
env.reset(seed=seed)
|
|
return env
|
|
return _init
|
|
|
|
|
|
# ── Evaluation ────────────────────────────────────────────────────────────────
|
|
|
|
def evaluate_at(model, vn_template, n_sheep, n_episodes, max_steps,
|
|
reward_cfg=None):
|
|
raw = DummyVecEnv([make_env_at(n_sheep, 9999, max_steps, reward_cfg)])
|
|
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
|
vn.obs_rms = deepcopy(vn_template.obs_rms)
|
|
vn.ret_rms = deepcopy(vn_template.ret_rms)
|
|
|
|
successes = 0
|
|
ep_lens, min_pen_list, action_mags = [], [], []
|
|
failure_counts, rc_sums = {}, {}
|
|
rc_n = 0
|
|
|
|
for _ in range(n_episodes):
|
|
obs = vn.reset()
|
|
done = False
|
|
steps, min_pen = 0, float("inf")
|
|
mags, ep_radii, ep_com_dists = [], [], []
|
|
while not done:
|
|
action, _ = model.predict(obs, deterministic=True)
|
|
obs, _, dones, infos = vn.step(action)
|
|
done = dones[0]
|
|
inner = vn.envs[0]
|
|
com, radius, _ = inner._flock_stats()
|
|
min_pen = min(min_pen,
|
|
float(np.linalg.norm(com - inner.PEN_CENTER)))
|
|
mags.append(float(np.linalg.norm(action[0])))
|
|
ep_radii.append(radius)
|
|
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
|
|
steps += 1
|
|
rc = infos[0].get("rcomps")
|
|
if rc:
|
|
for k, v in rc.items():
|
|
rc_sums[k] = rc_sums.get(k, 0.0) + v
|
|
rc_n += 1
|
|
n_penned = infos[0].get("n_penned", 0)
|
|
successes += int(n_penned == n_sheep)
|
|
ep_lens.append(steps)
|
|
min_pen_list.append(min_pen)
|
|
action_mags.extend(mags)
|
|
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
|
|
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
|
|
|
vn.close()
|
|
result = {
|
|
"sr": successes / n_episodes,
|
|
"mean_len": float(np.mean(ep_lens)),
|
|
"mean_min_pen": float(np.mean(min_pen_list)),
|
|
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
|
|
"failure_modes": failure_counts,
|
|
}
|
|
if rc_n > 0:
|
|
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
|
|
return result
|
|
|
|
|
|
# ── CLI ───────────────────────────────────────────────────────────────────────
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser(
|
|
description="PPO + attention training for herding task")
|
|
p.add_argument("--config", type=str, default=None)
|
|
p.add_argument("--max-sheep", type=int, default=10)
|
|
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
|
|
p.add_argument("--n-envs", type=int, default=8)
|
|
p.add_argument("--max-steps", type=int, default=2500)
|
|
p.add_argument("--eval-episodes", type=int, default=30)
|
|
p.add_argument("--run-dir", type=str, default=None)
|
|
p.add_argument("--no-gif", action="store_true")
|
|
p.add_argument("--gif-fps", type=int, default=20)
|
|
p.add_argument("--gif-skip", type=int, default=3)
|
|
# Attention architecture
|
|
p.add_argument("--embed-dim", type=int, default=64,
|
|
help="Transformer embedding dimension (default 64)")
|
|
p.add_argument("--n-heads", type=int, default=4,
|
|
help="Number of attention heads (default 4)")
|
|
p.add_argument("--n-layers", type=int, default=2,
|
|
help="Number of transformer encoder layers (default 2)")
|
|
p.add_argument("--ff-dim", type=int, default=128,
|
|
help="Transformer feed-forward dim (default 128)")
|
|
return p.parse_args()
|
|
|
|
|
|
# ── Main ──────────────────────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
cfg = dict(DEFAULT_CONFIG)
|
|
config_path = args.config
|
|
if config_path is None and os.path.exists("config.json"):
|
|
config_path = "config.json"
|
|
if config_path:
|
|
with open(config_path) as f:
|
|
cfg.update(json.load(f))
|
|
print(f"Config loaded from {config_path}")
|
|
|
|
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
|
|
|
|
run_dir = args.run_dir or os.path.join(
|
|
"runs", "at_" + time.strftime("%Y%m%d_%H%M%S"))
|
|
eval_dir = os.path.join(run_dir, "eval")
|
|
os.makedirs(eval_dir, exist_ok=True)
|
|
with open(os.path.join(run_dir, "config.json"), "w") as f:
|
|
json.dump(cfg, f, indent=2)
|
|
|
|
print(f"Config: {cfg}")
|
|
print(f"Run dir: {run_dir}")
|
|
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
|
|
f"{args.steps_per_stage:,} steps/stage")
|
|
print(f"Transformer: embed={args.embed_dim} heads={args.n_heads} "
|
|
f"layers={args.n_layers} ff={args.ff_dim}\n")
|
|
|
|
train_env = SubprocVecEnv([
|
|
make_env_at(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
|
|
for i in range(args.n_envs)
|
|
])
|
|
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
|
|
|
|
model = PPO(
|
|
"MlpPolicy", vn,
|
|
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
|
|
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
|
|
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
|
|
policy_kwargs=dict(
|
|
features_extractor_class=ShepherdAttentionExtractor,
|
|
features_extractor_kwargs=dict(
|
|
embed_dim=args.embed_dim,
|
|
n_heads=args.n_heads,
|
|
n_layers=args.n_layers,
|
|
ff_dim=args.ff_dim,
|
|
),
|
|
net_arch=[256, 256],
|
|
),
|
|
device="cpu",
|
|
verbose=0,
|
|
)
|
|
|
|
stage_results = []
|
|
t0 = time.time()
|
|
|
|
try:
|
|
for n in range(1, args.max_sheep + 1):
|
|
if n == 1:
|
|
print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps")
|
|
model.learn(
|
|
total_timesteps=args.steps_per_stage,
|
|
reset_num_timesteps=True,
|
|
callback=ProgressCallback("1 sheep", freq=100_000),
|
|
)
|
|
else:
|
|
half = max(1, args.n_envs // 2)
|
|
mix_steps = args.steps_per_stage // 2
|
|
full_steps = args.steps_per_stage - mix_steps
|
|
|
|
for i in range(half):
|
|
vn.env_method("set_n_sheep", n - 1, indices=[i])
|
|
for i in range(half, args.n_envs):
|
|
vn.env_method("set_n_sheep", n, indices=[i])
|
|
|
|
print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) "
|
|
f"{mix_steps:,} steps")
|
|
model.learn(
|
|
total_timesteps=mix_steps,
|
|
reset_num_timesteps=False,
|
|
callback=ProgressCallback(f"{n-1}→{n} mix", freq=100_000),
|
|
)
|
|
|
|
vn.env_method("set_n_sheep", n)
|
|
print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps")
|
|
model.learn(
|
|
total_timesteps=full_steps,
|
|
reset_num_timesteps=False,
|
|
callback=ProgressCallback(f"{n} sheep", freq=100_000),
|
|
)
|
|
|
|
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
|
|
r = evaluate_at(model, vn, n, args.eval_episodes,
|
|
args.max_steps, rcfg)
|
|
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
|
|
f"mean_len={r['mean_len']:.0f} "
|
|
f"mean_min_pen={r['mean_min_pen']:.1f}m "
|
|
f"mean_act={r['mean_act']:.2f}")
|
|
if r["failure_modes"]:
|
|
modes = " ".join(
|
|
f"{k}={v}" for k, v in sorted(
|
|
r["failure_modes"].items(), key=lambda x: -x[1]))
|
|
print(f" failure modes: {modes}")
|
|
if "reward_per_step" in r:
|
|
rps = r["reward_per_step"]
|
|
print(" reward/step: " + " ".join(
|
|
f"{k}={v:+.4f}" for k, v in rps.items()))
|
|
|
|
hist = run_and_record(
|
|
model, vn, n, args.max_steps, rcfg,
|
|
seed=1000 + n, make_env_fn=make_env_at,
|
|
)
|
|
tag = "success" if hist["success"] else "fail"
|
|
plot_trajectory(hist, os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
|
|
plot_timeseries(hist, os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
|
|
if not args.no_gif:
|
|
save_episode_gif(
|
|
hist,
|
|
os.path.join(eval_dir, f"ep_{n}s_{tag}.gif"),
|
|
fps=args.gif_fps, skip=args.gif_skip)
|
|
|
|
r["n_sheep"] = n
|
|
stage_results.append(r)
|
|
|
|
model.save(os.path.join(run_dir, "final_model"))
|
|
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
|
|
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
|
|
json.dump(stage_results, f, indent=2)
|
|
|
|
finally:
|
|
try:
|
|
vn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
elapsed = (time.time() - t0) / 60
|
|
print("\n" + "=" * 70)
|
|
print(" TRAINING SUMMARY (attention policy)")
|
|
print("=" * 70)
|
|
for r in stage_results:
|
|
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
|
|
f"len={r['mean_len']:>5.0f} "
|
|
f"min_pen={r['mean_min_pen']:>5.1f}m "
|
|
f"act={r['mean_act']:.2f}")
|
|
print(f"\n Total time: {elapsed:.1f} min")
|
|
print(f" Artefacts: {run_dir}/")
|
|
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
|
|
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|