Checkpoint 3

This commit is contained in:
Johnny Fernandes
2026-05-10 12:46:14 +01:00
parent 1bb9415414
commit 2a6db038df
16 changed files with 305 additions and 662 deletions
+25 -11
View File
@@ -1,20 +1,21 @@
"""Behavior cloning of the sequential teacher into an SB3-compatible policy.
"""Behavior cloning of an analytic teacher into an SB3-compatible policy.
Trains the policy network (mean-action head) of an SB3 ``MlpPolicy`` to
mimic the demonstrations collected by ``tools.collect_demos``. The
saved zip is loadable via ``PPO.load(...)`` and can be passed to
``train_ppo.py --resume`` for fine-tuning.
Trains the policy network (mean-action head) of an SB3 ``MlpPolicy``
to mimic the (obs, action) demonstrations produced by
``tools.collect_demos``. The saved zip is loadable via ``PPO.load(...)``
and is what the Webots dog controller uses in ``HERDING_MODE=rl``.
Why this works: the teacher (sequential single-target driving) solves
n=10 at 80%+ in our env. BC gives the RL a competent starting policy,
so PPO doesn't have to discover behavior from scratch — it only has to
*refine* the teacher's strategy via the sparse pen reward.
Loss: MSE + (1 - cosine similarity). The cosine term is what stops
the policy mean from collapsing toward zero against unit-vector
targets. Best-by-val_cos checkpoint is restored at the end of training
so noisy multi-modal teachers (e.g. Strömbom) don't lose progress when
the last epoch lands on a bad gradient step.
Usage::
python -m training.bc_pretrain \\
--demos training/demos.npz \\
--out training/runs/bc_pretrained
--out training/runs/bc_flock
"""
from __future__ import annotations
@@ -80,7 +81,7 @@ def policy_forward_mean(policy, obs_batch):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--demos", default="training/demos.npz")
parser.add_argument("--out", default="training/runs/bc_pretrained")
parser.add_argument("--out", default="training/runs/bc_solo")
parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--lr", type=float, default=1e-3)
@@ -147,6 +148,11 @@ def main():
f"lr={args.lr} device={args.device}")
t_start = time.time()
best_val = float("inf")
best_cos = -1.0
# Snapshot the best-by-val_cos policy weights and restore at the end —
# training is noisy on multi-modal teachers (e.g. Strömbom collect/drive),
# so the last epoch is often worse than an earlier one.
best_state = None
def combined_loss(pred, target):
mse = nn.functional.mse_loss(pred, target)
@@ -201,6 +207,14 @@ def main():
f"val_mse={val_mse:.4f} val_cos={cos_sim:+.3f}")
if val_mse < best_val:
best_val = val_mse
if cos_sim > best_cos:
best_cos = cos_sim
best_state = {k: v.detach().cpu().clone()
for k, v in policy.state_dict().items()}
if best_state is not None:
policy.load_state_dict(best_state)
print(f"[bc] restored best-val_cos snapshot (cos={best_cos:.3f})")
elapsed = time.time() - t_start
print(f"[bc] done in {elapsed:.0f}s best_val_mse={best_val:.4f}")