Files
TIR_PROJ/controllers/shepherd_dog/policy_loader.py
T
Johnny Fernandes 1bb9415414 Checkpoint 2
2026-05-07 22:00:10 +01:00

79 lines
2.8 KiB
Python

"""Lazy loader for the SB3 PPO policy used by the dog controller.
Importing stable-baselines3 inside the Webots Python interpreter is only
needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This
loader keeps SB3 out of the import path until you actually ask for the RL
policy, so users without SB3 installed can still run the Strömbom
baseline.
The policy + VecNormalize statistics are saved together by
``training/train_ppo.py``:
runs/<name>/best/best_model.zip # SB3 PPO checkpoint
runs/<name>/best/vecnormalize.pkl # observation-normaliser stats
Pass either the directory or the explicit zip path.
"""
import os
from pathlib import Path
class PolicyHandle:
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
``predict(obs)`` without thinking about either."""
def __init__(self, model, vecnorm):
self.model = model
self.vecnorm = vecnorm
def predict(self, obs):
# VecNormalize expects a batched obs of shape (n_envs, obs_dim).
if self.vecnorm is not None:
import numpy as np
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
obs_b = self.vecnorm.normalize_obs(obs_b)
else:
import numpy as np
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
action, _ = self.model.predict(obs_b, deterministic=True)
return action[0]
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
"""Load a PPO model (and optional VecNormalize) from disk.
``model_path`` may be the .zip checkpoint or a directory containing
``best_model.zip`` (and optionally ``vecnormalize.pkl``).
"""
p = Path(model_path)
if p.is_dir():
zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"]
zip_path = next((z for z in zip_candidates if z.exists()), None)
if zip_path is None:
raise FileNotFoundError(
f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)"
)
if vecnorm_path is None:
vn = p / "vecnormalize.pkl"
if vn.exists():
vecnorm_path = str(vn)
else:
zip_path = p
# Imports deferred so the Strömbom path doesn't require SB3.
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
model = PPO.load(str(zip_path), device="auto")
vecnorm = None
if vecnorm_path and os.path.exists(vecnorm_path):
# VecNormalize.load needs a venv to attach to; we only need its stats
# at inference, so we reconstruct the wrapper manually.
import pickle
with open(vecnorm_path, "rb") as f:
vecnorm = pickle.load(f)
vecnorm.training = False
vecnorm.norm_reward = False
return PolicyHandle(model=model, vecnorm=vecnorm)