Checkpoint 4
This commit is contained in:
+70
-54
@@ -1,21 +1,29 @@
|
||||
# Training pipeline
|
||||
|
||||
Behavior cloning of analytic herding teachers into a neural network
|
||||
policy that runs in Webots. PPO from scratch and PPO fine-tune of BC
|
||||
were tried earlier and are kept under `train_ppo.py` as experimental
|
||||
options, but the BC route alone is what we ship.
|
||||
Behavior cloning of analytic herding teachers into a neural-network
|
||||
policy that runs under LiDAR perception in Webots.
|
||||
|
||||
```
|
||||
sim demos (active-scan teacher on tracker output, K=4 frame stack)
|
||||
│
|
||||
▼
|
||||
bc_pretrain.py ──► runs/bc_v3 (deployed policy — beats Strömbom on n≥8)
|
||||
│
|
||||
▼ (optional: tools/auto_dagger.sh + tools/dagger_merge_train.py
|
||||
│ if sim-trained doesn't transfer cleanly to Webots)
|
||||
│
|
||||
runs/bc_dagger
|
||||
```
|
||||
|
||||
## Files
|
||||
|
||||
```
|
||||
herding_env.py — Gymnasium env (used for demo collection + eval)
|
||||
bc_pretrain.py — supervised MSE+cosine training of an SB3 MlpPolicy
|
||||
against (obs, action) demos
|
||||
herding_env.py — Gymnasium env (LiDAR raycast + tracker by default)
|
||||
bc_pretrain.py — MSE + cosine BC of (obs, action) demos into MlpPolicy
|
||||
eval.py — analytic teachers + BC policies, full n=1..10 grid
|
||||
parity_test.py — shape/determinism/baseline smoke test
|
||||
train_ppo.py — PPO trainer (experimental — see Appendix below)
|
||||
configs/ — PPO hyperparameter YAML
|
||||
runs/ — checkpoints (.gitignored)
|
||||
parity_test.py — shape / determinism / baseline smoke test
|
||||
runs/ — checkpoints (most are .gitignored; the deployed
|
||||
ones are whitelisted in the top-level .gitignore)
|
||||
```
|
||||
|
||||
## Setup
|
||||
@@ -31,66 +39,74 @@ rollout collection, not gradient compute.
|
||||
## The BC pipeline
|
||||
|
||||
```
|
||||
# 1. Generate demos from an analytic teacher.
|
||||
# --teacher: strombom (default), sequential, drive_only, hybrid, strombom_smooth
|
||||
# 1. Sim demos with the active-scan + Strömbom teacher under LiDAR
|
||||
# perception. K=4 frame stack so the MLP has temporal context.
|
||||
python -m tools.collect_demos --teacher strombom \
|
||||
--out demos.npz --seeds-per-n 30 --subsample 3
|
||||
--out demos_v3.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
||||
|
||||
# 2. Behavior-clone the demos into an MLP policy.
|
||||
python -m training.bc_pretrain --demos demos.npz \
|
||||
--out runs/bc_flock --epochs 100 --net-arch 512,512
|
||||
# 2. Behavior-clone.
|
||||
python -m training.bc_pretrain --demos demos_v3.npz \
|
||||
--out runs/bc_v3 --epochs 60 --net-arch 512,512
|
||||
|
||||
# 3. Evaluate the resulting policy.
|
||||
python -m training.eval --policy runs/bc_flock \
|
||||
--max-flock 10 --max-steps 30000 --n-seeds 5
|
||||
# 3. Evaluate.
|
||||
python -m training.eval --policy runs/bc_v3 \
|
||||
--max-flock 10 --max-steps 8000 --n-seeds 5
|
||||
```
|
||||
|
||||
Wall time: ~10 min demos + ~5 min BC training + ~5 min eval.
|
||||
|
||||
`bc_pretrain.py` saves the **best-val_cos** snapshot, not the final
|
||||
epoch — multi-modal teachers (Strömbom's collect/drive switch) make
|
||||
training noisy and the last epoch is often worse than an earlier one.
|
||||
epoch — multi-modal teachers make training noisy and the last epoch is
|
||||
often worse than an earlier one.
|
||||
|
||||
## DAgger from Webots
|
||||
|
||||
Sim-only BC plateaus because the env's 2D raycast can't reproduce all
|
||||
the false-positive clusters Webots generates from real geometry. The
|
||||
fix is to collect (obs, teacher_action) pairs from inside Webots:
|
||||
|
||||
```
|
||||
# Headless DAgger collection: 5 flock sizes × 3 runs each.
|
||||
tools/auto_dagger.sh 3 60
|
||||
|
||||
# Merge with the sim baseline + retrain.
|
||||
python -m tools.dagger_merge_train --out runs/bc_dagger
|
||||
```
|
||||
|
||||
Iterate by re-running collection with the new student in the driver's
|
||||
seat:
|
||||
|
||||
```
|
||||
HERDING_POLICY_DIR=$PWD/training/runs/bc_dagger \
|
||||
HERDING_DAGGER_DRIVER=student \
|
||||
tools/auto_dagger.sh 3 60
|
||||
python -m tools.dagger_merge_train --out runs/bc_dagger_v2
|
||||
```
|
||||
|
||||
## Available analytic teachers
|
||||
|
||||
| Name | What it does | Best for |
|
||||
| Name | What it does | Notes |
|
||||
|---|---|---|
|
||||
| `strombom` | Canonical Strömbom — collect when flock is scattered, drive CoM otherwise | Tight-cohesion regime, n=1-10 |
|
||||
| `sequential` | Pick the sheep closest to the pen and drive only it | Loose-cohesion regime, n=1-10 |
|
||||
| `drive_only` | Strömbom drive without collect mode (continuous action) | Easier-to-BC alternative; less reliable than full Strömbom |
|
||||
| `hybrid` | Drive rearmost sheep when far, switch to closest near gate | Failed experiment, kept for write-up |
|
||||
| `strombom_smooth` | Sigmoid-blended Strömbom collect↔drive | Failed experiment |
|
||||
| `strombom` | Canonical Strömbom — collect when flock is scattered, drive CoM otherwise | Default; works well for n=1–10 under tight cohesion |
|
||||
| `sequential` | Pick the sheep closest to the pen and drive only it | Alternative; needs loose-cohesion regime |
|
||||
|
||||
## Evaluating the analytic teachers directly
|
||||
Both are wrapped at demo-collection time in
|
||||
`herding/active_scan.py:ActiveScanTeacher`, which adds an opening
|
||||
in-place rotation, walk-to-centre when the LiDAR sees nothing, and
|
||||
near-sheep speed modulation (the same modulation `herding/control.py`
|
||||
applies to every dog mode at inference).
|
||||
|
||||
## Evaluating analytic teachers directly
|
||||
|
||||
```
|
||||
python -m training.eval --policy strombom --max-flock 10 --max-steps 30000 --n-seeds 5
|
||||
python -m training.eval --policy sequential --max-flock 10 --max-steps 30000 --n-seeds 5
|
||||
python -m training.eval --policy strombom --max-flock 10 --max-steps 8000 --n-seeds 5
|
||||
python -m training.eval --policy sequential --max-flock 10 --max-steps 8000 --n-seeds 5
|
||||
```
|
||||
|
||||
## Webots inference
|
||||
|
||||
The Webots dog controller (`controllers/shepherd_dog/shepherd_dog.py`)
|
||||
loads a saved BC zip when launched in `rl` mode:
|
||||
|
||||
```
|
||||
HERDING_POLICY_DIR=$PWD/runs/bc_flock tools/run_webots.sh 10 rl
|
||||
tools/run_webots.sh 10 rl
|
||||
```
|
||||
|
||||
It auto-discovers a checkpoint named `policy.zip`, `best_model.zip`, or
|
||||
`final.zip` in the directory.
|
||||
|
||||
## Appendix — experimental PPO scripts
|
||||
|
||||
`train_ppo.py` contains the PPO/RL pipeline tried before BC:
|
||||
* PPO from scratch with curriculum learning over flock size + spawn area.
|
||||
* PPO fine-tune of a BC checkpoint.
|
||||
|
||||
Both ran into stability issues (PPO's exploration noise destroys BC
|
||||
weights faster than the reward signal can rebuild them; PPO from
|
||||
scratch never sees pen events often enough during random exploration to
|
||||
credit-assign the +500 done bonus).
|
||||
|
||||
The script is left in place because the abstractions are sound and the
|
||||
code is reusable for follow-up work (e.g. KL-regularised fine-tune
|
||||
with a frozen reference policy). Not part of the deliverable pipeline.
|
||||
The dog controller loads the highest-priority policy that exists
|
||||
(`bc_dagger_v2` → `bc_dagger` → `bc_v3`). Override with
|
||||
`HERDING_POLICY_DIR=…` if you want a specific checkpoint.
|
||||
|
||||
+18
-6
@@ -43,13 +43,15 @@ from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def build_model(net_arch_pi, net_arch_vf, log_std_init: float):
|
||||
"""Build a fresh SB3 PPO with the same architecture as train_ppo.
|
||||
def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
|
||||
frame_stack: int = 1):
|
||||
"""Build a fresh SB3 PPO solely as a vehicle for the policy weights.
|
||||
|
||||
We only need the policy to load weights into; PPO's training-loop
|
||||
plumbing isn't used during BC.
|
||||
PPO's training-loop plumbing isn't used during BC. ``frame_stack``
|
||||
must match the demo file so the env's obs space agrees with the
|
||||
recorded obs shape.
|
||||
"""
|
||||
env = DummyVecEnv([lambda: HerdingEnv()])
|
||||
env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack)])
|
||||
model = PPO(
|
||||
"MlpPolicy", env,
|
||||
policy_kwargs=dict(
|
||||
@@ -139,7 +141,17 @@ def main():
|
||||
# --- Build model ---
|
||||
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
|
||||
net_arch_vf = net_arch_pi[:]
|
||||
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init)
|
||||
# Auto-detect frame stacking from the demo file so a stacked-obs
|
||||
# demo trains a stacked-obs policy without an extra CLI flag.
|
||||
obs_dim = obs.shape[1]
|
||||
from herding.obs import OBS_DIM as _SINGLE
|
||||
if obs_dim % _SINGLE != 0:
|
||||
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
|
||||
frame_stack = obs_dim // _SINGLE
|
||||
if frame_stack > 1:
|
||||
print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}")
|
||||
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init,
|
||||
frame_stack=frame_stack)
|
||||
policy = model.policy.to(args.device)
|
||||
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
|
||||
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
# PPO hyperparameters for the herding env. Tuned for a 28-D obs / 2-D
|
||||
# continuous action space with 16 parallel envs on GPU. These are SB3
|
||||
# defaults nudged toward longer credit assignment (gamma=0.995) and a
|
||||
# slightly higher entropy bonus to keep exploration alive while curriculum
|
||||
# expands the flock size.
|
||||
|
||||
# --- PPO ---
|
||||
learning_rate: 3.0e-4
|
||||
n_steps: 2048 # rollout length per env before each update
|
||||
batch_size: 256
|
||||
n_epochs: 10
|
||||
gamma: 0.995
|
||||
gae_lambda: 0.95
|
||||
clip_range: 0.2
|
||||
ent_coef: 0.05 # was 0.01 — earlier runs collapsed to ~0 actions
|
||||
vf_coef: 0.5
|
||||
max_grad_norm: 0.5
|
||||
target_kl: null # disable early-stop on KL
|
||||
|
||||
# --- Network ---
|
||||
policy: MlpPolicy
|
||||
net_arch_pi: [128, 128]
|
||||
net_arch_vf: [128, 128]
|
||||
log_std_init: 0.5 # std≈1.6 instead of default 1.0 — more exploration
|
||||
|
||||
# --- Training schedule ---
|
||||
total_timesteps: 10_000_000
|
||||
n_envs: 16
|
||||
checkpoint_freq: 500_000 # in env steps
|
||||
eval_freq: 100_000 # in env steps
|
||||
n_eval_episodes: 20
|
||||
|
||||
# --- Curriculum (max-n_sheep schedule, in env steps) ---
|
||||
# Each entry: at step s, raise the env's max_n_sheep to k. The env samples
|
||||
# uniformly from [1, max_n_sheep] each reset, so this widens the
|
||||
# distribution gradually rather than swapping fixed sizes.
|
||||
#
|
||||
# State-space curriculum: difficulty controls sheep spawn area
|
||||
# (0 = sheep spawn just north of gate, 1 = sheep spawn anywhere in field).
|
||||
# Plus the existing flock-size curriculum.
|
||||
#
|
||||
# The two together let the policy first learn "what penning looks like"
|
||||
# in a regime where random exploration reliably triggers it, then
|
||||
# gradually generalise to the deployment distribution.
|
||||
curriculum:
|
||||
- { step: 0, max_n_sheep: 1, difficulty: 0.0 }
|
||||
- { step: 1_000_000, max_n_sheep: 1, difficulty: 0.3 }
|
||||
- { step: 2_000_000, max_n_sheep: 2, difficulty: 0.5 }
|
||||
- { step: 4_000_000, max_n_sheep: 3, difficulty: 0.8 }
|
||||
- { step: 6_000_000, max_n_sheep: 5, difficulty: 1.0 }
|
||||
- { step: 8_000_000, max_n_sheep: 8, difficulty: 1.0 }
|
||||
- { step: 9_000_000, max_n_sheep: 10, difficulty: 1.0 }
|
||||
+16
-4
@@ -46,9 +46,11 @@ def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
|
||||
|
||||
def make_analytic_predictor(action_fn):
|
||||
def _predict(env, _obs):
|
||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||
for i in range(env.n_sheep)
|
||||
if not env.sheep_penned[i]}
|
||||
# Use whatever perception the env exposes — tracker output in
|
||||
# LiDAR mode, ground truth in privileged mode. This makes
|
||||
# evaluation honest: the analytic teacher sees what the
|
||||
# deployed controller would see.
|
||||
positions = env.perceived_positions()
|
||||
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
||||
return np.array([vx, vy], dtype=np.float32)
|
||||
return _predict
|
||||
@@ -82,6 +84,7 @@ def main():
|
||||
parser.add_argument("--difficulty", type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
frame_stack = 1 # default; analytic predictors don't use stacked obs
|
||||
if args.policy == "strombom":
|
||||
predict = make_analytic_predictor(strombom_action)
|
||||
elif args.policy == "sequential":
|
||||
@@ -103,6 +106,14 @@ def main():
|
||||
f"policy.zip, final.zip)"
|
||||
)
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
# Auto-detect frame stacking from the policy's expected obs dim,
|
||||
# so eval runs with whatever stacking the policy was trained on.
|
||||
from herding.obs import OBS_DIM as _SINGLE
|
||||
policy_obs_dim = int(model.observation_space.shape[0])
|
||||
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
|
||||
frame_stack = policy_obs_dim // _SINGLE
|
||||
if frame_stack > 1:
|
||||
print(f"[eval] policy expects frame_stack={frame_stack}")
|
||||
vecnorm = None
|
||||
vn_path = run / "vecnormalize.pkl"
|
||||
if not vn_path.exists() and run.parent.name != "best":
|
||||
@@ -121,7 +132,8 @@ def main():
|
||||
successes, steps, penned = [], [], []
|
||||
for seed in range(args.n_seeds):
|
||||
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
|
||||
difficulty=args.difficulty, seed=seed)
|
||||
difficulty=args.difficulty, seed=seed,
|
||||
frame_stack=frame_stack)
|
||||
r = rollout(env, predict, args.max_steps)
|
||||
successes.append(int(r["success"]))
|
||||
steps.append(r["steps"])
|
||||
|
||||
+104
-8
@@ -69,7 +69,10 @@ from herding.geometry import (
|
||||
SHEEP_MAX_WHEEL_OMEGA, SHEEP_WHEEL_BASE, SHEEP_WHEEL_RADIUS,
|
||||
WEBOTS_DT, is_penned_position,
|
||||
)
|
||||
from herding.lidar_perception import detections_from_scan
|
||||
from herding.lidar_sim import simulate_scan
|
||||
from herding.obs import OBS_DIM, build_obs
|
||||
from herding.sheep_tracker import SheepTracker
|
||||
from herding.strombom import compute_action as strombom_action
|
||||
|
||||
|
||||
@@ -130,11 +133,30 @@ class HerdingEnv(gym.Env):
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
difficulty: float = 0.0,
|
||||
seed: Optional[int] = None,
|
||||
use_lidar: bool = True,
|
||||
frame_stack: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
# When True (default), the obs and the imitation-reward teacher
|
||||
# see only LiDAR-perceived sheep positions through a tracker —
|
||||
# matching what the Webots controller has access to. When False,
|
||||
# both consume ground-truth positions (legacy "privileged" mode,
|
||||
# kept for ablation).
|
||||
self._use_lidar = bool(use_lidar)
|
||||
self._tracker = SheepTracker() if self._use_lidar else None
|
||||
self._np_rng_lidar: Optional[np.random.Generator] = None
|
||||
|
||||
# Frame stacking: the policy receives the last K single-frame
|
||||
# observations concatenated. Lets a memoryless MLP integrate
|
||||
# information across time, partly compensating for the limited
|
||||
# LiDAR FOV. K=1 reproduces the legacy single-frame obs.
|
||||
self._frame_stack = max(1, int(frame_stack))
|
||||
self._frame_buffer: list[np.ndarray] = []
|
||||
self.action_space = spaces.Box(-1.0, 1.0, shape=(2,), dtype=np.float32)
|
||||
self._single_obs_dim = OBS_DIM
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(OBS_DIM,), dtype=np.float32,
|
||||
low=-np.inf, high=np.inf,
|
||||
shape=(OBS_DIM * self._frame_stack,), dtype=np.float32,
|
||||
)
|
||||
|
||||
# If n_sheep is None, env will sample uniformly from [1, max_n_sheep]
|
||||
@@ -243,6 +265,16 @@ class HerdingEnv(gym.Env):
|
||||
self.prev_n_penned = 0
|
||||
self.prev_d_pen, self.prev_radius = self._flock_metrics()
|
||||
|
||||
if self._tracker is not None:
|
||||
self._tracker.reset()
|
||||
self._np_rng_lidar = np.random.default_rng(
|
||||
int(self.np_random.integers(0, 2**31 - 1)))
|
||||
# Prime the tracker with one scan so the first obs isn't empty.
|
||||
self._update_tracker()
|
||||
|
||||
# Clear the frame stack — the next _build_obs will repopulate.
|
||||
self._frame_buffer = []
|
||||
|
||||
obs = self._build_obs()
|
||||
info = {"n_sheep": self.n_sheep}
|
||||
return obs, info
|
||||
@@ -289,6 +321,12 @@ class HerdingEnv(gym.Env):
|
||||
and is_penned_position(self.sheep_x[i], self.sheep_y[i])):
|
||||
self.sheep_penned[i] = True
|
||||
|
||||
# --- Run LiDAR perception on this step's state (after sheep have
|
||||
# moved). Updates the tracker that obs and the imitation-
|
||||
# reward teacher consume. Reward / termination still use GT. ---
|
||||
if self._tracker is not None:
|
||||
self._update_tracker()
|
||||
|
||||
# --- Reward, termination ---
|
||||
d_pen, radius = self._flock_metrics()
|
||||
reward = self._compute_reward(d_pen, radius, action=action)
|
||||
@@ -395,10 +433,7 @@ class HerdingEnv(gym.Env):
|
||||
r = self.W_PEN_DELTA * delta_pen + self.W_PROGRESS * d_progress
|
||||
|
||||
if action is not None and self.W_IMITATE > 0.0:
|
||||
positions = {
|
||||
f"s{i}": (float(self.sheep_x[i]), float(self.sheep_y[i]))
|
||||
for i in range(self.n_sheep) if not self.sheep_penned[i]
|
||||
}
|
||||
positions = self._perceived_positions()
|
||||
if positions:
|
||||
sx, sy, _mode = strombom_action(
|
||||
(self.dog_x, self.dog_y), positions, PEN_ENTRY,
|
||||
@@ -411,11 +446,72 @@ class HerdingEnv(gym.Env):
|
||||
|
||||
return float(r)
|
||||
|
||||
def _build_obs(self) -> np.ndarray:
|
||||
sheep_xy_list = list(zip(self.sheep_x.tolist(), self.sheep_y.tolist()))
|
||||
sheep_penned_list = self.sheep_penned.tolist()
|
||||
def _build_single_obs(self) -> np.ndarray:
|
||||
if self._tracker is not None:
|
||||
# Obs sees only the tracker's active set; penned tracks are
|
||||
# intentionally excluded (matches the prior receiver-based
|
||||
# behaviour where penned sheep stopped contributing to the
|
||||
# symbolic obs).
|
||||
active = self._tracker.get_positions()
|
||||
sheep_xy_list = list(active.values())
|
||||
sheep_penned_list = [False] * len(sheep_xy_list)
|
||||
else:
|
||||
sheep_xy_list = list(zip(self.sheep_x.tolist(), self.sheep_y.tolist()))
|
||||
sheep_penned_list = self.sheep_penned.tolist()
|
||||
return build_obs(
|
||||
(self.dog_x, self.dog_y), self.dog_heading,
|
||||
sheep_xy_list, sheep_penned_list,
|
||||
n_max=self._max_n_sheep,
|
||||
)
|
||||
|
||||
def _build_obs(self) -> np.ndarray:
|
||||
single = self._build_single_obs()
|
||||
if self._frame_stack <= 1:
|
||||
return single
|
||||
# On a fresh reset the buffer is empty — duplicate the first
|
||||
# frame so the stack is always full-length.
|
||||
if not self._frame_buffer:
|
||||
self._frame_buffer = [single.copy() for _ in range(self._frame_stack)]
|
||||
else:
|
||||
self._frame_buffer.append(single)
|
||||
if len(self._frame_buffer) > self._frame_stack:
|
||||
self._frame_buffer = self._frame_buffer[-self._frame_stack:]
|
||||
# Concatenate oldest → newest.
|
||||
return np.concatenate(self._frame_buffer, axis=0).astype(np.float32)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LiDAR perception helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _all_sheep_xy(self) -> list[tuple[float, float]]:
|
||||
"""Every sheep, including penned ones (the LiDAR sees them)."""
|
||||
return [(float(self.sheep_x[i]), float(self.sheep_y[i]))
|
||||
for i in range(self.n_sheep)]
|
||||
|
||||
def _update_tracker(self) -> None:
|
||||
ranges = simulate_scan(
|
||||
self.dog_x, self.dog_y, self.dog_heading,
|
||||
self._all_sheep_xy(),
|
||||
rng=self._np_rng_lidar,
|
||||
)
|
||||
detections = detections_from_scan(
|
||||
ranges, self.dog_x, self.dog_y, self.dog_heading,
|
||||
)
|
||||
self._tracker.update(detections)
|
||||
|
||||
def perceived_positions(self) -> dict[str, tuple[float, float]]:
|
||||
"""Public accessor — what the controller would 'see' this step.
|
||||
|
||||
LiDAR mode → the tracker's active set.
|
||||
Privileged mode → ground-truth active sheep.
|
||||
|
||||
Used by ``training.eval`` and ``tools.collect_demos`` so analytic
|
||||
teachers run on the same perception the deployed controller has.
|
||||
"""
|
||||
if self._tracker is not None:
|
||||
return self._tracker.get_positions()
|
||||
return {f"s{i}": (float(self.sheep_x[i]), float(self.sheep_y[i]))
|
||||
for i in range(self.n_sheep) if not self.sheep_penned[i]}
|
||||
|
||||
# Internal alias so the imitation reward path doesn't need to know
|
||||
# which mode it's in.
|
||||
_perceived_positions = perceived_positions
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
+275
-206
@@ -1,31 +1,33 @@
|
||||
"""PPO trainer for the shepherd-dog policy — EXPERIMENTAL.
|
||||
"""KL-regularised PPO fine-tune of a behaviour-cloned policy.
|
||||
|
||||
The deliverable pipeline is `bc_pretrain.py` (see ``training/README.md``).
|
||||
This script is kept in the tree because it implements:
|
||||
The PPO-from-scratch and unregularised PPO-fine-tune-of-BC versions
|
||||
we tried earlier failed for the standard reasons (sparse pen reward,
|
||||
long horizons, exploration noise destroying BC weights). The fix is
|
||||
to anchor the policy to its BC initialisation with a KL penalty in
|
||||
the loss — the policy is free to refine the BC mean within a
|
||||
trust-region-like ball around the reference, and the dense-enough
|
||||
per-step reward signal does the rest.
|
||||
|
||||
* PPO from scratch with curriculum over flock size + spawn area, and
|
||||
* PPO fine-tune of a behavior-cloned policy.
|
||||
Pipeline
|
||||
--------
|
||||
1. Load ``bc_v3`` weights into both the trainable policy and a frozen
|
||||
reference ``ref_policy``.
|
||||
2. Initialise the policy's log_std to a small fixed value (≈ −1.5)
|
||||
and disable its gradient — exploration noise stays small so PPO
|
||||
updates don't blow up the BC mean before reward can stabilise.
|
||||
3. Override ``PPO.train()`` to add ``β · KL(π ‖ π_ref)`` to the loss
|
||||
each minibatch.
|
||||
4. Train for ~1–3 M timesteps with a low LR (5e-5).
|
||||
|
||||
Both ran into stability issues in our setting (long-horizon credit
|
||||
assignment for sparse pen reward, BC-degradation under PPO exploration
|
||||
noise). The abstractions are reusable for follow-up work — e.g.
|
||||
KL-regularised fine-tune with a frozen reference policy — so we leave
|
||||
the code in place.
|
||||
Output: ``runs/rl_v1/policy.zip`` — same SB3 format as bc_v3, loadable
|
||||
by the dog controller's ``HERDING_MODE=rl`` path.
|
||||
|
||||
Usage (PPO from scratch)::
|
||||
Usage::
|
||||
|
||||
python -m training.train_ppo \
|
||||
--config training/configs/ppo_default.yaml \
|
||||
--out-dir training/runs/ppo_scratch
|
||||
|
||||
Usage (PPO fine-tune of BC)::
|
||||
|
||||
python -m training.train_ppo \
|
||||
--resume training/runs/bc_flock/policy.zip \
|
||||
--out-dir training/runs/bc_ppo \
|
||||
--no-vecnorm --no-curriculum --imitate-weight 0 \
|
||||
--difficulty 1.0 --log-std -1.5 --learning-rate 5e-5 \
|
||||
--total-timesteps 3000000
|
||||
python -m training.train_ppo \\
|
||||
--bc training/runs/bc_v3 \\
|
||||
--out training/runs/rl_v1 \\
|
||||
--total-timesteps 2000000
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -35,8 +37,6 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
@@ -44,236 +44,305 @@ if _PROJECT_ROOT not in sys.path:
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import (
|
||||
BaseCallback, CheckpointCallback, EvalCallback,
|
||||
)
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import (
|
||||
DummyVecEnv, SubprocVecEnv, VecNormalize,
|
||||
)
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||
|
||||
from herding.obs import OBS_DIM
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Env factories
|
||||
# --------------------------------------------------------------------------
|
||||
# --------------------------------------------------------------------
|
||||
# Env factory
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
def _make_env(rank: int, seed: int = 0):
|
||||
def _make_env(rank: int, seed: int, frame_stack: int):
|
||||
def _thunk():
|
||||
env = HerdingEnv(seed=seed + rank)
|
||||
env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack)
|
||||
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
|
||||
return env
|
||||
return _thunk
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Curriculum callback
|
||||
# --------------------------------------------------------------------------
|
||||
# --------------------------------------------------------------------
|
||||
# KL-regularised PPO
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
class CurriculumCallback(BaseCallback):
|
||||
"""Drive the env's flock-size + state-space difficulty curriculum.
|
||||
class KLPPO(PPO):
|
||||
"""PPO with an extra KL-to-reference penalty in the policy loss.
|
||||
|
||||
Schedule entries: {step, max_n_sheep, difficulty}. The largest entry
|
||||
whose step <= num_timesteps wins; both knobs update together.
|
||||
Subclasses SB3's PPO and overrides ``train()`` only to add a single
|
||||
line for the KL term — everything else (rollout buffer, clipped
|
||||
surrogate, value loss, entropy bonus) is unchanged.
|
||||
"""
|
||||
|
||||
def __init__(self, schedule, vec_envs, verbose: int = 0):
|
||||
super().__init__(verbose)
|
||||
self.schedule = sorted(schedule, key=lambda d: d["step"])
|
||||
# Accept a list of envs so the eval env tracks training difficulty.
|
||||
self.vec_envs = vec_envs if isinstance(vec_envs, (list, tuple)) else [vec_envs]
|
||||
self._last_n = None
|
||||
self._last_d = None
|
||||
def __init__(self, *args, ref_policy=None, kl_coef: float = 0.05, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# ref_policy is set after construction (caller can build it
|
||||
# from the BC checkpoint once `self.policy` exists).
|
||||
self.ref_policy = ref_policy
|
||||
if self.ref_policy is not None:
|
||||
self.ref_policy.set_training_mode(False)
|
||||
for p in self.ref_policy.parameters():
|
||||
p.requires_grad = False
|
||||
self.kl_coef = kl_coef
|
||||
|
||||
def _call(self, method, value):
|
||||
for v in self.vec_envs:
|
||||
try:
|
||||
v.env_method(method, value)
|
||||
except AttributeError:
|
||||
v.venv.env_method(method, value)
|
||||
def train(self) -> None:
|
||||
# Copied from stable_baselines3.ppo.PPO.train (v2.x), with the
|
||||
# KL-to-reference term added. Keeping the structure intact so
|
||||
# behavioural parity with stock PPO is obvious.
|
||||
self.policy.set_training_mode(True)
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
clip_range = self.clip_range(self._current_progress_remaining)
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
t = self.num_timesteps
|
||||
n = self.schedule[0]["max_n_sheep"]
|
||||
d = self.schedule[0].get("difficulty", 1.0)
|
||||
for entry in self.schedule:
|
||||
if t >= entry["step"]:
|
||||
n = entry["max_n_sheep"]
|
||||
d = entry.get("difficulty", 1.0)
|
||||
if n != self._last_n:
|
||||
self._call("set_max_n_sheep", n)
|
||||
self._last_n = n
|
||||
if d != self._last_d:
|
||||
self._call("set_difficulty", d)
|
||||
self._last_d = d
|
||||
if self.verbose:
|
||||
print(f"[curriculum] t={t} → max_n_sheep={n} difficulty={d}")
|
||||
return True
|
||||
entropy_losses, pg_losses, value_losses, kl_losses = [], [], [], []
|
||||
clip_fractions = []
|
||||
continue_training = True
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
approx_kl_divs = []
|
||||
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||||
actions = rollout_data.actions
|
||||
if isinstance(self.action_space, th.distributions.Categorical.__bases__):
|
||||
actions = rollout_data.actions.long().flatten()
|
||||
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(
|
||||
rollout_data.observations, actions)
|
||||
values = values.flatten()
|
||||
advantages = rollout_data.advantages
|
||||
if self.normalize_advantage and len(advantages) > 1:
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||
policy_loss_1 = advantages * ratio
|
||||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
||||
pg_losses.append(policy_loss.item())
|
||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
||||
clip_fractions.append(clip_fraction)
|
||||
|
||||
if self.clip_range_vf is None:
|
||||
values_pred = values
|
||||
else:
|
||||
values_pred = rollout_data.old_values + th.clamp(
|
||||
values - rollout_data.old_values, -clip_range_vf, clip_range_vf)
|
||||
value_loss = F.mse_loss(rollout_data.returns, values_pred)
|
||||
value_losses.append(value_loss.item())
|
||||
|
||||
if entropy is None:
|
||||
entropy_loss = -th.mean(-log_prob)
|
||||
else:
|
||||
entropy_loss = -th.mean(entropy)
|
||||
entropy_losses.append(entropy_loss.item())
|
||||
|
||||
# --- KL-to-reference term ----------------------------
|
||||
# Both policies are diagonal Gaussian (ActorCriticPolicy).
|
||||
# KL(π ‖ π_ref) per-action-dim; sum over the action axis
|
||||
# to get total KL per sample, then mean over batch.
|
||||
# Computed on the rollout's observations so the penalty
|
||||
# reflects what the agent actually saw.
|
||||
if self.ref_policy is None:
|
||||
raise RuntimeError("KLPPO.train called without ref_policy")
|
||||
with th.no_grad():
|
||||
ref_dist = self.ref_policy.get_distribution(rollout_data.observations)
|
||||
pi_dist = self.policy.get_distribution(rollout_data.observations)
|
||||
kl_div = th.distributions.kl.kl_divergence(
|
||||
pi_dist.distribution, ref_dist.distribution).sum(dim=-1).mean()
|
||||
kl_losses.append(kl_div.item())
|
||||
# ----------------------------------------------------
|
||||
|
||||
loss = (policy_loss
|
||||
+ self.ent_coef * entropy_loss
|
||||
+ self.vf_coef * value_loss
|
||||
+ self.kl_coef * kl_div)
|
||||
|
||||
with th.no_grad():
|
||||
log_ratio = log_prob - rollout_data.old_log_prob
|
||||
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||||
approx_kl_divs.append(approx_kl_div)
|
||||
|
||||
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||||
continue_training = False
|
||||
if self.verbose >= 1:
|
||||
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
||||
break
|
||||
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
||||
self._n_updates += 1
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
explained_var = self._explained_variance()
|
||||
self.logger.record("train/entropy_loss", float(np.mean(entropy_losses)))
|
||||
self.logger.record("train/policy_gradient_loss", float(np.mean(pg_losses)))
|
||||
self.logger.record("train/value_loss", float(np.mean(value_losses)))
|
||||
self.logger.record("train/kl_to_reference", float(np.mean(kl_losses)))
|
||||
self.logger.record("train/approx_kl", float(np.mean(approx_kl_divs)))
|
||||
self.logger.record("train/clip_fraction", float(np.mean(clip_fractions)))
|
||||
self.logger.record("train/explained_variance", float(explained_var))
|
||||
if hasattr(self.policy, "log_std"):
|
||||
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
def _explained_variance(self) -> float:
|
||||
# SB3 doesn't expose this as a method; replicate the computation.
|
||||
y_pred = self.rollout_buffer.values.flatten()
|
||||
y_true = self.rollout_buffer.returns.flatten()
|
||||
var_y = np.var(y_true)
|
||||
return float("nan") if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# --------------------------------------------------------------------
|
||||
# Main
|
||||
# --------------------------------------------------------------------------
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", default=os.path.join(_HERE, "configs", "ppo_default.yaml"))
|
||||
parser.add_argument("--out-dir", default=os.path.join(_HERE, "runs", "latest"))
|
||||
parser.add_argument("--n-envs", type=int, default=None,
|
||||
help="Override config n_envs.")
|
||||
parser.add_argument("--total-timesteps", type=int, default=None,
|
||||
help="Override config total_timesteps.")
|
||||
parser.add_argument("--bc", default="training/runs/bc_v3",
|
||||
help="Directory containing the BC initialisation (policy.zip).")
|
||||
parser.add_argument("--out", default="training/runs/rl_v1",
|
||||
help="Where to save the fine-tuned policy.")
|
||||
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
|
||||
parser.add_argument("--n-envs", type=int, default=8)
|
||||
parser.add_argument("--learning-rate", type=float, default=5e-5,
|
||||
help="Low LR keeps PPO close to the BC mean.")
|
||||
parser.add_argument("--kl-coef", type=float, default=0.05,
|
||||
help="KL-to-reference penalty coefficient.")
|
||||
parser.add_argument("--log-std", type=float, default=-1.5,
|
||||
help="Initial (and frozen) log_std. σ ≈ exp(-1.5) ≈ 0.22.")
|
||||
parser.add_argument("--freeze-log-std", action="store_true", default=True,
|
||||
help="Keep log_std fixed; only the policy mean updates.")
|
||||
parser.add_argument("--n-steps", type=int, default=2048,
|
||||
help="Steps per rollout per env.")
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--n-epochs", type=int, default=10)
|
||||
parser.add_argument("--gamma", type=float, default=0.995)
|
||||
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||||
parser.add_argument("--clip-range", type=float, default=0.1,
|
||||
help="Tight clip range — keep updates conservative.")
|
||||
parser.add_argument("--ent-coef", type=float, default=0.0)
|
||||
parser.add_argument("--target-kl", type=float, default=0.02,
|
||||
help="SB3's per-batch KL early stop; safety belt.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--resume", type=str, default=None,
|
||||
help="Path to a SB3 zip to resume from.")
|
||||
# SB3 recommends CPU for MlpPolicy — GPU helps CNN policies, not MLPs
|
||||
# of this size. Override with --device cuda if you really want it.
|
||||
parser.add_argument("--device", default="cpu")
|
||||
parser.add_argument("--no-vecnorm", action="store_true",
|
||||
help="Disable VecNormalize wrapper. Required when "
|
||||
"resuming from a BC-pretrained policy that "
|
||||
"wasn't trained under it.")
|
||||
parser.add_argument("--no-curriculum", action="store_true",
|
||||
help="Skip curriculum callback (resumed policy is "
|
||||
"already competent across the distribution).")
|
||||
parser.add_argument("--imitate-weight", type=float, default=None,
|
||||
help="Override env W_IMITATE. Set to 0 to disable "
|
||||
"Strömbom imitation reward.")
|
||||
parser.add_argument("--difficulty", type=float, default=None,
|
||||
help="Override env difficulty (0=easy, 1=hard). "
|
||||
"Used in BC fine-tune to skip easy curriculum.")
|
||||
parser.add_argument("--log-std", type=float, default=None,
|
||||
help="Override the policy's log_std after load. "
|
||||
"BC trained with std≈1.6 (log_std=0.5) which "
|
||||
"is too noisy for fine-tune. Use -1.5 (std≈0.22) "
|
||||
"to keep PPO close to the BC mean while still "
|
||||
"exploring locally.")
|
||||
parser.add_argument("--learning-rate", type=float, default=None,
|
||||
help="Override config learning rate. For BC "
|
||||
"fine-tune, 5e-5 is much safer than the 3e-4 "
|
||||
"default.")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
bc_zip = Path(args.bc) / "policy.zip"
|
||||
if not bc_zip.exists():
|
||||
raise SystemExit(
|
||||
f"BC checkpoint not found at {bc_zip}. Train bc_v3 first with "
|
||||
f"`python -m training.bc_pretrain`."
|
||||
)
|
||||
|
||||
n_envs = args.n_envs or cfg["n_envs"]
|
||||
total_timesteps = args.total_timesteps or cfg["total_timesteps"]
|
||||
|
||||
out = Path(args.out_dir)
|
||||
out = Path(args.out)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
(out / "checkpoints").mkdir(exist_ok=True)
|
||||
(out / "best").mkdir(exist_ok=True)
|
||||
(out / "evals").mkdir(exist_ok=True)
|
||||
|
||||
print(f"[train] out={out} n_envs={n_envs} total={total_timesteps} device={args.device}")
|
||||
# --- Inspect BC obs dim → infer frame_stack ---
|
||||
ref_only = PPO.load(str(bc_zip), device=args.device)
|
||||
obs_dim = int(ref_only.observation_space.shape[0])
|
||||
if obs_dim % OBS_DIM != 0:
|
||||
raise SystemExit(f"BC obs dim {obs_dim} is not a multiple of {OBS_DIM}.")
|
||||
frame_stack = obs_dim // OBS_DIM
|
||||
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
|
||||
|
||||
# --- Train env (vectorised, optionally normalised) ---
|
||||
env_fns = [_make_env(i, seed=args.seed) for i in range(n_envs)]
|
||||
venv = SubprocVecEnv(env_fns) if n_envs > 1 else DummyVecEnv(env_fns)
|
||||
eval_venv = DummyVecEnv([_make_env(99, seed=args.seed + 999)])
|
||||
if not args.no_vecnorm:
|
||||
venv = VecNormalize(venv, norm_obs=True, norm_reward=False, clip_obs=10.0)
|
||||
eval_venv = VecNormalize(eval_venv, norm_obs=True, norm_reward=False,
|
||||
clip_obs=10.0, training=False)
|
||||
eval_venv.obs_rms = venv.obs_rms
|
||||
else:
|
||||
print("[train] VecNormalize disabled (resumed policy was trained without it).")
|
||||
# --- Vectorised envs (match BC obs space) ---
|
||||
env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)]
|
||||
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
||||
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
|
||||
|
||||
# Apply env-level overrides (used by BC fine-tune to disable Strömbom
|
||||
# imitation and start at full deployment difficulty).
|
||||
def _env_call(method, value):
|
||||
for v in (venv, eval_venv):
|
||||
try:
|
||||
v.env_method(method, value)
|
||||
except AttributeError:
|
||||
v.venv.env_method(method, value)
|
||||
|
||||
if args.imitate_weight is not None:
|
||||
_env_call("set_imitate_weight", args.imitate_weight)
|
||||
print(f"[train] W_IMITATE overridden to {args.imitate_weight}")
|
||||
if args.difficulty is not None:
|
||||
_env_call("set_difficulty", args.difficulty)
|
||||
print(f"[train] difficulty pinned to {args.difficulty}")
|
||||
|
||||
# --- Model ---
|
||||
policy_kwargs = dict(
|
||||
net_arch=dict(pi=cfg["net_arch_pi"], vf=cfg["net_arch_vf"]),
|
||||
log_std_init=cfg.get("log_std_init", 0.0),
|
||||
# --- Trainable policy: load BC weights, then bolt onto PPO ---
|
||||
# Trick: instantiate a PPO with the right env (so the policy
|
||||
# network is constructed at the correct obs/action shape), then
|
||||
# copy BC weights into it.
|
||||
model = KLPPO(
|
||||
"MlpPolicy", venv,
|
||||
ref_policy=None, # filled in below
|
||||
kl_coef=args.kl_coef,
|
||||
learning_rate=args.learning_rate,
|
||||
n_steps=args.n_steps,
|
||||
batch_size=args.batch_size,
|
||||
n_epochs=args.n_epochs,
|
||||
gamma=args.gamma,
|
||||
gae_lambda=args.gae_lambda,
|
||||
clip_range=args.clip_range,
|
||||
ent_coef=args.ent_coef,
|
||||
target_kl=args.target_kl,
|
||||
policy_kwargs=dict(
|
||||
net_arch=dict(pi=[512, 512], vf=[512, 512]),
|
||||
log_std_init=args.log_std,
|
||||
),
|
||||
verbose=1,
|
||||
seed=args.seed,
|
||||
device=args.device,
|
||||
tensorboard_log=str(out / "tb"),
|
||||
)
|
||||
|
||||
if args.resume:
|
||||
print(f"[train] resuming from {args.resume}")
|
||||
custom_objects = {}
|
||||
if args.learning_rate is not None:
|
||||
custom_objects["learning_rate"] = args.learning_rate
|
||||
model = PPO.load(args.resume, env=venv, device=args.device,
|
||||
tensorboard_log=str(out / "tb"),
|
||||
custom_objects=custom_objects or None)
|
||||
if args.log_std is not None:
|
||||
import torch as _th
|
||||
with _th.no_grad():
|
||||
model.policy.log_std.fill_(args.log_std)
|
||||
print(f"[train] log_std overridden to {args.log_std} "
|
||||
f"(std≈{2.71828 ** args.log_std:.2f})")
|
||||
if args.learning_rate is not None:
|
||||
print(f"[train] learning_rate overridden to {args.learning_rate}")
|
||||
else:
|
||||
model = PPO(
|
||||
cfg["policy"], venv,
|
||||
learning_rate=cfg["learning_rate"],
|
||||
n_steps=cfg["n_steps"],
|
||||
batch_size=cfg["batch_size"],
|
||||
n_epochs=cfg["n_epochs"],
|
||||
gamma=cfg["gamma"],
|
||||
gae_lambda=cfg["gae_lambda"],
|
||||
clip_range=cfg["clip_range"],
|
||||
ent_coef=cfg["ent_coef"],
|
||||
vf_coef=cfg["vf_coef"],
|
||||
max_grad_norm=cfg["max_grad_norm"],
|
||||
target_kl=cfg.get("target_kl"),
|
||||
policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=str(out / "tb"),
|
||||
seed=args.seed,
|
||||
device=args.device,
|
||||
verbose=1,
|
||||
)
|
||||
# --- Load BC weights into both `model.policy` and `ref_policy` ---
|
||||
bc_state = ref_only.policy.state_dict()
|
||||
# Strict=False because the value head may not have been trained in
|
||||
# BC — that's fine, PPO will train it from scratch.
|
||||
missing, unexpected = model.policy.load_state_dict(bc_state, strict=False)
|
||||
print(f"[rl] BC → policy: missing={len(missing)} unexpected={len(unexpected)}")
|
||||
|
||||
# Build a separate reference policy with identical architecture and
|
||||
# the BC weights, frozen.
|
||||
ref_policy = type(model.policy)(
|
||||
observation_space=model.observation_space,
|
||||
action_space=model.action_space,
|
||||
lr_schedule=lambda _: 0.0,
|
||||
net_arch=dict(pi=[512, 512], vf=[512, 512]),
|
||||
log_std_init=args.log_std,
|
||||
).to(args.device)
|
||||
ref_policy.load_state_dict(bc_state, strict=False)
|
||||
model.ref_policy = ref_policy
|
||||
model.ref_policy.set_training_mode(False)
|
||||
for p in model.ref_policy.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# Align both policies' log_std. BC was trained with log_std≈0.5
|
||||
# (σ≈1.65), which would make the KL term huge from a std mismatch
|
||||
# rather than the mean drift we actually care about. Force both to
|
||||
# the same small value so KL measures only how far the policy mean
|
||||
# has drifted from the BC mean.
|
||||
with th.no_grad():
|
||||
model.policy.log_std.fill_(args.log_std)
|
||||
model.ref_policy.log_std.fill_(args.log_std)
|
||||
if args.freeze_log_std:
|
||||
model.policy.log_std.requires_grad = False
|
||||
print(f"[rl] log_std frozen at {args.log_std} (σ ≈ {np.exp(args.log_std):.3f})")
|
||||
|
||||
# --- Callbacks ---
|
||||
ckpt_cb = CheckpointCallback(
|
||||
save_freq=max(1, cfg["checkpoint_freq"] // n_envs),
|
||||
save_path=str(out / "checkpoints"), name_prefix="ppo",
|
||||
save_vecnormalize=True,
|
||||
save_freq=max(1, 50_000 // args.n_envs),
|
||||
save_path=str(out / "checkpoints"),
|
||||
name_prefix="ppo",
|
||||
)
|
||||
eval_cb = EvalCallback(
|
||||
eval_venv,
|
||||
best_model_save_path=str(out / "best"),
|
||||
log_path=str(out / "evals"),
|
||||
eval_freq=max(1, cfg["eval_freq"] // n_envs),
|
||||
n_eval_episodes=cfg["n_eval_episodes"],
|
||||
eval_freq=max(1, 20_000 // args.n_envs),
|
||||
n_eval_episodes=5,
|
||||
deterministic=True,
|
||||
)
|
||||
callbacks = [ckpt_cb, eval_cb]
|
||||
if not args.no_curriculum and "curriculum" in cfg and cfg["curriculum"]:
|
||||
callbacks.append(CurriculumCallback(
|
||||
cfg["curriculum"], [venv, eval_venv], verbose=1,
|
||||
))
|
||||
elif args.no_curriculum:
|
||||
print("[train] curriculum disabled — env knobs left at their current values.")
|
||||
|
||||
# --- Train ---
|
||||
model.learn(total_timesteps=total_timesteps, callback=callbacks,
|
||||
progress_bar=True)
|
||||
print(f"[rl] training: total_timesteps={args.total_timesteps} "
|
||||
f"n_envs={args.n_envs} lr={args.learning_rate} kl_coef={args.kl_coef}")
|
||||
model.learn(total_timesteps=args.total_timesteps,
|
||||
callback=[ckpt_cb, eval_cb], progress_bar=True)
|
||||
|
||||
# --- Save final model + VecNormalize stats ---
|
||||
model.save(out / "final.zip")
|
||||
venv.save(str(out / "vecnormalize.pkl"))
|
||||
# The EvalCallback already wrote best_model.zip into out/best/ — drop the
|
||||
# VecNormalize stats next to it for the controller to pick up.
|
||||
venv.save(str(out / "best" / "vecnormalize.pkl"))
|
||||
print(f"[train] done. saved to {out}")
|
||||
# --- Save final checkpoint in the SB3 zip the controller expects ---
|
||||
model.save(out / "policy.zip")
|
||||
print(f"[rl] saved fine-tuned policy → {out/'policy.zip'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user