Files
TIR_PROJ/controllers/shepherd_dog/policy_loader.py
T
Johnny Fernandes 6688325d89 Checkpoint 4
2026-05-11 00:42:52 +01:00

105 lines
3.9 KiB
Python

"""Lazy loader for the SB3 PPO policy used by the dog controller.
Importing stable-baselines3 inside the Webots Python interpreter is only
needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This
loader keeps SB3 out of the import path until you actually ask for the RL
policy, so users without SB3 installed can still run the Strömbom
baseline.
The policy + VecNormalize statistics are saved together by
``training/train_ppo.py``:
runs/<name>/best/best_model.zip # SB3 PPO checkpoint
runs/<name>/best/vecnormalize.pkl # observation-normaliser stats
Pass either the directory or the explicit zip path.
"""
import os
from pathlib import Path
class PolicyHandle:
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
``predict(obs)`` without thinking about either.
Frame stacking is auto-detected from the policy's expected obs dim:
if it's a multiple of the single-frame ``OBS_DIM``, the handle keeps
a deque of the last K frames and concatenates them on each predict.
"""
def __init__(self, model, vecnorm):
self.model = model
self.vecnorm = vecnorm
# Lazy import to avoid forcing herding/* into the import path
# when SB3 isn't being used.
from herding.obs import OBS_DIM
policy_dim = int(model.observation_space.shape[0])
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
self.frame_stack = policy_dim // OBS_DIM
else:
self.frame_stack = 1
self._buffer: list = []
self._single_dim = OBS_DIM
def predict(self, obs):
import numpy as np
single = np.asarray(obs, dtype=np.float32).reshape(-1)
if single.shape[0] != self._single_dim:
# Caller already passed a stacked obs — use as-is.
stacked = single
elif self.frame_stack > 1:
if not self._buffer:
self._buffer = [single.copy() for _ in range(self.frame_stack)]
else:
self._buffer.append(single)
if len(self._buffer) > self.frame_stack:
self._buffer = self._buffer[-self.frame_stack:]
stacked = np.concatenate(self._buffer, axis=0)
else:
stacked = single
obs_b = stacked.reshape(1, -1)
if self.vecnorm is not None:
obs_b = self.vecnorm.normalize_obs(obs_b)
action, _ = self.model.predict(obs_b, deterministic=True)
return action[0]
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
"""Load a PPO model (and optional VecNormalize) from disk.
``model_path`` may be the .zip checkpoint or a directory containing
``best_model.zip`` (and optionally ``vecnormalize.pkl``).
"""
p = Path(model_path)
if p.is_dir():
zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"]
zip_path = next((z for z in zip_candidates if z.exists()), None)
if zip_path is None:
raise FileNotFoundError(
f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)"
)
if vecnorm_path is None:
vn = p / "vecnormalize.pkl"
if vn.exists():
vecnorm_path = str(vn)
else:
zip_path = p
# Imports deferred so the Strömbom path doesn't require SB3.
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
model = PPO.load(str(zip_path), device="auto")
vecnorm = None
if vecnorm_path and os.path.exists(vecnorm_path):
# VecNormalize.load needs a venv to attach to; we only need its stats
# at inference, so we reconstruct the wrapper manually.
import pickle
with open(vecnorm_path, "rb") as f:
vecnorm = pickle.load(f)
vecnorm.training = False
vecnorm.norm_reward = False
return PolicyHandle(model=model, vecnorm=vecnorm)