"""Lazy SB3 policy loader for the dog controller. SB3 is imported only when a learned policy is actually requested, so the analytic modes can run on installs without stable-baselines3 or torch. The handle auto-detects frame stacking from the policy's expected observation dimension: if it's a multiple of the single-frame ``OBS_DIM``, an internal buffer of the last K frames is maintained and concatenated on each ``predict`` call. """ import os from pathlib import Path class PolicyHandle: """Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``. Supports both MLP (PPO) and recurrent (RecurrentPPO/LSTM) policies. For LSTM policies, frame_stack is forced to 1 and the LSTM hidden state is maintained across calls; ``reset_recurrent`` is exposed for new episodes. """ def __init__(self, model, vecnorm, recurrent: bool = False): self.model = model self.vecnorm = vecnorm self.recurrent = recurrent from herding.perception.obs import OBS_DIM policy_dim = int(model.observation_space.shape[0]) if recurrent: self.frame_stack = 1 elif policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1: self.frame_stack = policy_dim // OBS_DIM else: self.frame_stack = 1 self._buffer: list = [] self._single_dim = OBS_DIM self._lstm_state = None self._first_step = True def reset_recurrent(self): self._lstm_state = None self._first_step = True self._buffer = [] def predict(self, obs): import numpy as np single = np.asarray(obs, dtype=np.float32).reshape(-1) if single.shape[0] != self._single_dim: # Caller passed an already-stacked obs. stacked = single elif self.frame_stack > 1: if not self._buffer: self._buffer = [single.copy() for _ in range(self.frame_stack)] else: self._buffer.append(single) if len(self._buffer) > self.frame_stack: self._buffer = self._buffer[-self.frame_stack:] stacked = np.concatenate(self._buffer, axis=0) else: stacked = single obs_b = stacked.reshape(1, -1) if self.vecnorm is not None: obs_b = self.vecnorm.normalize_obs(obs_b) if self.recurrent: episode_start = np.array([self._first_step], dtype=bool) action, self._lstm_state = self.model.predict( obs_b, state=self._lstm_state, episode_start=episode_start, deterministic=True, ) self._first_step = False else: action, _ = self.model.predict(obs_b, deterministic=True) return action[0] def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle: """Load a policy zip (+ optional VecNormalize pickle) from disk. ``model_path`` may be a ``.zip`` file or a directory; in the latter case ``policy.zip`` is preferred, with ``final.zip`` as a fallback for partially-completed RL runs. """ p = Path(model_path) if p.is_dir(): zip_candidates = [p / "policy.zip", p / "final.zip"] zip_path = next((z for z in zip_candidates if z.exists()), None) if zip_path is None: raise FileNotFoundError( f"No policy zip in {p} (looked for policy.zip, final.zip)" ) if vecnorm_path is None: vn = p / "vecnormalize.pkl" if vn.exists(): vecnorm_path = str(vn) else: zip_path = p # Deferred imports so the analytic path doesn't require SB3. from stable_baselines3 import PPO from stable_baselines3.common.vec_env import VecNormalize # noqa: F401 # Try RecurrentPPO (LSTM) first, fall back to PPO (MLP). recurrent = False model = None try: from sb3_contrib import RecurrentPPO model = RecurrentPPO.load(str(zip_path), device="auto") recurrent = True except Exception: model = PPO.load(str(zip_path), device="auto") vecnorm = None if vecnorm_path and os.path.exists(vecnorm_path): import pickle with open(vecnorm_path, "rb") as f: vecnorm = pickle.load(f) vecnorm.training = False vecnorm.norm_reward = False return PolicyHandle(model=model, vecnorm=vecnorm, recurrent=recurrent)