diff --git a/controllers/shepherd_dog/policy_loader.py b/controllers/shepherd_dog/policy_loader.py index 26e7a0e..817a32f 100644 --- a/controllers/shepherd_dog/policy_loader.py +++ b/controllers/shepherd_dog/policy_loader.py @@ -15,19 +15,35 @@ from pathlib import Path class PolicyHandle: - """Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``.""" + """Wrap a loaded policy (+ optional VecNormalize) for ``predict(obs)``. - def __init__(self, model, vecnorm): + Supports both MLP (PPO) and recurrent (RecurrentPPO/LSTM) policies. + For LSTM policies, frame_stack is forced to 1 and the LSTM hidden + state is maintained across calls; ``reset_recurrent`` is exposed for + new episodes. + """ + + def __init__(self, model, vecnorm, recurrent: bool = False): self.model = model self.vecnorm = vecnorm + self.recurrent = recurrent from herding.perception.obs import OBS_DIM policy_dim = int(model.observation_space.shape[0]) - if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1: + if recurrent: + self.frame_stack = 1 + elif policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1: self.frame_stack = policy_dim // OBS_DIM else: self.frame_stack = 1 self._buffer: list = [] self._single_dim = OBS_DIM + self._lstm_state = None + self._first_step = True + + def reset_recurrent(self): + self._lstm_state = None + self._first_step = True + self._buffer = [] def predict(self, obs): import numpy as np @@ -49,7 +65,15 @@ class PolicyHandle: obs_b = stacked.reshape(1, -1) if self.vecnorm is not None: obs_b = self.vecnorm.normalize_obs(obs_b) - action, _ = self.model.predict(obs_b, deterministic=True) + if self.recurrent: + episode_start = np.array([self._first_step], dtype=bool) + action, self._lstm_state = self.model.predict( + obs_b, state=self._lstm_state, + episode_start=episode_start, deterministic=True, + ) + self._first_step = False + else: + action, _ = self.model.predict(obs_b, deterministic=True) return action[0] @@ -79,7 +103,16 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle: from stable_baselines3 import PPO from stable_baselines3.common.vec_env import VecNormalize # noqa: F401 - model = PPO.load(str(zip_path), device="auto") + # Try RecurrentPPO (LSTM) first, fall back to PPO (MLP). + recurrent = False + model = None + try: + from sb3_contrib import RecurrentPPO + model = RecurrentPPO.load(str(zip_path), device="auto") + recurrent = True + except Exception: + model = PPO.load(str(zip_path), device="auto") + vecnorm = None if vecnorm_path and os.path.exists(vecnorm_path): import pickle @@ -87,4 +120,4 @@ def load(model_path: str, vecnorm_path: str | None = None) -> PolicyHandle: vecnorm = pickle.load(f) vecnorm.training = False vecnorm.norm_reward = False - return PolicyHandle(model=model, vecnorm=vecnorm) + return PolicyHandle(model=model, vecnorm=vecnorm, recurrent=recurrent) diff --git a/training/eval.py b/training/eval.py index 790e51a..bf312ad 100644 --- a/training/eval.py +++ b/training/eval.py @@ -59,16 +59,36 @@ def make_strombom_predictor(drive_mode: str = "differential"): return make_analytic_predictor(strombom_action, drive_mode) -def make_policy_predictor(model, vecnorm): +def make_policy_predictor(model, vecnorm, recurrent: bool = False): + state = {"lstm": None, "first": True} def _predict(_env, obs): obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1) if vecnorm is not None: obs_b = vecnorm.normalize_obs(obs_b) - action, _ = model.predict(obs_b, deterministic=True) + if recurrent: + episode_start = np.array([state["first"]], dtype=bool) + action, new_state = model.predict( + obs_b, state=state["lstm"], episode_start=episode_start, + deterministic=True, + ) + state["lstm"] = new_state + state["first"] = False + else: + action, _ = model.predict(obs_b, deterministic=True) return action[0] return _predict +def _reset_recurrent(predict_fn): + """Reset the recurrent state between episodes.""" + # The closure stores `state` dict; reach in via __closure__. + for cell in predict_fn.__closure__ or []: + if isinstance(cell.cell_contents, dict) and "lstm" in cell.cell_contents: + cell.cell_contents["lstm"] = None + cell.cell_contents["first"] = True + return + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--policy", required=True, @@ -110,7 +130,17 @@ def main(): f"No checkpoint found in {run} " f"(tried policy.zip, final.zip)" ) - model = PPO.load(str(zip_path), device="auto") + # Try RecurrentPPO first (sb3-contrib) for LSTM policies, then + # fall back to PPO for MLP policies. + recurrent = False + model = None + try: + from sb3_contrib import RecurrentPPO + model = RecurrentPPO.load(str(zip_path), device="auto") + recurrent = True + print(f"[eval] loaded RecurrentPPO (LSTM) policy") + except Exception: + model = PPO.load(str(zip_path), device="auto") from herding.perception.obs import OBS_DIM as _SINGLE policy_obs_dim = int(model.observation_space.shape[0]) if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1: @@ -127,7 +157,7 @@ def main(): vecnorm = pickle.load(f) vecnorm.training = False vecnorm.norm_reward = False - predict = make_policy_predictor(model, vecnorm) + predict = make_policy_predictor(model, vecnorm, recurrent=recurrent) # Infer drive_mode from policy action dim if using a learned policy. if args.policy not in ("strombom", "sequential"): @@ -149,6 +179,7 @@ def main(): env = HerdingEnv(n_sheep=n, max_steps=args.max_steps, difficulty=args.difficulty, seed=seed, frame_stack=frame_stack, drive_mode=drive_mode) + _reset_recurrent(predict) r = rollout(env, predict, args.max_steps) successes.append(int(r["success"])) steps.append(r["steps"]) diff --git a/training/rl/train_lstm.py b/training/rl/train_lstm.py new file mode 100644 index 0000000..84630fb --- /dev/null +++ b/training/rl/train_lstm.py @@ -0,0 +1,174 @@ +"""Recurrent-PPO (LSTM) policy trainer for the herding env. + +Motivation +---------- +The MLP+frame-stack policy struggles with partial observability under +the 140° Webots LiDAR: the tracker briefly empties when the dog turns, +and sporadic FP tracks at static features confuse the policy. An LSTM +gives the policy unbounded temporal memory so it can: + +* keep modelling sheep positions when the tracker briefly drops them, +* distinguish persistent (real) tracks from intermittent (phantom) ones. + +This is the literature-correct fix for partial-observability + noisy +perception. Trains from scratch (no BC init) using vanilla PPO without +the KL-to-reference term (no reference exists when starting clean). + +Usage +----- + python -m training.rl.train_lstm \\ + --out training/runs/lstm_differential_field \\ + --drive-mode differential --world field \\ + --total-timesteps 3000000 \\ + --use-webots-preset --fp-rate 0.0 --action-smooth 0.55 + +Frame stack is forced to 1 since the LSTM provides its own memory. +""" + +from __future__ import annotations + +import argparse +import os +import time +from pathlib import Path + +import numpy as np + +# Configure field geometry before other herding imports read it at module level. +from herding.world.geometry import configure_from_args as _configure_from_args +_configure_from_args() + +from sb3_contrib import RecurrentPPO +from stable_baselines3.common.callbacks import EvalCallback +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv + +from herding.world.geometry import MAX_SHEEP +from training.herding_env import HerdingEnv + + +def _make_env(rank: int, seed: int, drive_mode: str, difficulty: float, + max_n_sheep: int, herding_cfg): + def _init(): + env = HerdingEnv( + max_n_sheep=max_n_sheep, difficulty=difficulty, + seed=seed + rank, frame_stack=1, drive_mode=drive_mode, + herding_cfg=herding_cfg, + ) + return env + return _init + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--out", required=True, + help="Output directory for the LSTM policy.") + parser.add_argument("--total-timesteps", type=int, default=3_000_000) + parser.add_argument("--n-envs", type=int, default=8) + parser.add_argument("--n-steps", type=int, default=256) + parser.add_argument("--lstm-hidden", type=int, default=128) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max-n-sheep", type=int, default=MAX_SHEEP) + parser.add_argument("--difficulty", type=float, default=1.0) + parser.add_argument("--drive-mode", default="differential", + choices=["differential", "mecanum"]) + parser.add_argument("--world", default=None, + choices=["field", "field_round"]) + parser.add_argument("--fp-rate", type=float, default=0.0) + parser.add_argument("--action-smooth", type=float, default=0.55) + parser.add_argument("--wheel-slip-std", type=float, default=0.05) + parser.add_argument("--use-webots-preset", action="store_true", + help="Train in the HERDING_WEBOTS env (140° FOV + tight tracker).") + parser.add_argument("--device", default="cpu") + args = parser.parse_args() + + from herding.config import HerdingConfig, HERDING_WEBOTS, DomainRandomConfig, RobotConfig + + if args.use_webots_preset: + herding_cfg = HERDING_WEBOTS.replace( + domain_random=DomainRandomConfig( + fp_rate=args.fp_rate, + wheel_slip_std=args.wheel_slip_std, + ), + robot=RobotConfig(action_smooth=args.action_smooth), + ) + print(f"[lstm] HERDING_WEBOTS preset + DR: fp_rate={args.fp_rate}") + else: + herding_cfg = None + if args.fp_rate > 0.0 or args.action_smooth > 0.0 or args.wheel_slip_std > 0.0: + herding_cfg = HerdingConfig( + domain_random=DomainRandomConfig( + fp_rate=args.fp_rate, + wheel_slip_std=args.wheel_slip_std, + ), + robot=RobotConfig(action_smooth=args.action_smooth), + ) + + env_fns = [_make_env(i, args.seed, args.drive_mode, args.difficulty, + args.max_n_sheep, herding_cfg) + for i in range(args.n_envs)] + venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns) + eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, args.drive_mode, + args.difficulty, args.max_n_sheep, + herding_cfg)]) + + out = Path(args.out) + out.mkdir(parents=True, exist_ok=True) + + print(f"[lstm] drive_mode={args.drive_mode} world={os.environ.get('HERDING_WORLD', 'field')}") + print(f"[lstm] total_timesteps={args.total_timesteps} n_envs={args.n_envs} " + f"lr={args.lr} lstm_hidden={args.lstm_hidden}") + + model = RecurrentPPO( + "MlpLstmPolicy", venv, + learning_rate=args.lr, + n_steps=args.n_steps, + batch_size=args.n_steps, # full rollout = one batch (matches LSTM episode boundaries) + n_epochs=4, + gamma=0.99, + gae_lambda=0.95, + clip_range=0.2, + ent_coef=0.0, + max_grad_norm=0.5, + policy_kwargs=dict( + net_arch=dict(pi=[256, 256], vf=[256, 256]), + lstm_hidden_size=args.lstm_hidden, + n_lstm_layers=1, + shared_lstm=False, + enable_critic_lstm=True, + ), + device=args.device, + verbose=1, + seed=args.seed, + tensorboard_log=str(out / "tb"), + ) + + eval_cb = EvalCallback( + eval_venv, + best_model_save_path=str(out / "best"), + log_path=str(out / "evals"), + eval_freq=max(args.n_steps * args.n_envs, 20_000) // args.n_envs, + n_eval_episodes=5, + deterministic=True, + render=False, + ) + + t0 = time.time() + model.learn(total_timesteps=args.total_timesteps, callback=eval_cb, + progress_bar=True) + print(f"[lstm] training done in {time.time() - t0:.0f}s") + + # Save best (by eval) if it exists; otherwise save final. + best = out / "best" / "best_model.zip" + if best.exists(): + import shutil + shutil.copy(best, out / "policy.zip") + print(f"[lstm] best snapshot → {out / 'policy.zip'}") + else: + model.save(str(out / "policy.zip")) + print(f"[lstm] no eval beat init; final snapshot → {out / 'policy.zip'}") + model.save(str(out / "final.zip")) + + +if __name__ == "__main__": + main() diff --git a/training/runs/lstm_differential_field/policy.zip b/training/runs/lstm_differential_field/policy.zip new file mode 100644 index 0000000..bff65ff Binary files /dev/null and b/training/runs/lstm_differential_field/policy.zip differ