LSTM (RecurrentPPO) experiment + recurrent policy support
Adds RecurrentPPO-based training as an alternative to MLP+frame-stack. The LSTM gives the policy unbounded temporal memory, addressing the partial-obs failure mode of the 140° Webots LiDAR (tracker briefly empties when the dog turns; sporadic phantom tracks confuse decisions). * training/rl/train_lstm.py: from-scratch RecurrentPPO trainer (no BC init, no KL term since there's no reference). Uses HERDING_WEBOTS preset so the obs distribution matches deployment. * training/eval.py: auto-detects RecurrentPPO zips, maintains LSTM hidden state across steps, resets between episodes. * controllers/shepherd_dog/policy_loader.py: PolicyHandle supports recurrent policies — state managed inside, reset_recurrent() exposed. Result on diff/field after 3M steps: - Gym (default 360°): 69% avg success across n=1..10 - Gym (HERDING_WEBOTS preset, training env): 2% — penning 3-4/5 but rarely all 5 - Webots LiDAR 140°: 0/5 (same wall as DAgger and v1 policies) Conclusion: architectural changes (LSTM vs MLP) don't close the perception sim-to-real gap. The gym LiDAR sim doesn't faithfully reproduce Webots phantom-track distribution; any policy trained on the gym proxy fails to handle real Webots phantoms regardless of architecture. Closing this gap requires either modeling Webots phantom patterns in the gym sim (multi-day work) or Webots-in-the-loop training (very slow). See memory/lstm_results.md for details. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
+35
-4
@@ -59,16 +59,36 @@ def make_strombom_predictor(drive_mode: str = "differential"):
|
||||
return make_analytic_predictor(strombom_action, drive_mode)
|
||||
|
||||
|
||||
def make_policy_predictor(model, vecnorm):
|
||||
def make_policy_predictor(model, vecnorm, recurrent: bool = False):
|
||||
state = {"lstm": None, "first": True}
|
||||
def _predict(_env, obs):
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
if vecnorm is not None:
|
||||
obs_b = vecnorm.normalize_obs(obs_b)
|
||||
action, _ = model.predict(obs_b, deterministic=True)
|
||||
if recurrent:
|
||||
episode_start = np.array([state["first"]], dtype=bool)
|
||||
action, new_state = model.predict(
|
||||
obs_b, state=state["lstm"], episode_start=episode_start,
|
||||
deterministic=True,
|
||||
)
|
||||
state["lstm"] = new_state
|
||||
state["first"] = False
|
||||
else:
|
||||
action, _ = model.predict(obs_b, deterministic=True)
|
||||
return action[0]
|
||||
return _predict
|
||||
|
||||
|
||||
def _reset_recurrent(predict_fn):
|
||||
"""Reset the recurrent state between episodes."""
|
||||
# The closure stores `state` dict; reach in via __closure__.
|
||||
for cell in predict_fn.__closure__ or []:
|
||||
if isinstance(cell.cell_contents, dict) and "lstm" in cell.cell_contents:
|
||||
cell.cell_contents["lstm"] = None
|
||||
cell.cell_contents["first"] = True
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--policy", required=True,
|
||||
@@ -110,7 +130,17 @@ def main():
|
||||
f"No checkpoint found in {run} "
|
||||
f"(tried policy.zip, final.zip)"
|
||||
)
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
# Try RecurrentPPO first (sb3-contrib) for LSTM policies, then
|
||||
# fall back to PPO for MLP policies.
|
||||
recurrent = False
|
||||
model = None
|
||||
try:
|
||||
from sb3_contrib import RecurrentPPO
|
||||
model = RecurrentPPO.load(str(zip_path), device="auto")
|
||||
recurrent = True
|
||||
print(f"[eval] loaded RecurrentPPO (LSTM) policy")
|
||||
except Exception:
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
from herding.perception.obs import OBS_DIM as _SINGLE
|
||||
policy_obs_dim = int(model.observation_space.shape[0])
|
||||
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
|
||||
@@ -127,7 +157,7 @@ def main():
|
||||
vecnorm = pickle.load(f)
|
||||
vecnorm.training = False
|
||||
vecnorm.norm_reward = False
|
||||
predict = make_policy_predictor(model, vecnorm)
|
||||
predict = make_policy_predictor(model, vecnorm, recurrent=recurrent)
|
||||
|
||||
# Infer drive_mode from policy action dim if using a learned policy.
|
||||
if args.policy not in ("strombom", "sequential"):
|
||||
@@ -149,6 +179,7 @@ def main():
|
||||
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
|
||||
difficulty=args.difficulty, seed=seed,
|
||||
frame_stack=frame_stack, drive_mode=drive_mode)
|
||||
_reset_recurrent(predict)
|
||||
r = rollout(env, predict, args.max_steps)
|
||||
successes.append(int(r["success"]))
|
||||
steps.append(r["steps"])
|
||||
|
||||
Reference in New Issue
Block a user