LSTM (RecurrentPPO) experiment + recurrent policy support
Adds RecurrentPPO-based training as an alternative to MLP+frame-stack. The LSTM gives the policy unbounded temporal memory, addressing the partial-obs failure mode of the 140° Webots LiDAR (tracker briefly empties when the dog turns; sporadic phantom tracks confuse decisions). * training/rl/train_lstm.py: from-scratch RecurrentPPO trainer (no BC init, no KL term since there's no reference). Uses HERDING_WEBOTS preset so the obs distribution matches deployment. * training/eval.py: auto-detects RecurrentPPO zips, maintains LSTM hidden state across steps, resets between episodes. * controllers/shepherd_dog/policy_loader.py: PolicyHandle supports recurrent policies — state managed inside, reset_recurrent() exposed. Result on diff/field after 3M steps: - Gym (default 360°): 69% avg success across n=1..10 - Gym (HERDING_WEBOTS preset, training env): 2% — penning 3-4/5 but rarely all 5 - Webots LiDAR 140°: 0/5 (same wall as DAgger and v1 policies) Conclusion: architectural changes (LSTM vs MLP) don't close the perception sim-to-real gap. The gym LiDAR sim doesn't faithfully reproduce Webots phantom-track distribution; any policy trained on the gym proxy fails to handle real Webots phantoms regardless of architecture. Closing this gap requires either modeling Webots phantom patterns in the gym sim (multi-day work) or Webots-in-the-loop training (very slow). See memory/lstm_results.md for details. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -15,19 +15,35 @@ from pathlib import Path
|
||||
|
||||
|
||||
class PolicyHandle:
|
||||
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``."""
|
||||
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``.
|
||||
|
||||
def __init__(self, model, vecnorm):
|
||||
Supports both MLP (PPO) and recurrent (RecurrentPPO/LSTM) policies.
|
||||
For LSTM policies, frame_stack is forced to 1 and the LSTM hidden
|
||||
state is maintained across calls; ``reset_recurrent`` is exposed for
|
||||
new episodes.
|
||||
"""
|
||||
|
||||
def __init__(self, model, vecnorm, recurrent: bool = False):
|
||||
self.model = model
|
||||
self.vecnorm = vecnorm
|
||||
self.recurrent = recurrent
|
||||
from herding.perception.obs import OBS_DIM
|
||||
policy_dim = int(model.observation_space.shape[0])
|
||||
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
||||
if recurrent:
|
||||
self.frame_stack = 1
|
||||
elif policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
||||
self.frame_stack = policy_dim // OBS_DIM
|
||||
else:
|
||||
self.frame_stack = 1
|
||||
self._buffer: list = []
|
||||
self._single_dim = OBS_DIM
|
||||
self._lstm_state = None
|
||||
self._first_step = True
|
||||
|
||||
def reset_recurrent(self):
|
||||
self._lstm_state = None
|
||||
self._first_step = True
|
||||
self._buffer = []
|
||||
|
||||
def predict(self, obs):
|
||||
import numpy as np
|
||||
@@ -49,7 +65,15 @@ class PolicyHandle:
|
||||
obs_b = stacked.reshape(1, -1)
|
||||
if self.vecnorm is not None:
|
||||
obs_b = self.vecnorm.normalize_obs(obs_b)
|
||||
action, _ = self.model.predict(obs_b, deterministic=True)
|
||||
if self.recurrent:
|
||||
episode_start = np.array([self._first_step], dtype=bool)
|
||||
action, self._lstm_state = self.model.predict(
|
||||
obs_b, state=self._lstm_state,
|
||||
episode_start=episode_start, deterministic=True,
|
||||
)
|
||||
self._first_step = False
|
||||
else:
|
||||
action, _ = self.model.predict(obs_b, deterministic=True)
|
||||
return action[0]
|
||||
|
||||
|
||||
@@ -79,7 +103,16 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
|
||||
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
# Try RecurrentPPO (LSTM) first, fall back to PPO (MLP).
|
||||
recurrent = False
|
||||
model = None
|
||||
try:
|
||||
from sb3_contrib import RecurrentPPO
|
||||
model = RecurrentPPO.load(str(zip_path), device="auto")
|
||||
recurrent = True
|
||||
except Exception:
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
|
||||
vecnorm = None
|
||||
if vecnorm_path and os.path.exists(vecnorm_path):
|
||||
import pickle
|
||||
@@ -87,4 +120,4 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
||||
vecnorm = pickle.load(f)
|
||||
vecnorm.training = False
|
||||
vecnorm.norm_reward = False
|
||||
return PolicyHandle(model=model, vecnorm=vecnorm)
|
||||
return PolicyHandle(model=model, vecnorm=vecnorm, recurrent=recurrent)
|
||||
|
||||
Reference in New Issue
Block a user