Checkpoint 4

This commit is contained in:
Johnny Fernandes
2026-05-11 00:42:52 +01:00
parent 2a6db038df
commit 6688325d89
26 changed files with 2018 additions and 503 deletions
+70 -54
View File
@@ -1,21 +1,29 @@
# Training pipeline
Behavior cloning of analytic herding teachers into a neural network
policy that runs in Webots. PPO from scratch and PPO fine-tune of BC
were tried earlier and are kept under `train_ppo.py` as experimental
options, but the BC route alone is what we ship.
Behavior cloning of analytic herding teachers into a neural-network
policy that runs under LiDAR perception in Webots.
```
sim demos (active-scan teacher on tracker output, K=4 frame stack)
bc_pretrain.py ──► runs/bc_v3 (deployed policy — beats Strömbom on n≥8)
▼ (optional: tools/auto_dagger.sh + tools/dagger_merge_train.py
│ if sim-trained doesn't transfer cleanly to Webots)
runs/bc_dagger
```
## Files
```
herding_env.py — Gymnasium env (used for demo collection + eval)
bc_pretrain.py — supervised MSE+cosine training of an SB3 MlpPolicy
against (obs, action) demos
herding_env.py — Gymnasium env (LiDAR raycast + tracker by default)
bc_pretrain.py — MSE + cosine BC of (obs, action) demos into MlpPolicy
eval.py — analytic teachers + BC policies, full n=1..10 grid
parity_test.py — shape/determinism/baseline smoke test
train_ppo.py — PPO trainer (experimental — see Appendix below)
configs/ — PPO hyperparameter YAML
runs/ — checkpoints (.gitignored)
parity_test.py — shape / determinism / baseline smoke test
runs/ — checkpoints (most are .gitignored; the deployed
ones are whitelisted in the top-level .gitignore)
```
## Setup
@@ -31,66 +39,74 @@ rollout collection, not gradient compute.
## The BC pipeline
```
# 1. Generate demos from an analytic teacher.
# --teacher: strombom (default), sequential, drive_only, hybrid, strombom_smooth
# 1. Sim demos with the active-scan + Strömbom teacher under LiDAR
# perception. K=4 frame stack so the MLP has temporal context.
python -m tools.collect_demos --teacher strombom \
--out demos.npz --seeds-per-n 30 --subsample 3
--out demos_v3.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
# 2. Behavior-clone the demos into an MLP policy.
python -m training.bc_pretrain --demos demos.npz \
--out runs/bc_flock --epochs 100 --net-arch 512,512
# 2. Behavior-clone.
python -m training.bc_pretrain --demos demos_v3.npz \
--out runs/bc_v3 --epochs 60 --net-arch 512,512
# 3. Evaluate the resulting policy.
python -m training.eval --policy runs/bc_flock \
--max-flock 10 --max-steps 30000 --n-seeds 5
# 3. Evaluate.
python -m training.eval --policy runs/bc_v3 \
--max-flock 10 --max-steps 8000 --n-seeds 5
```
Wall time: ~10 min demos + ~5 min BC training + ~5 min eval.
`bc_pretrain.py` saves the **best-val_cos** snapshot, not the final
epoch — multi-modal teachers (Strömbom's collect/drive switch) make
training noisy and the last epoch is often worse than an earlier one.
epoch — multi-modal teachers make training noisy and the last epoch is
often worse than an earlier one.
## DAgger from Webots
Sim-only BC plateaus because the env's 2D raycast can't reproduce all
the false-positive clusters Webots generates from real geometry. The
fix is to collect (obs, teacher_action) pairs from inside Webots:
```
# Headless DAgger collection: 5 flock sizes × 3 runs each.
tools/auto_dagger.sh 3 60
# Merge with the sim baseline + retrain.
python -m tools.dagger_merge_train --out runs/bc_dagger
```
Iterate by re-running collection with the new student in the driver's
seat:
```
HERDING_POLICY_DIR=$PWD/training/runs/bc_dagger \
HERDING_DAGGER_DRIVER=student \
tools/auto_dagger.sh 3 60
python -m tools.dagger_merge_train --out runs/bc_dagger_v2
```
## Available analytic teachers
| Name | What it does | Best for |
| Name | What it does | Notes |
|---|---|---|
| `strombom` | Canonical Strömbom — collect when flock is scattered, drive CoM otherwise | Tight-cohesion regime, n=1-10 |
| `sequential` | Pick the sheep closest to the pen and drive only it | Loose-cohesion regime, n=1-10 |
| `drive_only` | Strömbom drive without collect mode (continuous action) | Easier-to-BC alternative; less reliable than full Strömbom |
| `hybrid` | Drive rearmost sheep when far, switch to closest near gate | Failed experiment, kept for write-up |
| `strombom_smooth` | Sigmoid-blended Strömbom collect↔drive | Failed experiment |
| `strombom` | Canonical Strömbom — collect when flock is scattered, drive CoM otherwise | Default; works well for n=110 under tight cohesion |
| `sequential` | Pick the sheep closest to the pen and drive only it | Alternative; needs loose-cohesion regime |
## Evaluating the analytic teachers directly
Both are wrapped at demo-collection time in
`herding/active_scan.py:ActiveScanTeacher`, which adds an opening
in-place rotation, walk-to-centre when the LiDAR sees nothing, and
near-sheep speed modulation (the same modulation `herding/control.py`
applies to every dog mode at inference).
## Evaluating analytic teachers directly
```
python -m training.eval --policy strombom --max-flock 10 --max-steps 30000 --n-seeds 5
python -m training.eval --policy sequential --max-flock 10 --max-steps 30000 --n-seeds 5
python -m training.eval --policy strombom --max-flock 10 --max-steps 8000 --n-seeds 5
python -m training.eval --policy sequential --max-flock 10 --max-steps 8000 --n-seeds 5
```
## Webots inference
The Webots dog controller (`controllers/shepherd_dog/shepherd_dog.py`)
loads a saved BC zip when launched in `rl` mode:
```
HERDING_POLICY_DIR=$PWD/runs/bc_flock tools/run_webots.sh 10 rl
tools/run_webots.sh 10 rl
```
It auto-discovers a checkpoint named `policy.zip`, `best_model.zip`, or
`final.zip` in the directory.
## Appendix — experimental PPO scripts
`train_ppo.py` contains the PPO/RL pipeline tried before BC:
* PPO from scratch with curriculum learning over flock size + spawn area.
* PPO fine-tune of a BC checkpoint.
Both ran into stability issues (PPO's exploration noise destroys BC
weights faster than the reward signal can rebuild them; PPO from
scratch never sees pen events often enough during random exploration to
credit-assign the +500 done bonus).
The script is left in place because the abstractions are sound and the
code is reusable for follow-up work (e.g. KL-regularised fine-tune
with a frozen reference policy). Not part of the deliverable pipeline.
The dog controller loads the highest-priority policy that exists
(`bc_dagger_v2``bc_dagger``bc_v3`). Override with
`HERDING_POLICY_DIR=…` if you want a specific checkpoint.
+18 -6
View File
@@ -43,13 +43,15 @@ from stable_baselines3.common.vec_env import DummyVecEnv
from training.herding_env import HerdingEnv
def build_model(net_arch_pi, net_arch_vf, log_std_init: float):
"""Build a fresh SB3 PPO with the same architecture as train_ppo.
def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
frame_stack: int = 1):
"""Build a fresh SB3 PPO solely as a vehicle for the policy weights.
We only need the policy to load weights into; PPO's training-loop
plumbing isn't used during BC.
PPO's training-loop plumbing isn't used during BC. ``frame_stack``
must match the demo file so the env's obs space agrees with the
recorded obs shape.
"""
env = DummyVecEnv([lambda: HerdingEnv()])
env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack)])
model = PPO(
"MlpPolicy", env,
policy_kwargs=dict(
@@ -139,7 +141,17 @@ def main():
# --- Build model ---
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
net_arch_vf = net_arch_pi[:]
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init)
# Auto-detect frame stacking from the demo file so a stacked-obs
# demo trains a stacked-obs policy without an extra CLI flag.
obs_dim = obs.shape[1]
from herding.obs import OBS_DIM as _SINGLE
if obs_dim % _SINGLE != 0:
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
frame_stack = obs_dim // _SINGLE
if frame_stack > 1:
print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}")
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init,
frame_stack=frame_stack)
policy = model.policy.to(args.device)
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
-52
View File
@@ -1,52 +0,0 @@
# PPO hyperparameters for the herding env. Tuned for a 28-D obs / 2-D
# continuous action space with 16 parallel envs on GPU. These are SB3
# defaults nudged toward longer credit assignment (gamma=0.995) and a
# slightly higher entropy bonus to keep exploration alive while curriculum
# expands the flock size.
# --- PPO ---
learning_rate: 3.0e-4
n_steps: 2048 # rollout length per env before each update
batch_size: 256
n_epochs: 10
gamma: 0.995
gae_lambda: 0.95
clip_range: 0.2
ent_coef: 0.05 # was 0.01 — earlier runs collapsed to ~0 actions
vf_coef: 0.5
max_grad_norm: 0.5
target_kl: null # disable early-stop on KL
# --- Network ---
policy: MlpPolicy
net_arch_pi: [128, 128]
net_arch_vf: [128, 128]
log_std_init: 0.5 # std≈1.6 instead of default 1.0 — more exploration
# --- Training schedule ---
total_timesteps: 10_000_000
n_envs: 16
checkpoint_freq: 500_000 # in env steps
eval_freq: 100_000 # in env steps
n_eval_episodes: 20
# --- Curriculum (max-n_sheep schedule, in env steps) ---
# Each entry: at step s, raise the env's max_n_sheep to k. The env samples
# uniformly from [1, max_n_sheep] each reset, so this widens the
# distribution gradually rather than swapping fixed sizes.
#
# State-space curriculum: difficulty controls sheep spawn area
# (0 = sheep spawn just north of gate, 1 = sheep spawn anywhere in field).
# Plus the existing flock-size curriculum.
#
# The two together let the policy first learn "what penning looks like"
# in a regime where random exploration reliably triggers it, then
# gradually generalise to the deployment distribution.
curriculum:
- { step: 0, max_n_sheep: 1, difficulty: 0.0 }
- { step: 1_000_000, max_n_sheep: 1, difficulty: 0.3 }
- { step: 2_000_000, max_n_sheep: 2, difficulty: 0.5 }
- { step: 4_000_000, max_n_sheep: 3, difficulty: 0.8 }
- { step: 6_000_000, max_n_sheep: 5, difficulty: 1.0 }
- { step: 8_000_000, max_n_sheep: 8, difficulty: 1.0 }
- { step: 9_000_000, max_n_sheep: 10, difficulty: 1.0 }
+16 -4
View File
@@ -46,9 +46,11 @@ def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
def make_analytic_predictor(action_fn):
def _predict(env, _obs):
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
for i in range(env.n_sheep)
if not env.sheep_penned[i]}
# Use whatever perception the env exposes — tracker output in
# LiDAR mode, ground truth in privileged mode. This makes
# evaluation honest: the analytic teacher sees what the
# deployed controller would see.
positions = env.perceived_positions()
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
return np.array([vx, vy], dtype=np.float32)
return _predict
@@ -82,6 +84,7 @@ def main():
parser.add_argument("--difficulty", type=float, default=1.0)
args = parser.parse_args()
frame_stack = 1 # default; analytic predictors don't use stacked obs
if args.policy == "strombom":
predict = make_analytic_predictor(strombom_action)
elif args.policy == "sequential":
@@ -103,6 +106,14 @@ def main():
f"policy.zip, final.zip)"
)
model = PPO.load(str(zip_path), device="auto")
# Auto-detect frame stacking from the policy's expected obs dim,
# so eval runs with whatever stacking the policy was trained on.
from herding.obs import OBS_DIM as _SINGLE
policy_obs_dim = int(model.observation_space.shape[0])
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
frame_stack = policy_obs_dim // _SINGLE
if frame_stack > 1:
print(f"[eval] policy expects frame_stack={frame_stack}")
vecnorm = None
vn_path = run / "vecnormalize.pkl"
if not vn_path.exists() and run.parent.name != "best":
@@ -121,7 +132,8 @@ def main():
successes, steps, penned = [], [], []
for seed in range(args.n_seeds):
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
difficulty=args.difficulty, seed=seed)
difficulty=args.difficulty, seed=seed,
frame_stack=frame_stack)
r = rollout(env, predict, args.max_steps)
successes.append(int(r["success"]))
steps.append(r["steps"])
+104 -8
View File
@@ -69,7 +69,10 @@ from herding.geometry import (
SHEEP_MAX_WHEEL_OMEGA, SHEEP_WHEEL_BASE, SHEEP_WHEEL_RADIUS,
WEBOTS_DT, is_penned_position,
)
from herding.lidar_perception import detections_from_scan
from herding.lidar_sim import simulate_scan
from herding.obs import OBS_DIM, build_obs
from herding.sheep_tracker import SheepTracker
from herding.strombom import compute_action as strombom_action
@@ -130,11 +133,30 @@ class HerdingEnv(gym.Env):
max_steps: int = DEFAULT_MAX_STEPS,
difficulty: float = 0.0,
seed: Optional[int] = None,
use_lidar: bool = True,
frame_stack: int = 1,
):
super().__init__()
# When True (default), the obs and the imitation-reward teacher
# see only LiDAR-perceived sheep positions through a tracker —
# matching what the Webots controller has access to. When False,
# both consume ground-truth positions (legacy "privileged" mode,
# kept for ablation).
self._use_lidar = bool(use_lidar)
self._tracker = SheepTracker() if self._use_lidar else None
self._np_rng_lidar: Optional[np.random.Generator] = None
# Frame stacking: the policy receives the last K single-frame
# observations concatenated. Lets a memoryless MLP integrate
# information across time, partly compensating for the limited
# LiDAR FOV. K=1 reproduces the legacy single-frame obs.
self._frame_stack = max(1, int(frame_stack))
self._frame_buffer: list[np.ndarray] = []
self.action_space = spaces.Box(-1.0, 1.0, shape=(2,), dtype=np.float32)
self._single_obs_dim = OBS_DIM
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(OBS_DIM,), dtype=np.float32,
low=-np.inf, high=np.inf,
shape=(OBS_DIM * self._frame_stack,), dtype=np.float32,
)
# If n_sheep is None, env will sample uniformly from [1, max_n_sheep]
@@ -243,6 +265,16 @@ class HerdingEnv(gym.Env):
self.prev_n_penned = 0
self.prev_d_pen, self.prev_radius = self._flock_metrics()
if self._tracker is not None:
self._tracker.reset()
self._np_rng_lidar = np.random.default_rng(
int(self.np_random.integers(0, 2**31 - 1)))
# Prime the tracker with one scan so the first obs isn't empty.
self._update_tracker()
# Clear the frame stack — the next _build_obs will repopulate.
self._frame_buffer = []
obs = self._build_obs()
info = {"n_sheep": self.n_sheep}
return obs, info
@@ -289,6 +321,12 @@ class HerdingEnv(gym.Env):
and is_penned_position(self.sheep_x[i], self.sheep_y[i])):
self.sheep_penned[i] = True
# --- Run LiDAR perception on this step's state (after sheep have
# moved). Updates the tracker that obs and the imitation-
# reward teacher consume. Reward / termination still use GT. ---
if self._tracker is not None:
self._update_tracker()
# --- Reward, termination ---
d_pen, radius = self._flock_metrics()
reward = self._compute_reward(d_pen, radius, action=action)
@@ -395,10 +433,7 @@ class HerdingEnv(gym.Env):
r = self.W_PEN_DELTA * delta_pen + self.W_PROGRESS * d_progress
if action is not None and self.W_IMITATE > 0.0:
positions = {
f"s{i}": (float(self.sheep_x[i]), float(self.sheep_y[i]))
for i in range(self.n_sheep) if not self.sheep_penned[i]
}
positions = self._perceived_positions()
if positions:
sx, sy, _mode = strombom_action(
(self.dog_x, self.dog_y), positions, PEN_ENTRY,
@@ -411,11 +446,72 @@ class HerdingEnv(gym.Env):
return float(r)
def _build_obs(self) -> np.ndarray:
sheep_xy_list = list(zip(self.sheep_x.tolist(), self.sheep_y.tolist()))
sheep_penned_list = self.sheep_penned.tolist()
def _build_single_obs(self) -> np.ndarray:
if self._tracker is not None:
# Obs sees only the tracker's active set; penned tracks are
# intentionally excluded (matches the prior receiver-based
# behaviour where penned sheep stopped contributing to the
# symbolic obs).
active = self._tracker.get_positions()
sheep_xy_list = list(active.values())
sheep_penned_list = [False] * len(sheep_xy_list)
else:
sheep_xy_list = list(zip(self.sheep_x.tolist(), self.sheep_y.tolist()))
sheep_penned_list = self.sheep_penned.tolist()
return build_obs(
(self.dog_x, self.dog_y), self.dog_heading,
sheep_xy_list, sheep_penned_list,
n_max=self._max_n_sheep,
)
def _build_obs(self) -> np.ndarray:
single = self._build_single_obs()
if self._frame_stack <= 1:
return single
# On a fresh reset the buffer is empty — duplicate the first
# frame so the stack is always full-length.
if not self._frame_buffer:
self._frame_buffer = [single.copy() for _ in range(self._frame_stack)]
else:
self._frame_buffer.append(single)
if len(self._frame_buffer) > self._frame_stack:
self._frame_buffer = self._frame_buffer[-self._frame_stack:]
# Concatenate oldest → newest.
return np.concatenate(self._frame_buffer, axis=0).astype(np.float32)
# ------------------------------------------------------------------
# LiDAR perception helpers
# ------------------------------------------------------------------
def _all_sheep_xy(self) -> list[tuple[float, float]]:
"""Every sheep, including penned ones (the LiDAR sees them)."""
return [(float(self.sheep_x[i]), float(self.sheep_y[i]))
for i in range(self.n_sheep)]
def _update_tracker(self) -> None:
ranges = simulate_scan(
self.dog_x, self.dog_y, self.dog_heading,
self._all_sheep_xy(),
rng=self._np_rng_lidar,
)
detections = detections_from_scan(
ranges, self.dog_x, self.dog_y, self.dog_heading,
)
self._tracker.update(detections)
def perceived_positions(self) -> dict[str, tuple[float, float]]:
"""Public accessor — what the controller would 'see' this step.
LiDAR mode → the tracker's active set.
Privileged mode → ground-truth active sheep.
Used by ``training.eval`` and ``tools.collect_demos`` so analytic
teachers run on the same perception the deployed controller has.
"""
if self._tracker is not None:
return self._tracker.get_positions()
return {f"s{i}": (float(self.sheep_x[i]), float(self.sheep_y[i]))
for i in range(self.n_sheep) if not self.sheep_penned[i]}
# Internal alias so the imitation reward path doesn't need to know
# which mode it's in.
_perceived_positions = perceived_positions
Binary file not shown.
Binary file not shown.
Binary file not shown.
+275 -206
View File
@@ -1,31 +1,33 @@
"""PPO trainer for the shepherd-dog policy — EXPERIMENTAL.
"""KL-regularised PPO fine-tune of a behaviour-cloned policy.
The deliverable pipeline is `bc_pretrain.py` (see ``training/README.md``).
This script is kept in the tree because it implements:
The PPO-from-scratch and unregularised PPO-fine-tune-of-BC versions
we tried earlier failed for the standard reasons (sparse pen reward,
long horizons, exploration noise destroying BC weights). The fix is
to anchor the policy to its BC initialisation with a KL penalty in
the loss — the policy is free to refine the BC mean within a
trust-region-like ball around the reference, and the dense-enough
per-step reward signal does the rest.
* PPO from scratch with curriculum over flock size + spawn area, and
* PPO fine-tune of a behavior-cloned policy.
Pipeline
--------
1. Load ``bc_v3`` weights into both the trainable policy and a frozen
reference ``ref_policy``.
2. Initialise the policy's log_std to a small fixed value (≈ 1.5)
and disable its gradient — exploration noise stays small so PPO
updates don't blow up the BC mean before reward can stabilise.
3. Override ``PPO.train()`` to add ``β · KL(π ‖ π_ref)`` to the loss
each minibatch.
4. Train for ~13 M timesteps with a low LR (5e-5).
Both ran into stability issues in our setting (long-horizon credit
assignment for sparse pen reward, BC-degradation under PPO exploration
noise). The abstractions are reusable for follow-up work — e.g.
KL-regularised fine-tune with a frozen reference policy — so we leave
the code in place.
Output: ``runs/rl_v1/policy.zip`` — same SB3 format as bc_v3, loadable
by the dog controller's ``HERDING_MODE=rl`` path.
Usage (PPO from scratch)::
Usage::
python -m training.train_ppo \
--config training/configs/ppo_default.yaml \
--out-dir training/runs/ppo_scratch
Usage (PPO fine-tune of BC)::
python -m training.train_ppo \
--resume training/runs/bc_flock/policy.zip \
--out-dir training/runs/bc_ppo \
--no-vecnorm --no-curriculum --imitate-weight 0 \
--difficulty 1.0 --log-std -1.5 --learning-rate 5e-5 \
--total-timesteps 3000000
python -m training.train_ppo \\
--bc training/runs/bc_v3 \\
--out training/runs/rl_v1 \\
--total-timesteps 2000000
"""
from __future__ import annotations
@@ -35,8 +37,6 @@ import os
import sys
from pathlib import Path
import yaml
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
@@ -44,236 +44,305 @@ if _PROJECT_ROOT not in sys.path:
import numpy as np
import torch as th
import torch.nn.functional as F
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import (
BaseCallback, CheckpointCallback, EvalCallback,
)
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import (
DummyVecEnv, SubprocVecEnv, VecNormalize,
)
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from herding.obs import OBS_DIM
from training.herding_env import HerdingEnv
# --------------------------------------------------------------------------
# Env factories
# --------------------------------------------------------------------------
# --------------------------------------------------------------------
# Env factory
# --------------------------------------------------------------------
def _make_env(rank: int, seed: int = 0):
def _make_env(rank: int, seed: int, frame_stack: int):
def _thunk():
env = HerdingEnv(seed=seed + rank)
env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack)
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
return env
return _thunk
# --------------------------------------------------------------------------
# Curriculum callback
# --------------------------------------------------------------------------
# --------------------------------------------------------------------
# KL-regularised PPO
# --------------------------------------------------------------------
class CurriculumCallback(BaseCallback):
"""Drive the env's flock-size + state-space difficulty curriculum.
class KLPPO(PPO):
"""PPO with an extra KL-to-reference penalty in the policy loss.
Schedule entries: {step, max_n_sheep, difficulty}. The largest entry
whose step <= num_timesteps wins; both knobs update together.
Subclasses SB3's PPO and overrides ``train()`` only to add a single
line for the KL term — everything else (rollout buffer, clipped
surrogate, value loss, entropy bonus) is unchanged.
"""
def __init__(self, schedule, vec_envs, verbose: int = 0):
super().__init__(verbose)
self.schedule = sorted(schedule, key=lambda d: d["step"])
# Accept a list of envs so the eval env tracks training difficulty.
self.vec_envs = vec_envs if isinstance(vec_envs, (list, tuple)) else [vec_envs]
self._last_n = None
self._last_d = None
def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs):
super().__init__(*args, **kwargs)
# ref_policy is set after construction (caller can build it
# from the BC checkpoint once `self.policy` exists).
self.ref_policy = ref_policy
if self.ref_policy is not None:
self.ref_policy.set_training_mode(False)
for p in self.ref_policy.parameters():
p.requires_grad = False
self.kl_coef = kl_coef
def _call(self, method, value):
for v in self.vec_envs:
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
def train(self) -> None:
# Copied from stable_baselines3.ppo.PPO.train (v2.x), with the
# KL-to-reference term added. Keeping the structure intact so
# behavioural parity with stock PPO is obvious.
self.policy.set_training_mode(True)
self._update_learning_rate(self.policy.optimizer)
clip_range = self.clip_range(self._current_progress_remaining)
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
def _on_step(self) -> bool:
t = self.num_timesteps
n = self.schedule[0]["max_n_sheep"]
d = self.schedule[0].get("difficulty", 1.0)
for entry in self.schedule:
if t >= entry["step"]:
n = entry["max_n_sheep"]
d = entry.get("difficulty", 1.0)
if n != self._last_n:
self._call("set_max_n_sheep", n)
self._last_n = n
if d != self._last_d:
self._call("set_difficulty", d)
self._last_d = d
if self.verbose:
print(f"[curriculum] t={t} → max_n_sheep={n} difficulty={d}")
return True
entropy_losses, pg_losses, value_losses, kl_losses = [], [], [], []
clip_fractions = []
continue_training = True
for epoch in range(self.n_epochs):
approx_kl_divs = []
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, th.distributions.Categorical.__bases__):
actions = rollout_data.actions.long().flatten()
values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations, actions)
values = values.flatten()
advantages = rollout_data.advantages
if self.normalize_advantage and len(advantages) > 1:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
values_pred = values
else:
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf)
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
if entropy is None:
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
# --- KL-to-reference term ----------------------------
# Both policies are diagonal Gaussian (ActorCriticPolicy).
# KL(π ‖ π_ref) per-action-dim; sum over the action axis
# to get total KL per sample, then mean over batch.
# Computed on the rollout's observations so the penalty
# reflects what the agent actually saw.
if self.ref_policy is None:
raise RuntimeError("KLPPO.train called without ref_policy")
with th.no_grad():
ref_dist = self.ref_policy.get_distribution(rollout_data.observations)
pi_dist = self.policy.get_distribution(rollout_data.observations)
kl_div = th.distributions.kl.kl_divergence(
pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean()
kl_losses.append(kl_div.item())
# ----------------------------------------------------
loss = (policy_loss
+ self.ent_coef * entropy_loss
+ self.vf_coef * value_loss
+ self.kl_coef * kl_div)
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
self.policy.optimizer.zero_grad()
loss.backward()
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
self._n_updates += 1
if not continue_training:
break
explained_var = self._explained_variance()
self.logger.record("train/entropy_loss", float(np.mean(entropy_losses)))
self.logger.record("train/policy_gradient_loss", float(np.mean(pg_losses)))
self.logger.record("train/value_loss", float(np.mean(value_losses)))
self.logger.record("train/kl_to_reference", float(np.mean(kl_losses)))
self.logger.record("train/approx_kl", float(np.mean(approx_kl_divs)))
self.logger.record("train/clip_fraction", float(np.mean(clip_fractions)))
self.logger.record("train/explained_variance", float(explained_var))
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def _explained_variance(self) -> float:
# SB3 doesn't expose this as a method; replicate the computation.
y_pred = self.rollout_buffer.values.flatten()
y_true = self.rollout_buffer.returns.flatten()
var_y = np.var(y_true)
return float("nan") if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# --------------------------------------------------------------------------
# --------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------
# --------------------------------------------------------------------
def main():
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=os.path.join(_HERE, "configs", "ppo_default.yaml"))
parser.add_argument("--out-dir", default=os.path.join(_HERE, "runs", "latest"))
parser.add_argument("--n-envs", type=int, default=None,
help="Override config n_envs.")
parser.add_argument("--total-timesteps", type=int, default=None,
help="Override config total_timesteps.")
parser.add_argument("--bc", default="training/runs/bc_v3",
help="Directory containing the BC initialisation (policy.zip).")
parser.add_argument("--out", default="training/runs/rl_v1",
help="Where to save the fine-tuned policy.")
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
parser.add_argument("--n-envs", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=5e-5,
help="Low LR keeps PPO close to the BC mean.")
parser.add_argument("--kl-coef", type=float, default=0.05,
help="KL-to-reference penalty coefficient.")
parser.add_argument("--log-std", type=float, default=-1.5,
help="Initial (and frozen) log_std. σ ≈ exp(-1.5) ≈ 0.22.")
parser.add_argument("--freeze-log-std", action="store_true", default=True,
help="Keep log_std fixed; only the policy mean updates.")
parser.add_argument("--n-steps", type=int, default=2048,
help="Steps per rollout per env.")
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--n-epochs", type=int, default=10)
parser.add_argument("--gamma", type=float, default=0.995)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--clip-range", type=float, default=0.1,
help="Tight clip range — keep updates conservative.")
parser.add_argument("--ent-coef", type=float, default=0.0)
parser.add_argument("--target-kl", type=float, default=0.02,
help="SB3's per-batch KL early stop; safety belt.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--resume", type=str, default=None,
help="Path to a SB3 zip to resume from.")
# SB3 recommends CPU for MlpPolicy — GPU helps CNN policies, not MLPs
# of this size. Override with --device cuda if you really want it.
parser.add_argument("--device", default="cpu")
parser.add_argument("--no-vecnorm", action="store_true",
help="Disable VecNormalize wrapper. Required when "
"resuming from a BC-pretrained policy that "
"wasn't trained under it.")
parser.add_argument("--no-curriculum", action="store_true",
help="Skip curriculum callback (resumed policy is "
"already competent across the distribution).")
parser.add_argument("--imitate-weight", type=float, default=None,
help="Override env W_IMITATE. Set to 0 to disable "
"Strömbom imitation reward.")
parser.add_argument("--difficulty", type=float, default=None,
help="Override env difficulty (0=easy, 1=hard). "
"Used in BC fine-tune to skip easy curriculum.")
parser.add_argument("--log-std", type=float, default=None,
help="Override the policy's log_std after load. "
"BC trained with std≈1.6 (log_std=0.5) which "
"is too noisy for fine-tune. Use -1.5 (std≈0.22) "
"to keep PPO close to the BC mean while still "
"exploring locally.")
parser.add_argument("--learning-rate", type=float, default=None,
help="Override config learning rate. For BC "
"fine-tune, 5e-5 is much safer than the 3e-4 "
"default.")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
bc_zip = Path(args.bc) / "policy.zip"
if not bc_zip.exists():
raise SystemExit(
f"BC checkpoint not found at {bc_zip}. Train bc_v3 first with "
f"`python -m training.bc_pretrain`."
)
n_envs = args.n_envs or cfg["n_envs"]
total_timesteps = args.total_timesteps or cfg["total_timesteps"]
out = Path(args.out_dir)
out = Path(args.out)
out.mkdir(parents=True, exist_ok=True)
(out / "checkpoints").mkdir(exist_ok=True)
(out / "best").mkdir(exist_ok=True)
(out / "evals").mkdir(exist_ok=True)
print(f"[train] out={out} n_envs={n_envs} total={total_timesteps} device={args.device}")
# --- Inspect BC obs dim → infer frame_stack ---
ref_only = PPO.load(str(bc_zip), device=args.device)
obs_dim = int(ref_only.observation_space.shape[0])
if obs_dim % OBS_DIM != 0:
raise SystemExit(f"BC obs dim {obs_dim} is not a multiple of {OBS_DIM}.")
frame_stack = obs_dim // OBS_DIM
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
# --- Train env (vectorised, optionally normalised) ---
env_fns = [_make_env(i, seed=args.seed) for i in range(n_envs)]
venv = SubprocVecEnv(env_fns) if n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, seed=args.seed + 999)])
if not args.no_vecnorm:
venv = VecNormalize(venv, norm_obs=True, norm_reward=False, clip_obs=10.0)
eval_venv = VecNormalize(eval_venv, norm_obs=True, norm_reward=False,
clip_obs=10.0, training=False)
eval_venv.obs_rms = venv.obs_rms
else:
print("[train] VecNormalize disabled (resumed policy was trained without it).")
# --- Vectorised envs (match BC obs space) ---
env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)]
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
# Apply env-level overrides (used by BC fine-tune to disable Strömbom
# imitation and start at full deployment difficulty).
def _env_call(method, value):
for v in (venv, eval_venv):
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
if args.imitate_weight is not None:
_env_call("set_imitate_weight", args.imitate_weight)
print(f"[train] W_IMITATE overridden to {args.imitate_weight}")
if args.difficulty is not None:
_env_call("set_difficulty", args.difficulty)
print(f"[train] difficulty pinned to {args.difficulty}")
# --- Model ---
policy_kwargs = dict(
net_arch=dict(pi=cfg["net_arch_pi"], vf=cfg["net_arch_vf"]),
log_std_init=cfg.get("log_std_init", 0.0),
# --- Trainable policy: load BC weights, then bolt onto PPO ---
# Trick: instantiate a PPO with the right env (so the policy
# network is constructed at the correct obs/action shape), then
# copy BC weights into it.
model = KLPPO(
"MlpPolicy", venv,
ref_policy=None, # filled in below
kl_coef=args.kl_coef,
learning_rate=args.learning_rate,
n_steps=args.n_steps,
batch_size=args.batch_size,
n_epochs=args.n_epochs,
gamma=args.gamma,
gae_lambda=args.gae_lambda,
clip_range=args.clip_range,
ent_coef=args.ent_coef,
target_kl=args.target_kl,
policy_kwargs=dict(
net_arch=dict(pi=[512, 512], vf=[512, 512]),
log_std_init=args.log_std,
),
verbose=1,
seed=args.seed,
device=args.device,
tensorboard_log=str(out / "tb"),
)
if args.resume:
print(f"[train] resuming from {args.resume}")
custom_objects = {}
if args.learning_rate is not None:
custom_objects["learning_rate"] = args.learning_rate
model = PPO.load(args.resume, env=venv, device=args.device,
tensorboard_log=str(out / "tb"),
custom_objects=custom_objects or None)
if args.log_std is not None:
import torch as _th
with _th.no_grad():
model.policy.log_std.fill_(args.log_std)
print(f"[train] log_std overridden to {args.log_std} "
f"(std≈{2.71828 ** args.log_std:.2f})")
if args.learning_rate is not None:
print(f"[train] learning_rate overridden to {args.learning_rate}")
else:
model = PPO(
cfg["policy"], venv,
learning_rate=cfg["learning_rate"],
n_steps=cfg["n_steps"],
batch_size=cfg["batch_size"],
n_epochs=cfg["n_epochs"],
gamma=cfg["gamma"],
gae_lambda=cfg["gae_lambda"],
clip_range=cfg["clip_range"],
ent_coef=cfg["ent_coef"],
vf_coef=cfg["vf_coef"],
max_grad_norm=cfg["max_grad_norm"],
target_kl=cfg.get("target_kl"),
policy_kwargs=policy_kwargs,
tensorboard_log=str(out / "tb"),
seed=args.seed,
device=args.device,
verbose=1,
)
# --- Load BC weights into both `model.policy` and `ref_policy` ---
bc_state = ref_only.policy.state_dict()
# Strict=False because the value head may not have been trained in
# BC — that's fine, PPO will train it from scratch.
missing, unexpected = model.policy.load_state_dict(bc_state, strict=False)
print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}")
# Build a separate reference policy with identical architecture and
# the BC weights, frozen.
ref_policy = type(model.policy)(
observation_space=model.observation_space,
action_space=model.action_space,
lr_schedule=lambda _: 0.0,
net_arch=dict(pi=[512, 512], vf=[512, 512]),
log_std_init=args.log_std,
).to(args.device)
ref_policy.load_state_dict(bc_state, strict=False)
model.ref_policy = ref_policy
model.ref_policy.set_training_mode(False)
for p in model.ref_policy.parameters():
p.requires_grad = False
# Align both policies' log_std. BC was trained with log_std≈0.5
# (σ≈1.65), which would make the KL term huge from a std mismatch
# rather than the mean drift we actually care about. Force both to
# the same small value so KL measures only how far the policy mean
# has drifted from the BC mean.
with th.no_grad():
model.policy.log_std.fill_(args.log_std)
model.ref_policy.log_std.fill_(args.log_std)
if args.freeze_log_std:
model.policy.log_std.requires_grad = False
print(f"[rl] log_std frozen at {args.log_std} (σ{np.exp(args.log_std):.3f})")
# --- Callbacks ---
ckpt_cb = CheckpointCallback(
save_freq=max(1, cfg["checkpoint_freq"] // n_envs),
save_path=str(out / "checkpoints"), name_prefix="ppo",
save_vecnormalize=True,
save_freq=max(1, 50_000 // args.n_envs),
save_path=str(out / "checkpoints"),
name_prefix="ppo",
)
eval_cb = EvalCallback(
eval_venv,
best_model_save_path=str(out / "best"),
log_path=str(out / "evals"),
eval_freq=max(1, cfg["eval_freq"] // n_envs),
n_eval_episodes=cfg["n_eval_episodes"],
eval_freq=max(1, 20_000 // args.n_envs),
n_eval_episodes=5,
deterministic=True,
)
callbacks = [ckpt_cb, eval_cb]
if not args.no_curriculum and "curriculum" in cfg and cfg["curriculum"]:
callbacks.append(CurriculumCallback(
cfg["curriculum"], [venv, eval_venv], verbose=1,
))
elif args.no_curriculum:
print("[train] curriculum disabled — env knobs left at their current values.")
# --- Train ---
model.learn(total_timesteps=total_timesteps, callback=callbacks,
progress_bar=True)
print(f"[rl] training: total_timesteps={args.total_timesteps} "
f"n_envs={args.n_envs} lr={args.learning_rate} kl_coef={args.kl_coef}")
model.learn(total_timesteps=args.total_timesteps,
callback=[ckpt_cb, eval_cb], progress_bar=True)
# --- Save final model + VecNormalize stats ---
model.save(out / "final.zip")
venv.save(str(out / "vecnormalize.pkl"))
# The EvalCallback already wrote best_model.zip into out/best/ — drop the
# VecNormalize stats next to it for the controller to pick up.
venv.save(str(out / "best" / "vecnormalize.pkl"))
print(f"[train] done. saved to {out}")
# --- Save final checkpoint in the SB3 zip the controller expects ---
model.save(out / "policy.zip")
print(f"[rl] saved fine-tuned policy → {out/'policy.zip'}")
if __name__ == "__main__":