LSTM (RecurrentPPO) experiment + recurrent policy support

Adds RecurrentPPO-based training as an alternative to MLP+frame-stack.
The LSTM gives the policy unbounded temporal memory, addressing the
partial-obs failure mode of the 140° Webots LiDAR (tracker briefly
empties when the dog turns; sporadic phantom tracks confuse decisions).

* training/rl/train_lstm.py: from-scratch RecurrentPPO trainer (no BC
  init, no KL term since there's no reference). Uses HERDING_WEBOTS
  preset so the obs distribution matches deployment.
* training/eval.py: auto-detects RecurrentPPO zips, maintains LSTM
  hidden state across steps, resets between episodes.
* controllers/shepherd_dog/policy_loader.py: PolicyHandle supports
  recurrent policies — state managed inside, reset_recurrent() exposed.

Result on diff/field after 3M steps:
- Gym (default 360°): 69% avg success across n=1..10
- Gym (HERDING_WEBOTS preset, training env): 2% — penning 3-4/5 but
  rarely all 5
- Webots LiDAR 140°: 0/5 (same wall as DAgger and v1 policies)

Conclusion: architectural changes (LSTM vs MLP) don't close the
perception sim-to-real gap. The gym LiDAR sim doesn't faithfully
reproduce Webots phantom-track distribution; any policy trained on the
gym proxy fails to handle real Webots phantoms regardless of
architecture. Closing this gap requires either modeling Webots phantom
patterns in the gym sim (multi-day work) or Webots-in-the-loop
training (very slow). See memory/lstm_results.md for details.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Johnny Fernandes
2026-05-16 19:22:32 +00:00
parent dd5ac669e5
commit 876e14e74f
4 changed files with 248 additions and 10 deletions
+35 -4
View File
@@ -59,16 +59,36 @@ def make_strombom_predictor(drive_mode: str = "differential"):
return make_analytic_predictor(strombom_action, drive_mode)
def make_policy_predictor(model, vecnorm):
def make_policy_predictor(model, vecnorm, recurrent: bool = False):
state = {"lstm": None, "first": True}
def _predict(_env, obs):
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
if vecnorm is not None:
obs_b = vecnorm.normalize_obs(obs_b)
action, _ = model.predict(obs_b, deterministic=True)
if recurrent:
episode_start = np.array([state["first"]], dtype=bool)
action, new_state = model.predict(
obs_b, state=state["lstm"], episode_start=episode_start,
deterministic=True,
)
state["lstm"] = new_state
state["first"] = False
else:
action, _ = model.predict(obs_b, deterministic=True)
return action[0]
return _predict
def _reset_recurrent(predict_fn):
"""Reset the recurrent state between episodes."""
# The closure stores `state` dict; reach in via __closure__.
for cell in predict_fn.__closure__ or []:
if isinstance(cell.cell_contents, dict) and "lstm" in cell.cell_contents:
cell.cell_contents["lstm"] = None
cell.cell_contents["first"] = True
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--policy", required=True,
@@ -110,7 +130,17 @@ def main():
f"No checkpoint found in {run} "
f"(tried policy.zip, final.zip)"
)
model = PPO.load(str(zip_path), device="auto")
# Try RecurrentPPO first (sb3-contrib) for LSTM policies, then
# fall back to PPO for MLP policies.
recurrent = False
model = None
try:
from sb3_contrib import RecurrentPPO
model = RecurrentPPO.load(str(zip_path), device="auto")
recurrent = True
print(f"[eval] loaded RecurrentPPO (LSTM) policy")
except Exception:
model = PPO.load(str(zip_path), device="auto")
from herding.perception.obs import OBS_DIM as _SINGLE
policy_obs_dim = int(model.observation_space.shape[0])
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
@@ -127,7 +157,7 @@ def main():
vecnorm = pickle.load(f)
vecnorm.training = False
vecnorm.norm_reward = False
predict = make_policy_predictor(model, vecnorm)
predict = make_policy_predictor(model, vecnorm, recurrent=recurrent)
# Infer drive_mode from policy action dim if using a learned policy.
if args.policy not in ("strombom", "sequential"):
@@ -149,6 +179,7 @@ def main():
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
difficulty=args.difficulty, seed=seed,
frame_stack=frame_stack, drive_mode=drive_mode)
_reset_recurrent(predict)
r = rollout(env, predict, args.max_steps)
successes.append(int(r["success"]))
steps.append(r["steps"])
+174
View File
@@ -0,0 +1,174 @@
"""Recurrent-PPO (LSTM) policy trainer for the herding env.
Motivation
----------
The MLP+frame-stack policy struggles with partial observability under
the 140° Webots LiDAR: the tracker briefly empties when the dog turns,
and sporadic FP tracks at static features confuse the policy. An LSTM
gives the policy unbounded temporal memory so it can:
* keep modelling sheep positions when the tracker briefly drops them,
* distinguish persistent (real) tracks from intermittent (phantom) ones.
This is the literature-correct fix for partial-observability + noisy
perception. Trains from scratch (no BC init) using vanilla PPO without
the KL-to-reference term (no reference exists when starting clean).
Usage
-----
python -m training.rl.train_lstm \\
--out training/runs/lstm_differential_field \\
--drive-mode differential --world field \\
--total-timesteps 3000000 \\
--use-webots-preset --fp-rate 0.0 --action-smooth 0.55
Frame stack is forced to 1 since the LSTM provides its own memory.
"""
from __future__ import annotations
import argparse
import os
import time
from pathlib import Path
import numpy as np
# Configure field geometry before other herding imports read it at module level.
from herding.world.geometry import configure_from_args as _configure_from_args
_configure_from_args()
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from herding.world.geometry import MAX_SHEEP
from training.herding_env import HerdingEnv
def _make_env(rank: int, seed: int, drive_mode: str, difficulty: float,
max_n_sheep: int, herding_cfg):
def _init():
env = HerdingEnv(
max_n_sheep=max_n_sheep, difficulty=difficulty,
seed=seed + rank, frame_stack=1, drive_mode=drive_mode,
herding_cfg=herding_cfg,
)
return env
return _init
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--out", required=True,
help="Output directory for the LSTM policy.")
parser.add_argument("--total-timesteps", type=int, default=3_000_000)
parser.add_argument("--n-envs", type=int, default=8)
parser.add_argument("--n-steps", type=int, default=256)
parser.add_argument("--lstm-hidden", type=int, default=128)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max-n-sheep", type=int, default=MAX_SHEEP)
parser.add_argument("--difficulty", type=float, default=1.0)
parser.add_argument("--drive-mode", default="differential",
choices=["differential", "mecanum"])
parser.add_argument("--world", default=None,
choices=["field", "field_round"])
parser.add_argument("--fp-rate", type=float, default=0.0)
parser.add_argument("--action-smooth", type=float, default=0.55)
parser.add_argument("--wheel-slip-std", type=float, default=0.05)
parser.add_argument("--use-webots-preset", action="store_true",
help="Train in the HERDING_WEBOTS env (140° FOV + tight tracker).")
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
from herding.config import HerdingConfig, HERDING_WEBOTS, DomainRandomConfig, RobotConfig
if args.use_webots_preset:
herding_cfg = HERDING_WEBOTS.replace(
domain_random=DomainRandomConfig(
fp_rate=args.fp_rate,
wheel_slip_std=args.wheel_slip_std,
),
robot=RobotConfig(action_smooth=args.action_smooth),
)
print(f"[lstm] HERDING_WEBOTS preset + DR: fp_rate={args.fp_rate}")
else:
herding_cfg = None
if args.fp_rate > 0.0 or args.action_smooth > 0.0 or args.wheel_slip_std > 0.0:
herding_cfg = HerdingConfig(
domain_random=DomainRandomConfig(
fp_rate=args.fp_rate,
wheel_slip_std=args.wheel_slip_std,
),
robot=RobotConfig(action_smooth=args.action_smooth),
)
env_fns = [_make_env(i, args.seed, args.drive_mode, args.difficulty,
args.max_n_sheep, herding_cfg)
for i in range(args.n_envs)]
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, args.drive_mode,
args.difficulty, args.max_n_sheep,
herding_cfg)])
out = Path(args.out)
out.mkdir(parents=True, exist_ok=True)
print(f"[lstm] drive_mode={args.drive_mode} world={os.environ.get('HERDING_WORLD', 'field')}")
print(f"[lstm] total_timesteps={args.total_timesteps} n_envs={args.n_envs} "
f"lr={args.lr} lstm_hidden={args.lstm_hidden}")
model = RecurrentPPO(
"MlpLstmPolicy", venv,
learning_rate=args.lr,
n_steps=args.n_steps,
batch_size=args.n_steps, # full rollout = one batch (matches LSTM episode boundaries)
n_epochs=4,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.0,
max_grad_norm=0.5,
policy_kwargs=dict(
net_arch=dict(pi=[256, 256], vf=[256, 256]),
lstm_hidden_size=args.lstm_hidden,
n_lstm_layers=1,
shared_lstm=False,
enable_critic_lstm=True,
),
device=args.device,
verbose=1,
seed=args.seed,
tensorboard_log=str(out / "tb"),
)
eval_cb = EvalCallback(
eval_venv,
best_model_save_path=str(out / "best"),
log_path=str(out / "evals"),
eval_freq=max(args.n_steps * args.n_envs, 20_000) // args.n_envs,
n_eval_episodes=5,
deterministic=True,
render=False,
)
t0 = time.time()
model.learn(total_timesteps=args.total_timesteps, callback=eval_cb,
progress_bar=True)
print(f"[lstm] training done in {time.time() - t0:.0f}s")
# Save best (by eval) if it exists; otherwise save final.
best = out / "best" / "best_model.zip"
if best.exists():
import shutil
shutil.copy(best, out / "policy.zip")
print(f"[lstm] best snapshot → {out / 'policy.zip'}")
else:
model.save(str(out / "policy.zip"))
print(f"[lstm] no eval beat init; final snapshot → {out / 'policy.zip'}")
model.save(str(out / "final.zip"))
if __name__ == "__main__":
main()
Binary file not shown.