245 lines
9.5 KiB
Python
245 lines
9.5 KiB
Python
"""Behavior cloning of an analytic teacher into an SB3-compatible policy.
|
|
|
|
Trains the policy network (mean-action head) of an SB3 ``MlpPolicy``
|
|
to mimic the (obs, action) demonstrations produced by
|
|
``tools.collect_demos``. The saved zip is loadable via ``PPO.load(...)``
|
|
and is what the Webots dog controller uses in ``HERDING_MODE=rl``.
|
|
|
|
Loss: MSE + (1 - cosine similarity). The cosine term is what stops
|
|
the policy mean from collapsing toward zero against unit-vector
|
|
targets. Best-by-val_cos checkpoint is restored at the end of training
|
|
so noisy multi-modal teachers (e.g. Strömbom) don't lose progress when
|
|
the last epoch lands on a bad gradient step.
|
|
|
|
Usage::
|
|
|
|
python -m training.bc_pretrain \\
|
|
--demos training/demos.npz \\
|
|
--out training/runs/bc
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
|
if _PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, _PROJECT_ROOT)
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
|
|
|
from training.herding_env import HerdingEnv
|
|
|
|
|
|
def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
|
|
frame_stack: int = 1):
|
|
"""Build a fresh SB3 PPO solely as a vehicle for the policy weights.
|
|
|
|
PPO's training-loop plumbing isn't used during BC. ``frame_stack``
|
|
must match the demo file so the env's obs space agrees with the
|
|
recorded obs shape.
|
|
"""
|
|
env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack)])
|
|
model = PPO(
|
|
"MlpPolicy", env,
|
|
policy_kwargs=dict(
|
|
net_arch=dict(pi=net_arch_pi, vf=net_arch_vf),
|
|
log_std_init=log_std_init,
|
|
),
|
|
verbose=0,
|
|
)
|
|
return model, env
|
|
|
|
|
|
def policy_forward_mean(policy, obs_batch):
|
|
"""Return the policy's deterministic mean action for a batch.
|
|
|
|
SB3's ActorCriticPolicy doesn't expose this directly — it goes
|
|
through a Distribution wrapper. We replicate the forward path:
|
|
extract_features → mlp_extractor → action_net.
|
|
"""
|
|
features = policy.extract_features(obs_batch)
|
|
if isinstance(features, tuple):
|
|
# SB3 ≥ 2.0 sometimes returns (pi_features, vf_features)
|
|
pi_features = features[0]
|
|
else:
|
|
pi_features = features
|
|
latent_pi, _latent_vf = policy.mlp_extractor(pi_features)
|
|
return policy.action_net(latent_pi)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--demos", default="training/demos.npz")
|
|
parser.add_argument("--out", default="training/runs/bc")
|
|
parser.add_argument("--epochs", type=int, default=60)
|
|
parser.add_argument("--batch-size", type=int, default=256)
|
|
parser.add_argument("--lr", type=float, default=1e-3)
|
|
parser.add_argument("--val-split", type=float, default=0.1)
|
|
parser.add_argument("--net-arch", default="256,256",
|
|
help="Comma-separated hidden layer widths.")
|
|
parser.add_argument("--log-std-init", type=float, default=0.5)
|
|
parser.add_argument("--cos-weight", type=float, default=1.0,
|
|
help="Weight on (1 - cosine similarity) loss term. "
|
|
"MSE alone shrinks policy output toward zero "
|
|
"(zero-magnitude action minimises mean squared "
|
|
"error against ±1 targets); cos loss keeps "
|
|
"the action pointed correctly even at small "
|
|
"magnitudes.")
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument("--device", default="cpu")
|
|
args = parser.parse_args()
|
|
|
|
torch.manual_seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
|
|
# --- Load demos ---
|
|
print(f"[bc] loading demos from {args.demos}")
|
|
data = np.load(args.demos)
|
|
obs = data["obs"].astype(np.float32)
|
|
actions = data["actions"].astype(np.float32)
|
|
meta = data["meta"]
|
|
print(f"[bc] obs={obs.shape} actions={actions.shape} trajectories={len(meta)}")
|
|
if obs.size == 0:
|
|
raise RuntimeError("Empty demo file.")
|
|
|
|
# Action sanity check — sequential outputs unit vectors.
|
|
a_norms = np.linalg.norm(actions, axis=1)
|
|
print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} "
|
|
f"min={a_norms.min():.3f} max={a_norms.max():.3f}")
|
|
|
|
# --- Train/val split ---
|
|
n = len(obs)
|
|
perm = np.random.permutation(n)
|
|
n_val = int(n * args.val_split)
|
|
val_idx, train_idx = perm[:n_val], perm[n_val:]
|
|
print(f"[bc] train={len(train_idx)} val={len(val_idx)}")
|
|
|
|
obs_t = torch.from_numpy(obs)
|
|
act_t = torch.from_numpy(actions)
|
|
train_loader = DataLoader(
|
|
TensorDataset(obs_t[train_idx], act_t[train_idx]),
|
|
batch_size=args.batch_size, shuffle=True,
|
|
)
|
|
val_loader = DataLoader(
|
|
TensorDataset(obs_t[val_idx], act_t[val_idx]),
|
|
batch_size=args.batch_size, shuffle=False,
|
|
)
|
|
|
|
# --- Build model ---
|
|
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
|
|
net_arch_vf = net_arch_pi[:]
|
|
# Auto-detect frame stacking from the demo file so a stacked-obs
|
|
# demo trains a stacked-obs policy without an extra CLI flag.
|
|
obs_dim = obs.shape[1]
|
|
from herding.obs import OBS_DIM as _SINGLE
|
|
if obs_dim % _SINGLE != 0:
|
|
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
|
|
frame_stack = obs_dim // _SINGLE
|
|
if frame_stack > 1:
|
|
print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}")
|
|
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init,
|
|
frame_stack=frame_stack)
|
|
policy = model.policy.to(args.device)
|
|
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
|
|
|
|
# --- Train ---
|
|
print(f"[bc] training: epochs={args.epochs} batch={args.batch_size} "
|
|
f"lr={args.lr} device={args.device}")
|
|
t_start = time.time()
|
|
best_val = float("inf")
|
|
best_cos = -1.0
|
|
# Snapshot the best-by-val_cos policy weights and restore at the end —
|
|
# training is noisy on multi-modal teachers (e.g. Strömbom collect/drive),
|
|
# so the last epoch is often worse than an earlier one.
|
|
best_state = None
|
|
|
|
def combined_loss(pred, target):
|
|
mse = nn.functional.mse_loss(pred, target)
|
|
p_norm = pred.norm(dim=1).clamp_min(1e-6)
|
|
t_norm = target.norm(dim=1).clamp_min(1e-6)
|
|
cos_sim = (pred * target).sum(dim=1) / (p_norm * t_norm)
|
|
cos_loss = (1.0 - cos_sim).mean()
|
|
return mse + args.cos_weight * cos_loss, mse.item(), cos_sim.mean().item()
|
|
|
|
for epoch in range(args.epochs):
|
|
policy.train()
|
|
train_loss_total, train_mse_total, train_cos_total, train_count = 0.0, 0.0, 0.0, 0
|
|
for ob_batch, act_batch in train_loader:
|
|
ob_batch = ob_batch.to(args.device)
|
|
act_batch = act_batch.to(args.device)
|
|
optimizer.zero_grad()
|
|
mean_action = policy_forward_mean(policy, ob_batch)
|
|
loss, mse_val, cos_val = combined_loss(mean_action, act_batch)
|
|
loss.backward()
|
|
optimizer.step()
|
|
bs = ob_batch.size(0)
|
|
train_loss_total += loss.item() * bs
|
|
train_mse_total += mse_val * bs
|
|
train_cos_total += cos_val * bs
|
|
train_count += bs
|
|
train_mse = train_mse_total / max(1, train_count)
|
|
train_cos = train_cos_total / max(1, train_count)
|
|
|
|
policy.eval()
|
|
val_total, val_count = 0.0, 0
|
|
cos_sim_total = 0.0
|
|
with torch.no_grad():
|
|
for ob_batch, act_batch in val_loader:
|
|
ob_batch = ob_batch.to(args.device)
|
|
act_batch = act_batch.to(args.device)
|
|
mean_action = policy_forward_mean(policy, ob_batch)
|
|
bs = ob_batch.size(0)
|
|
val_total += nn.functional.mse_loss(
|
|
mean_action, act_batch, reduction="sum",
|
|
).item()
|
|
# Cosine similarity in action space — useful sanity for
|
|
# "is the policy pointing the same way as the teacher?".
|
|
m_norm = mean_action.norm(dim=1).clamp_min(1e-6)
|
|
a_norm = act_batch.norm(dim=1).clamp_min(1e-6)
|
|
cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm)
|
|
cos_sim_total += cos.sum().item()
|
|
val_count += bs
|
|
val_mse = val_total / max(1, val_count) / actions.shape[1]
|
|
cos_sim = cos_sim_total / max(1, val_count)
|
|
print(f" epoch {epoch+1:>2d}/{args.epochs} "
|
|
f"train_mse={train_mse:.4f} train_cos={train_cos:+.3f} "
|
|
f"val_mse={val_mse:.4f} val_cos={cos_sim:+.3f}")
|
|
if val_mse < best_val:
|
|
best_val = val_mse
|
|
if cos_sim > best_cos:
|
|
best_cos = cos_sim
|
|
best_state = {k: v.detach().cpu().clone()
|
|
for k, v in policy.state_dict().items()}
|
|
|
|
if best_state is not None:
|
|
policy.load_state_dict(best_state)
|
|
print(f"[bc] restored best-val_cos snapshot (cos={best_cos:.3f})")
|
|
|
|
elapsed = time.time() - t_start
|
|
print(f"[bc] done in {elapsed:.0f}s best_val_mse={best_val:.4f}")
|
|
|
|
# --- Save ---
|
|
out_dir = Path(args.out)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
model.save(out_dir / "policy.zip")
|
|
print(f"[bc] saved policy to {out_dir / 'policy.zip'}")
|
|
print(f"\n[bc] verify with: "
|
|
f"python -m training.eval --policy {out_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|