Checkpoint 7
This commit is contained in:
@@ -0,0 +1,144 @@
|
||||
"""Collect (obs, action) demonstrations from an analytic teacher.
|
||||
|
||||
Runs the chosen teacher across a grid of ``(n_sheep, seed)`` combos at
|
||||
full difficulty, logs every Nth ``(obs, action)`` pair, and saves
|
||||
successful trajectories to ``.npz`` for behaviour cloning. The teacher
|
||||
is wrapped in :class:`ActiveScanTeacher` by default so it operates on
|
||||
the same partial-obs view the student will have at deployment.
|
||||
|
||||
Usage::
|
||||
|
||||
python -m training.bc.collect --teacher strombom \\
|
||||
--out training/bc/demos.npz --frame-stack 4
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from herding.control.active_scan import ActiveScanTeacher
|
||||
from herding.world.geometry import PEN_ENTRY
|
||||
from herding.control.sequential import compute_action as sequential_action
|
||||
from herding.control.strombom import compute_action as strombom_action
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
TEACHERS = {
|
||||
"sequential": sequential_action,
|
||||
"strombom": strombom_action,
|
||||
}
|
||||
|
||||
|
||||
def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int,
|
||||
teacher_fn, frame_stack: int = 1, privileged: bool = False):
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||
difficulty=1.0, seed=seed, frame_stack=frame_stack)
|
||||
obs, _ = env.reset(seed=seed)
|
||||
obs_list, action_list = [], []
|
||||
# Wrap the base teacher so it opens with a rotation and walks to
|
||||
# centre when the tracker briefly empties — matches the student.
|
||||
scan_teacher = ActiveScanTeacher(teacher_fn)
|
||||
for step in range(max_steps):
|
||||
if privileged:
|
||||
# Asymmetric variant: teacher reads ground truth while the
|
||||
# student keeps the LiDAR obs. Default off.
|
||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||
for i in range(env.n_sheep) if not env.sheep_penned[i]}
|
||||
if not positions:
|
||||
break
|
||||
vx, vy, _mode = teacher_fn(
|
||||
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
|
||||
)
|
||||
else:
|
||||
positions = env.perceived_positions()
|
||||
vx, vy, _mode = scan_teacher(
|
||||
(env.dog_x, env.dog_y), env.dog_heading,
|
||||
positions, PEN_ENTRY,
|
||||
)
|
||||
action = np.array([vx, vy], dtype=np.float32)
|
||||
if step % subsample == 0:
|
||||
obs_list.append(obs.copy())
|
||||
action_list.append(action.copy())
|
||||
obs, _r, term, trunc, _info = env.step(action)
|
||||
if term or trunc:
|
||||
break
|
||||
success = bool(env.sheep_penned.all())
|
||||
return (
|
||||
np.asarray(obs_list, dtype=np.float32),
|
||||
np.asarray(action_list, dtype=np.float32),
|
||||
success,
|
||||
env.steps,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--out", default="training/bc/demos.npz")
|
||||
parser.add_argument("--n-sheep-list", default="1,2,3,5,8,10")
|
||||
parser.add_argument("--seeds-per-n", type=int, default=15)
|
||||
parser.add_argument("--max-steps", type=int, default=30000)
|
||||
parser.add_argument("--subsample", type=int, default=5,
|
||||
help="Keep every Nth (obs, action) pair.")
|
||||
parser.add_argument("--keep-failures", action="store_true",
|
||||
help="Include partial-success trajectories. Default off.")
|
||||
parser.add_argument("--teacher", default="sequential",
|
||||
choices=list(TEACHERS.keys()),
|
||||
help="Which analytic teacher to demonstrate.")
|
||||
parser.add_argument("--frame-stack", type=int, default=1,
|
||||
help="Concatenate the last K obs into a "
|
||||
"(32·K)-D vector for the policy.")
|
||||
parser.add_argument("--privileged", action="store_true",
|
||||
help="Teacher reads ground truth instead of "
|
||||
"tracker output (asymmetric BC).")
|
||||
args = parser.parse_args()
|
||||
teacher_fn = TEACHERS[args.teacher]
|
||||
print(f"[demos] teacher: {args.teacher}")
|
||||
|
||||
n_sheep_list = [int(x) for x in args.n_sheep_list.split(",")]
|
||||
print(f"[demos] grid: n_sheep={n_sheep_list}, seeds={args.seeds_per_n}, "
|
||||
f"max_steps={args.max_steps}, subsample={args.subsample}")
|
||||
|
||||
all_obs, all_actions, all_meta = [], [], []
|
||||
t_start = time.time()
|
||||
n_success = 0; n_total = 0
|
||||
|
||||
for n in n_sheep_list:
|
||||
for seed in range(args.seeds_per_n):
|
||||
obs, actions, success, total_steps = collect_one(
|
||||
n, seed, args.max_steps, args.subsample, teacher_fn,
|
||||
frame_stack=args.frame_stack, privileged=args.privileged,
|
||||
)
|
||||
n_total += 1
|
||||
if success:
|
||||
n_success += 1
|
||||
keep = success or args.keep_failures
|
||||
if keep and len(obs) > 0:
|
||||
all_obs.append(obs)
|
||||
all_actions.append(actions)
|
||||
all_meta.append((n, seed, len(obs), int(success), total_steps))
|
||||
tag = "✓" if success else "✗"
|
||||
print(f" [{tag}] n={n:>2d} seed={seed:>2d} steps={total_steps:>6d} "
|
||||
f"logged={len(obs):>5d}")
|
||||
|
||||
if not all_obs:
|
||||
raise RuntimeError("No trajectories kept — try --keep-failures.")
|
||||
|
||||
obs = np.concatenate(all_obs, axis=0)
|
||||
actions = np.concatenate(all_actions, axis=0)
|
||||
meta = np.array(all_meta, dtype=np.int32)
|
||||
|
||||
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
||||
np.savez(args.out, obs=obs, actions=actions, meta=meta)
|
||||
|
||||
elapsed = time.time() - t_start
|
||||
print(f"\n=== {n_success}/{n_total} trajectories successful ({100*n_success/n_total:.0f}%) ===")
|
||||
print(f"=== {len(obs)} transitions saved to {args.out} ===")
|
||||
print(f"=== obs={obs.shape}, actions={actions.shape}, elapsed={elapsed:.0f}s ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,219 @@
|
||||
"""Behaviour cloning of an analytic teacher into an SB3 MlpPolicy.
|
||||
|
||||
Trains the mean-action head against ``(obs, action)`` demos from
|
||||
``training.bc.collect`` using ``MSE + (1 − cos_sim)`` — the cosine
|
||||
term prevents collapse toward zero against unit-vector targets. The
|
||||
best-by-val_cos snapshot is restored at the end of training because
|
||||
multi-modal teachers make the last epoch unreliable.
|
||||
|
||||
Output zip is loadable by ``PPO.load(...)`` and consumed by
|
||||
``HERDING_MODE=bc`` in the dog controller.
|
||||
|
||||
Usage::
|
||||
|
||||
python -m training.bc.pretrain \\
|
||||
--demos training/bc/demos.npz \\
|
||||
--out training/runs/bc
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
|
||||
frame_stack: int = 1):
|
||||
"""Build a fresh SB3 PPO solely as a vehicle for the policy weights.
|
||||
|
||||
PPO's training-loop plumbing isn't used during BC. ``frame_stack``
|
||||
must match the demo file so the env's obs space agrees with the
|
||||
recorded obs shape.
|
||||
"""
|
||||
env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack)])
|
||||
model = PPO(
|
||||
"MlpPolicy", env,
|
||||
policy_kwargs=dict(
|
||||
net_arch=dict(pi=net_arch_pi, vf=net_arch_vf),
|
||||
log_std_init=log_std_init,
|
||||
),
|
||||
verbose=0,
|
||||
)
|
||||
return model, env
|
||||
|
||||
|
||||
def policy_forward_mean(policy, obs_batch):
|
||||
"""Return the deterministic mean action for an obs batch.
|
||||
|
||||
SB3's ActorCriticPolicy routes ``forward`` through a Distribution
|
||||
wrapper; we replicate the underlying chain
|
||||
``extract_features → mlp_extractor → action_net``.
|
||||
"""
|
||||
features = policy.extract_features(obs_batch)
|
||||
pi_features = features[0] if isinstance(features, tuple) else features
|
||||
latent_pi, _ = policy.mlp_extractor(pi_features)
|
||||
return policy.action_net(latent_pi)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--demos", default="training/bc/demos.npz")
|
||||
parser.add_argument("--out", default="training/runs/bc")
|
||||
parser.add_argument("--epochs", type=int, default=60)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--lr", type=float, default=1e-3)
|
||||
parser.add_argument("--val-split", type=float, default=0.1)
|
||||
parser.add_argument("--net-arch", default="256,256",
|
||||
help="Comma-separated hidden layer widths.")
|
||||
parser.add_argument("--log-std-init", type=float, default=0.5)
|
||||
parser.add_argument("--cos-weight", type=float, default=1.0,
|
||||
help="Weight of the (1 - cosine_similarity) loss "
|
||||
"term; balances against MSE.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--device", default="cpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# --- Load demos ---
|
||||
print(f"[bc] loading demos from {args.demos}")
|
||||
data = np.load(args.demos)
|
||||
obs = data["obs"].astype(np.float32)
|
||||
actions = data["actions"].astype(np.float32)
|
||||
meta = data["meta"]
|
||||
print(f"[bc] obs={obs.shape} actions={actions.shape} trajectories={len(meta)}")
|
||||
if obs.size == 0:
|
||||
raise RuntimeError("Empty demo file.")
|
||||
|
||||
a_norms = np.linalg.norm(actions, axis=1)
|
||||
print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} "
|
||||
f"min={a_norms.min():.3f} max={a_norms.max():.3f}")
|
||||
|
||||
# --- Train/val split ---
|
||||
n = len(obs)
|
||||
perm = np.random.permutation(n)
|
||||
n_val = int(n * args.val_split)
|
||||
val_idx, train_idx = perm[:n_val], perm[n_val:]
|
||||
print(f"[bc] train={len(train_idx)} val={len(val_idx)}")
|
||||
|
||||
obs_t = torch.from_numpy(obs)
|
||||
act_t = torch.from_numpy(actions)
|
||||
train_loader = DataLoader(
|
||||
TensorDataset(obs_t[train_idx], act_t[train_idx]),
|
||||
batch_size=args.batch_size, shuffle=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
TensorDataset(obs_t[val_idx], act_t[val_idx]),
|
||||
batch_size=args.batch_size, shuffle=False,
|
||||
)
|
||||
|
||||
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
|
||||
net_arch_vf = net_arch_pi[:]
|
||||
# Frame stack is inferred from the demo obs dim.
|
||||
obs_dim = obs.shape[1]
|
||||
from herding.perception.obs import OBS_DIM as _SINGLE
|
||||
if obs_dim % _SINGLE != 0:
|
||||
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
|
||||
frame_stack = obs_dim // _SINGLE
|
||||
if frame_stack > 1:
|
||||
print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}")
|
||||
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init,
|
||||
frame_stack=frame_stack)
|
||||
policy = model.policy.to(args.device)
|
||||
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
|
||||
|
||||
# --- Train ---
|
||||
print(f"[bc] training: epochs={args.epochs} batch={args.batch_size} "
|
||||
f"lr={args.lr} device={args.device}")
|
||||
t_start = time.time()
|
||||
best_val = float("inf")
|
||||
best_cos = -1.0
|
||||
best_state = None # restored at the end so noisy last epochs don't win
|
||||
|
||||
def combined_loss(pred, target):
|
||||
mse = nn.functional.mse_loss(pred, target)
|
||||
p_norm = pred.norm(dim=1).clamp_min(1e-6)
|
||||
t_norm = target.norm(dim=1).clamp_min(1e-6)
|
||||
cos_sim = (pred * target).sum(dim=1) / (p_norm * t_norm)
|
||||
cos_loss = (1.0 - cos_sim).mean()
|
||||
return mse + args.cos_weight * cos_loss, mse.item(), cos_sim.mean().item()
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
policy.train()
|
||||
train_loss_total, train_mse_total, train_cos_total, train_count = 0.0, 0.0, 0.0, 0
|
||||
for ob_batch, act_batch in train_loader:
|
||||
ob_batch = ob_batch.to(args.device)
|
||||
act_batch = act_batch.to(args.device)
|
||||
optimizer.zero_grad()
|
||||
mean_action = policy_forward_mean(policy, ob_batch)
|
||||
loss, mse_val, cos_val = combined_loss(mean_action, act_batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
bs = ob_batch.size(0)
|
||||
train_loss_total += loss.item() * bs
|
||||
train_mse_total += mse_val * bs
|
||||
train_cos_total += cos_val * bs
|
||||
train_count += bs
|
||||
train_mse = train_mse_total / max(1, train_count)
|
||||
train_cos = train_cos_total / max(1, train_count)
|
||||
|
||||
policy.eval()
|
||||
val_total, val_count = 0.0, 0
|
||||
cos_sim_total = 0.0
|
||||
with torch.no_grad():
|
||||
for ob_batch, act_batch in val_loader:
|
||||
ob_batch = ob_batch.to(args.device)
|
||||
act_batch = act_batch.to(args.device)
|
||||
mean_action = policy_forward_mean(policy, ob_batch)
|
||||
bs = ob_batch.size(0)
|
||||
val_total += nn.functional.mse_loss(
|
||||
mean_action, act_batch, reduction="sum",
|
||||
).item()
|
||||
m_norm = mean_action.norm(dim=1).clamp_min(1e-6)
|
||||
a_norm = act_batch.norm(dim=1).clamp_min(1e-6)
|
||||
cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm)
|
||||
cos_sim_total += cos.sum().item()
|
||||
val_count += bs
|
||||
val_mse = val_total / max(1, val_count) / actions.shape[1]
|
||||
cos_sim = cos_sim_total / max(1, val_count)
|
||||
print(f" epoch {epoch+1:>2d}/{args.epochs} "
|
||||
f"train_mse={train_mse:.4f} train_cos={train_cos:+.3f} "
|
||||
f"val_mse={val_mse:.4f} val_cos={cos_sim:+.3f}")
|
||||
if val_mse < best_val:
|
||||
best_val = val_mse
|
||||
if cos_sim > best_cos:
|
||||
best_cos = cos_sim
|
||||
best_state = {k: v.detach().cpu().clone()
|
||||
for k, v in policy.state_dict().items()}
|
||||
|
||||
if best_state is not None:
|
||||
policy.load_state_dict(best_state)
|
||||
print(f"[bc] restored best-val_cos snapshot (cos={best_cos:.3f})")
|
||||
|
||||
elapsed = time.time() - t_start
|
||||
print(f"[bc] done in {elapsed:.0f}s best_val_mse={best_val:.4f}")
|
||||
|
||||
# --- Save ---
|
||||
out_dir = Path(args.out)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
model.save(out_dir / "policy.zip")
|
||||
print(f"[bc] saved policy to {out_dir / 'policy.zip'}")
|
||||
print(f"\n[bc] verify with: "
|
||||
f"python -m training.eval --policy {out_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user