876e14e74f
Adds RecurrentPPO-based training as an alternative to MLP+frame-stack. The LSTM gives the policy unbounded temporal memory, addressing the partial-obs failure mode of the 140° Webots LiDAR (tracker briefly empties when the dog turns; sporadic phantom tracks confuse decisions). * training/rl/train_lstm.py: from-scratch RecurrentPPO trainer (no BC init, no KL term since there's no reference). Uses HERDING_WEBOTS preset so the obs distribution matches deployment. * training/eval.py: auto-detects RecurrentPPO zips, maintains LSTM hidden state across steps, resets between episodes. * controllers/shepherd_dog/policy_loader.py: PolicyHandle supports recurrent policies — state managed inside, reset_recurrent() exposed. Result on diff/field after 3M steps: - Gym (default 360°): 69% avg success across n=1..10 - Gym (HERDING_WEBOTS preset, training env): 2% — penning 3-4/5 but rarely all 5 - Webots LiDAR 140°: 0/5 (same wall as DAgger and v1 policies) Conclusion: architectural changes (LSTM vs MLP) don't close the perception sim-to-real gap. The gym LiDAR sim doesn't faithfully reproduce Webots phantom-track distribution; any policy trained on the gym proxy fails to handle real Webots phantoms regardless of architecture. Closing this gap requires either modeling Webots phantom patterns in the gym sim (multi-day work) or Webots-in-the-loop training (very slow). See memory/lstm_results.md for details. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
124 lines
4.4 KiB
Python
124 lines
4.4 KiB
Python
"""Lazy SB3 policy loader for the dog controller.
|
|
|
|
SB3 is imported only when a learned policy is actually requested,
|
|
so the analytic modes can run on installs without stable-baselines3
|
|
or torch.
|
|
|
|
The handle auto-detects frame stacking from the policy's expected
|
|
observation dimension: if it's a multiple of the single-frame
|
|
``OBS_DIM``, an internal buffer of the last K frames is maintained
|
|
and concatenated on each ``predict`` call.
|
|
"""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
|
|
|
|
class PolicyHandle:
|
|
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``.
|
|
|
|
Supports both MLP (PPO) and recurrent (RecurrentPPO/LSTM) policies.
|
|
For LSTM policies, frame_stack is forced to 1 and the LSTM hidden
|
|
state is maintained across calls; ``reset_recurrent`` is exposed for
|
|
new episodes.
|
|
"""
|
|
|
|
def __init__(self, model, vecnorm, recurrent: bool = False):
|
|
self.model = model
|
|
self.vecnorm = vecnorm
|
|
self.recurrent = recurrent
|
|
from herding.perception.obs import OBS_DIM
|
|
policy_dim = int(model.observation_space.shape[0])
|
|
if recurrent:
|
|
self.frame_stack = 1
|
|
elif policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
|
self.frame_stack = policy_dim // OBS_DIM
|
|
else:
|
|
self.frame_stack = 1
|
|
self._buffer: list = []
|
|
self._single_dim = OBS_DIM
|
|
self._lstm_state = None
|
|
self._first_step = True
|
|
|
|
def reset_recurrent(self):
|
|
self._lstm_state = None
|
|
self._first_step = True
|
|
self._buffer = []
|
|
|
|
def predict(self, obs):
|
|
import numpy as np
|
|
single = np.asarray(obs, dtype=np.float32).reshape(-1)
|
|
if single.shape[0] != self._single_dim:
|
|
# Caller passed an already-stacked obs.
|
|
stacked = single
|
|
elif self.frame_stack > 1:
|
|
if not self._buffer:
|
|
self._buffer = [single.copy() for _ in range(self.frame_stack)]
|
|
else:
|
|
self._buffer.append(single)
|
|
if len(self._buffer) > self.frame_stack:
|
|
self._buffer = self._buffer[-self.frame_stack:]
|
|
stacked = np.concatenate(self._buffer, axis=0)
|
|
else:
|
|
stacked = single
|
|
|
|
obs_b = stacked.reshape(1, -1)
|
|
if self.vecnorm is not None:
|
|
obs_b = self.vecnorm.normalize_obs(obs_b)
|
|
if self.recurrent:
|
|
episode_start = np.array([self._first_step], dtype=bool)
|
|
action, self._lstm_state = self.model.predict(
|
|
obs_b, state=self._lstm_state,
|
|
episode_start=episode_start, deterministic=True,
|
|
)
|
|
self._first_step = False
|
|
else:
|
|
action, _ = self.model.predict(obs_b, deterministic=True)
|
|
return action[0]
|
|
|
|
|
|
def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
|
"""Load a policy zip (+ optional VecNormalize pickle) from disk.
|
|
|
|
``model_path`` may be a ``.zip`` file or a directory; in the
|
|
latter case ``policy.zip`` is preferred, with ``final.zip`` as
|
|
a fallback for partially-completed RL runs.
|
|
"""
|
|
p = Path(model_path)
|
|
if p.is_dir():
|
|
zip_candidates = [p / "policy.zip", p / "final.zip"]
|
|
zip_path = next((z for z in zip_candidates if z.exists()), None)
|
|
if zip_path is None:
|
|
raise FileNotFoundError(
|
|
f"No policy zip in {p} (looked for policy.zip, final.zip)"
|
|
)
|
|
if vecnorm_path is None:
|
|
vn = p / "vecnormalize.pkl"
|
|
if vn.exists():
|
|
vecnorm_path = str(vn)
|
|
else:
|
|
zip_path = p
|
|
|
|
# Deferred imports so the analytic path doesn't require SB3.
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
|
|
|
|
# Try RecurrentPPO (LSTM) first, fall back to PPO (MLP).
|
|
recurrent = False
|
|
model = None
|
|
try:
|
|
from sb3_contrib import RecurrentPPO
|
|
model = RecurrentPPO.load(str(zip_path), device="auto")
|
|
recurrent = True
|
|
except Exception:
|
|
model = PPO.load(str(zip_path), device="auto")
|
|
|
|
vecnorm = None
|
|
if vecnorm_path and os.path.exists(vecnorm_path):
|
|
import pickle
|
|
with open(vecnorm_path, "rb") as f:
|
|
vecnorm = pickle.load(f)
|
|
vecnorm.training = False
|
|
vecnorm.norm_reward = False
|
|
return PolicyHandle(model=model, vecnorm=vecnorm, recurrent=recurrent)
|