Files
TIR_PROJ/controllers/shepherd_dog/policy_loader.py
T
Johnny Fernandes a01a5c9cef Checkpoint 7
2026-05-11 12:21:51 +01:00

91 lines
3.2 KiB
Python

"""Lazy SB3 policy loader for the dog controller.
SB3 is imported only when a learned policy is actually requested,
so the analytic modes can run on installs without stable-baselines3
or torch.
The handle auto-detects frame stacking from the policy's expected
observation dimension: if it's a multiple of the single-frame
``OBS_DIM``, an internal buffer of the last K frames is maintained
and concatenated on each ``predict`` call.
"""
import os
from pathlib import Path
class PolicyHandle:
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``."""
def __init__(self, model, vecnorm):
self.model = model
self.vecnorm = vecnorm
from herding.perception.obs import OBS_DIM
policy_dim = int(model.observation_space.shape[0])
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
self.frame_stack = policy_dim // OBS_DIM
else:
self.frame_stack = 1
self._buffer: list = []
self._single_dim = OBS_DIM
def predict(self, obs):
import numpy as np
single = np.asarray(obs, dtype=np.float32).reshape(-1)
if single.shape[0] != self._single_dim:
# Caller passed an already-stacked obs.
stacked = single
elif self.frame_stack > 1:
if not self._buffer:
self._buffer = [single.copy() for _ in range(self.frame_stack)]
else:
self._buffer.append(single)
if len(self._buffer) > self.frame_stack:
self._buffer = self._buffer[-self.frame_stack:]
stacked = np.concatenate(self._buffer, axis=0)
else:
stacked = single
obs_b = stacked.reshape(1, -1)
if self.vecnorm is not None:
obs_b = self.vecnorm.normalize_obs(obs_b)
action, _ = self.model.predict(obs_b, deterministic=True)
return action[0]
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
"""Load a policy zip (+ optional VecNormalize pickle) from disk.
``model_path`` may be a ``.zip`` file or a directory; in the
latter case ``policy.zip`` is preferred, with ``final.zip`` as
a fallback for partially-completed RL runs.
"""
p = Path(model_path)
if p.is_dir():
zip_candidates = [p / "policy.zip", p / "final.zip"]
zip_path = next((z for z in zip_candidates if z.exists()), None)
if zip_path is None:
raise FileNotFoundError(
f"No policy zip in {p} (looked for policy.zip, final.zip)"
)
if vecnorm_path is None:
vn = p / "vecnormalize.pkl"
if vn.exists():
vecnorm_path = str(vn)
else:
zip_path = p
# Deferred imports so the analytic path doesn't require SB3.
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
model = PPO.load(str(zip_path), device="auto")
vecnorm = None
if vecnorm_path and os.path.exists(vecnorm_path):
import pickle
with open(vecnorm_path, "rb") as f:
vecnorm = pickle.load(f)
vecnorm.training = False
vecnorm.norm_reward = False
return PolicyHandle(model=model, vecnorm=vecnorm)