Trying attention method
This commit is contained in:
@@ -73,8 +73,8 @@ class HerdingEnv(gym.Env):
|
||||
# Peer communication lag — sheep broadcast every 3 Webots steps
|
||||
PEER_BROADCAST_INTERVAL = 3
|
||||
|
||||
# Action smoothing EMA alpha — matches shepherd_dog_rl.py ACTION_SMOOTH
|
||||
ACTION_SMOOTH = 0.3
|
||||
# Action smoothing EMA alpha; 0 = disabled (smoothing applied at Webots inference)
|
||||
ACTION_SMOOTH = 0.0
|
||||
|
||||
# Boid parameters — identical to sheep.py
|
||||
FLEE_DIST = 7.0
|
||||
|
||||
+29
-5
@@ -286,13 +286,37 @@ def main():
|
||||
|
||||
try:
|
||||
for n in range(1, args.max_sheep + 1):
|
||||
if n > 1:
|
||||
vn.env_method("set_n_sheep", n)
|
||||
|
||||
print(f"\n[Stage n_sheep={n}] training {args.steps_per_stage:,} steps")
|
||||
if n == 1:
|
||||
print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=args.steps_per_stage,
|
||||
reset_num_timesteps=(n == 1),
|
||||
reset_num_timesteps=True,
|
||||
callback=ProgressCallback("1 sheep", freq=100_000),
|
||||
)
|
||||
else:
|
||||
# Mixed transition: half envs stay at n-1, half advance to n,
|
||||
# for the first half of the stage budget. This prevents the
|
||||
# n+1 task's noisy early gradients from destroying the n policy
|
||||
# (catastrophic forgetting) before it has a chance to adapt.
|
||||
half = max(1, args.n_envs // 2)
|
||||
for i in range(half):
|
||||
vn.env_method("set_n_sheep", n - 1, indices=[i])
|
||||
for i in range(half, args.n_envs):
|
||||
vn.env_method("set_n_sheep", n, indices=[i])
|
||||
mix_steps = args.steps_per_stage // 2
|
||||
full_steps = args.steps_per_stage - mix_steps
|
||||
print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) "
|
||||
f"{mix_steps:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=mix_steps,
|
||||
reset_num_timesteps=False,
|
||||
callback=ProgressCallback(f"{n-1}→{n} mix", freq=100_000),
|
||||
)
|
||||
vn.env_method("set_n_sheep", n)
|
||||
print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=full_steps,
|
||||
reset_num_timesteps=False,
|
||||
callback=ProgressCallback(f"{n} sheep", freq=100_000),
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
PPO training with attention-based policy (train_at.py).
|
||||
|
||||
Key difference from train.py
|
||||
-----------------------------
|
||||
- Observation exposes ALL sheep as individual per-sheep tokens rather than
|
||||
only the top-3 farthest. The policy therefore has complete flock visibility
|
||||
at any sheep count — no hidden sheep even at n=10.
|
||||
- A TransformerFeaturesExtractor processes the sheep tokens with multi-head
|
||||
self-attention (permutation-invariant), then mean-pools over valid tokens
|
||||
and concatenates the result with global dog/pen features.
|
||||
- Curriculum transition uses the same mixed-env approach as train.py: half
|
||||
the envs stay at n-1 for the first half of each new stage to suppress
|
||||
catastrophic forgetting.
|
||||
|
||||
Observation layout (7 + MAX_SHEEP*6 = 67 dims, fixed)
|
||||
-------------------------------------------------------
|
||||
Global (7):
|
||||
dog_x / FIELD, dog_y / FIELD,
|
||||
cos(heading), sin(heading),
|
||||
(pen_x - dog_x) / D, (pen_y - dog_y) / D,
|
||||
n_active / n_sheep
|
||||
|
||||
Per sheep i (6):
|
||||
(sheep_x - dog_x) / D, (sheep_y - dog_y) / D, ← pos rel to dog
|
||||
(pen_x - sheep_x) / D, (pen_y - sheep_y) / D, ← sheep-to-pen
|
||||
is_active 1.0 if not penned, else 0.0
|
||||
is_valid 1.0 if i < n_sheep, else 0.0 (padding sentinel)
|
||||
|
||||
After VecNormalize, is_valid for real sheep normalises > 0 and for
|
||||
padding tokens < 0 (because mean ∈ (0,1)), so a threshold of 0 cleanly
|
||||
separates real from padded without any extra bookkeeping.
|
||||
|
||||
Usage
|
||||
-----
|
||||
python train_at.py # defaults from config.json
|
||||
python train_at.py --max-sheep 10 --steps-per-stage 2000000
|
||||
python train_at.py --embed-dim 128 --n-heads 4 --n-layers 3
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gymnasium import spaces
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
from train import ProgressCallback, _classify, COMPACT_RADIUS, DEFAULT_CONFIG
|
||||
from viz import (
|
||||
run_and_record, plot_trajectory, plot_timeseries,
|
||||
plot_success_rate, save_episode_gif,
|
||||
)
|
||||
|
||||
|
||||
# ── Per-sheep token observation environment ───────────────────────────────────
|
||||
|
||||
class HerdingEnvAt(HerdingEnv):
|
||||
"""
|
||||
HerdingEnv with a per-sheep token observation for the attention policy.
|
||||
Everything else (dynamics, reward, curriculum interface) is inherited.
|
||||
"""
|
||||
|
||||
OBS_GLOBAL = 7
|
||||
OBS_SHEEP = 6
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
obs_dim = self.OBS_GLOBAL + self.MAX_SHEEP * self.OBS_SHEEP
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32
|
||||
)
|
||||
|
||||
def _obs(self) -> np.ndarray:
|
||||
S = self.FIELD
|
||||
D = 2.0 * self.FIELD
|
||||
pen_ref = self.PEN_ENTRY if self.ENTRY_AWARE else self.PEN_CENTER
|
||||
active_mask = ~self.penned[:self.n_sheep]
|
||||
n_active = int(active_mask.sum())
|
||||
|
||||
global_feats = np.array([
|
||||
self.dog_pos[0] / S,
|
||||
self.dog_pos[1] / S,
|
||||
float(np.cos(self.dog_heading)),
|
||||
float(np.sin(self.dog_heading)),
|
||||
(pen_ref[0] - self.dog_pos[0]) / D,
|
||||
(pen_ref[1] - self.dog_pos[1]) / D,
|
||||
n_active / max(self.n_sheep, 1),
|
||||
], dtype=np.float32)
|
||||
|
||||
sheep_feats = np.zeros((self.MAX_SHEEP, self.OBS_SHEEP), dtype=np.float32)
|
||||
for i in range(self.n_sheep):
|
||||
pos = self.sheep_pos[i]
|
||||
sheep_feats[i] = [
|
||||
(pos[0] - self.dog_pos[0]) / D,
|
||||
(pos[1] - self.dog_pos[1]) / D,
|
||||
(pen_ref[0] - pos[0]) / D,
|
||||
(pen_ref[1] - pos[1]) / D,
|
||||
float(not self.penned[i]),
|
||||
1.0, # is_valid: this sheep exists
|
||||
]
|
||||
# i >= n_sheep: all zeros, is_valid=0 → masked out in attention
|
||||
|
||||
return np.concatenate([global_feats, sheep_feats.ravel()])
|
||||
|
||||
|
||||
# ── Attention features extractor ──────────────────────────────────────────────
|
||||
|
||||
class ShepherdAttentionExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
Multi-head self-attention over per-sheep tokens, mean-pooled over valid
|
||||
(non-padding) tokens and concatenated with global dog/pen features.
|
||||
|
||||
After VecNormalize:
|
||||
real sheep → is_valid_norm > 0 (normalised from 1.0)
|
||||
padding → is_valid_norm ≤ 0 (normalised from 0.0)
|
||||
so threshold at 0 is always correct regardless of curriculum stage.
|
||||
"""
|
||||
|
||||
GLOBAL_DIM = HerdingEnvAt.OBS_GLOBAL # 7
|
||||
SHEEP_DIM = HerdingEnvAt.OBS_SHEEP # 6
|
||||
MAX_SHEEP = HerdingEnv.MAX_SHEEP # 10
|
||||
VALID_IDX = 5 # index of is_valid within each token
|
||||
|
||||
def __init__(self, observation_space, embed_dim: int = 64,
|
||||
n_heads: int = 4, n_layers: int = 2, ff_dim: int = 128):
|
||||
super().__init__(observation_space,
|
||||
features_dim=self.GLOBAL_DIM + embed_dim)
|
||||
self.sheep_embed = nn.Linear(self.SHEEP_DIM, embed_dim)
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=embed_dim, nhead=n_heads, dim_feedforward=ff_dim,
|
||||
dropout=0.0, batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer,
|
||||
num_layers=n_layers)
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
B = obs.shape[0]
|
||||
global_feats = obs[:, :self.GLOBAL_DIM] # (B, 7)
|
||||
tokens = obs[:, self.GLOBAL_DIM:].view(
|
||||
B, self.MAX_SHEEP, self.SHEEP_DIM) # (B, 10, 6)
|
||||
|
||||
# is_valid after VecNorm: real > 0, padding ≤ 0
|
||||
is_valid_norm = tokens[:, :, self.VALID_IDX] # (B, 10)
|
||||
key_padding_mask = is_valid_norm <= 0.0 # True → ignore
|
||||
|
||||
x = self.sheep_embed(tokens) # (B, 10, E)
|
||||
x = self.transformer(x, src_key_padding_mask=key_padding_mask)
|
||||
|
||||
valid_w = (is_valid_norm > 0.0).float().unsqueeze(-1) # (B, 10, 1)
|
||||
pooled = (x * valid_w).sum(1) / valid_w.sum(1).clamp(min=1.0)
|
||||
|
||||
return torch.cat([global_feats, pooled], dim=1) # (B, 7+E)
|
||||
|
||||
|
||||
# ── Environment factory ───────────────────────────────────────────────────────
|
||||
|
||||
def make_env_at(n_sheep, seed, max_steps, reward_cfg=None):
|
||||
def _init():
|
||||
env = HerdingEnvAt(n_sheep=n_sheep, max_steps=max_steps,
|
||||
reward_cfg=reward_cfg)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
|
||||
# ── Evaluation ────────────────────────────────────────────────────────────────
|
||||
|
||||
def evaluate_at(model, vn_template, n_sheep, n_episodes, max_steps,
|
||||
reward_cfg=None):
|
||||
raw = DummyVecEnv([make_env_at(n_sheep, 9999, max_steps, reward_cfg)])
|
||||
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
vn.obs_rms = deepcopy(vn_template.obs_rms)
|
||||
vn.ret_rms = deepcopy(vn_template.ret_rms)
|
||||
|
||||
successes = 0
|
||||
ep_lens, min_pen_list, action_mags = [], [], []
|
||||
failure_counts, rc_sums = {}, {}
|
||||
rc_n = 0
|
||||
|
||||
for _ in range(n_episodes):
|
||||
obs = vn.reset()
|
||||
done = False
|
||||
steps, min_pen = 0, float("inf")
|
||||
mags, ep_radii, ep_com_dists = [], [], []
|
||||
while not done:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, _, dones, infos = vn.step(action)
|
||||
done = dones[0]
|
||||
inner = vn.envs[0]
|
||||
com, radius, _ = inner._flock_stats()
|
||||
min_pen = min(min_pen,
|
||||
float(np.linalg.norm(com - inner.PEN_CENTER)))
|
||||
mags.append(float(np.linalg.norm(action[0])))
|
||||
ep_radii.append(radius)
|
||||
ep_com_dists.append(float(np.linalg.norm(com - inner.PEN_CENTER)))
|
||||
steps += 1
|
||||
rc = infos[0].get("rcomps")
|
||||
if rc:
|
||||
for k, v in rc.items():
|
||||
rc_sums[k] = rc_sums.get(k, 0.0) + v
|
||||
rc_n += 1
|
||||
n_penned = infos[0].get("n_penned", 0)
|
||||
successes += int(n_penned == n_sheep)
|
||||
ep_lens.append(steps)
|
||||
min_pen_list.append(min_pen)
|
||||
action_mags.extend(mags)
|
||||
mode = _classify(ep_radii, ep_com_dists, n_penned, n_sheep)
|
||||
failure_counts[mode] = failure_counts.get(mode, 0) + 1
|
||||
|
||||
vn.close()
|
||||
result = {
|
||||
"sr": successes / n_episodes,
|
||||
"mean_len": float(np.mean(ep_lens)),
|
||||
"mean_min_pen": float(np.mean(min_pen_list)),
|
||||
"mean_act": float(np.mean(action_mags)) if action_mags else 0.0,
|
||||
"failure_modes": failure_counts,
|
||||
}
|
||||
if rc_n > 0:
|
||||
result["reward_per_step"] = {k: v / rc_n for k, v in rc_sums.items()}
|
||||
return result
|
||||
|
||||
|
||||
# ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(
|
||||
description="PPO + attention training for herding task")
|
||||
p.add_argument("--config", type=str, default=None)
|
||||
p.add_argument("--max-sheep", type=int, default=10)
|
||||
p.add_argument("--steps-per-stage", type=int, default=1_500_000)
|
||||
p.add_argument("--n-envs", type=int, default=8)
|
||||
p.add_argument("--max-steps", type=int, default=2500)
|
||||
p.add_argument("--eval-episodes", type=int, default=30)
|
||||
p.add_argument("--run-dir", type=str, default=None)
|
||||
p.add_argument("--no-gif", action="store_true")
|
||||
p.add_argument("--gif-fps", type=int, default=20)
|
||||
p.add_argument("--gif-skip", type=int, default=3)
|
||||
# Attention architecture
|
||||
p.add_argument("--embed-dim", type=int, default=64,
|
||||
help="Transformer embedding dimension (default 64)")
|
||||
p.add_argument("--n-heads", type=int, default=4,
|
||||
help="Number of attention heads (default 4)")
|
||||
p.add_argument("--n-layers", type=int, default=2,
|
||||
help="Number of transformer encoder layers (default 2)")
|
||||
p.add_argument("--ff-dim", type=int, default=128,
|
||||
help="Transformer feed-forward dim (default 128)")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
cfg = dict(DEFAULT_CONFIG)
|
||||
config_path = args.config
|
||||
if config_path is None and os.path.exists("config.json"):
|
||||
config_path = "config.json"
|
||||
if config_path:
|
||||
with open(config_path) as f:
|
||||
cfg.update(json.load(f))
|
||||
print(f"Config loaded from {config_path}")
|
||||
|
||||
rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}
|
||||
|
||||
run_dir = args.run_dir or os.path.join(
|
||||
"runs", "at_" + time.strftime("%Y%m%d_%H%M%S"))
|
||||
eval_dir = os.path.join(run_dir, "eval")
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
with open(os.path.join(run_dir, "config.json"), "w") as f:
|
||||
json.dump(cfg, f, indent=2)
|
||||
|
||||
print(f"Config: {cfg}")
|
||||
print(f"Run dir: {run_dir}")
|
||||
print(f"Curriculum: 1 → {args.max_sheep} sheep, "
|
||||
f"{args.steps_per_stage:,} steps/stage")
|
||||
print(f"Transformer: embed={args.embed_dim} heads={args.n_heads} "
|
||||
f"layers={args.n_layers} ff={args.ff_dim}\n")
|
||||
|
||||
train_env = SubprocVecEnv([
|
||||
make_env_at(1, seed=i, max_steps=args.max_steps, reward_cfg=rcfg)
|
||||
for i in range(args.n_envs)
|
||||
])
|
||||
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
|
||||
|
||||
model = PPO(
|
||||
"MlpPolicy", vn,
|
||||
learning_rate=3e-4, n_steps=2048, batch_size=256, n_epochs=10,
|
||||
gamma=0.995, gae_lambda=0.95, clip_range=0.2,
|
||||
ent_coef=cfg.get("ent_coef", 0.02), vf_coef=0.5, max_grad_norm=0.5,
|
||||
policy_kwargs=dict(
|
||||
features_extractor_class=ShepherdAttentionExtractor,
|
||||
features_extractor_kwargs=dict(
|
||||
embed_dim=args.embed_dim,
|
||||
n_heads=args.n_heads,
|
||||
n_layers=args.n_layers,
|
||||
ff_dim=args.ff_dim,
|
||||
),
|
||||
net_arch=[256, 256],
|
||||
),
|
||||
device="cpu",
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
stage_results = []
|
||||
t0 = time.time()
|
||||
|
||||
try:
|
||||
for n in range(1, args.max_sheep + 1):
|
||||
if n == 1:
|
||||
print(f"\n[Stage n_sheep=1] training {args.steps_per_stage:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=args.steps_per_stage,
|
||||
reset_num_timesteps=True,
|
||||
callback=ProgressCallback("1 sheep", freq=100_000),
|
||||
)
|
||||
else:
|
||||
half = max(1, args.n_envs // 2)
|
||||
mix_steps = args.steps_per_stage // 2
|
||||
full_steps = args.steps_per_stage - mix_steps
|
||||
|
||||
for i in range(half):
|
||||
vn.env_method("set_n_sheep", n - 1, indices=[i])
|
||||
for i in range(half, args.n_envs):
|
||||
vn.env_method("set_n_sheep", n, indices=[i])
|
||||
|
||||
print(f"\n[Stage n_sheep={n}] mixed ({n-1}/{n} sheep) "
|
||||
f"{mix_steps:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=mix_steps,
|
||||
reset_num_timesteps=False,
|
||||
callback=ProgressCallback(f"{n-1}→{n} mix", freq=100_000),
|
||||
)
|
||||
|
||||
vn.env_method("set_n_sheep", n)
|
||||
print(f"[Stage n_sheep={n}] full ({n} sheep) {full_steps:,} steps")
|
||||
model.learn(
|
||||
total_timesteps=full_steps,
|
||||
reset_num_timesteps=False,
|
||||
callback=ProgressCallback(f"{n} sheep", freq=100_000),
|
||||
)
|
||||
|
||||
print(f"[Stage n_sheep={n}] evaluating {args.eval_episodes} eps")
|
||||
r = evaluate_at(model, vn, n, args.eval_episodes,
|
||||
args.max_steps, rcfg)
|
||||
print(f"[Stage n_sheep={n}] sr={r['sr']*100:.0f}% "
|
||||
f"mean_len={r['mean_len']:.0f} "
|
||||
f"mean_min_pen={r['mean_min_pen']:.1f}m "
|
||||
f"mean_act={r['mean_act']:.2f}")
|
||||
if r["failure_modes"]:
|
||||
modes = " ".join(
|
||||
f"{k}={v}" for k, v in sorted(
|
||||
r["failure_modes"].items(), key=lambda x: -x[1]))
|
||||
print(f" failure modes: {modes}")
|
||||
if "reward_per_step" in r:
|
||||
rps = r["reward_per_step"]
|
||||
print(" reward/step: " + " ".join(
|
||||
f"{k}={v:+.4f}" for k, v in rps.items()))
|
||||
|
||||
hist = run_and_record(
|
||||
model, vn, n, args.max_steps, rcfg,
|
||||
seed=1000 + n, make_env_fn=make_env_at,
|
||||
)
|
||||
tag = "success" if hist["success"] else "fail"
|
||||
plot_trajectory(hist, os.path.join(eval_dir, f"traj_{n}s_{tag}.png"))
|
||||
plot_timeseries(hist, os.path.join(eval_dir, f"ts_{n}s_{tag}.png"))
|
||||
if not args.no_gif:
|
||||
save_episode_gif(
|
||||
hist,
|
||||
os.path.join(eval_dir, f"ep_{n}s_{tag}.gif"),
|
||||
fps=args.gif_fps, skip=args.gif_skip)
|
||||
|
||||
r["n_sheep"] = n
|
||||
stage_results.append(r)
|
||||
|
||||
model.save(os.path.join(run_dir, "final_model"))
|
||||
vn.save(os.path.join(run_dir, "vecnorm.pkl"))
|
||||
with open(os.path.join(run_dir, "stage_results.json"), "w") as f:
|
||||
json.dump(stage_results, f, indent=2)
|
||||
|
||||
finally:
|
||||
try:
|
||||
vn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elapsed = (time.time() - t0) / 60
|
||||
print("\n" + "=" * 70)
|
||||
print(" TRAINING SUMMARY (attention policy)")
|
||||
print("=" * 70)
|
||||
for r in stage_results:
|
||||
print(f" n_sheep={r['n_sheep']} sr={r['sr']*100:>3.0f}% "
|
||||
f"len={r['mean_len']:>5.0f} "
|
||||
f"min_pen={r['mean_min_pen']:>5.1f}m "
|
||||
f"act={r['mean_act']:.2f}")
|
||||
print(f"\n Total time: {elapsed:.1f} min")
|
||||
print(f" Artefacts: {run_dir}/")
|
||||
plot_success_rate(stage_results, os.path.join(run_dir, "success_rate.png"))
|
||||
print(f" Plots: {run_dir}/success_rate.png, {eval_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+3
-2
@@ -78,9 +78,10 @@ def make_eval_env(n_sheep, seed, max_steps, reward_cfg=None):
|
||||
|
||||
|
||||
def run_and_record(model, vn_template, n_sheep, max_steps,
|
||||
reward_cfg=None, seed=42):
|
||||
reward_cfg=None, seed=42, make_env_fn=None):
|
||||
"""Run one deterministic episode and return full trajectory history."""
|
||||
raw = DummyVecEnv([make_eval_env(n_sheep, seed, max_steps, reward_cfg)])
|
||||
_factory = make_env_fn or make_eval_env
|
||||
raw = DummyVecEnv([_factory(n_sheep, seed, max_steps, reward_cfg)])
|
||||
vn = VecNormalize(raw, norm_obs=True, norm_reward=False, training=False)
|
||||
vn.obs_rms = deepcopy(vn_template.obs_rms)
|
||||
vn.ret_rms = deepcopy(vn_template.ret_rms)
|
||||
|
||||
Reference in New Issue
Block a user