Files
TIR_PROJ/training/bc/pretrain.py
Johnny Fernandes 7ab69ab0f3 Rename multi-segment functions to two-concept names; polish docstrings
Naming pass: rename functions whose third+ segment is redundant or
implementation-detail, sticking to the codebase's preferred
``noun_verb`` / ``verb_noun`` two-concept idiom. Renames are atomic
across definitions, callers, and tests.

  is_penned_position        →  is_penned
  modulate_speed_near_sheep →  modulate_speed
  mecanum_kinematics_step   →  mecanum_step
  policy_forward_mean       →  forward_mean

Two-concept patterns like ``velocity_to_wheels`` / ``detections_from_scan``
/ ``make_strombom_predictor`` are left alone — they're idiomatic
converters / factories that read as a single concept, and the longer
form aids grep-ability.

Docstring polish:
* ``herding/config.py`` header drops the "previously lived as a
  module-level literal" historical framing — we ship as a single
  thing, so the refactor anecdote no longer earns its keep. The
  usage examples now mention both ``HERDING_WEBOTS`` and
  ``HERDING_MEC_WEBOTS`` presets.

126 pytest cases still pass.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-17 01:58:15 +00:00

239 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Behaviour cloning of an analytic teacher into an SB3 MlpPolicy.
Trains the mean-action head against ``(obs, action)`` demos from
``training.bc.collect`` using ``MSE + (1 cos_sim)`` — the cosine
term prevents collapse toward zero against unit-vector targets. The
best-by-val_cos snapshot is restored at the end of training because
multi-modal teachers make the last epoch unreliable.
Output zip is loadable by ``PPO.load(...)`` and consumed by
``HERDING_MODE=bc`` in the dog controller.
Usage::
python -m training.bc.pretrain \\
--demos training/bc/demos_differential_field.npz \\
--out training/runs/bc_differential_field
"""
from __future__ import annotations
import argparse
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from training.herding_env import HerdingEnv
def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
frame_stack: int = 1, drive_mode: str = "differential"):
"""Build a fresh SB3 PPO solely as a vehicle for the policy weights.
PPO's training-loop plumbing isn't used during BC. ``frame_stack``
must match the demo file so the env's obs space agrees with the
recorded obs shape.
"""
env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack,
drive_mode=drive_mode)])
model = PPO(
"MlpPolicy", env,
policy_kwargs=dict(
net_arch=dict(pi=net_arch_pi, vf=net_arch_vf),
log_std_init=log_std_init,
),
verbose=0,
)
return model, env
def forward_mean(policy, obs_batch):
"""Return the deterministic mean action for an obs batch.
SB3's ActorCriticPolicy routes ``forward`` through a Distribution
wrapper; we replicate the underlying chain
``extract_features → mlp_extractor → action_net``.
"""
features = policy.extract_features(obs_batch)
pi_features = features[0] if isinstance(features, tuple) else features
latent_pi, _ = policy.mlp_extractor(pi_features)
return policy.action_net(latent_pi)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--demos", required=True,
help="Path to demos .npz collected by training.bc.collect.")
parser.add_argument("--out", required=True,
help="Output directory (convention: "
"training/runs/bc_<drive>_<world>).")
parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--val-split", type=float, default=0.1)
parser.add_argument("--net-arch", default="256,256",
help="Comma-separated hidden layer widths.")
parser.add_argument("--log-std-init", type=float, default=0.5)
parser.add_argument("--cos-weight", type=float, default=1.0,
help="Weight of the (1 - cosine_similarity) loss "
"term; balances against MSE.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", default="cpu")
parser.add_argument("--drive-mode", default=None,
choices=["differential", "mecanum"],
help="Drive mode. If not set, inferred from "
"demo action dimension (2→differential, 3→mecanum).")
args = parser.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# --- Load demos ---
print(f"[bc] loading demos from {args.demos}")
data = np.load(args.demos)
obs = data["obs"].astype(np.float32)
actions = data["actions"].astype(np.float32)
meta = data["meta"]
print(f"[bc] obs={obs.shape} actions={actions.shape} trajectories={len(meta)}")
if obs.size == 0:
raise RuntimeError("Empty demo file.")
a_norms = np.linalg.norm(actions, axis=1)
print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} "
f"min={a_norms.min():.3f} max={a_norms.max():.3f}")
# --- Train/val split ---
n = len(obs)
perm = np.random.permutation(n)
n_val = int(n * args.val_split)
val_idx, train_idx = perm[:n_val], perm[n_val:]
print(f"[bc] train={len(train_idx)} val={len(val_idx)}")
obs_t = torch.from_numpy(obs)
act_t = torch.from_numpy(actions)
train_loader = DataLoader(
TensorDataset(obs_t[train_idx], act_t[train_idx]),
batch_size=args.batch_size, shuffle=True,
)
val_loader = DataLoader(
TensorDataset(obs_t[val_idx], act_t[val_idx]),
batch_size=args.batch_size, shuffle=False,
)
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
net_arch_vf = net_arch_pi[:]
# Frame stack is inferred from the demo obs dim.
obs_dim = obs.shape[1]
from herding.perception.obs import OBS_DIM as _SINGLE
if obs_dim % _SINGLE != 0:
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
frame_stack = obs_dim // _SINGLE
if frame_stack > 1:
print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}")
# Infer drive mode from action dimension if not explicitly set.
action_dim = actions.shape[1]
if args.drive_mode is not None:
drive_mode = args.drive_mode
elif action_dim == 3:
drive_mode = "mecanum"
else:
drive_mode = "differential"
print(f"[bc] drive_mode={drive_mode} (action_dim={action_dim})")
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init,
frame_stack=frame_stack, drive_mode=drive_mode)
policy = model.policy.to(args.device)
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
# --- Train ---
print(f"[bc] training: epochs={args.epochs} batch={args.batch_size} "
f"lr={args.lr} device={args.device}")
t_start = time.time()
best_val = float("inf")
best_cos = -1.0
best_state = None # restored at the end so noisy last epochs don't win
def combined_loss(pred, target):
mse = nn.functional.mse_loss(pred, target)
p_norm = pred.norm(dim=1).clamp_min(1e-6)
t_norm = target.norm(dim=1).clamp_min(1e-6)
cos_sim = (pred * target).sum(dim=1) / (p_norm * t_norm)
cos_loss = (1.0 - cos_sim).mean()
return mse + args.cos_weight * cos_loss, mse.item(), cos_sim.mean().item()
for epoch in range(args.epochs):
policy.train()
train_loss_total, train_mse_total, train_cos_total, train_count = 0.0, 0.0, 0.0, 0
for ob_batch, act_batch in train_loader:
ob_batch = ob_batch.to(args.device)
act_batch = act_batch.to(args.device)
optimizer.zero_grad()
mean_action = forward_mean(policy, ob_batch)
loss, mse_val, cos_val = combined_loss(mean_action, act_batch)
loss.backward()
optimizer.step()
bs = ob_batch.size(0)
train_loss_total += loss.item() * bs
train_mse_total += mse_val * bs
train_cos_total += cos_val * bs
train_count += bs
train_mse = train_mse_total / max(1, train_count)
train_cos = train_cos_total / max(1, train_count)
policy.eval()
val_total, val_count = 0.0, 0
cos_sim_total = 0.0
with torch.no_grad():
for ob_batch, act_batch in val_loader:
ob_batch = ob_batch.to(args.device)
act_batch = act_batch.to(args.device)
mean_action = forward_mean(policy, ob_batch)
bs = ob_batch.size(0)
val_total += nn.functional.mse_loss(
mean_action, act_batch, reduction="sum",
).item()
m_norm = mean_action.norm(dim=1).clamp_min(1e-6)
a_norm = act_batch.norm(dim=1).clamp_min(1e-6)
cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm)
cos_sim_total += cos.sum().item()
val_count += bs
val_mse = val_total / max(1, val_count) / actions.shape[1]
cos_sim = cos_sim_total / max(1, val_count)
print(f" epoch {epoch+1:>2d}/{args.epochs} "
f"train_mse={train_mse:.4f} train_cos={train_cos:+.3f} "
f"val_mse={val_mse:.4f} val_cos={cos_sim:+.3f}")
if val_mse < best_val:
best_val = val_mse
if cos_sim > best_cos:
best_cos = cos_sim
best_state = {k: v.detach().cpu().clone()
for k, v in policy.state_dict().items()}
if best_state is not None:
policy.load_state_dict(best_state)
print(f"[bc] restored best-val_cos snapshot (cos={best_cos:.3f})")
elapsed = time.time() - t_start
print(f"[bc] done in {elapsed:.0f}s best_val_mse={best_val:.4f}")
# --- Save ---
out_dir = Path(args.out)
out_dir.mkdir(parents=True, exist_ok=True)
model.save(out_dir / "policy.zip")
print(f"[bc] saved policy to {out_dir / 'policy.zip'}")
print(f"\n[bc] verify with: "
f"python -m training.eval --policy {out_dir}")
if __name__ == "__main__":
main()