Checkpoint 7

This commit is contained in:
Johnny Fernandes
2026-05-11 12:21:51 +01:00
parent fce0e0c786
commit a01a5c9cef
34 changed files with 1266 additions and 1038 deletions
+17 -9
View File
@@ -6,7 +6,7 @@ Two stages, strictly sequential:
sim demos (Strömbom on tracker output, K=4 frame stack)
bc_pretrain.py ──► runs/bc (Strömbom-imitated MLP)
bc/pretrain.py ──► runs/bc (Strömbom-imitated MLP)
▼ KL-regularised PPO fine-tune
@@ -17,10 +17,13 @@ runs/rl (deployed `rl` mode — beats BC and Strömbom)
```
herding_env.py — Gymnasium env (LiDAR raycast + tracker by default)
bc_pretrain.py — MSE + cosine BC of (obs, action) demos into MlpPolicy
train_ppo.py — KL-regularised PPO fine-tune of a BC checkpoint
bc/pretrain.py — MSE + cosine BC of (obs, action) demos into MlpPolicy
rl/train.py — KL-regularised PPO fine-tune of a BC checkpoint
eval.py — multi-seed analytic / learned policy comparison
runs/ — checkpoints (whitelisted entries in top-level .gitignore)
(Unit + integration tests live in the top-level ``tests/`` directory;
run with ``python -m pytest tests/``.)
```
## Setup
@@ -35,18 +38,23 @@ rollout collection, not gradient compute.
## End-to-end pipeline
The simplest way to run everything is the Makefile at the project
root: ``make`` does the full chain, ``make rl`` rebuilds whatever's
needed up to that point, etc. The individual stages below are kept
explicit for cases where you want to tune a single step.
```bash
# 1. Sim demos with the active-scan + Strömbom teacher under LiDAR
# perception. K=4 frame stack so the MLP has temporal context.
python -m tools.collect_demos --teacher strombom \
--out training/demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
python -m training.bc.collect --teacher strombom \
--out training/bc/demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
# 2. Behaviour-clone.
python -m training.bc_pretrain --demos training/demos.npz \
python -m training.bc.pretrain --demos training/bc/demos.npz \
--out training/runs/bc --epochs 60 --net-arch 512,512
# 3. KL-regularised PPO fine-tune of bc.
python -m training.train_ppo \
python -m training.rl.train \
--bc training/runs/bc --out training/runs/rl \
--total-timesteps 1000000
@@ -55,11 +63,11 @@ python -m training.eval --policy training/runs/rl \
--max-flock 10 --max-steps 15000 --n-seeds 10
```
`bc_pretrain.py` saves the **best-val_cos** snapshot, not the final
`bc/pretrain.py` saves the **best-val_cos** snapshot, not the final
epoch — multi-modal teachers make training noisy and the last epoch is
often worse than an earlier one.
`train_ppo.py` loads BC weights into both a trainable policy and a
`rl/train.py` loads BC weights into both a trainable policy and a
frozen reference, fixes `log_std` small, and adds `β · KL(π‖π_ref)` to
the loss so the policy can only move within a trust region around BC.
See the file header for hyperparameter rationale.
View File
+144
View File
@@ -0,0 +1,144 @@
"""Collect (obs, action) demonstrations from an analytic teacher.
Runs the chosen teacher across a grid of ``(n_sheep, seed)`` combos at
full difficulty, logs every Nth ``(obs, action)`` pair, and saves
successful trajectories to ``.npz`` for behaviour cloning. The teacher
is wrapped in :class:`ActiveScanTeacher` by default so it operates on
the same partial-obs view the student will have at deployment.
Usage::
python -m training.bc.collect --teacher strombom \\
--out training/bc/demos.npz --frame-stack 4
"""
from __future__ import annotations
import argparse
import time
from pathlib import Path
import numpy as np
from herding.control.active_scan import ActiveScanTeacher
from herding.world.geometry import PEN_ENTRY
from herding.control.sequential import compute_action as sequential_action
from herding.control.strombom import compute_action as strombom_action
from training.herding_env import HerdingEnv
TEACHERS = {
"sequential": sequential_action,
"strombom": strombom_action,
}
def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int,
teacher_fn, frame_stack: int = 1, privileged: bool = False):
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
difficulty=1.0, seed=seed, frame_stack=frame_stack)
obs, _ = env.reset(seed=seed)
obs_list, action_list = [], []
# Wrap the base teacher so it opens with a rotation and walks to
# centre when the tracker briefly empties — matches the student.
scan_teacher = ActiveScanTeacher(teacher_fn)
for step in range(max_steps):
if privileged:
# Asymmetric variant: teacher reads ground truth while the
# student keeps the LiDAR obs. Default off.
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
for i in range(env.n_sheep) if not env.sheep_penned[i]}
if not positions:
break
vx, vy, _mode = teacher_fn(
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
)
else:
positions = env.perceived_positions()
vx, vy, _mode = scan_teacher(
(env.dog_x, env.dog_y), env.dog_heading,
positions, PEN_ENTRY,
)
action = np.array([vx, vy], dtype=np.float32)
if step % subsample == 0:
obs_list.append(obs.copy())
action_list.append(action.copy())
obs, _r, term, trunc, _info = env.step(action)
if term or trunc:
break
success = bool(env.sheep_penned.all())
return (
np.asarray(obs_list, dtype=np.float32),
np.asarray(action_list, dtype=np.float32),
success,
env.steps,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--out", default="training/bc/demos.npz")
parser.add_argument("--n-sheep-list", default="1,2,3,5,8,10")
parser.add_argument("--seeds-per-n", type=int, default=15)
parser.add_argument("--max-steps", type=int, default=30000)
parser.add_argument("--subsample", type=int, default=5,
help="Keep every Nth (obs, action) pair.")
parser.add_argument("--keep-failures", action="store_true",
help="Include partial-success trajectories. Default off.")
parser.add_argument("--teacher", default="sequential",
choices=list(TEACHERS.keys()),
help="Which analytic teacher to demonstrate.")
parser.add_argument("--frame-stack", type=int, default=1,
help="Concatenate the last K obs into a "
"(32·K)-D vector for the policy.")
parser.add_argument("--privileged", action="store_true",
help="Teacher reads ground truth instead of "
"tracker output (asymmetric BC).")
args = parser.parse_args()
teacher_fn = TEACHERS[args.teacher]
print(f"[demos] teacher: {args.teacher}")
n_sheep_list = [int(x) for x in args.n_sheep_list.split(",")]
print(f"[demos] grid: n_sheep={n_sheep_list}, seeds={args.seeds_per_n}, "
f"max_steps={args.max_steps}, subsample={args.subsample}")
all_obs, all_actions, all_meta = [], [], []
t_start = time.time()
n_success = 0; n_total = 0
for n in n_sheep_list:
for seed in range(args.seeds_per_n):
obs, actions, success, total_steps = collect_one(
n, seed, args.max_steps, args.subsample, teacher_fn,
frame_stack=args.frame_stack, privileged=args.privileged,
)
n_total += 1
if success:
n_success += 1
keep = success or args.keep_failures
if keep and len(obs) > 0:
all_obs.append(obs)
all_actions.append(actions)
all_meta.append((n, seed, len(obs), int(success), total_steps))
tag = "" if success else ""
print(f" [{tag}] n={n:>2d} seed={seed:>2d} steps={total_steps:>6d} "
f"logged={len(obs):>5d}")
if not all_obs:
raise RuntimeError("No trajectories kept — try --keep-failures.")
obs = np.concatenate(all_obs, axis=0)
actions = np.concatenate(all_actions, axis=0)
meta = np.array(all_meta, dtype=np.int32)
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
np.savez(args.out, obs=obs, actions=actions, meta=meta)
elapsed = time.time() - t_start
print(f"\n=== {n_success}/{n_total} trajectories successful ({100*n_success/n_total:.0f}%) ===")
print(f"=== {len(obs)} transitions saved to {args.out} ===")
print(f"=== obs={obs.shape}, actions={actions.shape}, elapsed={elapsed:.0f}s ===")
if __name__ == "__main__":
main()
@@ -1,36 +1,27 @@
"""Behavior cloning of an analytic teacher into an SB3-compatible policy.
"""Behaviour cloning of an analytic teacher into an SB3 MlpPolicy.
Trains the policy network (mean-action head) of an SB3 ``MlpPolicy``
to mimic the (obs, action) demonstrations produced by
``tools.collect_demos``. The saved zip is loadable via ``PPO.load(...)``
and is what the Webots dog controller uses in ``HERDING_MODE=rl``.
Trains the mean-action head against ``(obs, action)`` demos from
``training.bc.collect`` using ``MSE + (1 cos_sim)`` the cosine
term prevents collapse toward zero against unit-vector targets. The
best-by-val_cos snapshot is restored at the end of training because
multi-modal teachers make the last epoch unreliable.
Loss: MSE + (1 - cosine similarity). The cosine term is what stops
the policy mean from collapsing toward zero against unit-vector
targets. Best-by-val_cos checkpoint is restored at the end of training
so noisy multi-modal teachers (e.g. Strömbom) don't lose progress when
the last epoch lands on a bad gradient step.
Output zip is loadable by ``PPO.load(...)`` and consumed by
``HERDING_MODE=bc`` in the dog controller.
Usage::
python -m training.bc_pretrain \\
--demos training/demos.npz \\
python -m training.bc.pretrain \\
--demos training/bc/demos.npz \\
--out training/runs/bc
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
import torch
import torch.nn as nn
@@ -64,25 +55,21 @@ def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
def policy_forward_mean(policy, obs_batch):
"""Return the policy's deterministic mean action for a batch.
"""Return the deterministic mean action for an obs batch.
SB3's ActorCriticPolicy doesn't expose this directly it goes
through a Distribution wrapper. We replicate the forward path:
extract_features mlp_extractor action_net.
SB3's ActorCriticPolicy routes ``forward`` through a Distribution
wrapper; we replicate the underlying chain
``extract_features mlp_extractor action_net``.
"""
features = policy.extract_features(obs_batch)
if isinstance(features, tuple):
# SB3 ≥ 2.0 sometimes returns (pi_features, vf_features)
pi_features = features[0]
else:
pi_features = features
latent_pi, _latent_vf = policy.mlp_extractor(pi_features)
pi_features = features[0] if isinstance(features, tuple) else features
latent_pi, _ = policy.mlp_extractor(pi_features)
return policy.action_net(latent_pi)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--demos", default="training/demos.npz")
parser.add_argument("--demos", default="training/bc/demos.npz")
parser.add_argument("--out", default="training/runs/bc")
parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--batch-size", type=int, default=256)
@@ -92,12 +79,8 @@ def main():
help="Comma-separated hidden layer widths.")
parser.add_argument("--log-std-init", type=float, default=0.5)
parser.add_argument("--cos-weight", type=float, default=1.0,
help="Weight on (1 - cosine similarity) loss term. "
"MSE alone shrinks policy output toward zero "
"(zero-magnitude action minimises mean squared "
"error against ±1 targets); cos loss keeps "
"the action pointed correctly even at small "
"magnitudes.")
help="Weight of the (1 - cosine_similarity) loss "
"term; balances against MSE.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
@@ -115,7 +98,6 @@ def main():
if obs.size == 0:
raise RuntimeError("Empty demo file.")
# Action sanity check — sequential outputs unit vectors.
a_norms = np.linalg.norm(actions, axis=1)
print(f"[bc] action L2 norm: mean={a_norms.mean():.3f} "
f"min={a_norms.min():.3f} max={a_norms.max():.3f}")
@@ -138,13 +120,11 @@ def main():
batch_size=args.batch_size, shuffle=False,
)
# --- Build model ---
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
net_arch_vf = net_arch_pi[:]
# Auto-detect frame stacking from the demo file so a stacked-obs
# demo trains a stacked-obs policy without an extra CLI flag.
# Frame stack is inferred from the demo obs dim.
obs_dim = obs.shape[1]
from herding.obs import OBS_DIM as _SINGLE
from herding.perception.obs import OBS_DIM as _SINGLE
if obs_dim % _SINGLE != 0:
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
frame_stack = obs_dim // _SINGLE
@@ -161,10 +141,7 @@ def main():
t_start = time.time()
best_val = float("inf")
best_cos = -1.0
# Snapshot the best-by-val_cos policy weights and restore at the end —
# training is noisy on multi-modal teachers (e.g. Strömbom collect/drive),
# so the last epoch is often worse than an earlier one.
best_state = None
best_state = None # restored at the end so noisy last epochs don't win
def combined_loss(pred, target):
mse = nn.functional.mse_loss(pred, target)
@@ -205,8 +182,6 @@ def main():
val_total += nn.functional.mse_loss(
mean_action, act_batch, reduction="sum",
).item()
# Cosine similarity in action space — useful sanity for
# "is the policy pointing the same way as the teacher?".
m_norm = mean_action.norm(dim=1).clamp_min(1e-6)
a_norm = act_batch.norm(dim=1).clamp_min(1e-6)
cos = (mean_action * act_batch).sum(dim=1) / (m_norm * a_norm)
+26 -38
View File
@@ -1,27 +1,19 @@
"""Evaluate a trained PPO policy (or the Strömbom baseline) on the env.
"""Env-side evaluation of analytic or learned policies.
Reports success rate and time-to-pen across a fixed seed grid for each
flock size 1..MAX_SHEEP. Used to produce the M5 quantitative comparison
table mentioned in plan.md.
Reports success rate, mean steps and mean penned per flock size for
``n_sheep ∈ 1..max_flock`` across ``--n-seeds`` seeds each.
Usage::
python -m training.eval --policy training/runs/latest/best
python -m training.eval --policy training/runs/rl --n-seeds 10
python -m training.eval --policy strombom
"""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
from statistics import mean, stdev
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from statistics import mean
import numpy as np
@@ -33,40 +25,38 @@ from training.herding_env import HerdingEnv
def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
obs, _ = env.reset()
success = False
for t in range(max_steps):
action = predict_fn(env, obs)
obs, _r, terminated, truncated, info = env.step(action)
if terminated or truncated:
success = bool(info.get("is_success", False))
return {"success": success, "steps": info.get("steps", t + 1),
"n_penned": info.get("n_penned", 0)}
return {"success": False, "steps": max_steps, "n_penned": int(env.sheep_penned.sum())}
return {
"success": bool(info.get("is_success", False)),
"steps": info.get("steps", t + 1),
"n_penned": info.get("n_penned", 0),
}
return {"success": False, "steps": max_steps,
"n_penned": int(env.sheep_penned.sum())}
def make_analytic_predictor(action_fn):
"""Wrap an analytic teacher so it runs on the env's exposed
perception (tracker in LiDAR mode, GT in privileged mode)."""
def _predict(env, _obs):
# Use whatever perception the env exposes — tracker output in
# LiDAR mode, ground truth in privileged mode. This makes
# evaluation honest: the analytic teacher sees what the
# deployed controller would see.
positions = env.perceived_positions()
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
return np.array([vx, vy], dtype=np.float32)
return _predict
# Backwards-compat alias.
def make_strombom_predictor():
return make_analytic_predictor(strombom_action)
def make_policy_predictor(model, vecnorm):
def _predict(_env, obs):
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
if vecnorm is not None:
obs_b = vecnorm.normalize_obs(np.asarray(obs, dtype=np.float32).reshape(1, -1))
else:
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
obs_b = vecnorm.normalize_obs(obs_b)
action, _ = model.predict(obs_b, deterministic=True)
return action[0]
return _predict
@@ -75,16 +65,17 @@ def make_policy_predictor(model, vecnorm):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--policy", required=True,
help="Either 'strombom' or path to an SB3 run directory.")
help="'strombom', 'sequential', or path to a "
"policy directory / zip.")
parser.add_argument("--n-seeds", type=int, default=10)
parser.add_argument("--max-steps", type=int, default=5000)
parser.add_argument("--max-flock", type=int, default=MAX_SHEEP)
# 1.0 = deployment distribution (sheep anywhere in field).
# Lower values use the training-curriculum spawn band (sheep near gate).
parser.add_argument("--difficulty", type=float, default=1.0)
parser.add_argument("--difficulty", type=float, default=1.0,
help="0 = sheep spawn near the gate (easy); "
"1 = full field (deployment distribution).")
args = parser.parse_args()
frame_stack = 1 # default; analytic predictors don't use stacked obs
frame_stack = 1
if args.policy == "strombom":
predict = make_analytic_predictor(strombom_action)
elif args.policy == "sequential":
@@ -92,23 +83,20 @@ def main():
else:
from stable_baselines3 import PPO
run = Path(args.policy)
# Resolve to a zip: directory of checkpoints, or a direct zip path.
if run.is_file():
zip_path = run
else:
for name in ("best_model.zip", "policy.zip", "final.zip"):
for name in ("policy.zip", "final.zip"):
if (run / name).exists():
zip_path = run / name
break
else:
raise FileNotFoundError(
f"No checkpoint found in {run} (tried best_model.zip, "
f"policy.zip, final.zip)"
f"No checkpoint found in {run} "
f"(tried policy.zip, final.zip)"
)
model = PPO.load(str(zip_path), device="auto")
# Auto-detect frame stacking from the policy's expected obs dim,
# so eval runs with whatever stacking the policy was trained on.
from herding.obs import OBS_DIM as _SINGLE
from herding.perception.obs import OBS_DIM as _SINGLE
policy_obs_dim = int(model.observation_space.shape[0])
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
frame_stack = policy_obs_dim // _SINGLE
+73 -179
View File
@@ -1,61 +1,30 @@
"""Gymnasium environment for the shepherd-dog herding task.
Single-agent: the agent is the dog. Sheep are environment-controlled
flocking agents whose dynamics are imported verbatim from
``herding.flocking_sim`` so a policy trained here transfers to Webots
without re-tuning. Differential-drive kinematics for both dog and sheep
match the proto specs (wheel radius, base, max wheel ω) via
``herding.diffdrive``.
Single-agent: the dog is the policy; sheep are env-controlled flocking
agents (``herding.world.flocking_sim``). Differential-drive kinematics
match the proto specs (``herding.world.diffdrive``) so a policy trained
here transfers to Webots without re-tuning.
Action space
------------
Box(-1, 1, (2,)) — the dog's desired (vx, vy) velocity *intent*. This
matches the high-level action representation the Webots controller
already uses; the env converts (vx, vy) → wheel speeds with the same
formula.
Observation space
-----------------
Box(-inf, inf, (28,)) — the order-invariant feature vector built by
``herding.obs.build_obs``. See ``herding/obs.py`` for the layout.
Reset
-----
``options["n_sheep"]`` (1..MAX_SHEEP) overrides the default flock size
for the next episode. If absent, flock size is sampled uniformly from
[1, max_n_sheep] each reset, where ``max_n_sheep`` can be raised over
training time by an outer callback.
Reward
------
Sparse + shaping (see :func:`HerdingEnv._compute_reward` for weights).
+2.0 per newly penned sheep
+0.5 · ΔCoM-distance-to-pen (positive when CoM moves closer)
+0.2 · ΔFlock-radius (positive when flock tightens)
-0.005 per step (encourages speed)
- wall and collision penalties
+10.0 terminal bonus when all sheep penned
* **Action**: ``Box(-1, 1, (2,))`` — desired ``(vx, vy)`` intent.
* **Observation**: ``Box(-inf, inf, (32·K,))`` from ``herding.perception.obs.build_obs``
with optional frame stacking K (concatenated oldest → newest).
* **Reset**: ``options["n_sheep"]`` overrides flock size; otherwise
sampled uniformly from ``[1, max_n_sheep]``.
* **Reward**: dense shaping (per-sheep distance progress, time
penalty, Strömbom-imitation cosine bonus) + sparse pen/done jackpots.
Weights live as class attributes on :class:`HerdingEnv`.
"""
from __future__ import annotations
import math
import os
import random
import sys
from typing import Optional
import gymnasium as gym
import numpy as np
from gymnasium import spaces
# Make herding/ importable when run from anywhere.
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from herding.world.diffdrive import (
heading_speed_to_wheels, kinematics_step, velocity_to_wheels,
)
@@ -71,7 +40,7 @@ from herding.world.geometry import (
)
from herding.perception.lidar_perception import detections_from_scan
from herding.perception.lidar_sim import simulate_scan
from herding.obs import OBS_DIM, build_obs
from herding.perception.obs import OBS_DIM, build_obs
from herding.perception.sheep_tracker import SheepTracker
from herding.control.strombom import compute_action as strombom_action
@@ -85,45 +54,23 @@ class HerdingEnv(gym.Env):
metadata = {"render_modes": []}
# Reward shaping weights. Re-tuned after the first run got stuck at
# 0% success: progress reward must dominate the time penalty by a
# large margin, and the pen-event bonus must be big enough that PPO's
# advantage estimator can credit-assign across the long path that
# leads to it. Per-step shaping is bounded by the clamps inside
# _compute_reward.
# Drastically simplified after two runs got stuck farming a position
# bonus instead of penning sheep. Reward now is essentially:
# • huge jackpot for actually penning sheep (+100 per pen, +500 done)
# • small dense gradient: per-sheep mean distance to pen
# No position shaping (gameable), no compactness shaping (gameable),
# no engagement bonus (gameable). The terminal per-unpenned penalty
# forbids "good enough" partial herds.
# We have a working analytic baseline (Strömbom, 100 % on easy mode).
# Use it as a teacher: per-step bonus proportional to the cosine
# similarity between the policy's action and what Strömbom would do
# at the same state. This drags the policy out of "do nothing" local
# optima without locking it to the teacher — PPO can still find
# improvements over Strömbom because pen jackpots dominate.
W_PEN_DELTA = 100.0
W_PROGRESS = 20.0
W_IMITATE = 0.5 # per-step max ±0.5 (action cosine sim, [-1, 1])
W_TIME = 0.0
W_WALL = 0.0
W_COLLISION = 0.0
W_DONE = 500.0
# Reward weights. Sparse jackpots (W_PEN_DELTA, W_DONE) dominate;
# dense shaping (W_PROGRESS on Δ mean-distance-to-pen) provides the
# gradient; W_IMITATE adds a small cosine bonus toward the analytic
# teacher's action; W_TIME is a per-step penalty (0 by default).
W_PEN_DELTA = 100.0
W_PROGRESS = 20.0
W_IMITATE = 0.5
W_TIME = 0.0
W_WALL = 0.0
W_COLLISION = 0.0
W_DONE = 500.0
# Action smoothing during training: 0 = none. The Webots controller
# still applies its own EMA at inference for actuator stability, so
# the policy doesn't need to learn smoothness explicitly.
# In-env action EMA. 0 = none; the Webots controller applies its own
# EMA at inference, so the policy needn't learn smoothness.
ACTION_SMOOTH = 0.0
# Episode budget. ~80 s of sim time at dt=0.016. The new external-pen
# layout has paths up to ~28 m from spawn to pen entry; at sheep flee
# speed ~0.4 m/s, that's 70 s minimum. 3000 steps (48 s) was leaving
# the dog with no margin for collect-then-drive on multi-sheep cases.
DEFAULT_MAX_STEPS = 5000
# Distance under which the dog is considered "colliding" with a sheep.
COLLISION_DIST = 0.30
def __init__(
@@ -137,19 +84,15 @@ class HerdingEnv(gym.Env):
frame_stack: int = 1,
):
super().__init__()
# When True (default), the obs and the imitation-reward teacher
# see only LiDAR-perceived sheep positions through a tracker
# matching what the Webots controller has access to. When False,
# both consume ground-truth positions (legacy "privileged" mode,
# kept for ablation).
# ``use_lidar=True`` (default): obs and imitation-reward teacher
# see only LiDAR-perceived positions via a tracker, matching the
# Webots controller. ``False`` exposes ground truth for ablation.
self._use_lidar = bool(use_lidar)
self._tracker = SheepTracker() if self._use_lidar else None
self._np_rng_lidar: Optional[np.random.Generator] = None
# Frame stacking: the policy receives the last K single-frame
# observations concatenated. Lets a memoryless MLP integrate
# information across time, partly compensating for the limited
# LiDAR FOV. K=1 reproduces the legacy single-frame obs.
# Frame stacking: the policy receives the last K obs concatenated,
# giving a memoryless MLP temporal context. K=1 → single frame.
self._frame_stack = max(1, int(frame_stack))
self._frame_buffer: list[np.ndarray] = []
self.action_space = spaces.Box(-1.0, 1.0, shape=(2,), dtype=np.float32)
@@ -159,18 +102,16 @@ class HerdingEnv(gym.Env):
shape=(OBS_DIM * self._frame_stack,), dtype=np.float32,
)
# If n_sheep is None, env will sample uniformly from [1, max_n_sheep]
# on every reset — this is the default for curriculum-free training.
# n_sheep=None → sample uniformly from [1, max_n_sheep] each reset.
self._fixed_n_sheep = n_sheep
self._max_n_sheep = max_n_sheep
self.max_steps = max_steps
# difficulty ∈ [0, 1]: 0 = sheep spawn next to the gate (easy),
# 1 = sheep spawn anywhere in the field (hard, the deployment
# distribution). Curriculum bumps this from 0 → 1 over training.
# difficulty ∈ [0, 1]: 0 = sheep spawn near the gate (easy);
# 1 = sheep spawn anywhere in the field (deployment distribution).
self._difficulty = float(difficulty)
self._initial_seed = seed
# State (initialized in reset)
# State (initialised in reset)
self.dog_x = self.dog_y = self.dog_heading = 0.0
self.sheep_x = np.zeros(0, dtype=np.float32)
self.sheep_y = np.zeros(0, dtype=np.float32)
@@ -186,12 +127,10 @@ class HerdingEnv(gym.Env):
self.prev_d_pen = 0.0
self.prev_radius = 0.0
# Env-owned RNG for the flocking wander-jitter, seeded fresh on each
# reset so determinism is preserved without touching the global
# random module.
# Env-owned RNG for wander jitter, re-seeded from np_random on reset.
self._py_rng = random.Random()
# ---- public knobs (used by curriculum callback) ----
# --- Public knobs ---
def set_max_n_sheep(self, value: int) -> None:
self._max_n_sheep = int(np.clip(value, 1, MAX_SHEEP))
@@ -199,22 +138,18 @@ class HerdingEnv(gym.Env):
self._difficulty = float(np.clip(value, 0.0, 1.0))
def set_imitate_weight(self, value: float) -> None:
"""Override W_IMITATE (instance-level) — used to disable the
Strömbom imitation reward during BC fine-tuning, when the policy
already mimics a stronger teacher (sequential)."""
"""Override the instance W_IMITATE — used to disable Strömbom
imitation during PPO fine-tune."""
self.W_IMITATE = float(value)
def set_time_weight(self, value: float) -> None:
"""Override W_TIME (instance-level). Default 0.0; a small
negative value (e.g. -0.1) adds a per-step penalty that
explicitly rewards fast time-to-pen during PPO fine-tune."""
"""Override the instance W_TIME — set negative to penalise step
count and encourage faster time-to-pen during PPO fine-tune."""
self.W_TIME = float(value)
# ---- gym API ----
# --- gym API ---
def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
# Re-seed the flocking RNG from np_random so flocking jitter is
# reproducible alongside everything else the env samples.
self._py_rng.seed(int(self.np_random.integers(0, 2**31 - 1)))
opts = options or {}
@@ -230,28 +165,26 @@ class HerdingEnv(gym.Env):
self.dog_y = float(self.np_random.uniform(-2.5, 2.5))
self.dog_heading = float(self.np_random.uniform(-math.pi, math.pi))
# Sheep spawn region scales with difficulty:
# 0.0 → narrow box just north of the gate (x ∈ [7, 14], y ∈ [-12, -6])
# 1.0 → full field (x ∈ [-13, 13], y ∈ [-12, 13])
# Linear interpolation between the two for intermediate values.
# Sheep spawn region linearly interpolates with difficulty:
# 0 → small box near the gate, 1 → full field.
d = self._difficulty
sx_lo = 7.0 - d * 20.0 # → -13 at d=1
sx_hi = 14.0 - d * 1.0 # → 13 at d=1
sy_lo = -12.0 + d * 0.0 # → -12 at d=1
sy_hi = -6.0 + d * 19.0 # → 13 at d=1
sx_lo = 7.0 - d * 20.0
sx_hi = 14.0 - d * 1.0
sy_lo = -12.0 + d * 0.0
sy_hi = -6.0 + d * 19.0
sxs, sys_, shs, sws = [], [], [], []
for _ in range(self.n_sheep):
for _try in range(100):
sx = float(self.np_random.uniform(sx_lo, sx_hi))
sy = float(self.np_random.uniform(sy_lo, sy_hi))
# Reject too close to dog or to other sheep.
# Reject if too close to the dog or another sheep, or
# already in the gate column (would start "penned").
if math.hypot(sx - self.dog_x, sy - self.dog_y) < 3.0:
continue
if any(math.hypot(sx - x, sy - y) < 1.5
for x, y in zip(sxs, sys_)):
continue
# Reject inside the gate column already (they'd start "penned").
if PEN_X[0] <= sx <= PEN_X[1] and sy < -8.0:
continue
break
@@ -275,10 +208,8 @@ class HerdingEnv(gym.Env):
self._tracker.reset()
self._np_rng_lidar = np.random.default_rng(
int(self.np_random.integers(0, 2**31 - 1)))
# Prime the tracker with one scan so the first obs isn't empty.
self._update_tracker()
# Clear the frame stack — the next _build_obs will repopulate.
self._frame_buffer = []
obs = self._build_obs()
@@ -288,7 +219,6 @@ class HerdingEnv(gym.Env):
def step(self, action):
action = np.clip(np.asarray(action, dtype=np.float32), -1.0, 1.0)
# EMA smoothing — the Webots controller does this too.
self.smoothed_action = (
self.ACTION_SMOOTH * self.prev_action
+ (1.0 - self.ACTION_SMOOTH) * action
@@ -296,12 +226,11 @@ class HerdingEnv(gym.Env):
self.prev_action = self.smoothed_action.copy()
vx, vy = float(self.smoothed_action[0]), float(self.smoothed_action[1])
# Safety supervisor mirrored from the controller — keeps the dog
# north of the gate so the policy can't strand itself in the pen.
# Safety supervisor — dog stays north of the gate.
if self.dog_y < DOG_SOUTH_LIMIT and vy < 0.0:
vx, vy = 0.0, 1.0
# --- Step the dog ---
# Step the dog.
wL, wR = velocity_to_wheels(
vx, vy, self.dog_heading,
max_linear=DOG_MAX_LINEAR,
@@ -313,27 +242,22 @@ class HerdingEnv(gym.Env):
self.dog_x, self.dog_y, self.dog_heading,
wL, wR, DOG_WHEEL_RADIUS, DOG_WHEEL_BASE, WEBOTS_DT,
)
# Clip dog to field bounds and out of pen — same as the Webots stone walls.
self.dog_x = float(np.clip(self.dog_x, FIELD_X[0] + 0.3, FIELD_X[1] - 0.3))
self.dog_y = float(np.clip(self.dog_y, DOG_SOUTH_LIMIT, FIELD_Y[1] - 0.3))
# --- Step each sheep ---
# Step sheep and update penned flags (GT-based).
for i in range(self.n_sheep):
self._step_one_sheep(i)
# --- Update penned state ---
for i in range(self.n_sheep):
if (not self.sheep_penned[i]
and is_penned_position(self.sheep_x[i], self.sheep_y[i])):
self.sheep_penned[i] = True
# --- Run LiDAR perception on this step's state (after sheep have
# moved). Updates the tracker that obs and the imitation-
# reward teacher consume. Reward / termination still use GT. ---
# LiDAR perception runs after sheep move; feeds the obs and the
# imitation reward. Reward/termination still use GT.
if self._tracker is not None:
self._update_tracker()
# --- Reward, termination ---
d_pen, radius = self._flock_metrics()
reward = self._compute_reward(d_pen, radius, action=action)
self.prev_d_pen = d_pen
@@ -346,12 +270,6 @@ class HerdingEnv(gym.Env):
truncated = self.steps >= self.max_steps
if all_penned:
reward += self.W_DONE
# No timeout penalty: a per-unpenned penalty made "do nothing"
# strictly preferable to noisy-random under reward-progress shaping
# (random sometimes pushes sheep away → negative progress, then
# always ate the timeout penalty), which collapsed exploration to
# tiny actions. The pen jackpot alone provides the directional
# signal once exploration is wide enough to find it.
obs = self._build_obs()
info = {
@@ -362,7 +280,7 @@ class HerdingEnv(gym.Env):
}
return obs, float(reward), terminated, truncated, info
# ---- internals ----
# --- Internals ---
def _step_one_sheep(self, i: int) -> None:
x, y = float(self.sheep_x[i]), float(self.sheep_y[i])
peers = [(float(self.sheep_x[j]), float(self.sheep_y[j]))
@@ -386,8 +304,7 @@ class HerdingEnv(gym.Env):
SHEEP_WHEEL_RADIUS, SHEEP_WHEEL_BASE, WEBOTS_DT,
)
# Wall clipping — matches Webots stone walls, except in the gate column
# where the south wall is absent.
# Wall clipping (south wall absent inside the gate column).
nx = float(np.clip(nx, FIELD_X[0] + 0.2, FIELD_X[1] - 0.2))
in_gate_col = PEN_X[0] <= nx <= PEN_X[1]
if in_gate_col:
@@ -400,12 +317,11 @@ class HerdingEnv(gym.Env):
self.sheep_h[i] = nh
def _flock_metrics(self):
"""(per-sheep mean distance to pen entry, max-radius).
"""Return (per-sheep mean distance to pen, max radius from CoM).
Using the per-sheep mean instead of CoM-distance ensures stragglers
keep contributing to the progress signal — the dog can't game the
shaping by herding the bulk of the flock and abandoning one
outlier (CoM moves toward pen, but mean-distance doesn't).
The per-sheep mean (not CoM distance) makes the progress signal
sensitive to stragglers: the dog can't game it by herding most of
the flock and abandoning one outlier.
"""
active_mask = ~self.sheep_penned
if not active_mask.any():
@@ -422,24 +338,14 @@ class HerdingEnv(gym.Env):
return d_pen, radius
def _compute_reward(self, d_pen: float, radius: float, action=None) -> float:
"""Sparse + per-sheep distance shaping + Strömbom imitation.
d_pen is the *mean* distance over active sheep, so progress only
accrues when ALL active sheep get closer to the pen on average —
the dog can't farm it by herding one sheep while ignoring others.
The imitation term is computed by querying Strömbom for the
recommended action at the *current* (post-step) state and
rewarding cosine similarity with what the policy actually did.
"""
"""Sparse pen jackpot + dense progress shaping + Strömbom imitation."""
n_penned = int(self.sheep_penned.sum())
delta_pen = n_penned - self.prev_n_penned
d_progress = max(-5.0, min(5.0, self.prev_d_pen - d_pen))
r = self.W_PEN_DELTA * delta_pen + self.W_PROGRESS * d_progress
# Per-step time penalty (0 by default). When negative, encourages
# the policy to finish quickly — used during PPO fine-tune.
r += self.W_TIME
r = (self.W_PEN_DELTA * delta_pen
+ self.W_PROGRESS * d_progress
+ self.W_TIME)
if action is not None and self.W_IMITATE > 0.0:
positions = self._perceived_positions()
@@ -457,10 +363,7 @@ class HerdingEnv(gym.Env):
def _build_single_obs(self) -> np.ndarray:
if self._tracker is not None:
# Obs sees only the tracker's active set; penned tracks are
# intentionally excluded (matches the prior receiver-based
# behaviour where penned sheep stopped contributing to the
# symbolic obs).
# LiDAR mode: the obs sees only the tracker's active set.
active = self._tracker.get_positions()
sheep_xy_list = list(active.values())
sheep_penned_list = [False] * len(sheep_xy_list)
@@ -477,22 +380,18 @@ class HerdingEnv(gym.Env):
single = self._build_single_obs()
if self._frame_stack <= 1:
return single
# On a fresh reset the buffer is empty — duplicate the first
# frame so the stack is always full-length.
# On reset the buffer is empty — pad with copies of the first frame.
if not self._frame_buffer:
self._frame_buffer = [single.copy() for _ in range(self._frame_stack)]
else:
self._frame_buffer.append(single)
if len(self._frame_buffer) > self._frame_stack:
self._frame_buffer = self._frame_buffer[-self._frame_stack:]
# Concatenate oldest → newest.
return np.concatenate(self._frame_buffer, axis=0).astype(np.float32)
# ------------------------------------------------------------------
# LiDAR perception helpers
# ------------------------------------------------------------------
# --- LiDAR perception ---
def _all_sheep_xy(self) -> list[tuple[float, float]]:
"""Every sheep, including penned ones (the LiDAR sees them)."""
"""Every sheep, including penned (the LiDAR sees them)."""
return [(float(self.sheep_x[i]), float(self.sheep_y[i]))
for i in range(self.n_sheep)]
@@ -508,19 +407,14 @@ class HerdingEnv(gym.Env):
self._tracker.update(detections)
def perceived_positions(self) -> dict[str, tuple[float, float]]:
"""Public accessor — what the controller would 'see' this step.
LiDAR mode → the tracker's active set.
Privileged mode → ground-truth active sheep.
Used by ``training.eval`` and ``tools.collect_demos`` so analytic
teachers run on the same perception the deployed controller has.
"""What the controller would "see" this step: tracker output in
LiDAR mode, ground truth in privileged mode. Used by demo
collection and analytic-policy eval so the teacher runs on the
same perception the deployed controller has.
"""
if self._tracker is not None:
return self._tracker.get_positions()
return {f"s{i}": (float(self.sheep_x[i]), float(self.sheep_y[i]))
for i in range(self.n_sheep) if not self.sheep_penned[i]}
# Internal alias so the imitation reward path doesn't need to know
# which mode it's in.
_perceived_positions = perceived_positions
+1
View File
@@ -6,3 +6,4 @@ numpy>=1.24
pyyaml>=6.0
tensorboard>=2.14
tqdm>=4.66
pytest>=8.0
View File
+57 -89
View File
@@ -1,30 +1,17 @@
"""KL-regularised PPO fine-tune of a behaviour-cloned policy.
The PPO-from-scratch and unregularised PPO-fine-tune-of-BC versions
we tried earlier failed for the standard reasons (sparse pen reward,
long horizons, exploration noise destroying BC weights). The fix is
to anchor the policy to its BC initialisation with a KL penalty in
the loss the policy is free to refine the BC mean within a
trust-region-like ball around the reference, and the dense-enough
per-step reward signal does the rest.
The trainable policy is initialised from ``runs/bc/policy.zip``. A
frozen copy of the same weights becomes the reference; each PPO loss
gets an extra ``β · KL(π π_ref)`` term so the policy can only move
within a trust region around BC. ``log_std`` is fixed small to keep
exploration tight.
Pipeline
--------
1. Load ``bc`` weights into both the trainable policy and a frozen
reference ``ref_policy``.
2. Initialise the policy's log_std to a small fixed value (≈ 1.5)
and disable its gradient exploration noise stays small so PPO
updates don't blow up the BC mean before reward can stabilise.
3. Override ``PPO.train()`` to add ``β · KL(π π_ref)`` to the loss
each minibatch.
4. Train for ~13 M timesteps with a low LR (5e-5).
Output: ``runs/rl/policy.zip`` same SB3 format as bc, loadable
by the dog controller's ``HERDING_MODE=rl`` path.
Output: ``runs/rl/policy.zip`` same SB3 format as the BC checkpoint,
loadable by ``HERDING_MODE=rl`` in the dog controller.
Usage::
python -m training.train_ppo \\
python -m training.rl.train \\
--bc training/runs/bc \\
--out training/runs/rl \\
--total-timesteps 2000000
@@ -33,15 +20,8 @@ Usage::
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
import torch as th
import torch.nn.functional as F
@@ -50,7 +30,7 @@ from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from herding.obs import OBS_DIM
from herding.perception.obs import OBS_DIM
from training.herding_env import HerdingEnv
@@ -73,15 +53,12 @@ def _make_env(rank: int, seed: int, frame_stack: int):
class KLPPO(PPO):
"""PPO with an extra KL-to-reference penalty in the policy loss.
Subclasses SB3's PPO and overrides ``train()`` only to add a single
line for the KL term everything else (rollout buffer, clipped
surrogate, value loss, entropy bonus) is unchanged.
Overrides only ``train()``; rollout buffer, clipped surrogate, value
loss and entropy bonus are unchanged from stock SB3 PPO.
"""
def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs):
super().__init__(*args, **kwargs)
# ref_policy is set after construction (caller can build it
# from the BC checkpoint once `self.policy` exists).
self.ref_policy = ref_policy
if self.ref_policy is not None:
self.ref_policy.set_training_mode(False)
@@ -90,9 +67,8 @@ class KLPPO(PPO):
self.kl_coef = kl_coef
def train(self) -> None:
# Copied from stable_baselines3.ppo.PPO.train (v2.x), with the
# KL-to-reference term added. Keeping the structure intact so
# behavioural parity with stock PPO is obvious.
# Stock SB3 PPO.train() structure with the KL-to-reference term
# added inside the inner minibatch loop.
self.policy.set_training_mode(True)
self._update_learning_rate(self.policy.optimizer)
clip_range = self.clip_range(self._current_progress_remaining)
@@ -139,12 +115,8 @@ class KLPPO(PPO):
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
# --- KL-to-reference term ----------------------------
# Both policies are diagonal Gaussian (ActorCriticPolicy).
# KL(π ‖ π_ref) per-action-dim; sum over the action axis
# to get total KL per sample, then mean over batch.
# Computed on the rollout's observations so the penalty
# reflects what the agent actually saw.
# KL-to-reference: closed-form KL between two diagonal
# Gaussians, summed over the action axis, mean over batch.
if self.ref_policy is None:
raise RuntimeError("KLPPO.train called without ref_policy")
with th.no_grad():
@@ -153,7 +125,6 @@ class KLPPO(PPO):
kl_div = th.distributions.kl.kl_divergence(
pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean()
kl_losses.append(kl_div.item())
# ----------------------------------------------------
loss = (policy_loss
+ self.ent_coef * entropy_loss
@@ -192,7 +163,6 @@ class KLPPO(PPO):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def _explained_variance(self) -> float:
# SB3 doesn't expose this as a method; replicate the computation.
y_pred = self.rollout_buffer.values.flatten()
y_true = self.rollout_buffer.returns.flatten()
var_y = np.var(y_true)
@@ -206,50 +176,41 @@ class KLPPO(PPO):
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--bc", default="training/runs/bc",
help="Directory containing the BC initialisation (policy.zip).")
help="Directory containing the BC initialisation.")
parser.add_argument("--out", default="training/runs/rl",
help="Where to save the fine-tuned policy.")
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
parser.add_argument("--n-envs", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=5e-5,
help="Low LR keeps PPO close to the BC mean.")
parser.add_argument("--learning-rate", type=float, default=5e-5)
parser.add_argument("--kl-coef", type=float, default=0.05,
help="KL-to-reference penalty coefficient.")
help="Coefficient of the KL-to-reference penalty.")
parser.add_argument("--log-std", type=float, default=-1.5,
help="Initial (and frozen) log_std. σ ≈ exp(-1.5) ≈ 0.22.")
parser.add_argument("--freeze-log-std", action="store_true", default=True,
help="Keep log_std fixed; only the policy mean updates.")
parser.add_argument("--n-steps", type=int, default=2048,
help="Steps per rollout per env.")
help="Initial (and frozen) log_std for exploration.")
parser.add_argument("--freeze-log-std", action="store_true", default=True)
parser.add_argument("--n-steps", type=int, default=2048)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--n-epochs", type=int, default=10)
parser.add_argument("--gamma", type=float, default=0.995)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--clip-range", type=float, default=0.1,
help="Tight clip range — keep updates conservative.")
parser.add_argument("--clip-range", type=float, default=0.1)
parser.add_argument("--ent-coef", type=float, default=0.0)
parser.add_argument("--target-kl", type=float, default=0.02,
help="SB3's per-batch KL early stop; safety belt.")
help="SB3 per-batch KL early-stop guard.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", default="cpu")
parser.add_argument("--imitate-weight", type=float, default=None,
help="Override env.W_IMITATE for this training "
"run. Set to 0.0 to drop the Strömbom "
"cosine-imitation reward — useful during "
"PPO refinement where you want reward, "
"not teacher imitation, to drive updates.")
help="Override env.W_IMITATE (e.g. 0.0 to drop "
"Strömbom imitation during fine-tune).")
parser.add_argument("--time-weight", type=float, default=None,
help="Override env.W_TIME. Default env value is "
"0.0; setting e.g. -0.1 adds a small per-"
"step penalty that explicitly rewards "
"fast time-to-pen.")
help="Override env.W_TIME (e.g. -0.1 for a "
"per-step time penalty).")
args = parser.parse_args()
bc_zip = Path(args.bc) / "policy.zip"
if not bc_zip.exists():
raise SystemExit(
f"BC checkpoint not found at {bc_zip}. Train bc first with "
f"`python -m training.bc_pretrain`."
f"`python -m training.bc.pretrain`."
)
out = Path(args.out)
@@ -257,7 +218,7 @@ def main() -> None:
(out / "checkpoints").mkdir(exist_ok=True)
(out / "best").mkdir(exist_ok=True)
# --- Inspect BC obs dim → infer frame_stack ---
# Infer frame_stack from the BC checkpoint's obs space.
ref_only = PPO.load(str(bc_zip), device=args.device)
obs_dim = int(ref_only.observation_space.shape[0])
if obs_dim % OBS_DIM != 0:
@@ -265,12 +226,11 @@ def main() -> None:
frame_stack = obs_dim // OBS_DIM
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
# --- Vectorised envs (match BC obs space) ---
env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)]
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
# --- Apply reward-shaping overrides to every env instance ---
# Reward-shaping overrides (broadcast to every env instance).
def _broadcast(method: str, value):
for v in (venv, eval_venv):
try:
@@ -284,10 +244,8 @@ def main() -> None:
_broadcast("set_time_weight", args.time_weight)
print(f"[rl] W_TIME overridden to {args.time_weight}")
# --- Trainable policy: load BC weights, then bolt onto PPO ---
# Trick: instantiate a PPO with the right env (so the policy
# network is constructed at the correct obs/action shape), then
# copy BC weights into it.
# Build a fresh KLPPO at the right obs/action shape, then copy BC
# weights into both the trainable policy and the frozen reference.
model = KLPPO(
"MlpPolicy", venv,
ref_policy=None, # filled in below
@@ -311,15 +269,11 @@ def main() -> None:
tensorboard_log=str(out / "tb"),
)
# --- Load BC weights into both `model.policy` and `ref_policy` ---
# strict=False — the BC value head wasn't trained; PPO trains it.
bc_state = ref_only.policy.state_dict()
# Strict=False because the value head may not have been trained in
# BC — that's fine, PPO will train it from scratch.
missing, unexpected = model.policy.load_state_dict(bc_state, strict=False)
print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}")
# Build a separate reference policy with identical architecture and
# the BC weights, frozen.
ref_policy = type(model.policy)(
observation_space=model.observation_space,
action_space=model.action_space,
@@ -333,11 +287,8 @@ def main() -> None:
for p in model.ref_policy.parameters():
p.requires_grad = False
# Align both policies' log_std. BC was trained with log_std≈0.5
# (σ≈1.65), which would make the KL term huge from a std mismatch
# rather than the mean drift we actually care about. Force both to
# the same small value so KL measures only how far the policy mean
# has drifted from the BC mean.
# Force both policies to the same log_std so the KL term measures
# mean drift only, not a std mismatch carried over from BC.
with th.no_grad():
model.policy.log_std.fill_(args.log_std)
model.ref_policy.log_std.fill_(args.log_std)
@@ -345,15 +296,18 @@ def main() -> None:
model.policy.log_std.requires_grad = False
print(f"[rl] log_std frozen at {args.log_std} (σ{np.exp(args.log_std):.3f})")
# --- Callbacks ---
ckpt_cb = CheckpointCallback(
save_freq=max(1, 50_000 // args.n_envs),
save_path=str(out / "checkpoints"),
name_prefix="ppo",
)
# EvalCallback writes <save_path>/best_model.zip on every new best
# eval reward. We send it straight to ``out/`` and rename to
# ``policy.zip`` after training so the deployed file lives at the
# canonical path.
eval_cb = EvalCallback(
eval_venv,
best_model_save_path=str(out / "best"),
best_model_save_path=str(out),
log_path=str(out / "evals"),
eval_freq=max(1, 20_000 // args.n_envs),
n_eval_episodes=5,
@@ -365,9 +319,23 @@ def main() -> None:
model.learn(total_timesteps=args.total_timesteps,
callback=[ckpt_cb, eval_cb], progress_bar=True)
# --- Save final checkpoint in the SB3 zip the controller expects ---
model.save(out / "policy.zip")
print(f"[rl] saved fine-tuned policy → {out/'policy.zip'}")
# Save the end-of-training state for debugging convergence behaviour.
model.save(out / "final.zip")
# Promote the EvalCallback's best-by-eval-reward snapshot to the
# canonical ``policy.zip`` (what the controller loads). Fall back
# to the final state if eval never recorded a "best".
import shutil
best_zip = out / "best_model.zip"
policy_zip = out / "policy.zip"
if best_zip.exists():
if policy_zip.exists():
policy_zip.unlink()
best_zip.rename(policy_zip)
print(f"[rl] best snapshot → {policy_zip} (final state kept at {out/'final.zip'})")
else:
shutil.copy(out / "final.zip", policy_zip)
print(f"[rl] no best snapshot recorded; using final → {policy_zip}")
if __name__ == "__main__":