Checkpoint 7

This commit is contained in:
Johnny Fernandes
2026-05-11 12:21:51 +01:00
parent fce0e0c786
commit a01a5c9cef
34 changed files with 1266 additions and 1038 deletions
+19 -33
View File
@@ -1,18 +1,13 @@
"""Lazy loader for the SB3 PPO policy used by the dog controller.
"""Lazy SB3 policy loader for the dog controller.
Importing stable-baselines3 inside the Webots Python interpreter is only
needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This
loader keeps SB3 out of the import path until you actually ask for the RL
policy, so users without SB3 installed can still run the Strömbom
baseline.
SB3 is imported only when a learned policy is actually requested,
so the analytic modes can run on installs without stable-baselines3
or torch.
The policy + VecNormalize statistics are saved together by
``training/train_ppo.py``:
runs/<name>/best/best_model.zip # SB3 PPO checkpoint
runs/<name>/best/vecnormalize.pkl # observation-normaliser stats
Pass either the directory or the explicit zip path.
The handle auto-detects frame stacking from the policy's expected
observation dimension: if it's a multiple of the single-frame
``OBS_DIM``, an internal buffer of the last K frames is maintained
and concatenated on each ``predict`` call.
"""
import os
@@ -20,20 +15,12 @@ from pathlib import Path
class PolicyHandle:
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
``predict(obs)`` without thinking about either.
Frame stacking is auto-detected from the policy's expected obs dim:
if it's a multiple of the single-frame ``OBS_DIM``, the handle keeps
a deque of the last K frames and concatenates them on each predict.
"""
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``."""
def __init__(self, model, vecnorm):
self.model = model
self.vecnorm = vecnorm
# Lazy import to avoid forcing herding/* into the import path
# when SB3 isn't being used.
from herding.obs import OBS_DIM
from herding.perception.obs import OBS_DIM
policy_dim = int(model.observation_space.shape[0])
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
self.frame_stack = policy_dim // OBS_DIM
@@ -46,7 +33,7 @@ class PolicyHandle:
import numpy as np
single = np.asarray(obs, dtype=np.float32).reshape(-1)
if single.shape[0] != self._single_dim:
# Caller already passed a stacked obs — use as-is.
# Caller passed an already-stacked obs.
stacked = single
elif self.frame_stack > 1:
if not self._buffer:
@@ -67,18 +54,19 @@ class PolicyHandle:
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
"""Load a PPO model (and optional VecNormalize) from disk.
"""Load a policy zip (+ optional VecNormalize pickle) from disk.
``model_path`` may be the .zip checkpoint or a directory containing
``best_model.zip`` (and optionally ``vecnormalize.pkl``).
``model_path`` may be a ``.zip`` file or a directory; in the
latter case ``policy.zip`` is preferred, with ``final.zip`` as
a fallback for partially-completed RL runs.
"""
p = Path(model_path)
if p.is_dir():
zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"]
zip_candidates = [p / "policy.zip", p / "final.zip"]
zip_path = next((z for z in zip_candidates if z.exists()), None)
if zip_path is None:
raise FileNotFoundError(
f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)"
f"No policy zip in {p} (looked for policy.zip, final.zip)"
)
if vecnorm_path is None:
vn = p / "vecnormalize.pkl"
@@ -87,15 +75,13 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
else:
zip_path = p
# Imports deferred so the Strömbom path doesn't require SB3.
# Deferred imports so the analytic path doesn't require SB3.
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
model = PPO.load(str(zip_path), device="auto")
vecnorm = None
if vecnorm_path and os.path.exists(vecnorm_path):
# VecNormalize.load needs a venv to attach to; we only need its stats
# at inference, so we reconstruct the wrapper manually.
import pickle
with open(vecnorm_path, "rb") as f:
vecnorm = pickle.load(f)