LSTM (RecurrentPPO) experiment + recurrent policy support
Adds RecurrentPPO-based training as an alternative to MLP+frame-stack. The LSTM gives the policy unbounded temporal memory, addressing the partial-obs failure mode of the 140° Webots LiDAR (tracker briefly empties when the dog turns; sporadic phantom tracks confuse decisions). * training/rl/train_lstm.py: from-scratch RecurrentPPO trainer (no BC init, no KL term since there's no reference). Uses HERDING_WEBOTS preset so the obs distribution matches deployment. * training/eval.py: auto-detects RecurrentPPO zips, maintains LSTM hidden state across steps, resets between episodes. * controllers/shepherd_dog/policy_loader.py: PolicyHandle supports recurrent policies — state managed inside, reset_recurrent() exposed. Result on diff/field after 3M steps: - Gym (default 360°): 69% avg success across n=1..10 - Gym (HERDING_WEBOTS preset, training env): 2% — penning 3-4/5 but rarely all 5 - Webots LiDAR 140°: 0/5 (same wall as DAgger and v1 policies) Conclusion: architectural changes (LSTM vs MLP) don't close the perception sim-to-real gap. The gym LiDAR sim doesn't faithfully reproduce Webots phantom-track distribution; any policy trained on the gym proxy fails to handle real Webots phantoms regardless of architecture. Closing this gap requires either modeling Webots phantom patterns in the gym sim (multi-day work) or Webots-in-the-loop training (very slow). See memory/lstm_results.md for details. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -15,19 +15,35 @@ from pathlib import Path
|
|||||||
|
|
||||||
|
|
||||||
class PolicyHandle:
|
class PolicyHandle:
|
||||||
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``."""
|
"""Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``.
|
||||||
|
|
||||||
def __init__(self, model, vecnorm):
|
Supports both MLP (PPO) and recurrent (RecurrentPPO/LSTM) policies.
|
||||||
|
For LSTM policies, frame_stack is forced to 1 and the LSTM hidden
|
||||||
|
state is maintained across calls; ``reset_recurrent`` is exposed for
|
||||||
|
new episodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, vecnorm, recurrent: bool = False):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.vecnorm = vecnorm
|
self.vecnorm = vecnorm
|
||||||
|
self.recurrent = recurrent
|
||||||
from herding.perception.obs import OBS_DIM
|
from herding.perception.obs import OBS_DIM
|
||||||
policy_dim = int(model.observation_space.shape[0])
|
policy_dim = int(model.observation_space.shape[0])
|
||||||
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
if recurrent:
|
||||||
|
self.frame_stack = 1
|
||||||
|
elif policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
||||||
self.frame_stack = policy_dim // OBS_DIM
|
self.frame_stack = policy_dim // OBS_DIM
|
||||||
else:
|
else:
|
||||||
self.frame_stack = 1
|
self.frame_stack = 1
|
||||||
self._buffer: list = []
|
self._buffer: list = []
|
||||||
self._single_dim = OBS_DIM
|
self._single_dim = OBS_DIM
|
||||||
|
self._lstm_state = None
|
||||||
|
self._first_step = True
|
||||||
|
|
||||||
|
def reset_recurrent(self):
|
||||||
|
self._lstm_state = None
|
||||||
|
self._first_step = True
|
||||||
|
self._buffer = []
|
||||||
|
|
||||||
def predict(self, obs):
|
def predict(self, obs):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -49,7 +65,15 @@ class PolicyHandle:
|
|||||||
obs_b = stacked.reshape(1, -1)
|
obs_b = stacked.reshape(1, -1)
|
||||||
if self.vecnorm is not None:
|
if self.vecnorm is not None:
|
||||||
obs_b = self.vecnorm.normalize_obs(obs_b)
|
obs_b = self.vecnorm.normalize_obs(obs_b)
|
||||||
action, _ = self.model.predict(obs_b, deterministic=True)
|
if self.recurrent:
|
||||||
|
episode_start = np.array([self._first_step], dtype=bool)
|
||||||
|
action, self._lstm_state = self.model.predict(
|
||||||
|
obs_b, state=self._lstm_state,
|
||||||
|
episode_start=episode_start, deterministic=True,
|
||||||
|
)
|
||||||
|
self._first_step = False
|
||||||
|
else:
|
||||||
|
action, _ = self.model.predict(obs_b, deterministic=True)
|
||||||
return action[0]
|
return action[0]
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +103,16 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
|||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
|
from stable_baselines3.common.vec_env import VecNormalize # noqa: F401
|
||||||
|
|
||||||
model = PPO.load(str(zip_path), device="auto")
|
# Try RecurrentPPO (LSTM) first, fall back to PPO (MLP).
|
||||||
|
recurrent = False
|
||||||
|
model = None
|
||||||
|
try:
|
||||||
|
from sb3_contrib import RecurrentPPO
|
||||||
|
model = RecurrentPPO.load(str(zip_path), device="auto")
|
||||||
|
recurrent = True
|
||||||
|
except Exception:
|
||||||
|
model = PPO.load(str(zip_path), device="auto")
|
||||||
|
|
||||||
vecnorm = None
|
vecnorm = None
|
||||||
if vecnorm_path and os.path.exists(vecnorm_path):
|
if vecnorm_path and os.path.exists(vecnorm_path):
|
||||||
import pickle
|
import pickle
|
||||||
@@ -87,4 +120,4 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle:
|
|||||||
vecnorm = pickle.load(f)
|
vecnorm = pickle.load(f)
|
||||||
vecnorm.training = False
|
vecnorm.training = False
|
||||||
vecnorm.norm_reward = False
|
vecnorm.norm_reward = False
|
||||||
return PolicyHandle(model=model, vecnorm=vecnorm)
|
return PolicyHandle(model=model, vecnorm=vecnorm, recurrent=recurrent)
|
||||||
|
|||||||
+35
-4
@@ -59,16 +59,36 @@ def make_strombom_predictor(drive_mode: str = "differential"):
|
|||||||
return make_analytic_predictor(strombom_action, drive_mode)
|
return make_analytic_predictor(strombom_action, drive_mode)
|
||||||
|
|
||||||
|
|
||||||
def make_policy_predictor(model, vecnorm):
|
def make_policy_predictor(model, vecnorm, recurrent: bool = False):
|
||||||
|
state = {"lstm": None, "first": True}
|
||||||
def _predict(_env, obs):
|
def _predict(_env, obs):
|
||||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||||
if vecnorm is not None:
|
if vecnorm is not None:
|
||||||
obs_b = vecnorm.normalize_obs(obs_b)
|
obs_b = vecnorm.normalize_obs(obs_b)
|
||||||
action, _ = model.predict(obs_b, deterministic=True)
|
if recurrent:
|
||||||
|
episode_start = np.array([state["first"]], dtype=bool)
|
||||||
|
action, new_state = model.predict(
|
||||||
|
obs_b, state=state["lstm"], episode_start=episode_start,
|
||||||
|
deterministic=True,
|
||||||
|
)
|
||||||
|
state["lstm"] = new_state
|
||||||
|
state["first"] = False
|
||||||
|
else:
|
||||||
|
action, _ = model.predict(obs_b, deterministic=True)
|
||||||
return action[0]
|
return action[0]
|
||||||
return _predict
|
return _predict
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_recurrent(predict_fn):
|
||||||
|
"""Reset the recurrent state between episodes."""
|
||||||
|
# The closure stores `state` dict; reach in via __closure__.
|
||||||
|
for cell in predict_fn.__closure__ or []:
|
||||||
|
if isinstance(cell.cell_contents, dict) and "lstm" in cell.cell_contents:
|
||||||
|
cell.cell_contents["lstm"] = None
|
||||||
|
cell.cell_contents["first"] = True
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--policy", required=True,
|
parser.add_argument("--policy", required=True,
|
||||||
@@ -110,7 +130,17 @@ def main():
|
|||||||
f"No checkpoint found in {run} "
|
f"No checkpoint found in {run} "
|
||||||
f"(tried policy.zip, final.zip)"
|
f"(tried policy.zip, final.zip)"
|
||||||
)
|
)
|
||||||
model = PPO.load(str(zip_path), device="auto")
|
# Try RecurrentPPO first (sb3-contrib) for LSTM policies, then
|
||||||
|
# fall back to PPO for MLP policies.
|
||||||
|
recurrent = False
|
||||||
|
model = None
|
||||||
|
try:
|
||||||
|
from sb3_contrib import RecurrentPPO
|
||||||
|
model = RecurrentPPO.load(str(zip_path), device="auto")
|
||||||
|
recurrent = True
|
||||||
|
print(f"[eval] loaded RecurrentPPO (LSTM) policy")
|
||||||
|
except Exception:
|
||||||
|
model = PPO.load(str(zip_path), device="auto")
|
||||||
from herding.perception.obs import OBS_DIM as _SINGLE
|
from herding.perception.obs import OBS_DIM as _SINGLE
|
||||||
policy_obs_dim = int(model.observation_space.shape[0])
|
policy_obs_dim = int(model.observation_space.shape[0])
|
||||||
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
|
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
|
||||||
@@ -127,7 +157,7 @@ def main():
|
|||||||
vecnorm = pickle.load(f)
|
vecnorm = pickle.load(f)
|
||||||
vecnorm.training = False
|
vecnorm.training = False
|
||||||
vecnorm.norm_reward = False
|
vecnorm.norm_reward = False
|
||||||
predict = make_policy_predictor(model, vecnorm)
|
predict = make_policy_predictor(model, vecnorm, recurrent=recurrent)
|
||||||
|
|
||||||
# Infer drive_mode from policy action dim if using a learned policy.
|
# Infer drive_mode from policy action dim if using a learned policy.
|
||||||
if args.policy not in ("strombom", "sequential"):
|
if args.policy not in ("strombom", "sequential"):
|
||||||
@@ -149,6 +179,7 @@ def main():
|
|||||||
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
|
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
|
||||||
difficulty=args.difficulty, seed=seed,
|
difficulty=args.difficulty, seed=seed,
|
||||||
frame_stack=frame_stack, drive_mode=drive_mode)
|
frame_stack=frame_stack, drive_mode=drive_mode)
|
||||||
|
_reset_recurrent(predict)
|
||||||
r = rollout(env, predict, args.max_steps)
|
r = rollout(env, predict, args.max_steps)
|
||||||
successes.append(int(r["success"]))
|
successes.append(int(r["success"]))
|
||||||
steps.append(r["steps"])
|
steps.append(r["steps"])
|
||||||
|
|||||||
@@ -0,0 +1,174 @@
|
|||||||
|
"""Recurrent-PPO (LSTM) policy trainer for the herding env.
|
||||||
|
|
||||||
|
Motivation
|
||||||
|
----------
|
||||||
|
The MLP+frame-stack policy struggles with partial observability under
|
||||||
|
the 140° Webots LiDAR: the tracker briefly empties when the dog turns,
|
||||||
|
and sporadic FP tracks at static features confuse the policy. An LSTM
|
||||||
|
gives the policy unbounded temporal memory so it can:
|
||||||
|
|
||||||
|
* keep modelling sheep positions when the tracker briefly drops them,
|
||||||
|
* distinguish persistent (real) tracks from intermittent (phantom) ones.
|
||||||
|
|
||||||
|
This is the literature-correct fix for partial-observability + noisy
|
||||||
|
perception. Trains from scratch (no BC init) using vanilla PPO without
|
||||||
|
the KL-to-reference term (no reference exists when starting clean).
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
python -m training.rl.train_lstm \\
|
||||||
|
--out training/runs/lstm_differential_field \\
|
||||||
|
--drive-mode differential --world field \\
|
||||||
|
--total-timesteps 3000000 \\
|
||||||
|
--use-webots-preset --fp-rate 0.0 --action-smooth 0.55
|
||||||
|
|
||||||
|
Frame stack is forced to 1 since the LSTM provides its own memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Configure field geometry before other herding imports read it at module level.
|
||||||
|
from herding.world.geometry import configure_from_args as _configure_from_args
|
||||||
|
_configure_from_args()
|
||||||
|
|
||||||
|
from sb3_contrib import RecurrentPPO
|
||||||
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||||
|
|
||||||
|
from herding.world.geometry import MAX_SHEEP
|
||||||
|
from training.herding_env import HerdingEnv
|
||||||
|
|
||||||
|
|
||||||
|
def _make_env(rank: int, seed: int, drive_mode: str, difficulty: float,
|
||||||
|
max_n_sheep: int, herding_cfg):
|
||||||
|
def _init():
|
||||||
|
env = HerdingEnv(
|
||||||
|
max_n_sheep=max_n_sheep, difficulty=difficulty,
|
||||||
|
seed=seed + rank, frame_stack=1, drive_mode=drive_mode,
|
||||||
|
herding_cfg=herding_cfg,
|
||||||
|
)
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--out", required=True,
|
||||||
|
help="Output directory for the LSTM policy.")
|
||||||
|
parser.add_argument("--total-timesteps", type=int, default=3_000_000)
|
||||||
|
parser.add_argument("--n-envs", type=int, default=8)
|
||||||
|
parser.add_argument("--n-steps", type=int, default=256)
|
||||||
|
parser.add_argument("--lstm-hidden", type=int, default=128)
|
||||||
|
parser.add_argument("--lr", type=float, default=3e-4)
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--max-n-sheep", type=int, default=MAX_SHEEP)
|
||||||
|
parser.add_argument("--difficulty", type=float, default=1.0)
|
||||||
|
parser.add_argument("--drive-mode", default="differential",
|
||||||
|
choices=["differential", "mecanum"])
|
||||||
|
parser.add_argument("--world", default=None,
|
||||||
|
choices=["field", "field_round"])
|
||||||
|
parser.add_argument("--fp-rate", type=float, default=0.0)
|
||||||
|
parser.add_argument("--action-smooth", type=float, default=0.55)
|
||||||
|
parser.add_argument("--wheel-slip-std", type=float, default=0.05)
|
||||||
|
parser.add_argument("--use-webots-preset", action="store_true",
|
||||||
|
help="Train in the HERDING_WEBOTS env (140° FOV + tight tracker).")
|
||||||
|
parser.add_argument("--device", default="cpu")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
from herding.config import HerdingConfig, HERDING_WEBOTS, DomainRandomConfig, RobotConfig
|
||||||
|
|
||||||
|
if args.use_webots_preset:
|
||||||
|
herding_cfg = HERDING_WEBOTS.replace(
|
||||||
|
domain_random=DomainRandomConfig(
|
||||||
|
fp_rate=args.fp_rate,
|
||||||
|
wheel_slip_std=args.wheel_slip_std,
|
||||||
|
),
|
||||||
|
robot=RobotConfig(action_smooth=args.action_smooth),
|
||||||
|
)
|
||||||
|
print(f"[lstm] HERDING_WEBOTS preset + DR: fp_rate={args.fp_rate}")
|
||||||
|
else:
|
||||||
|
herding_cfg = None
|
||||||
|
if args.fp_rate > 0.0 or args.action_smooth > 0.0 or args.wheel_slip_std > 0.0:
|
||||||
|
herding_cfg = HerdingConfig(
|
||||||
|
domain_random=DomainRandomConfig(
|
||||||
|
fp_rate=args.fp_rate,
|
||||||
|
wheel_slip_std=args.wheel_slip_std,
|
||||||
|
),
|
||||||
|
robot=RobotConfig(action_smooth=args.action_smooth),
|
||||||
|
)
|
||||||
|
|
||||||
|
env_fns = [_make_env(i, args.seed, args.drive_mode, args.difficulty,
|
||||||
|
args.max_n_sheep, herding_cfg)
|
||||||
|
for i in range(args.n_envs)]
|
||||||
|
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
||||||
|
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, args.drive_mode,
|
||||||
|
args.difficulty, args.max_n_sheep,
|
||||||
|
herding_cfg)])
|
||||||
|
|
||||||
|
out = Path(args.out)
|
||||||
|
out.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"[lstm] drive_mode={args.drive_mode} world={os.environ.get('HERDING_WORLD', 'field')}")
|
||||||
|
print(f"[lstm] total_timesteps={args.total_timesteps} n_envs={args.n_envs} "
|
||||||
|
f"lr={args.lr} lstm_hidden={args.lstm_hidden}")
|
||||||
|
|
||||||
|
model = RecurrentPPO(
|
||||||
|
"MlpLstmPolicy", venv,
|
||||||
|
learning_rate=args.lr,
|
||||||
|
n_steps=args.n_steps,
|
||||||
|
batch_size=args.n_steps, # full rollout = one batch (matches LSTM episode boundaries)
|
||||||
|
n_epochs=4,
|
||||||
|
gamma=0.99,
|
||||||
|
gae_lambda=0.95,
|
||||||
|
clip_range=0.2,
|
||||||
|
ent_coef=0.0,
|
||||||
|
max_grad_norm=0.5,
|
||||||
|
policy_kwargs=dict(
|
||||||
|
net_arch=dict(pi=[256, 256], vf=[256, 256]),
|
||||||
|
lstm_hidden_size=args.lstm_hidden,
|
||||||
|
n_lstm_layers=1,
|
||||||
|
shared_lstm=False,
|
||||||
|
enable_critic_lstm=True,
|
||||||
|
),
|
||||||
|
device=args.device,
|
||||||
|
verbose=1,
|
||||||
|
seed=args.seed,
|
||||||
|
tensorboard_log=str(out / "tb"),
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_cb = EvalCallback(
|
||||||
|
eval_venv,
|
||||||
|
best_model_save_path=str(out / "best"),
|
||||||
|
log_path=str(out / "evals"),
|
||||||
|
eval_freq=max(args.n_steps * args.n_envs, 20_000) // args.n_envs,
|
||||||
|
n_eval_episodes=5,
|
||||||
|
deterministic=True,
|
||||||
|
render=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
model.learn(total_timesteps=args.total_timesteps, callback=eval_cb,
|
||||||
|
progress_bar=True)
|
||||||
|
print(f"[lstm] training done in {time.time() - t0:.0f}s")
|
||||||
|
|
||||||
|
# Save best (by eval) if it exists; otherwise save final.
|
||||||
|
best = out / "best" / "best_model.zip"
|
||||||
|
if best.exists():
|
||||||
|
import shutil
|
||||||
|
shutil.copy(best, out / "policy.zip")
|
||||||
|
print(f"[lstm] best snapshot → {out / 'policy.zip'}")
|
||||||
|
else:
|
||||||
|
model.save(str(out / "policy.zip"))
|
||||||
|
print(f"[lstm] no eval beat init; final snapshot → {out / 'policy.zip'}")
|
||||||
|
model.save(str(out / "final.zip"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Binary file not shown.
Reference in New Issue
Block a user