Files
TIR_PROJ/training/rl/train_lstm.py
T
Johnny Fernandes 876e14e74f LSTM (RecurrentPPO) experiment + recurrent policy support
Adds RecurrentPPO-based training as an alternative to MLP+frame-stack.
The LSTM gives the policy unbounded temporal memory, addressing the
partial-obs failure mode of the 140° Webots LiDAR (tracker briefly
empties when the dog turns; sporadic phantom tracks confuse decisions).

* training/rl/train_lstm.py: from-scratch RecurrentPPO trainer (no BC
  init, no KL term since there's no reference). Uses HERDING_WEBOTS
  preset so the obs distribution matches deployment.
* training/eval.py: auto-detects RecurrentPPO zips, maintains LSTM
  hidden state across steps, resets between episodes.
* controllers/shepherd_dog/policy_loader.py: PolicyHandle supports
  recurrent policies — state managed inside, reset_recurrent() exposed.

Result on diff/field after 3M steps:
- Gym (default 360°): 69% avg success across n=1..10
- Gym (HERDING_WEBOTS preset, training env): 2% — penning 3-4/5 but
  rarely all 5
- Webots LiDAR 140°: 0/5 (same wall as DAgger and v1 policies)

Conclusion: architectural changes (LSTM vs MLP) don't close the
perception sim-to-real gap. The gym LiDAR sim doesn't faithfully
reproduce Webots phantom-track distribution; any policy trained on the
gym proxy fails to handle real Webots phantoms regardless of
architecture. Closing this gap requires either modeling Webots phantom
patterns in the gym sim (multi-day work) or Webots-in-the-loop
training (very slow). See memory/lstm_results.md for details.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-16 19:22:32 +00:00

175 lines
6.5 KiB
Python

"""Recurrent-PPO (LSTM) policy trainer for the herding env.
Motivation
----------
The MLP+frame-stack policy struggles with partial observability under
the 140° Webots LiDAR: the tracker briefly empties when the dog turns,
and sporadic FP tracks at static features confuse the policy. An LSTM
gives the policy unbounded temporal memory so it can:
* keep modelling sheep positions when the tracker briefly drops them,
* distinguish persistent (real) tracks from intermittent (phantom) ones.
This is the literature-correct fix for partial-observability + noisy
perception. Trains from scratch (no BC init) using vanilla PPO without
the KL-to-reference term (no reference exists when starting clean).
Usage
-----
python -m training.rl.train_lstm \\
--out training/runs/lstm_differential_field \\
--drive-mode differential --world field \\
--total-timesteps 3000000 \\
--use-webots-preset --fp-rate 0.0 --action-smooth 0.55
Frame stack is forced to 1 since the LSTM provides its own memory.
"""
from __future__ import annotations
import argparse
import os
import time
from pathlib import Path
import numpy as np
# Configure field geometry before other herding imports read it at module level.
from herding.world.geometry import configure_from_args as _configure_from_args
_configure_from_args()
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from herding.world.geometry import MAX_SHEEP
from training.herding_env import HerdingEnv
def _make_env(rank: int, seed: int, drive_mode: str, difficulty: float,
max_n_sheep: int, herding_cfg):
def _init():
env = HerdingEnv(
max_n_sheep=max_n_sheep, difficulty=difficulty,
seed=seed + rank, frame_stack=1, drive_mode=drive_mode,
herding_cfg=herding_cfg,
)
return env
return _init
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--out", required=True,
help="Output directory for the LSTM policy.")
parser.add_argument("--total-timesteps", type=int, default=3_000_000)
parser.add_argument("--n-envs", type=int, default=8)
parser.add_argument("--n-steps", type=int, default=256)
parser.add_argument("--lstm-hidden", type=int, default=128)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max-n-sheep", type=int, default=MAX_SHEEP)
parser.add_argument("--difficulty", type=float, default=1.0)
parser.add_argument("--drive-mode", default="differential",
choices=["differential", "mecanum"])
parser.add_argument("--world", default=None,
choices=["field", "field_round"])
parser.add_argument("--fp-rate", type=float, default=0.0)
parser.add_argument("--action-smooth", type=float, default=0.55)
parser.add_argument("--wheel-slip-std", type=float, default=0.05)
parser.add_argument("--use-webots-preset", action="store_true",
help="Train in the HERDING_WEBOTS env (140° FOV + tight tracker).")
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
from herding.config import HerdingConfig, HERDING_WEBOTS, DomainRandomConfig, RobotConfig
if args.use_webots_preset:
herding_cfg = HERDING_WEBOTS.replace(
domain_random=DomainRandomConfig(
fp_rate=args.fp_rate,
wheel_slip_std=args.wheel_slip_std,
),
robot=RobotConfig(action_smooth=args.action_smooth),
)
print(f"[lstm] HERDING_WEBOTS preset + DR: fp_rate={args.fp_rate}")
else:
herding_cfg = None
if args.fp_rate > 0.0 or args.action_smooth > 0.0 or args.wheel_slip_std > 0.0:
herding_cfg = HerdingConfig(
domain_random=DomainRandomConfig(
fp_rate=args.fp_rate,
wheel_slip_std=args.wheel_slip_std,
),
robot=RobotConfig(action_smooth=args.action_smooth),
)
env_fns = [_make_env(i, args.seed, args.drive_mode, args.difficulty,
args.max_n_sheep, herding_cfg)
for i in range(args.n_envs)]
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, args.drive_mode,
args.difficulty, args.max_n_sheep,
herding_cfg)])
out = Path(args.out)
out.mkdir(parents=True, exist_ok=True)
print(f"[lstm] drive_mode={args.drive_mode} world={os.environ.get('HERDING_WORLD', 'field')}")
print(f"[lstm] total_timesteps={args.total_timesteps} n_envs={args.n_envs} "
f"lr={args.lr} lstm_hidden={args.lstm_hidden}")
model = RecurrentPPO(
"MlpLstmPolicy", venv,
learning_rate=args.lr,
n_steps=args.n_steps,
batch_size=args.n_steps, # full rollout = one batch (matches LSTM episode boundaries)
n_epochs=4,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.0,
max_grad_norm=0.5,
policy_kwargs=dict(
net_arch=dict(pi=[256, 256], vf=[256, 256]),
lstm_hidden_size=args.lstm_hidden,
n_lstm_layers=1,
shared_lstm=False,
enable_critic_lstm=True,
),
device=args.device,
verbose=1,
seed=args.seed,
tensorboard_log=str(out / "tb"),
)
eval_cb = EvalCallback(
eval_venv,
best_model_save_path=str(out / "best"),
log_path=str(out / "evals"),
eval_freq=max(args.n_steps * args.n_envs, 20_000) // args.n_envs,
n_eval_episodes=5,
deterministic=True,
render=False,
)
t0 = time.time()
model.learn(total_timesteps=args.total_timesteps, callback=eval_cb,
progress_bar=True)
print(f"[lstm] training done in {time.time() - t0:.0f}s")
# Save best (by eval) if it exists; otherwise save final.
best = out / "best" / "best_model.zip"
if best.exists():
import shutil
shutil.copy(best, out / "policy.zip")
print(f"[lstm] best snapshot → {out / 'policy.zip'}")
else:
model.save(str(out / "policy.zip"))
print(f"[lstm] no eval beat init; final snapshot → {out / 'policy.zip'}")
model.save(str(out / "final.zip"))
if __name__ == "__main__":
main()