From a2363d882f0acbe143f64a84ee4cf150f0983b0e Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Sun, 26 Apr 2026 22:28:43 +0100 Subject: [PATCH] Trying attention method --- training/herding_env.py | 4 +- training/train.py | 40 +++- training/train_at.py | 411 ++++++++++++++++++++++++++++++++++++++++ training/viz.py | 5 +- 4 files changed, 448 insertions(+), 12 deletions(-) create mode 100644 training/train_at.py diff --git a/training/herding_env.py b/training/herding_env.py index 4a722e3..3cc9fd2 100644 --- a/training/herding_env.py +++ b/training/herding_env.py @@ -73,8 +73,8 @@ class HerdingEnv(gym.Env): # Peer communication lag — sheep broadcast every 3 Webots steps PEER_BROADCAST_INTERVAL = 3 - # Action smoothing EMA alpha — matches shepherd_dog_rl.py ACTION_SMOOTH - ACTION_SMOOTH = 0.3 + # Action smoothing EMA alpha; 0 = disabled (smoothing applied at Webots inference) + ACTION_SMOOTH = 0.0 # Boid parameters — identical to sheep.py FLEE_DIST = 7.0 diff --git a/training/train.py b/training/train.py index 9de24f2..94a1f44 100644 --- a/training/train.py +++ b/training/train.py @@ -286,15 +286,39 @@ def main(): try: for n in range(1, args.max_sheep + 1): - if n > 1: + if n == 1: + print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps") + model.learn( + total_timesteps=args.steps_per_stage, + reset_num_timesteps=True, + callback=ProgressCallback("1 sheep", freq=100_000), + ) + else: + # Mixed transition: half envs stay at n-1, half advance to n, + # for the first half of the stage budget. This prevents the + # n+1 task's noisy early gradients from destroying the n policy + # (catastrophic forgetting) before it has a chance to adapt. + half = max(1, args.n_envs // 2) + for i in range(half): + vn.env_method("set_n_sheep", n - 1, indices=[i]) + for i in range(half, args.n_envs): + vn.env_method("set_n_sheep", n, indices=[i]) + mix_steps = args.steps_per_stage // 2 + full_steps = args.steps_per_stage - mix_steps + print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) " + f"{mix_steps:,} steps") + model.learn( + total_timesteps=mix_steps, + reset_num_timesteps=False, + callback=ProgressCallback(f"{n-1}→{n} mix", freq=100_000), + ) vn.env_method("set_n_sheep", n) - - print(f"\n[Stage n_sheep={n}] training {args.steps_per_stage:,} steps") - model.learn( - total_timesteps=args.steps_per_stage, - reset_num_timesteps=(n == 1), - callback=ProgressCallback(f"{n} sheep", freq=100_000), - ) + print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps") + model.learn( + total_timesteps=full_steps, + reset_num_timesteps=False, + callback=ProgressCallback(f"{n} sheep", freq=100_000), + ) # Evaluate print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps") diff --git a/training/train_at.py b/training/train_at.py new file mode 100644 index 0000000..cbf37f3 --- /dev/null +++ b/training/train_at.py @@ -0,0 +1,411 @@ +""" +PPO training with attention-based policy (train_at.py). + +Key difference from train.py +----------------------------- +- Observation exposes ALL sheep as individual per-sheep tokens rather than + only the top-3 farthest. The policy therefore has complete flock visibility + at any sheep count — no hidden sheep even at n=10. +- A TransformerFeaturesExtractor processes the sheep tokens with multi-head + self-attention (permutation-invariant), then mean-pools over valid tokens + and concatenates the result with global dog/pen features. +- Curriculum transition uses the same mixed-env approach as train.py: half + the envs stay at n-1 for the first half of each new stage to suppress + catastrophic forgetting. + +Observation layout (7 + MAX_SHEEP*6 = 67 dims, fixed) +------------------------------------------------------- + Global (7): + dog_x / FIELD, dog_y / FIELD, + cos(heading), sin(heading), + (pen_x - dog_x) / D, (pen_y - dog_y) / D, + n_active / n_sheep + + Per sheep i (6): + (sheep_x - dog_x) / D, (sheep_y - dog_y) / D, ← pos rel to dog + (pen_x - sheep_x) / D, (pen_y - sheep_y) / D, ← sheep-to-pen + is_active 1.0 if not penned, else 0.0 + is_valid 1.0 if i < n_sheep, else 0.0 (padding sentinel) + + After VecNormalize, is_valid for real sheep normalises > 0 and for + padding tokens < 0 (because mean ∈ (0,1)), so a threshold of 0 cleanly + separates real from padded without any extra bookkeeping. + +Usage +----- + python train_at.py # defaults from config.json + python train_at.py --max-sheep 10 --steps-per-stage 2000000 + python train_at.py --embed-dim 128 --n-heads 4 --n-layers 3 +""" + +import argparse +import json +import os +import time +from copy import deepcopy + +import numpy as np +import torch +import torch.nn as nn +from gymnasium import spaces +from stable_baselines3 import PPO +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize + +from herding_env import HerdingEnv +from train import ProgressCallback, _classify, COMPACT_RADIUS, DEFAULT_CONFIG +from viz import ( + run_and_record, plot_trajectory, plot_timeseries, + plot_success_rate, save_episode_gif, +) + + +# ── Per-sheep token observation environment ─────────────────────────────────── + +class HerdingEnvAt(HerdingEnv): + """ + HerdingEnv with a per-sheep token observation for the attention policy. + Everything else (dynamics, reward, curriculum interface) is inherited. + """ + + OBS_GLOBAL = 7 + OBS_SHEEP = 6 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + obs_dim = self.OBS_GLOBAL + self.MAX_SHEEP * self.OBS_SHEEP + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32 + ) + + def _obs(self) -> np.ndarray: + S = self.FIELD + D = 2.0 * self.FIELD + pen_ref = self.PEN_ENTRY if self.ENTRY_AWARE else self.PEN_CENTER + active_mask = ~self.penned[:self.n_sheep] + n_active = int(active_mask.sum()) + + global_feats = np.array([ + self.dog_pos[0] / S, + self.dog_pos[1] / S, + float(np.cos(self.dog_heading)), + float(np.sin(self.dog_heading)), + (pen_ref[0] - self.dog_pos[0]) / D, + (pen_ref[1] - self.dog_pos[1]) / D, + n_active / max(self.n_sheep, 1), + ], dtype=np.float32) + + sheep_feats = np.zeros((self.MAX_SHEEP, self.OBS_SHEEP), dtype=np.float32) + for i in range(self.n_sheep): + pos = self.sheep_pos[i] + sheep_feats[i] = [ + (pos[0] - self.dog_pos[0]) / D, + (pos[1] - self.dog_pos[1]) / D, + (pen_ref[0] - pos[0]) / D, + (pen_ref[1] - pos[1]) / D, + float(not self.penned[i]), + 1.0, # is_valid: this sheep exists + ] + # i >= n_sheep: all zeros, is_valid=0 → masked out in attention + + return np.concatenate([global_feats, sheep_feats.ravel()]) + + +# ── Attention features extractor ────────────────────────────────────────────── + +class ShepherdAttentionExtractor(BaseFeaturesExtractor): + """ + Multi-head self-attention over per-sheep tokens, mean-pooled over valid + (non-padding) tokens and concatenated with global dog/pen features. + + After VecNormalize: + real sheep → is_valid_norm > 0 (normalised from 1.0) + padding → is_valid_norm ≤ 0 (normalised from 0.0) + so threshold at 0 is always correct regardless of curriculum stage. + """ + + GLOBAL_DIM = HerdingEnvAt.OBS_GLOBAL # 7 + SHEEP_DIM = HerdingEnvAt.OBS_SHEEP # 6 + MAX_SHEEP = HerdingEnv.MAX_SHEEP # 10 + VALID_IDX = 5 # index of is_valid within each token + + def __init__(self, observation_space, embed_dim: int = 64, + n_heads: int = 4, n_layers: int = 2, ff_dim: int = 128): + super().__init__(observation_space, + features_dim=self.GLOBAL_DIM + embed_dim) + self.sheep_embed = nn.Linear(self.SHEEP_DIM, embed_dim) + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, nhead=n_heads, dim_feedforward=ff_dim, + dropout=0.0, batch_first=True, + ) + self.transformer = nn.TransformerEncoder(encoder_layer, + num_layers=n_layers) + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + B = obs.shape[0] + global_feats = obs[:, :self.GLOBAL_DIM] # (B, 7) + tokens = obs[:, self.GLOBAL_DIM:].view( + B, self.MAX_SHEEP, self.SHEEP_DIM) # (B, 10, 6) + + # is_valid after VecNorm: real > 0, padding ≤ 0 + is_valid_norm = tokens[:, :, self.VALID_IDX] # (B, 10) + key_padding_mask = is_valid_norm <= 0.0 # True → ignore + + x = self.sheep_embed(tokens) # (B, 10, E) + x = self.transformer(x, src_key_padding_mask=key_padding_mask) + + valid_w = (is_valid_norm > 0.0).float().unsqueeze(-1) # (B, 10, 1) + pooled = (x * valid_w).sum(1) / valid_w.sum(1).clamp(min=1.0) + + return torch.cat([global_feats, pooled], dim=1) # (B, 7+E) + + +# ── Environment factory ─────────────────────────────────────────────────────── + +def make_env_at(n_sheep, seed, max_steps, reward_cfg=None): + def _init(): + env = HerdingEnvAt(n_sheep=n_sheep, max_steps=max_steps, + reward_cfg=reward_cfg) + env.reset(seed=seed) + return env + return _init + + +# ── Evaluation ──────────────────────────────────────────────────────────────── + +def evaluate_at(model, vn_template, n_sheep, n_episodes, max_steps, + reward_cfg=None): + raw = DummyVecEnv([make_env_at(n_sheep, 9999, max_steps, reward_cfg)]) + vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False) + vn.obs_rms = deepcopy(vn_template.obs_rms) + vn.ret_rms = deepcopy(vn_template.ret_rms) + + successes = 0 + ep_lens, min_pen_list, action_mags = [], [], [] + failure_counts, rc_sums = {}, {} + rc_n = 0 + + for _ in range(n_episodes): + obs = vn.reset() + done = False + steps, min_pen = 0, float("inf") + mags, ep_radii, ep_com_dists = [], [], [] + while not done: + action, _ = model.predict(obs, deterministic=True) + obs, _, dones, infos = vn.step(action) + done = dones[0] + inner = vn.envs[0] + com, radius, _ = inner._flock_stats() + min_pen = min(min_pen, + float(np.linalg.norm(com - inner.PEN_CENTER))) + mags.append(float(np.linalg.norm(action[0]))) + ep_radii.append(radius) + ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER))) + steps += 1 + rc = infos[0].get("rcomps") + if rc: + for k, v in rc.items(): + rc_sums[k] = rc_sums.get(k, 0.0) + v + rc_n += 1 + n_penned = infos[0].get("n_penned", 0) + successes += int(n_penned == n_sheep) + ep_lens.append(steps) + min_pen_list.append(min_pen) + action_mags.extend(mags) + mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep) + failure_counts[mode] = failure_counts.get(mode, 0) + 1 + + vn.close() + result = { + "sr": successes / n_episodes, + "mean_len": float(np.mean(ep_lens)), + "mean_min_pen": float(np.mean(min_pen_list)), + "mean_act": float(np.mean(action_mags)) if action_mags else 0.0, + "failure_modes": failure_counts, + } + if rc_n > 0: + result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()} + return result + + +# ── CLI ─────────────────────────────────────────────────────────────────────── + +def parse_args(): + p = argparse.ArgumentParser( + description="PPO + attention training for herding task") + p.add_argument("--config", type=str, default=None) + p.add_argument("--max-sheep", type=int, default=10) + p.add_argument("--steps-per-stage", type=int, default=1_500_000) + p.add_argument("--n-envs", type=int, default=8) + p.add_argument("--max-steps", type=int, default=2500) + p.add_argument("--eval-episodes", type=int, default=30) + p.add_argument("--run-dir", type=str, default=None) + p.add_argument("--no-gif", action="store_true") + p.add_argument("--gif-fps", type=int, default=20) + p.add_argument("--gif-skip", type=int, default=3) + # Attention architecture + p.add_argument("--embed-dim", type=int, default=64, + help="Transformer embedding dimension (default 64)") + p.add_argument("--n-heads", type=int, default=4, + help="Number of attention heads (default 4)") + p.add_argument("--n-layers", type=int, default=2, + help="Number of transformer encoder layers (default 2)") + p.add_argument("--ff-dim", type=int, default=128, + help="Transformer feed-forward dim (default 128)") + return p.parse_args() + + +# ── Main ────────────────────────────────────────────────────────────────────── + +def main(): + args = parse_args() + + cfg = dict(DEFAULT_CONFIG) + config_path = args.config + if config_path is None and os.path.exists("config.json"): + config_path = "config.json" + if config_path: + with open(config_path) as f: + cfg.update(json.load(f)) + print(f"Config loaded from {config_path}") + + rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)} + + run_dir = args.run_dir or os.path.join( + "runs", "at_" + time.strftime("%Y%m%d_%H%M%S")) + eval_dir = os.path.join(run_dir, "eval") + os.makedirs(eval_dir, exist_ok=True) + with open(os.path.join(run_dir, "config.json"), "w") as f: + json.dump(cfg, f, indent=2) + + print(f"Config: {cfg}") + print(f"Run dir: {run_dir}") + print(f"Curriculum: 1 → {args.max_sheep} sheep, " + f"{args.steps_per_stage:,} steps/stage") + print(f"Transformer: embed={args.embed_dim} heads={args.n_heads} " + f"layers={args.n_layers} ff={args.ff_dim}\n") + + train_env = SubprocVecEnv([ + make_env_at(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg) + for i in range(args.n_envs) + ]) + vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0) + + model = PPO( + "MlpPolicy", vn, + learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10, + gamma=0.995, gae_lambda=0.95, clip_range=0.2, + ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5, + policy_kwargs=dict( + features_extractor_class=ShepherdAttentionExtractor, + features_extractor_kwargs=dict( + embed_dim=args.embed_dim, + n_heads=args.n_heads, + n_layers=args.n_layers, + ff_dim=args.ff_dim, + ), + net_arch=[256, 256], + ), + device="cpu", + verbose=0, + ) + + stage_results = [] + t0 = time.time() + + try: + for n in range(1, args.max_sheep + 1): + if n == 1: + print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps") + model.learn( + total_timesteps=args.steps_per_stage, + reset_num_timesteps=True, + callback=ProgressCallback("1 sheep", freq=100_000), + ) + else: + half = max(1, args.n_envs // 2) + mix_steps = args.steps_per_stage // 2 + full_steps = args.steps_per_stage - mix_steps + + for i in range(half): + vn.env_method("set_n_sheep", n - 1, indices=[i]) + for i in range(half, args.n_envs): + vn.env_method("set_n_sheep", n, indices=[i]) + + print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) " + f"{mix_steps:,} steps") + model.learn( + total_timesteps=mix_steps, + reset_num_timesteps=False, + callback=ProgressCallback(f"{n-1}→{n} mix", freq=100_000), + ) + + vn.env_method("set_n_sheep", n) + print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps") + model.learn( + total_timesteps=full_steps, + reset_num_timesteps=False, + callback=ProgressCallback(f"{n} sheep", freq=100_000), + ) + + print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps") + r = evaluate_at(model, vn, n, args.eval_episodes, + args.max_steps, rcfg) + print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% " + f"mean_len={r['mean_len']:.0f} " + f"mean_min_pen={r['mean_min_pen']:.1f}m " + f"mean_act={r['mean_act']:.2f}") + if r["failure_modes"]: + modes = " ".join( + f"{k}={v}" for k, v in sorted( + r["failure_modes"].items(), key=lambda x: -x[1])) + print(f" failure modes: {modes}") + if "reward_per_step" in r: + rps = r["reward_per_step"] + print(" reward/step: " + " ".join( + f"{k}={v:+.4f}" for k, v in rps.items())) + + hist = run_and_record( + model, vn, n, args.max_steps, rcfg, + seed=1000 + n, make_env_fn=make_env_at, + ) + tag = "success" if hist["success"] else "fail" + plot_trajectory(hist, os.path.join(eval_dir, f"traj_{n}s_{tag}.png")) + plot_timeseries(hist, os.path.join(eval_dir, f"ts_{n}s_{tag}.png")) + if not args.no_gif: + save_episode_gif( + hist, + os.path.join(eval_dir, f"ep_{n}s_{tag}.gif"), + fps=args.gif_fps, skip=args.gif_skip) + + r["n_sheep"] = n + stage_results.append(r) + + model.save(os.path.join(run_dir, "final_model")) + vn.save(os.path.join(run_dir, "vecnorm.pkl")) + with open(os.path.join(run_dir, "stage_results.json"), "w") as f: + json.dump(stage_results, f, indent=2) + + finally: + try: + vn.close() + except Exception: + pass + + elapsed = (time.time() - t0) / 60 + print("\n" + "=" * 70) + print(" TRAINING SUMMARY (attention policy)") + print("=" * 70) + for r in stage_results: + print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% " + f"len={r['mean_len']:>5.0f} " + f"min_pen={r['mean_min_pen']:>5.1f}m " + f"act={r['mean_act']:.2f}") + print(f"\n Total time: {elapsed:.1f} min") + print(f" Artefacts: {run_dir}/") + plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png")) + print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/") + + +if __name__ == "__main__": + main() diff --git a/training/viz.py b/training/viz.py index c2882ad..1b3ada2 100644 --- a/training/viz.py +++ b/training/viz.py @@ -78,9 +78,10 @@ def make_eval_env(n_sheep, seed, max_steps, reward_cfg=None): def run_and_record(model, vn_template, n_sheep, max_steps, - reward_cfg=None, seed=42): + reward_cfg=None, seed=42, make_env_fn=None): """Run one deterministic episode and return full trajectory history.""" - raw = DummyVecEnv([make_eval_env(n_sheep, seed, max_steps, reward_cfg)]) + _factory = make_env_fn or make_eval_env + raw = DummyVecEnv([_factory(n_sheep, seed, max_steps, reward_cfg)]) vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False) vn.obs_rms = deepcopy(vn_template.obs_rms) vn.ret_rms = deepcopy(vn_template.ret_rms)