"""Lazy loader for the SB3 PPO policy used by the dog controller. Importing stable-baselines3 inside the Webots Python interpreter is only needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This loader keeps SB3 out of the import path until you actually ask for the RL policy, so users without SB3 installed can still run the Strömbom baseline. The policy + VecNormalize statistics are saved together by ``training/train_ppo.py``: runs//best/best_model.zip # SB3 PPO checkpoint runs//best/vecnormalize.pkl # observation-normaliser stats Pass either the directory or the explicit zip path. """ import os from pathlib import Path class PolicyHandle: """Wrap a loaded PPO policy + VecNormalize so the controller can call ``predict(obs)`` without thinking about either. Frame stacking is auto-detected from the policy's expected obs dim: if it's a multiple of the single-frame ``OBS_DIM``, the handle keeps a deque of the last K frames and concatenates them on each predict. """ def __init__(self, model, vecnorm): self.model = model self.vecnorm = vecnorm # Lazy import to avoid forcing herding/* into the import path # when SB3 isn't being used. from herding.obs import OBS_DIM policy_dim = int(model.observation_space.shape[0]) if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1: self.frame_stack = policy_dim // OBS_DIM else: self.frame_stack = 1 self._buffer: list = [] self._single_dim = OBS_DIM def predict(self, obs): import numpy as np single = np.asarray(obs, dtype=np.float32).reshape(-1) if single.shape[0] != self._single_dim: # Caller already passed a stacked obs — use as-is. stacked = single elif self.frame_stack > 1: if not self._buffer: self._buffer = [single.copy() for _ in range(self.frame_stack)] else: self._buffer.append(single) if len(self._buffer) > self.frame_stack: self._buffer = self._buffer[-self.frame_stack:] stacked = np.concatenate(self._buffer, axis=0) else: stacked = single obs_b = stacked.reshape(1, -1) if self.vecnorm is not None: obs_b = self.vecnorm.normalize_obs(obs_b) action, _ = self.model.predict(obs_b, deterministic=True) return action[0] def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle: """Load a PPO model (and optional VecNormalize) from disk. ``model_path`` may be the .zip checkpoint or a directory containing ``best_model.zip`` (and optionally ``vecnormalize.pkl``). """ p = Path(model_path) if p.is_dir(): zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"] zip_path = next((z for z in zip_candidates if z.exists()), None) if zip_path is None: raise FileNotFoundError( f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)" ) if vecnorm_path is None: vn = p / "vecnormalize.pkl" if vn.exists(): vecnorm_path = str(vn) else: zip_path = p # Imports deferred so the Strömbom path doesn't require SB3. from stable_baselines3 import PPO from stable_baselines3.common.vec_env import VecNormalize model = PPO.load(str(zip_path), device="auto") vecnorm = None if vecnorm_path and os.path.exists(vecnorm_path): # VecNormalize.load needs a venv to attach to; we only need its stats # at inference, so we reconstruct the wrapper manually. import pickle with open(vecnorm_path, "rb") as f: vecnorm = pickle.load(f) vecnorm.training = False vecnorm.norm_reward = False return PolicyHandle(model=model, vecnorm=vecnorm)