Files
TIR_PROJ/training/bc/pretrain.py
T
Johnny Fernandes 10c01a938e Drop versioning vocabulary, polish docstrings, fix world-aware policy resolution
User-facing pass after the project was decided to be a single
submission with no inner iterations.

* Remove every "v1"/"v2"/"versioning" reference from the docs:
  - README mecanum section trims the "v1 predates the rewrite" prose
    in favour of a self-contained retrain recipe.
  - The 3.2 GB `training/runs/v1_clean/` backup directory is deleted.
* Refresh control-layer docstrings:
  - `sheep_tracker.py` header now describes the three actual pipeline
    stages (consensus, prediction, pen latching) instead of layering
    the consensus stage on top of a stale "predictive mode" preamble.
  - `controllers/shepherd_dog/shepherd_dog.py` mode list is
    up-to-date — adds `universal`, removes outdated single-policy
    default paths, mentions `HERDING_USE_GT=1` as the perception
    ablation.
* Refresh training command examples:
  - `training/bc/collect.py` and `training/bc/pretrain.py` usage
    snippets show the world-suffixed paths the Makefile actually
    uses; the `--out` arg is now required so old "demos.npz"
    invocations error loudly instead of silently overwriting.
  - `training/README.md` rewritten — drops the legacy `runs/bc`
    diagram, documents the per-(drive, world) pipeline, and adds
    the mecanum retraining caveat.
* Fix policy-directory resolution end-to-end:
  - `tools/run_webots.sh` now tries
    `training/runs/{bc,rl}_<drive>_<world>` first, then the drive-
    only path, then the bare-mode legacy path — matching the actual
    on-disk layout. Previously it looked for `bc_<drive>` (no
    world) and silently fell back to `bc`, masking the world
    selection.
  - `controllers/shepherd_dog/shepherd_dog.py:_resolve_policy_dir`
    has the same fix plus a latent NameError unmasked: it referenced
    `DRIVE_MODE` before that variable was set at module load. The
    block is restructured so MODE/DRIVE_MODE/WORLD are resolved
    first, then the function uses them as explicit arguments.

126 pytest cases still pass.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-17 01:50:54 +00:00

