"""Recurrent-PPO (LSTM) policy trainer for the herding env. Motivation ---------- The MLP+frame-stack policy struggles with partial observability under the 140° Webots LiDAR: the tracker briefly empties when the dog turns, and sporadic FP tracks at static features confuse the policy. An LSTM gives the policy unbounded temporal memory so it can: * keep modelling sheep positions when the tracker briefly drops them, * distinguish persistent (real) tracks from intermittent (phantom) ones. This is the literature-correct fix for partial-observability + noisy perception. Trains from scratch (no BC init) using vanilla PPO without the KL-to-reference term (no reference exists when starting clean). Usage ----- python -m training.rl.train_lstm \\ --out training/runs/lstm_differential_field \\ --drive-mode differential --world field \\ --total-timesteps 3000000 \\ --use-webots-preset --fp-rate 0.0 --action-smooth 0.55 Frame stack is forced to 1 since the LSTM provides its own memory. """ from __future__ import annotations import argparse import os import time from pathlib import Path import numpy as np # Configure field geometry before other herding imports read it at module level. from herding.world.geometry import configure_from_args as _configure_from_args _configure_from_args() from sb3_contrib import RecurrentPPO from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv from herding.world.geometry import MAX_SHEEP from training.herding_env import HerdingEnv def _make_env(rank: int, seed: int, drive_mode: str, difficulty: float, max_n_sheep: int, herding_cfg): def _init(): env = HerdingEnv( max_n_sheep=max_n_sheep, difficulty=difficulty, seed=seed + rank, frame_stack=1, drive_mode=drive_mode, herding_cfg=herding_cfg, ) return env return _init def main(): parser = argparse.ArgumentParser() parser.add_argument("--out", required=True, help="Output directory for the LSTM policy.") parser.add_argument("--total-timesteps", type=int, default=3_000_000) parser.add_argument("--n-envs", type=int, default=8) parser.add_argument("--n-steps", type=int, default=256) parser.add_argument("--lstm-hidden", type=int, default=128) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--max-n-sheep", type=int, default=MAX_SHEEP) parser.add_argument("--difficulty", type=float, default=1.0) parser.add_argument("--drive-mode", default="differential", choices=["differential", "mecanum"]) parser.add_argument("--world", default=None, choices=["field", "field_round"]) parser.add_argument("--fp-rate", type=float, default=0.0) parser.add_argument("--action-smooth", type=float, default=0.55) parser.add_argument("--wheel-slip-std", type=float, default=0.05) parser.add_argument("--use-webots-preset", action="store_true", help="Train in the HERDING_WEBOTS env (140° FOV + tight tracker).") parser.add_argument("--device", default="cpu") args = parser.parse_args() from herding.config import HerdingConfig, HERDING_WEBOTS, DomainRandomConfig, RobotConfig if args.use_webots_preset: herding_cfg = HERDING_WEBOTS.replace( domain_random=DomainRandomConfig( fp_rate=args.fp_rate, wheel_slip_std=args.wheel_slip_std, ), robot=RobotConfig(action_smooth=args.action_smooth), ) print(f"[lstm] HERDING_WEBOTS preset + DR: fp_rate={args.fp_rate}") else: herding_cfg = None if args.fp_rate > 0.0 or args.action_smooth > 0.0 or args.wheel_slip_std > 0.0: herding_cfg = HerdingConfig( domain_random=DomainRandomConfig( fp_rate=args.fp_rate, wheel_slip_std=args.wheel_slip_std, ), robot=RobotConfig(action_smooth=args.action_smooth), ) env_fns = [_make_env(i, args.seed, args.drive_mode, args.difficulty, args.max_n_sheep, herding_cfg) for i in range(args.n_envs)] venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns) eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, args.drive_mode, args.difficulty, args.max_n_sheep, herding_cfg)]) out = Path(args.out) out.mkdir(parents=True, exist_ok=True) print(f"[lstm] drive_mode={args.drive_mode} world={os.environ.get('HERDING_WORLD', 'field')}") print(f"[lstm] total_timesteps={args.total_timesteps} n_envs={args.n_envs} " f"lr={args.lr} lstm_hidden={args.lstm_hidden}") model = RecurrentPPO( "MlpLstmPolicy", venv, learning_rate=args.lr, n_steps=args.n_steps, batch_size=args.n_steps, # full rollout = one batch (matches LSTM episode boundaries) n_epochs=4, gamma=0.99, gae_lambda=0.95, clip_range=0.2, ent_coef=0.0, max_grad_norm=0.5, policy_kwargs=dict( net_arch=dict(pi=[256, 256], vf=[256, 256]), lstm_hidden_size=args.lstm_hidden, n_lstm_layers=1, shared_lstm=False, enable_critic_lstm=True, ), device=args.device, verbose=1, seed=args.seed, tensorboard_log=str(out / "tb"), ) eval_cb = EvalCallback( eval_venv, best_model_save_path=str(out / "best"), log_path=str(out / "evals"), eval_freq=max(args.n_steps * args.n_envs, 20_000) // args.n_envs, n_eval_episodes=5, deterministic=True, render=False, ) t0 = time.time() model.learn(total_timesteps=args.total_timesteps, callback=eval_cb, progress_bar=True) print(f"[lstm] training done in {time.time() - t0:.0f}s") # Save best (by eval) if it exists; otherwise save final. best = out / "best" / "best_model.zip" if best.exists(): import shutil shutil.copy(best, out / "policy.zip") print(f"[lstm] best snapshot → {out / 'policy.zip'}") else: model.save(str(out / "policy.zip")) print(f"[lstm] no eval beat init; final snapshot → {out / 'policy.zip'}") model.save(str(out / "final.zip")) if __name__ == "__main__": main()