Checkpoint 7
This commit is contained in:
@@ -1,18 +1,13 @@
|
||||
"""Lazy loader for the SB3 PPO policy used by the dog controller.
|
||||
"""Lazy SB3 policy loader for the dog controller.
|
||||
|
||||
Importing stable-baselines3 inside the Webots Python interpreter is only
|
||||
needed when ``HERDING_MODE=rl``; the Strömbom mode runs without it. This
|
||||
loader keeps SB3 out of the import path until you actually ask for the RL
|
||||
policy, so users without SB3 installed can still run the Strömbom
|
||||
baseline.
|
||||
SB3 is imported only when a learned policy is actually requested,
|
||||
so the analytic modes can run on installs without stable-baselines3
|
||||
or torch.
|
||||
|
||||
The policy + VecNormalize statistics are saved together by
|
||||
``training/train_ppo.py``:
|
||||
|
||||
runs/<name>/best/best_model.zip # SB3 PPO checkpoint
|
||||
runs/<name>/best/vecnormalize.pkl # observation-normaliser stats
|
||||
|
||||
Pass either the directory or the explicit zip path.
|
||||
The handle auto-detects frame stacking from the policy's expected
|
||||
observation dimension: if it's a multiple of the single-frame
|
||||
``OBS_DIM``, an internal buffer of the last K frames is maintained
|
||||
and concatenated on each ``predict`` call.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -20,20 +15,12 @@ from pathlib import Path
|
||||
|
||||
|
||||
class PolicyHandle:
|
||||
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
|
||||
``predict(obs)`` without thinking about either.
|
||||
|
||||
Frame stacking is auto-detected from the policy's expected obs dim:
|
||||
if it's a multiple of the single-frame ``OBS_DIM``, the handle keeps
|
||||
a deque of the last K frames and concatenates them on each predict.
|
||||
"""
|
||||
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``."""
|
||||
|
||||
def __init__(self, model, vecnorm):
|
||||
self.model = model
|
||||
self.vecnorm = vecnorm
|
||||
# Lazy import to avoid forcing herding/* into the import path
|
||||
# when SB3 isn't being used.
|
||||
from herding.obs import OBS_DIM
|
||||
from herding.perception.obs import OBS_DIM
|
||||
policy_dim = int(model.observation_space.shape[0])
|
||||
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
||||
self.frame_stack = policy_dim // OBS_DIM
|
||||
@@ -46,7 +33,7 @@ class PolicyHandle:
|
||||
import numpy as np
|
||||
single = np.asarray(obs, dtype=np.float32).reshape(-1)
|
||||
if single.shape[0] != self._single_dim:
|
||||
# Caller already passed a stacked obs — use as-is.
|
||||
# Caller passed an already-stacked obs.
|
||||
stacked = single
|
||||
elif self.frame_stack > 1:
|
||||
if not self._buffer:
|
||||
@@ -67,18 +54,19 @@ class PolicyHandle:
|
||||
|
||||
|
||||
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
||||
"""Load a PPO model (and optional VecNormalize) from disk.
|
||||
"""Load a policy zip (+ optional VecNormalize pickle) from disk.
|
||||
|
||||
``model_path`` may be the .zip checkpoint or a directory containing
|
||||
``best_model.zip`` (and optionally ``vecnormalize.pkl``).
|
||||
``model_path`` may be a ``.zip`` file or a directory; in the
|
||||
latter case ``policy.zip`` is preferred, with ``final.zip`` as
|
||||
a fallback for partially-completed RL runs.
|
||||
"""
|
||||
p = Path(model_path)
|
||||
if p.is_dir():
|
||||
zip_candidates = [p / "best_model.zip", p / "final.zip", p / "policy.zip"]
|
||||
zip_candidates = [p / "policy.zip", p / "final.zip"]
|
||||
zip_path = next((z for z in zip_candidates if z.exists()), None)
|
||||
if zip_path is None:
|
||||
raise FileNotFoundError(
|
||||
f"No PPO zip found in {p} (looked for best_model.zip, final.zip, policy.zip)"
|
||||
f"No policy zip in {p} (looked for policy.zip, final.zip)"
|
||||
)
|
||||
if vecnorm_path is None:
|
||||
vn = p / "vecnormalize.pkl"
|
||||
@@ -87,15 +75,13 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
||||
else:
|
||||
zip_path = p
|
||||
|
||||
# Imports deferred so the Strömbom path doesn't require SB3.
|
||||
# Deferred imports so the analytic path doesn't require SB3.
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
|
||||
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
vecnorm = None
|
||||
if vecnorm_path and os.path.exists(vecnorm_path):
|
||||
# VecNormalize.load needs a venv to attach to; we only need its stats
|
||||
# at inference, so we reconstruct the wrapper manually.
|
||||
import pickle
|
||||
with open(vecnorm_path, "rb") as f:
|
||||
vecnorm = pickle.load(f)
|
||||
|
||||
Reference in New Issue
Block a user