Trying attention method

This commit is contained in:
Johnny Fernandes
2026-04-26 22:28:43 +01:00
parent 57b1735e1a
commit a2363d882f
4 changed files with 448 additions and 12 deletions
+2 -2
View File
@@ -73,8 +73,8 @@ class HerdingEnv(gym.Env):
# Peer communication lag — sheep broadcast every 3 Webots steps
PEER_BROADCAST_INTERVAL = 3
# Action smoothing EMA alpha — matches shepherd_dog_rl.py ACTION_SMOOTH
ACTION_SMOOTH = 0.3
# Action smoothing EMA alpha; 0 = disabled (smoothing applied at Webots inference)
ACTION_SMOOTH = 0.0
# Boid parameters — identical to sheep.py
FLEE_DIST = 7.0
+32 -8
View File
@@ -286,15 +286,39 @@ def main():
try:
for n in range(1, args.max_sheep + 1):
if n > 1:
if n == 1:
print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps")
model.learn(
total_timesteps=args.steps_per_stage,
reset_num_timesteps=True,
callback=ProgressCallback("1 sheep", freq=100_000),
)
else:
# Mixed transition: half envs stay at n-1, half advance to n,
# for the first half of the stage budget. This prevents the
# n+1 task's noisy early gradients from destroying the n policy
# (catastrophic forgetting) before it has a chance to adapt.
half = max(1, args.n_envs // 2)
for i in range(half):
vn.env_method("set_n_sheep", n - 1, indices=[i])
for i in range(half, args.n_envs):
vn.env_method("set_n_sheep", n, indices=[i])
mix_steps = args.steps_per_stage // 2
full_steps = args.steps_per_stage - mix_steps
print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) "
f"{mix_steps:,} steps")
model.learn(
total_timesteps=mix_steps,
reset_num_timesteps=False,
callback=ProgressCallback(f"{n-1}{n} mix", freq=100_000),
)
vn.env_method("set_n_sheep", n)
print(f"\n[Stage n_sheep={n}] training {args.steps_per_stage:,} steps")
model.learn(
total_timesteps=args.steps_per_stage,
reset_num_timesteps=(n == 1),
callback=ProgressCallback(f"{n} sheep", freq=100_000),
)
print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps")
model.learn(
total_timesteps=full_steps,
reset_num_timesteps=False,
callback=ProgressCallback(f"{n} sheep", freq=100_000),
)
# Evaluate
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
+411
View File
@@ -0,0 +1,411 @@
"""
PPO training with attention-based policy (train_at.py).
Key difference from train.py
-----------------------------
- Observation exposes ALL sheep as individual per-sheep tokens rather than
only the top-3 farthest. The policy therefore has complete flock visibility
at any sheep count — no hidden sheep even at n=10.
- A TransformerFeaturesExtractor processes the sheep tokens with multi-head
self-attention (permutation-invariant), then mean-pools over valid tokens
and concatenates the result with global dog/pen features.
- Curriculum transition uses the same mixed-env approach as train.py: half
the envs stay at n-1 for the first half of each new stage to suppress
catastrophic forgetting.
Observation layout (7 + MAX_SHEEP*6 = 67 dims, fixed)
-------------------------------------------------------
Global (7):
dog_x / FIELD, dog_y / FIELD,
cos(heading), sin(heading),
(pen_x - dog_x) / D, (pen_y - dog_y) / D,
n_active / n_sheep
Per sheep i (6):
(sheep_x - dog_x) / D, (sheep_y - dog_y) / D, ← pos rel to dog
(pen_x - sheep_x) / D, (pen_y - sheep_y) / D, ← sheep-to-pen
is_active 1.0 if not penned, else 0.0
is_valid 1.0 if i < n_sheep, else 0.0 (padding sentinel)
After VecNormalize, is_valid for real sheep normalises > 0 and for
padding tokens < 0 (because mean ∈ (0,1)), so a threshold of 0 cleanly
separates real from padded without any extra bookkeeping.
Usage
-----
python train_at.py # defaults from config.json
python train_at.py --max-sheep 10 --steps-per-stage 2000000
python train_at.py --embed-dim 128 --n-heads 4 --n-layers 3
"""
import argparse
import json
import os
import time
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize
from herding_env import HerdingEnv
from train import ProgressCallback, _classify, COMPACT_RADIUS, DEFAULT_CONFIG
from viz import (
run_and_record, plot_trajectory, plot_timeseries,
plot_success_rate, save_episode_gif,
)
# ── Per-sheep token observation environment ───────────────────────────────────
class HerdingEnvAt(HerdingEnv):
"""
HerdingEnv with a per-sheep token observation for the attention policy.
Everything else (dynamics, reward, curriculum interface) is inherited.
"""
OBS_GLOBAL = 7
OBS_SHEEP = 6
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
obs_dim = self.OBS_GLOBAL + self.MAX_SHEEP * self.OBS_SHEEP
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32
)
def _obs(self) -> np.ndarray:
S = self.FIELD
D = 2.0 * self.FIELD
pen_ref = self.PEN_ENTRY if self.ENTRY_AWARE else self.PEN_CENTER
active_mask = ~self.penned[:self.n_sheep]
n_active = int(active_mask.sum())
global_feats = np.array([
self.dog_pos[0] / S,
self.dog_pos[1] / S,
float(np.cos(self.dog_heading)),
float(np.sin(self.dog_heading)),
(pen_ref[0] - self.dog_pos[0]) / D,
(pen_ref[1] - self.dog_pos[1]) / D,
n_active / max(self.n_sheep, 1),
], dtype=np.float32)
sheep_feats = np.zeros((self.MAX_SHEEP, self.OBS_SHEEP), dtype=np.float32)
for i in range(self.n_sheep):
pos = self.sheep_pos[i]
sheep_feats[i] = [
(pos[0] - self.dog_pos[0]) / D,
(pos[1] - self.dog_pos[1]) / D,
(pen_ref[0] - pos[0]) / D,
(pen_ref[1] - pos[1]) / D,
float(not self.penned[i]),
1.0, # is_valid: this sheep exists
]
# i >= n_sheep: all zeros, is_valid=0 → masked out in attention
return np.concatenate([global_feats, sheep_feats.ravel()])
# ── Attention features extractor ──────────────────────────────────────────────
class ShepherdAttentionExtractor(BaseFeaturesExtractor):
"""
Multi-head self-attention over per-sheep tokens, mean-pooled over valid
(non-padding) tokens and concatenated with global dog/pen features.
After VecNormalize:
real sheep → is_valid_norm > 0 (normalised from 1.0)
padding → is_valid_norm ≤ 0 (normalised from 0.0)
so threshold at 0 is always correct regardless of curriculum stage.
"""
GLOBAL_DIM = HerdingEnvAt.OBS_GLOBAL # 7
SHEEP_DIM = HerdingEnvAt.OBS_SHEEP # 6
MAX_SHEEP = HerdingEnv.MAX_SHEEP # 10
VALID_IDX = 5 # index of is_valid within each token
def __init__(self, observation_space, embed_dim: int = 64,
n_heads: int = 4, n_layers: int = 2, ff_dim: int = 128):
super().__init__(observation_space,
features_dim=self.GLOBAL_DIM + embed_dim)
self.sheep_embed = nn.Linear(self.SHEEP_DIM, embed_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=n_heads, dim_feedforward=ff_dim,
dropout=0.0, batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer,
num_layers=n_layers)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
B = obs.shape[0]
global_feats = obs[:, :self.GLOBAL_DIM] # (B, 7)
tokens = obs[:, self.GLOBAL_DIM:].view(
B, self.MAX_SHEEP, self.SHEEP_DIM) # (B, 10, 6)
# is_valid after VecNorm: real > 0, padding ≤ 0
is_valid_norm = tokens[:, :, self.VALID_IDX] # (B, 10)
key_padding_mask = is_valid_norm <= 0.0 # True → ignore
x = self.sheep_embed(tokens) # (B, 10, E)
x = self.transformer(x, src_key_padding_mask=key_padding_mask)
valid_w = (is_valid_norm > 0.0).float().unsqueeze(-1) # (B, 10, 1)
pooled = (x * valid_w).sum(1) / valid_w.sum(1).clamp(min=1.0)
return torch.cat([global_feats, pooled], dim=1) # (B, 7+E)
# ── Environment factory ───────────────────────────────────────────────────────
def make_env_at(n_sheep, seed, max_steps, reward_cfg=None):
def _init():
env = HerdingEnvAt(n_sheep=n_sheep, max_steps=max_steps,
reward_cfg=reward_cfg)
env.reset(seed=seed)
return env
return _init
# ── Evaluation ────────────────────────────────────────────────────────────────
def evaluate_at(model, vn_template, n_sheep, n_episodes, max_steps,
reward_cfg=None):
raw = DummyVecEnv([make_env_at(n_sheep, 9999, max_steps, reward_cfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)
successes = 0
ep_lens, min_pen_list, action_mags = [], [], []
failure_counts, rc_sums = {}, {}
rc_n = 0
for _ in range(n_episodes):
obs = vn.reset()
done = False
steps, min_pen = 0, float("inf")
mags, ep_radii, ep_com_dists = [], [], []
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, infos = vn.step(action)
done = dones[0]
inner = vn.envs[0]
com, radius, _ = inner._flock_stats()
min_pen = min(min_pen,
float(np.linalg.norm(com - inner.PEN_CENTER)))
mags.append(float(np.linalg.norm(action[0])))
ep_radii.append(radius)
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
steps += 1
rc = infos[0].get("rcomps")
if rc:
for k, v in rc.items():
rc_sums[k] = rc_sums.get(k, 0.0) + v
rc_n += 1
n_penned = infos[0].get("n_penned", 0)
successes += int(n_penned == n_sheep)
ep_lens.append(steps)
min_pen_list.append(min_pen)
action_mags.extend(mags)
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
failure_counts[mode] = failure_counts.get(mode, 0) + 1
vn.close()
result = {
"sr": successes / n_episodes,
"mean_len": float(np.mean(ep_lens)),
"mean_min_pen": float(np.mean(min_pen_list)),
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
"failure_modes": failure_counts,
}
if rc_n > 0:
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
return result
# ── CLI ───────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser(
description="PPO + attention training for herding task")
p.add_argument("--config", type=str, default=None)
p.add_argument("--max-sheep", type=int, default=10)
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
p.add_argument("--n-envs", type=int, default=8)
p.add_argument("--max-steps", type=int, default=2500)
p.add_argument("--eval-episodes", type=int, default=30)
p.add_argument("--run-dir", type=str, default=None)
p.add_argument("--no-gif", action="store_true")
p.add_argument("--gif-fps", type=int, default=20)
p.add_argument("--gif-skip", type=int, default=3)
# Attention architecture
p.add_argument("--embed-dim", type=int, default=64,
help="Transformer embedding dimension (default 64)")
p.add_argument("--n-heads", type=int, default=4,
help="Number of attention heads (default 4)")
p.add_argument("--n-layers", type=int, default=2,
help="Number of transformer encoder layers (default 2)")
p.add_argument("--ff-dim", type=int, default=128,
help="Transformer feed-forward dim (default 128)")
return p.parse_args()
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
cfg = dict(DEFAULT_CONFIG)
config_path = args.config
if config_path is None and os.path.exists("config.json"):
config_path = "config.json"
if config_path:
with open(config_path) as f:
cfg.update(json.load(f))
print(f"Config loaded from {config_path}")
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
run_dir = args.run_dir or os.path.join(
"runs", "at_" + time.strftime("%Y%m%d_%H%M%S"))
eval_dir = os.path.join(run_dir, "eval")
os.makedirs(eval_dir, exist_ok=True)
with open(os.path.join(run_dir, "config.json"), "w") as f:
json.dump(cfg, f, indent=2)
print(f"Config: {cfg}")
print(f"Run dir: {run_dir}")
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
f"{args.steps_per_stage:,} steps/stage")
print(f"Transformer: embed={args.embed_dim} heads={args.n_heads} "
f"layers={args.n_layers} ff={args.ff_dim}\n")
train_env = SubprocVecEnv([
make_env_at(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
for i in range(args.n_envs)
])
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
model = PPO(
"MlpPolicy", vn,
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
policy_kwargs=dict(
features_extractor_class=ShepherdAttentionExtractor,
features_extractor_kwargs=dict(
embed_dim=args.embed_dim,
n_heads=args.n_heads,
n_layers=args.n_layers,
ff_dim=args.ff_dim,
),
net_arch=[256, 256],
),
device="cpu",
verbose=0,
)
stage_results = []
t0 = time.time()
try:
for n in range(1, args.max_sheep + 1):
if n == 1:
print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps")
model.learn(
total_timesteps=args.steps_per_stage,
reset_num_timesteps=True,
callback=ProgressCallback("1 sheep", freq=100_000),
)
else:
half = max(1, args.n_envs // 2)
mix_steps = args.steps_per_stage // 2
full_steps = args.steps_per_stage - mix_steps
for i in range(half):
vn.env_method("set_n_sheep", n - 1, indices=[i])
for i in range(half, args.n_envs):
vn.env_method("set_n_sheep", n, indices=[i])
print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) "
f"{mix_steps:,} steps")
model.learn(
total_timesteps=mix_steps,
reset_num_timesteps=False,
callback=ProgressCallback(f"{n-1}{n} mix", freq=100_000),
)
vn.env_method("set_n_sheep", n)
print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps")
model.learn(
total_timesteps=full_steps,
reset_num_timesteps=False,
callback=ProgressCallback(f"{n} sheep", freq=100_000),
)
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
r = evaluate_at(model, vn, n, args.eval_episodes,
args.max_steps, rcfg)
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
f"mean_len={r['mean_len']:.0f} "
f"mean_min_pen={r['mean_min_pen']:.1f}m "
f"mean_act={r['mean_act']:.2f}")
if r["failure_modes"]:
modes = " ".join(
f"{k}={v}" for k, v in sorted(
r["failure_modes"].items(), key=lambda x: -x[1]))
print(f" failure modes: {modes}")
if "reward_per_step" in r:
rps = r["reward_per_step"]
print(" reward/step: " + " ".join(
f"{k}={v:+.4f}" for k, v in rps.items()))
hist = run_and_record(
model, vn, n, args.max_steps, rcfg,
seed=1000 + n, make_env_fn=make_env_at,
)
tag = "success" if hist["success"] else "fail"
plot_trajectory(hist, os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
plot_timeseries(hist, os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
if not args.no_gif:
save_episode_gif(
hist,
os.path.join(eval_dir, f"ep_{n}s_{tag}.gif"),
fps=args.gif_fps, skip=args.gif_skip)
r["n_sheep"] = n
stage_results.append(r)
model.save(os.path.join(run_dir, "final_model"))
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
json.dump(stage_results, f, indent=2)
finally:
try:
vn.close()
except Exception:
pass
elapsed = (time.time() - t0) / 60
print("\n" + "=" * 70)
print(" TRAINING SUMMARY (attention policy)")
print("=" * 70)
for r in stage_results:
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
f"len={r['mean_len']:>5.0f} "
f"min_pen={r['mean_min_pen']:>5.1f}m "
f"act={r['mean_act']:.2f}")
print(f"\n Total time: {elapsed:.1f} min")
print(f" Artefacts: {run_dir}/")
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
if __name__ == "__main__":
main()
+3 -2
View File
@@ -78,9 +78,10 @@ def make_eval_env(n_sheep, seed, max_steps, reward_cfg=None):
def run_and_record(model, vn_template, n_sheep, max_steps,
reward_cfg=None, seed=42):
reward_cfg=None, seed=42, make_env_fn=None):
"""Run one deterministic episode and return full trajectory history."""
raw = DummyVecEnv([make_eval_env(n_sheep, seed, max_steps, reward_cfg)])
_factory = make_env_fn or make_eval_env
raw = DummyVecEnv([_factory(n_sheep, seed, max_steps, reward_cfg)])
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
vn.obs_rms = deepcopy(vn_template.obs_rms)
vn.ret_rms = deepcopy(vn_template.ret_rms)