LSTM (RecurrentPPO) experiment + recurrent policy support

Adds RecurrentPPO-based training as an alternative to MLP+frame-stack.
The LSTM gives the policy unbounded temporal memory, addressing the
partial-obs failure mode of the 140° Webots LiDAR (tracker briefly
empties when the dog turns; sporadic phantom tracks confuse decisions).

* training/rl/train_lstm.py: from-scratch RecurrentPPO trainer (no BC
  init, no KL term since there's no reference). Uses HERDING_WEBOTS
  preset so the obs distribution matches deployment.
* training/eval.py: auto-detects RecurrentPPO zips, maintains LSTM
  hidden state across steps, resets between episodes.
* controllers/shepherd_dog/policy_loader.py: PolicyHandle supports
  recurrent policies — state managed inside, reset_recurrent() exposed.

Result on diff/field after 3M steps:
- Gym (default 360°): 69% avg success across n=1..10
- Gym (HERDING_WEBOTS preset, training env): 2% — penning 3-4/5 but
  rarely all 5
- Webots LiDAR 140°: 0/5 (same wall as DAgger and v1 policies)

Conclusion: architectural changes (LSTM vs MLP) don't close the
perception sim-to-real gap. The gym LiDAR sim doesn't faithfully
reproduce Webots phantom-track distribution; any policy trained on the
gym proxy fails to handle real Webots phantoms regardless of
architecture. Closing this gap requires either modeling Webots phantom
patterns in the gym sim (multi-day work) or Webots-in-the-loop
training (very slow). See memory/lstm_results.md for details.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Johnny Fernandes
2026-05-16 19:22:32 +00:00
parent dd5ac669e5
commit 876e14e74f
4 changed files with 248 additions and 10 deletions
+35 -4
View File
@@ -59,16 +59,36 @@ def make_strombom_predictor(drive_mode: str = "differential"):
return make_analytic_predictor(strombom_action, drive_mode)
def make_policy_predictor(model, vecnorm):
def make_policy_predictor(model, vecnorm, recurrent: bool = False):
state = {"lstm": None, "first": True}
def _predict(_env, obs):
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
if vecnorm is not None:
obs_b = vecnorm.normalize_obs(obs_b)
action, _ = model.predict(obs_b, deterministic=True)
if recurrent:
episode_start = np.array([state["first"]], dtype=bool)
action, new_state = model.predict(
obs_b, state=state["lstm"], episode_start=episode_start,
deterministic=True,
)
state["lstm"] = new_state
state["first"] = False
else:
action, _ = model.predict(obs_b, deterministic=True)
return action[0]
return _predict
def _reset_recurrent(predict_fn):
"""Reset the recurrent state between episodes."""
# The closure stores `state` dict; reach in via __closure__.
for cell in predict_fn.__closure__ or []:
if isinstance(cell.cell_contents, dict) and "lstm" in cell.cell_contents:
cell.cell_contents["lstm"] = None
cell.cell_contents["first"] = True
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--policy", required=True,
@@ -110,7 +130,17 @@ def main():
f"No checkpoint found in {run} "
f"(tried policy.zip, final.zip)"
)
model = PPO.load(str(zip_path), device="auto")
# Try RecurrentPPO first (sb3-contrib) for LSTM policies, then
# fall back to PPO for MLP policies.
recurrent = False
model = None
try:
from sb3_contrib import RecurrentPPO
model = RecurrentPPO.load(str(zip_path), device="auto")
recurrent = True
print(f"[eval] loaded RecurrentPPO (LSTM) policy")
except Exception:
model = PPO.load(str(zip_path), device="auto")
from herding.perception.obs import OBS_DIM as _SINGLE
policy_obs_dim = int(model.observation_space.shape[0])
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
@@ -127,7 +157,7 @@ def main():
vecnorm = pickle.load(f)
vecnorm.training = False
vecnorm.norm_reward = False
predict = make_policy_predictor(model, vecnorm)
predict = make_policy_predictor(model, vecnorm, recurrent=recurrent)
# Infer drive_mode from policy action dim if using a learned policy.
if args.policy not in ("strombom", "sequential"):
@@ -149,6 +179,7 @@ def main():
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
difficulty=args.difficulty, seed=seed,
frame_stack=frame_stack, drive_mode=drive_mode)
_reset_recurrent(predict)
r = rollout(env, predict, args.max_steps)
successes.append(int(r["success"]))
steps.append(r["steps"])