239 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Behaviour cloning of an analytic teacher into an SB3 MlpPolicy.
Trains the mean-action head against ``(obs, action)`` demos from
``training.bc.collect`` using ``MSE + (1 cos_sim)`` — the cosine
term prevents collapse toward zero against unit-vector targets. The
best-by-val_cos snapshot is restored at the end of training because
multi-modal teachers make the last epoch unreliable.
Output zip is loadable by ``PPO.load(...)`` and consumed by
``HERDING_MODE=bc`` in the dog controller.
Usage::
python -m training.bc.pretrain \\
--demos training/bc/demos_differential_field.npz \\
--out training/runs/bc_differential_field
"""
from __future__ import annotations
import argparse
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from training.herding_env import HerdingEnv
def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
frame_stack: int = 1, drive_mode: str = "differential"):
"""Build a fresh SB3 PPO solely as a vehicle for the policy weights.
PPO's training-loop plumbing isn't used during BC. ``frame_stack``
must match the demo file so the env's obs space agrees with the
recorded obs shape.
"""
env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack,
drive_mode=drive_mode)])
model = PPO(
"MlpPolicy", env,
policy_kwargs=dict(
net_arch=dict(pi=net_arch_pi, vf=net_arch_vf),
log_std_init=log_std_init,
),
verbose=0,
)
return model, env
def policy_forward_mean(policy, obs_batch):
"""Return the deterministic mean action for an obs batch.
SB3's ActorCriticPolicy routes ``forward`` through a Distribution
wrapper; we replicate the underlying chain
``extract_features → mlp_extractor → action_net``.
"""
features = policy.extract_features(obs_batch)
pi_features = features[0] if isinstance(features, tuple) else features
latent_pi, _ = policy.mlp_extractor(pi_features)
return policy.action_net(latent_pi)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--demos", required=True,
help="Path to demos .npz collected by training.bc.collect.")
parser.add_argument("--out", required=True,
help="Output directory (convention: "
"training/runs/bc_<drive>_<world>).")
parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--val-split", type=float, default=0.1)
parser.add_argument("--net-arch", default="256,256",
help="Comma-separated hidden layer widths.")
parser.add_argument("--log-std-init", type=float, default=0.5)
parser.add_argument("--cos-weight", type=float, default=1.0,
help="Weight of the (1 - cosine_similarity) loss "
"term; balances against MSE.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", default="cpu")
parser.add_argument("--drive-mode", default=None,
choices=["differential", "mecanum"],
help="Drive mode. If not set, inferred from "
"demo action dimension (2→differential, 3→mecanum).")
args = parser.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# --- Load demos ---
print(f"[bc] loading demos from {args.demos}")
data = np.load(args.demos)
obs = data["obs"].astype(np.float32)
actions = data["actions"].astype(np.float32)
meta = data["meta"]
print(f"[bc] obs={obs.shape} actions={actions.shape} trajectories={len(meta)}")
if obs.size == 0:
raise RuntimeError("Empty demo file.")
a_norms = np.linalg.norm(actions, axis=1)
print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} "
f"min={a_norms.min():.3f} max={a_norms.max():.3f}")
# --- Train/val split ---
n = len(obs)
perm = np.random.permutation(n)
n_val = int(n * args.val_split)
val_idx, train_idx = perm[:n_val], perm[n_val:]
print(f"[bc] train={len(train_idx)} val={len(val_idx)}")
obs_t = torch.from_numpy(obs)
act_t = torch.from_numpy(actions)
train_loader = DataLoader(
TensorDataset(obs_t[train_idx], act_t[train_idx]),
batch_size=args.batch_size, shuffle=True,
)
val_loader = DataLoader(
TensorDataset(obs_t[val_idx], act_t[val_idx]),
batch_size=args.batch_size, shuffle=False,
)
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
net_arch_vf = net_arch_pi[:]
# Frame stack is inferred from the demo obs dim.
obs_dim = obs.shape[1]
from herding.perception.obs import OBS_DIM as _SINGLE
if obs_dim % _SINGLE != 0:
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
frame_stack = obs_dim // _SINGLE
if frame_stack > 1:
print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}")
# Infer drive mode from action dimension if not explicitly set.
action_dim = actions.shape[1]
if args.drive_mode is not None:
drive_mode = args.drive_mode
elif action_dim == 3:
drive_mode = "mecanum"
else:
drive_mode = "differential"
print(f"[bc] drive_mode={drive_mode} (action_dim={action_dim})")
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init,
frame_stack=frame_stack, drive_mode=drive_mode)
policy = model.policy.to(args.device)
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
# --- Train ---
print(f"[bc] training: epochs={args.epochs} batch={args.batch_size} "
f"lr={args.lr} device={args.device}")
t_start = time.time()
best_val = float("inf")
best_cos = -1.0
best_state = None # restored at the end so noisy last epochs don't win
def combined_loss(pred, target):
mse = nn.functional.mse_loss(pred, target)
p_norm = pred.norm(dim=1).clamp_min(1e-6)
t_norm = target.norm(dim=1).clamp_min(1e-6)
cos_sim = (pred * target).sum(dim=1) / (p_norm * t_norm)
cos_loss = (1.0 - cos_sim).mean()
return mse + args.cos_weight * cos_loss, mse.item(), cos_sim.mean().item()
for epoch in range(args.epochs):
policy.train()
train_loss_total, train_mse_total, train_cos_total, train_count = 0.0, 0.0, 0.0, 0
for ob_batch, act_batch in train_loader:
ob_batch = ob_batch.to(args.device)
act_batch = act_batch.to(args.device)
optimizer.zero_grad()
mean_action = policy_forward_mean(policy, ob_batch)
loss, mse_val, cos_val = combined_loss(mean_action, act_batch)
loss.backward()
optimizer.step()
bs = ob_batch.size(0)
train_loss_total += loss.item() * bs
train_mse_total += mse_val * bs
train_cos_total += cos_val * bs
train_count += bs
train_mse = train_mse_total / max(1, train_count)
train_cos = train_cos_total / max(1, train_count)
policy.eval()
val_total, val_count = 0.0, 0
cos_sim_total = 0.0
with torch.no_grad():
for ob_batch, act_batch in val_loader:
ob_batch = ob_batch.to(args.device)
act_batch = act_batch.to(args.device)
mean_action = policy_forward_mean(policy, ob_batch)
bs = ob_batch.size(0)
val_total += nn.functional.mse_loss(
mean_action, act_batch, reduction="sum",
).item()
m_norm = mean_action.norm(dim=1).clamp_min(1e-6)
a_norm = act_batch.norm(dim=1).clamp_min(1e-6)
cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm)
cos_sim_total += cos.sum().item()
val_count += bs
val_mse = val_total / max(1, val_count) / actions.shape[1]
cos_sim = cos_sim_total / max(1, val_count)
print(f" epoch {epoch+1:>2d}/{args.epochs} "
f"train_mse={train_mse:.4f} train_cos={train_cos:+.3f} "
f"val_mse={val_mse:.4f} val_cos={cos_sim:+.3f}")
if val_mse < best_val:
best_val = val_mse
if cos_sim > best_cos:
best_cos = cos_sim
best_state = {k: v.detach().cpu().clone()
for k, v in policy.state_dict().items()}
if best_state is not None:
policy.load_state_dict(best_state)
print(f"[bc] restored best-val_cos snapshot (cos={best_cos:.3f})")
elapsed = time.time() - t_start
print(f"[bc] done in {elapsed:.0f}s best_val_mse={best_val:.4f}")
# --- Save ---
out_dir = Path(args.out)
out_dir.mkdir(parents=True, exist_ok=True)
model.save(out_dir / "policy.zip")
print(f"[bc] saved policy to {out_dir / 'policy.zip'}")
print(f"\n[bc] verify with: "
f"python -m training.eval --policy {out_dir}")
if __name__ == "__main__":
main()