LSTM (RecurrentPPO) experiment + recurrent policy support
Adds RecurrentPPO-based training as an alternative to MLP+frame-stack. The LSTM gives the policy unbounded temporal memory, addressing the partial-obs failure mode of the 140° Webots LiDAR (tracker briefly empties when the dog turns; sporadic phantom tracks confuse decisions). * training/rl/train_lstm.py: from-scratch RecurrentPPO trainer (no BC init, no KL term since there's no reference). Uses HERDING_WEBOTS preset so the obs distribution matches deployment. * training/eval.py: auto-detects RecurrentPPO zips, maintains LSTM hidden state across steps, resets between episodes. * controllers/shepherd_dog/policy_loader.py: PolicyHandle supports recurrent policies — state managed inside, reset_recurrent() exposed. Result on diff/field after 3M steps: - Gym (default 360°): 69% avg success across n=1..10 - Gym (HERDING_WEBOTS preset, training env): 2% — penning 3-4/5 but rarely all 5 - Webots LiDAR 140°: 0/5 (same wall as DAgger and v1 policies) Conclusion: architectural changes (LSTM vs MLP) don't close the perception sim-to-real gap. The gym LiDAR sim doesn't faithfully reproduce Webots phantom-track distribution; any policy trained on the gym proxy fails to handle real Webots phantoms regardless of architecture. Closing this gap requires either modeling Webots phantom patterns in the gym sim (multi-day work) or Webots-in-the-loop training (very slow). See memory/lstm_results.md for details. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,174 @@
|
||||
"""Recurrent-PPO (LSTM) policy trainer for the herding env.
|
||||
|
||||
Motivation
|
||||
----------
|
||||
The MLP+frame-stack policy struggles with partial observability under
|
||||
the 140° Webots LiDAR: the tracker briefly empties when the dog turns,
|
||||
and sporadic FP tracks at static features confuse the policy. An LSTM
|
||||
gives the policy unbounded temporal memory so it can:
|
||||
|
||||
* keep modelling sheep positions when the tracker briefly drops them,
|
||||
* distinguish persistent (real) tracks from intermittent (phantom) ones.
|
||||
|
||||
This is the literature-correct fix for partial-observability + noisy
|
||||
perception. Trains from scratch (no BC init) using vanilla PPO without
|
||||
the KL-to-reference term (no reference exists when starting clean).
|
||||
|
||||
Usage
|
||||
-----
|
||||
python -m training.rl.train_lstm \\
|
||||
--out training/runs/lstm_differential_field \\
|
||||
--drive-mode differential --world field \\
|
||||
--total-timesteps 3000000 \\
|
||||
--use-webots-preset --fp-rate 0.0 --action-smooth 0.55
|
||||
|
||||
Frame stack is forced to 1 since the LSTM provides its own memory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Configure field geometry before other herding imports read it at module level.
|
||||
from herding.world.geometry import configure_from_args as _configure_from_args
|
||||
_configure_from_args()
|
||||
|
||||
from sb3_contrib import RecurrentPPO
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||
|
||||
from herding.world.geometry import MAX_SHEEP
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def _make_env(rank: int, seed: int, drive_mode: str, difficulty: float,
|
||||
max_n_sheep: int, herding_cfg):
|
||||
def _init():
|
||||
env = HerdingEnv(
|
||||
max_n_sheep=max_n_sheep, difficulty=difficulty,
|
||||
seed=seed + rank, frame_stack=1, drive_mode=drive_mode,
|
||||
herding_cfg=herding_cfg,
|
||||
)
|
||||
return env
|
||||
return _init
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--out", required=True,
|
||||
help="Output directory for the LSTM policy.")
|
||||
parser.add_argument("--total-timesteps", type=int, default=3_000_000)
|
||||
parser.add_argument("--n-envs", type=int, default=8)
|
||||
parser.add_argument("--n-steps", type=int, default=256)
|
||||
parser.add_argument("--lstm-hidden", type=int, default=128)
|
||||
parser.add_argument("--lr", type=float, default=3e-4)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--max-n-sheep", type=int, default=MAX_SHEEP)
|
||||
parser.add_argument("--difficulty", type=float, default=1.0)
|
||||
parser.add_argument("--drive-mode", default="differential",
|
||||
choices=["differential", "mecanum"])
|
||||
parser.add_argument("--world", default=None,
|
||||
choices=["field", "field_round"])
|
||||
parser.add_argument("--fp-rate", type=float, default=0.0)
|
||||
parser.add_argument("--action-smooth", type=float, default=0.55)
|
||||
parser.add_argument("--wheel-slip-std", type=float, default=0.05)
|
||||
parser.add_argument("--use-webots-preset", action="store_true",
|
||||
help="Train in the HERDING_WEBOTS env (140° FOV + tight tracker).")
|
||||
parser.add_argument("--device", default="cpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
from herding.config import HerdingConfig, HERDING_WEBOTS, DomainRandomConfig, RobotConfig
|
||||
|
||||
if args.use_webots_preset:
|
||||
herding_cfg = HERDING_WEBOTS.replace(
|
||||
domain_random=DomainRandomConfig(
|
||||
fp_rate=args.fp_rate,
|
||||
wheel_slip_std=args.wheel_slip_std,
|
||||
),
|
||||
robot=RobotConfig(action_smooth=args.action_smooth),
|
||||
)
|
||||
print(f"[lstm] HERDING_WEBOTS preset + DR: fp_rate={args.fp_rate}")
|
||||
else:
|
||||
herding_cfg = None
|
||||
if args.fp_rate > 0.0 or args.action_smooth > 0.0 or args.wheel_slip_std > 0.0:
|
||||
herding_cfg = HerdingConfig(
|
||||
domain_random=DomainRandomConfig(
|
||||
fp_rate=args.fp_rate,
|
||||
wheel_slip_std=args.wheel_slip_std,
|
||||
),
|
||||
robot=RobotConfig(action_smooth=args.action_smooth),
|
||||
)
|
||||
|
||||
env_fns = [_make_env(i, args.seed, args.drive_mode, args.difficulty,
|
||||
args.max_n_sheep, herding_cfg)
|
||||
for i in range(args.n_envs)]
|
||||
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
||||
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, args.drive_mode,
|
||||
args.difficulty, args.max_n_sheep,
|
||||
herding_cfg)])
|
||||
|
||||
out = Path(args.out)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"[lstm] drive_mode={args.drive_mode} world={os.environ.get('HERDING_WORLD', 'field')}")
|
||||
print(f"[lstm] total_timesteps={args.total_timesteps} n_envs={args.n_envs} "
|
||||
f"lr={args.lr} lstm_hidden={args.lstm_hidden}")
|
||||
|
||||
model = RecurrentPPO(
|
||||
"MlpLstmPolicy", venv,
|
||||
learning_rate=args.lr,
|
||||
n_steps=args.n_steps,
|
||||
batch_size=args.n_steps, # full rollout = one batch (matches LSTM episode boundaries)
|
||||
n_epochs=4,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
clip_range=0.2,
|
||||
ent_coef=0.0,
|
||||
max_grad_norm=0.5,
|
||||
policy_kwargs=dict(
|
||||
net_arch=dict(pi=[256, 256], vf=[256, 256]),
|
||||
lstm_hidden_size=args.lstm_hidden,
|
||||
n_lstm_layers=1,
|
||||
shared_lstm=False,
|
||||
enable_critic_lstm=True,
|
||||
),
|
||||
device=args.device,
|
||||
verbose=1,
|
||||
seed=args.seed,
|
||||
tensorboard_log=str(out / "tb"),
|
||||
)
|
||||
|
||||
eval_cb = EvalCallback(
|
||||
eval_venv,
|
||||
best_model_save_path=str(out / "best"),
|
||||
log_path=str(out / "evals"),
|
||||
eval_freq=max(args.n_steps * args.n_envs, 20_000) // args.n_envs,
|
||||
n_eval_episodes=5,
|
||||
deterministic=True,
|
||||
render=False,
|
||||
)
|
||||
|
||||
t0 = time.time()
|
||||
model.learn(total_timesteps=args.total_timesteps, callback=eval_cb,
|
||||
progress_bar=True)
|
||||
print(f"[lstm] training done in {time.time() - t0:.0f}s")
|
||||
|
||||
# Save best (by eval) if it exists; otherwise save final.
|
||||
best = out / "best" / "best_model.zip"
|
||||
if best.exists():
|
||||
import shutil
|
||||
shutil.copy(best, out / "policy.zip")
|
||||
print(f"[lstm] best snapshot → {out / 'policy.zip'}")
|
||||
else:
|
||||
model.save(str(out / "policy.zip"))
|
||||
print(f"[lstm] no eval beat init; final snapshot → {out / 'policy.zip'}")
|
||||
model.save(str(out / "final.zip"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user