""" PPO training with attention-based policy (train_at.py). Key difference from train.py ----------------------------- - Observation exposes ALL sheep as individual per-sheep tokens rather than only the top-3 farthest. The policy therefore has complete flock visibility at any sheep count — no hidden sheep even at n=10. - A TransformerFeaturesExtractor processes the sheep tokens with multi-head self-attention (permutation-invariant), then mean-pools over valid tokens and concatenates the result with global dog/pen features. - Curriculum transition uses the same mixed-env approach as train.py: half the envs stay at n-1 for the first half of each new stage to suppress catastrophic forgetting. Observation layout (7 + MAX_SHEEP*6 = 67 dims, fixed) ------------------------------------------------------- Global (7): dog_x / FIELD, dog_y / FIELD, cos(heading), sin(heading), (pen_x - dog_x) / D, (pen_y - dog_y) / D, n_active / n_sheep Per sheep i (6): (sheep_x - dog_x) / D, (sheep_y - dog_y) / D, ← pos rel to dog (pen_x - sheep_x) / D, (pen_y - sheep_y) / D, ← sheep-to-pen is_active 1.0 if not penned, else 0.0 is_valid 1.0 if i < n_sheep, else 0.0 (padding sentinel) After VecNormalize, is_valid for real sheep normalises > 0 and for padding tokens < 0 (because mean ∈ (0,1)), so a threshold of 0 cleanly separates real from padded without any extra bookkeeping. Usage ----- python train_at.py # defaults from config.json python train_at.py --max-sheep 10 --steps-per-stage 2000000 python train_at.py --embed-dim 128 --n-heads 4 --n-layers 3 """ import argparse import json import os import time from copy import deepcopy import numpy as np import torch import torch.nn as nn from gymnasium import spaces from stable_baselines3 import PPO from stable_baselines3.common.torch_layers import BaseFeaturesExtractor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize from herding_env import HerdingEnv from train import ProgressCallback, _classify, COMPACT_RADIUS, DEFAULT_CONFIG from viz import ( run_and_record, plot_trajectory, plot_timeseries, plot_success_rate, save_episode_gif, ) # ── Per-sheep token observation environment ─────────────────────────────────── class HerdingEnvAt(HerdingEnv): """ HerdingEnv with a per-sheep token observation for the attention policy. Everything else (dynamics, reward, curriculum interface) is inherited. """ OBS_GLOBAL = 7 OBS_SHEEP = 6 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) obs_dim = self.OBS_GLOBAL + self.MAX_SHEEP * self.OBS_SHEEP self.observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32 ) def _obs(self) -> np.ndarray: S = self.FIELD D = 2.0 * self.FIELD pen_ref = self.PEN_ENTRY if self.ENTRY_AWARE else self.PEN_CENTER active_mask = ~self.penned[:self.n_sheep] n_active = int(active_mask.sum()) global_feats = np.array([ self.dog_pos[0] / S, self.dog_pos[1] / S, float(np.cos(self.dog_heading)), float(np.sin(self.dog_heading)), (pen_ref[0] - self.dog_pos[0]) / D, (pen_ref[1] - self.dog_pos[1]) / D, n_active / max(self.n_sheep, 1), ], dtype=np.float32) sheep_feats = np.zeros((self.MAX_SHEEP, self.OBS_SHEEP), dtype=np.float32) for i in range(self.n_sheep): pos = self.sheep_pos[i] sheep_feats[i] = [ (pos[0] - self.dog_pos[0]) / D, (pos[1] - self.dog_pos[1]) / D, (pen_ref[0] - pos[0]) / D, (pen_ref[1] - pos[1]) / D, float(not self.penned[i]), 1.0, # is_valid: this sheep exists ] # i >= n_sheep: all zeros, is_valid=0 → masked out in attention return np.concatenate([global_feats, sheep_feats.ravel()]) # ── Attention features extractor ────────────────────────────────────────────── class ShepherdAttentionExtractor(BaseFeaturesExtractor): """ Multi-head self-attention over per-sheep tokens, mean-pooled over valid (non-padding) tokens and concatenated with global dog/pen features. After VecNormalize: real sheep → is_valid_norm > 0 (normalised from 1.0) padding → is_valid_norm ≤ 0 (normalised from 0.0) so threshold at 0 is always correct regardless of curriculum stage. """ GLOBAL_DIM = HerdingEnvAt.OBS_GLOBAL # 7 SHEEP_DIM = HerdingEnvAt.OBS_SHEEP # 6 MAX_SHEEP = HerdingEnv.MAX_SHEEP # 10 VALID_IDX = 5 # index of is_valid within each token def __init__(self, observation_space, embed_dim: int = 64, n_heads: int = 4, n_layers: int = 2, ff_dim: int = 128): super().__init__(observation_space, features_dim=self.GLOBAL_DIM + embed_dim) self.sheep_embed = nn.Linear(self.SHEEP_DIM, embed_dim) encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=n_heads, dim_feedforward=ff_dim, dropout=0.0, batch_first=True, ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) def forward(self, obs: torch.Tensor) -> torch.Tensor: B = obs.shape[0] global_feats = obs[:, :self.GLOBAL_DIM] # (B, 7) tokens = obs[:, self.GLOBAL_DIM:].view( B, self.MAX_SHEEP, self.SHEEP_DIM) # (B, 10, 6) # is_valid after VecNorm: real > 0, padding ≤ 0 is_valid_norm = tokens[:, :, self.VALID_IDX] # (B, 10) key_padding_mask = is_valid_norm <= 0.0 # True → ignore x = self.sheep_embed(tokens) # (B, 10, E) x = self.transformer(x, src_key_padding_mask=key_padding_mask) valid_w = (is_valid_norm > 0.0).float().unsqueeze(-1) # (B, 10, 1) pooled = (x * valid_w).sum(1) / valid_w.sum(1).clamp(min=1.0) return torch.cat([global_feats, pooled], dim=1) # (B, 7+E) # ── Environment factory ─────────────────────────────────────────────────────── def make_env_at(n_sheep, seed, max_steps, reward_cfg=None): def _init(): env = HerdingEnvAt(n_sheep=n_sheep, max_steps=max_steps, reward_cfg=reward_cfg) env.reset(seed=seed) return env return _init # ── Evaluation ──────────────────────────────────────────────────────────────── def evaluate_at(model, vn_template, n_sheep, n_episodes, max_steps, reward_cfg=None): raw = DummyVecEnv([make_env_at(n_sheep, 9999, max_steps, reward_cfg)]) vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False) vn.obs_rms = deepcopy(vn_template.obs_rms) vn.ret_rms = deepcopy(vn_template.ret_rms) successes = 0 ep_lens, min_pen_list, action_mags = [], [], [] failure_counts, rc_sums = {}, {} rc_n = 0 for _ in range(n_episodes): obs = vn.reset() done = False steps, min_pen = 0, float("inf") mags, ep_radii, ep_com_dists = [], [], [] while not done: action, _ = model.predict(obs, deterministic=True) obs, _, dones, infos = vn.step(action) done = dones[0] inner = vn.envs[0] com, radius, _ = inner._flock_stats() min_pen = min(min_pen, float(np.linalg.norm(com - inner.PEN_CENTER))) mags.append(float(np.linalg.norm(action[0]))) ep_radii.append(radius) ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER))) steps += 1 rc = infos[0].get("rcomps") if rc: for k, v in rc.items(): rc_sums[k] = rc_sums.get(k, 0.0) + v rc_n += 1 n_penned = infos[0].get("n_penned", 0) successes += int(n_penned == n_sheep) ep_lens.append(steps) min_pen_list.append(min_pen) action_mags.extend(mags) mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep) failure_counts[mode] = failure_counts.get(mode, 0) + 1 vn.close() result = { "sr": successes / n_episodes, "mean_len": float(np.mean(ep_lens)), "mean_min_pen": float(np.mean(min_pen_list)), "mean_act": float(np.mean(action_mags)) if action_mags else 0.0, "failure_modes": failure_counts, } if rc_n > 0: result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()} return result # ── CLI ─────────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser( description="PPO + attention training for herding task") p.add_argument("--config", type=str, default=None) p.add_argument("--max-sheep", type=int, default=10) p.add_argument("--steps-per-stage", type=int, default=1_500_000) p.add_argument("--n-envs", type=int, default=8) p.add_argument("--max-steps", type=int, default=2500) p.add_argument("--eval-episodes", type=int, default=30) p.add_argument("--run-dir", type=str, default=None) p.add_argument("--no-gif", action="store_true") p.add_argument("--gif-fps", type=int, default=20) p.add_argument("--gif-skip", type=int, default=3) # Attention architecture p.add_argument("--embed-dim", type=int, default=64, help="Transformer embedding dimension (default 64)") p.add_argument("--n-heads", type=int, default=4, help="Number of attention heads (default 4)") p.add_argument("--n-layers", type=int, default=2, help="Number of transformer encoder layers (default 2)") p.add_argument("--ff-dim", type=int, default=128, help="Transformer feed-forward dim (default 128)") return p.parse_args() # ── Main ────────────────────────────────────────────────────────────────────── def main(): args = parse_args() cfg = dict(DEFAULT_CONFIG) config_path = args.config if config_path is None and os.path.exists("config.json"): config_path = "config.json" if config_path: with open(config_path) as f: cfg.update(json.load(f)) print(f"Config loaded from {config_path}") rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)} run_dir = args.run_dir or os.path.join( "runs", "at_" + time.strftime("%Y%m%d_%H%M%S")) eval_dir = os.path.join(run_dir, "eval") os.makedirs(eval_dir, exist_ok=True) with open(os.path.join(run_dir, "config.json"), "w") as f: json.dump(cfg, f, indent=2) print(f"Config: {cfg}") print(f"Run dir: {run_dir}") print(f"Curriculum: 1 → {args.max_sheep} sheep, " f"{args.steps_per_stage:,} steps/stage") print(f"Transformer: embed={args.embed_dim} heads={args.n_heads} " f"layers={args.n_layers} ff={args.ff_dim}\n") train_env = SubprocVecEnv([ make_env_at(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg) for i in range(args.n_envs) ]) vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0) model = PPO( "MlpPolicy", vn, learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10, gamma=0.995, gae_lambda=0.95, clip_range=0.2, ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5, policy_kwargs=dict( features_extractor_class=ShepherdAttentionExtractor, features_extractor_kwargs=dict( embed_dim=args.embed_dim, n_heads=args.n_heads, n_layers=args.n_layers, ff_dim=args.ff_dim, ), net_arch=[256, 256], ), device="cpu", verbose=0, ) stage_results = [] t0 = time.time() try: for n in range(1, args.max_sheep + 1): if n == 1: print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps") model.learn( total_timesteps=args.steps_per_stage, reset_num_timesteps=True, callback=ProgressCallback("1 sheep", freq=100_000), ) else: half = max(1, args.n_envs // 2) mix_steps = args.steps_per_stage // 2 full_steps = args.steps_per_stage - mix_steps for i in range(half): vn.env_method("set_n_sheep", n - 1, indices=[i]) for i in range(half, args.n_envs): vn.env_method("set_n_sheep", n, indices=[i]) print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) " f"{mix_steps:,} steps") model.learn( total_timesteps=mix_steps, reset_num_timesteps=False, callback=ProgressCallback(f"{n-1}→{n} mix", freq=100_000), ) vn.env_method("set_n_sheep", n) print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps") model.learn( total_timesteps=full_steps, reset_num_timesteps=False, callback=ProgressCallback(f"{n} sheep", freq=100_000), ) print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps") r = evaluate_at(model, vn, n, args.eval_episodes, args.max_steps, rcfg) print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% " f"mean_len={r['mean_len']:.0f} " f"mean_min_pen={r['mean_min_pen']:.1f}m " f"mean_act={r['mean_act']:.2f}") if r["failure_modes"]: modes = " ".join( f"{k}={v}" for k, v in sorted( r["failure_modes"].items(), key=lambda x: -x[1])) print(f" failure modes: {modes}") if "reward_per_step" in r: rps = r["reward_per_step"] print(" reward/step: " + " ".join( f"{k}={v:+.4f}" for k, v in rps.items())) hist = run_and_record( model, vn, n, args.max_steps, rcfg, seed=1000 + n, make_env_fn=make_env_at, ) tag = "success" if hist["success"] else "fail" plot_trajectory(hist, os.path.join(eval_dir, f"traj_{n}s_{tag}.png")) plot_timeseries(hist, os.path.join(eval_dir, f"ts_{n}s_{tag}.png")) if not args.no_gif: save_episode_gif( hist, os.path.join(eval_dir, f"ep_{n}s_{tag}.gif"), fps=args.gif_fps, skip=args.gif_skip) r["n_sheep"] = n stage_results.append(r) model.save(os.path.join(run_dir, "final_model")) vn.save(os.path.join(run_dir, "vecnorm.pkl")) with open(os.path.join(run_dir, "stage_results.json"), "w") as f: json.dump(stage_results, f, indent=2) finally: try: vn.close() except Exception: pass elapsed = (time.time() - t0) / 60 print("\n" + "=" * 70) print(" TRAINING SUMMARY (attention policy)") print("=" * 70) for r in stage_results: print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% " f"len={r['mean_len']:>5.0f} " f"min_pen={r['mean_min_pen']:>5.1f}m " f"act={r['mean_act']:.2f}") print(f"\n Total time: {elapsed:.1f} min") print(f" Artefacts: {run_dir}/") plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png")) print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/") if __name__ == "__main__": main()