LSTM (RecurrentPPO) experiment + recurrent policy support

Adds RecurrentPPO-based training as an alternative to MLP+frame-stack.
The LSTM gives the policy unbounded temporal memory, addressing the
partial-obs failure mode of the 140° Webots LiDAR (tracker briefly
empties when the dog turns; sporadic phantom tracks confuse decisions).

* training/rl/train_lstm.py: from-scratch RecurrentPPO trainer (no BC
  init, no KL term since there's no reference). Uses HERDING_WEBOTS
  preset so the obs distribution matches deployment.
* training/eval.py: auto-detects RecurrentPPO zips, maintains LSTM
  hidden state across steps, resets between episodes.
* controllers/shepherd_dog/policy_loader.py: PolicyHandle supports
  recurrent policies — state managed inside, reset_recurrent() exposed.

Result on diff/field after 3M steps:
- Gym (default 360°): 69% avg success across n=1..10
- Gym (HERDING_WEBOTS preset, training env): 2% — penning 3-4/5 but
  rarely all 5
- Webots LiDAR 140°: 0/5 (same wall as DAgger and v1 policies)

Conclusion: architectural changes (LSTM vs MLP) don't close the
perception sim-to-real gap. The gym LiDAR sim doesn't faithfully
reproduce Webots phantom-track distribution; any policy trained on the
gym proxy fails to handle real Webots phantoms regardless of
architecture. Closing this gap requires either modeling Webots phantom
patterns in the gym sim (multi-day work) or Webots-in-the-loop
training (very slow). See memory/lstm_results.md for details.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Johnny Fernandes
2026-05-16 19:22:32 +00:00
parent dd5ac669e5
commit 876e14e74f
4 changed files with 248 additions and 10 deletions
+39 -6
View File
@@ -15,19 +15,35 @@ from pathlib import Path
class PolicyHandle:
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``."""
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``.
def __init__(self, model, vecnorm):
Supports both MLP (PPO) and recurrent (RecurrentPPO/LSTM) policies.
For LSTM policies, frame_stack is forced to 1 and the LSTM hidden
state is maintained across calls; ``reset_recurrent`` is exposed for
new episodes.
"""
def __init__(self, model, vecnorm, recurrent: bool = False):
self.model = model
self.vecnorm = vecnorm
self.recurrent = recurrent
from herding.perception.obs import OBS_DIM
policy_dim = int(model.observation_space.shape[0])
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
if recurrent:
self.frame_stack = 1
elif policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
self.frame_stack = policy_dim // OBS_DIM
else:
self.frame_stack = 1
self._buffer: list = []
self._single_dim = OBS_DIM
self._lstm_state = None
self._first_step = True
def reset_recurrent(self):
self._lstm_state = None
self._first_step = True
self._buffer = []
def predict(self, obs):
import numpy as np
@@ -49,7 +65,15 @@ class PolicyHandle:
obs_b = stacked.reshape(1, -1)
if self.vecnorm is not None:
obs_b = self.vecnorm.normalize_obs(obs_b)
action, _ = self.model.predict(obs_b, deterministic=True)
if self.recurrent:
episode_start = np.array([self._first_step], dtype=bool)
action, self._lstm_state = self.model.predict(
obs_b, state=self._lstm_state,
episode_start=episode_start, deterministic=True,
)
self._first_step = False
else:
action, _ = self.model.predict(obs_b, deterministic=True)
return action[0]
@@ -79,7 +103,16 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
model = PPO.load(str(zip_path), device="auto")
# Try RecurrentPPO (LSTM) first, fall back to PPO (MLP).
recurrent = False
model = None
try:
from sb3_contrib import RecurrentPPO
model = RecurrentPPO.load(str(zip_path), device="auto")
recurrent = True
except Exception:
model = PPO.load(str(zip_path), device="auto")
vecnorm = None
if vecnorm_path and os.path.exists(vecnorm_path):
import pickle
@@ -87,4 +120,4 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
vecnorm = pickle.load(f)
vecnorm.training = False
vecnorm.norm_reward = False
return PolicyHandle(model=model, vecnorm=vecnorm)
return PolicyHandle(model=model, vecnorm=vecnorm, recurrent=recurrent)