Checkpoint 2
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
"""Lazy loader for the SB3 PPO policy used by the dog controller.
|
||||
|
||||
Importing stable-baselines3 inside the Webots Python interpreter is only
|
||||
needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This
|
||||
loader keeps SB3 out of the import path until you actually ask for the RL
|
||||
policy, so users without SB3 installed can still run the Strömbom
|
||||
baseline.
|
||||
|
||||
The policy + VecNormalize statistics are saved together by
|
||||
``training/train_ppo.py``:
|
||||
|
||||
runs/<name>/best/best_model.zip # SB3 PPO checkpoint
|
||||
runs/<name>/best/vecnormalize.pkl # observation-normaliser stats
|
||||
|
||||
Pass either the directory or the explicit zip path.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class PolicyHandle:
|
||||
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
|
||||
``predict(obs)`` without thinking about either."""
|
||||
|
||||
def __init__(self, model, vecnorm):
|
||||
self.model = model
|
||||
self.vecnorm = vecnorm
|
||||
|
||||
def predict(self, obs):
|
||||
# VecNormalize expects a batched obs of shape (n_envs, obs_dim).
|
||||
if self.vecnorm is not None:
|
||||
import numpy as np
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
obs_b = self.vecnorm.normalize_obs(obs_b)
|
||||
else:
|
||||
import numpy as np
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
action, _ = self.model.predict(obs_b, deterministic=True)
|
||||
return action[0]
|
||||
|
||||
|
||||
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
||||
"""Load a PPO model (and optional VecNormalize) from disk.
|
||||
|
||||
``model_path`` may be the .zip checkpoint or a directory containing
|
||||
``best_model.zip`` (and optionally ``vecnormalize.pkl``).
|
||||
"""
|
||||
p = Path(model_path)
|
||||
if p.is_dir():
|
||||
zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"]
|
||||
zip_path = next((z for z in zip_candidates if z.exists()), None)
|
||||
if zip_path is None:
|
||||
raise FileNotFoundError(
|
||||
f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)"
|
||||
)
|
||||
if vecnorm_path is None:
|
||||
vn = p / "vecnormalize.pkl"
|
||||
if vn.exists():
|
||||
vecnorm_path = str(vn)
|
||||
else:
|
||||
zip_path = p
|
||||
|
||||
# Imports deferred so the Strömbom path doesn't require SB3.
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
vecnorm = None
|
||||
if vecnorm_path and os.path.exists(vecnorm_path):
|
||||
# VecNormalize.load needs a venv to attach to; we only need its stats
|
||||
# at inference, so we reconstruct the wrapper manually.
|
||||
import pickle
|
||||
with open(vecnorm_path, "rb") as f:
|
||||
vecnorm = pickle.load(f)
|
||||
vecnorm.training = False
|
||||
vecnorm.norm_reward = False
|
||||
return PolicyHandle(model=model, vecnorm=vecnorm)
|
||||
Reference in New Issue
Block a user