Checkpoint 3
This commit is contained in:
+25
-11
@@ -1,20 +1,21 @@
|
||||
"""Behavior cloning of the sequential teacher into an SB3-compatible policy.
|
||||
"""Behavior cloning of an analytic teacher into an SB3-compatible policy.
|
||||
|
||||
Trains the policy network (mean-action head) of an SB3 ``MlpPolicy`` to
|
||||
mimic the demonstrations collected by ``tools.collect_demos``. The
|
||||
saved zip is loadable via ``PPO.load(...)`` and can be passed to
|
||||
``train_ppo.py --resume`` for fine-tuning.
|
||||
Trains the policy network (mean-action head) of an SB3 ``MlpPolicy``
|
||||
to mimic the (obs, action) demonstrations produced by
|
||||
``tools.collect_demos``. The saved zip is loadable via ``PPO.load(...)``
|
||||
and is what the Webots dog controller uses in ``HERDING_MODE=rl``.
|
||||
|
||||
Why this works: the teacher (sequential single-target driving) solves
|
||||
n=10 at 80%+ in our env. BC gives the RL a competent starting policy,
|
||||
so PPO doesn't have to discover behavior from scratch — it only has to
|
||||
*refine* the teacher's strategy via the sparse pen reward.
|
||||
Loss: MSE + (1 - cosine similarity). The cosine term is what stops
|
||||
the policy mean from collapsing toward zero against unit-vector
|
||||
targets. Best-by-val_cos checkpoint is restored at the end of training
|
||||
so noisy multi-modal teachers (e.g. Strömbom) don't lose progress when
|
||||
the last epoch lands on a bad gradient step.
|
||||
|
||||
Usage::
|
||||
|
||||
python -m training.bc_pretrain \\
|
||||
--demos training/demos.npz \\
|
||||
--out training/runs/bc_pretrained
|
||||
--out training/runs/bc_flock
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -80,7 +81,7 @@ def policy_forward_mean(policy, obs_batch):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--demos", default="training/demos.npz")
|
||||
parser.add_argument("--out", default="training/runs/bc_pretrained")
|
||||
parser.add_argument("--out", default="training/runs/bc_solo")
|
||||
parser.add_argument("--epochs", type=int, default=60)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--lr", type=float, default=1e-3)
|
||||
@@ -147,6 +148,11 @@ def main():
|
||||
f"lr={args.lr} device={args.device}")
|
||||
t_start = time.time()
|
||||
best_val = float("inf")
|
||||
best_cos = -1.0
|
||||
# Snapshot the best-by-val_cos policy weights and restore at the end —
|
||||
# training is noisy on multi-modal teachers (e.g. Strömbom collect/drive),
|
||||
# so the last epoch is often worse than an earlier one.
|
||||
best_state = None
|
||||
|
||||
def combined_loss(pred, target):
|
||||
mse = nn.functional.mse_loss(pred, target)
|
||||
@@ -201,6 +207,14 @@ def main():
|
||||
f"val_mse={val_mse:.4f} val_cos={cos_sim:+.3f}")
|
||||
if val_mse < best_val:
|
||||
best_val = val_mse
|
||||
if cos_sim > best_cos:
|
||||
best_cos = cos_sim
|
||||
best_state = {k: v.detach().cpu().clone()
|
||||
for k, v in policy.state_dict().items()}
|
||||
|
||||
if best_state is not None:
|
||||
policy.load_state_dict(best_state)
|
||||
print(f"[bc] restored best-val_cos snapshot (cos={best_cos:.3f})")
|
||||
|
||||
elapsed = time.time() - t_start
|
||||
print(f"[bc] done in {elapsed:.0f}s best_val_mse={best_val:.4f}")
|
||||
|
||||
Reference in New Issue
Block a user