Files
TIR_PROJ/training/train_ppo.py
T
Johnny Fernandes 2a6db038df Checkpoint 3
2026-05-10 12:46:14 +01:00

281 lines
11 KiB
Python

"""PPO trainer for the shepherd-dog policy — EXPERIMENTAL.
The deliverable pipeline is `bc_pretrain.py` (see ``training/README.md``).
This script is kept in the tree because it implements:
* PPO from scratch with curriculum over flock size + spawn area, and
* PPO fine-tune of a behavior-cloned policy.
Both ran into stability issues in our setting (long-horizon credit
assignment for sparse pen reward, BC-degradation under PPO exploration
noise). The abstractions are reusable for follow-up work — e.g.
KL-regularised fine-tune with a frozen reference policy — so we leave
the code in place.
Usage (PPO from scratch)::
python -m training.train_ppo \
--config training/configs/ppo_default.yaml \
--out-dir training/runs/ppo_scratch
Usage (PPO fine-tune of BC)::
python -m training.train_ppo \
--resume training/runs/bc_flock/policy.zip \
--out-dir training/runs/bc_ppo \
--no-vecnorm --no-curriculum --imitate-weight 0 \
--difficulty 1.0 --log-std -1.5 --learning-rate 5e-5 \
--total-timesteps 3000000
"""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
import yaml
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
import torch as th
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import (
BaseCallback, CheckpointCallback, EvalCallback,
)
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import (
DummyVecEnv, SubprocVecEnv, VecNormalize,
)
from training.herding_env import HerdingEnv
# --------------------------------------------------------------------------
# Env factories
# --------------------------------------------------------------------------
def _make_env(rank: int, seed: int = 0):
def _thunk():
env = HerdingEnv(seed=seed + rank)
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
return env
return _thunk
# --------------------------------------------------------------------------
# Curriculum callback
# --------------------------------------------------------------------------
class CurriculumCallback(BaseCallback):
"""Drive the env's flock-size + state-space difficulty curriculum.
Schedule entries: {step, max_n_sheep, difficulty}. The largest entry
whose step <= num_timesteps wins; both knobs update together.
"""
def __init__(self, schedule, vec_envs, verbose: int = 0):
super().__init__(verbose)
self.schedule = sorted(schedule, key=lambda d: d["step"])
# Accept a list of envs so the eval env tracks training difficulty.
self.vec_envs = vec_envs if isinstance(vec_envs, (list, tuple)) else [vec_envs]
self._last_n = None
self._last_d = None
def _call(self, method, value):
for v in self.vec_envs:
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
def _on_step(self) -> bool:
t = self.num_timesteps
n = self.schedule[0]["max_n_sheep"]
d = self.schedule[0].get("difficulty", 1.0)
for entry in self.schedule:
if t >= entry["step"]:
n = entry["max_n_sheep"]
d = entry.get("difficulty", 1.0)
if n != self._last_n:
self._call("set_max_n_sheep", n)
self._last_n = n
if d != self._last_d:
self._call("set_difficulty", d)
self._last_d = d
if self.verbose:
print(f"[curriculum] t={t} → max_n_sheep={n} difficulty={d}")
return True
# --------------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=os.path.join(_HERE, "configs", "ppo_default.yaml"))
parser.add_argument("--out-dir", default=os.path.join(_HERE, "runs", "latest"))
parser.add_argument("--n-envs", type=int, default=None,
help="Override config n_envs.")
parser.add_argument("--total-timesteps", type=int, default=None,
help="Override config total_timesteps.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--resume", type=str, default=None,
help="Path to a SB3 zip to resume from.")
# SB3 recommends CPU for MlpPolicy — GPU helps CNN policies, not MLPs
# of this size. Override with --device cuda if you really want it.
parser.add_argument("--device", default="cpu")
parser.add_argument("--no-vecnorm", action="store_true",
help="Disable VecNormalize wrapper. Required when "
"resuming from a BC-pretrained policy that "
"wasn't trained under it.")
parser.add_argument("--no-curriculum", action="store_true",
help="Skip curriculum callback (resumed policy is "
"already competent across the distribution).")
parser.add_argument("--imitate-weight", type=float, default=None,
help="Override env W_IMITATE. Set to 0 to disable "
"Strömbom imitation reward.")
parser.add_argument("--difficulty", type=float, default=None,
help="Override env difficulty (0=easy, 1=hard). "
"Used in BC fine-tune to skip easy curriculum.")
parser.add_argument("--log-std", type=float, default=None,
help="Override the policy's log_std after load. "
"BC trained with std≈1.6 (log_std=0.5) which "
"is too noisy for fine-tune. Use -1.5 (std≈0.22) "
"to keep PPO close to the BC mean while still "
"exploring locally.")
parser.add_argument("--learning-rate", type=float, default=None,
help="Override config learning rate. For BC "
"fine-tune, 5e-5 is much safer than the 3e-4 "
"default.")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
n_envs = args.n_envs or cfg["n_envs"]
total_timesteps = args.total_timesteps or cfg["total_timesteps"]
out = Path(args.out_dir)
out.mkdir(parents=True, exist_ok=True)
(out / "checkpoints").mkdir(exist_ok=True)
(out / "best").mkdir(exist_ok=True)
(out / "evals").mkdir(exist_ok=True)
print(f"[train] out={out} n_envs={n_envs} total={total_timesteps} device={args.device}")
# --- Train env (vectorised, optionally normalised) ---
env_fns = [_make_env(i, seed=args.seed) for i in range(n_envs)]
venv = SubprocVecEnv(env_fns) if n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, seed=args.seed + 999)])
if not args.no_vecnorm:
venv = VecNormalize(venv, norm_obs=True, norm_reward=False, clip_obs=10.0)
eval_venv = VecNormalize(eval_venv, norm_obs=True, norm_reward=False,
clip_obs=10.0, training=False)
eval_venv.obs_rms = venv.obs_rms
else:
print("[train] VecNormalize disabled (resumed policy was trained without it).")
# Apply env-level overrides (used by BC fine-tune to disable Strömbom
# imitation and start at full deployment difficulty).
def _env_call(method, value):
for v in (venv, eval_venv):
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
if args.imitate_weight is not None:
_env_call("set_imitate_weight", args.imitate_weight)
print(f"[train] W_IMITATE overridden to {args.imitate_weight}")
if args.difficulty is not None:
_env_call("set_difficulty", args.difficulty)
print(f"[train] difficulty pinned to {args.difficulty}")
# --- Model ---
policy_kwargs = dict(
net_arch=dict(pi=cfg["net_arch_pi"], vf=cfg["net_arch_vf"]),
log_std_init=cfg.get("log_std_init", 0.0),
)
if args.resume:
print(f"[train] resuming from {args.resume}")
custom_objects = {}
if args.learning_rate is not None:
custom_objects["learning_rate"] = args.learning_rate
model = PPO.load(args.resume, env=venv, device=args.device,
tensorboard_log=str(out / "tb"),
custom_objects=custom_objects or None)
if args.log_std is not None:
import torch as _th
with _th.no_grad():
model.policy.log_std.fill_(args.log_std)
print(f"[train] log_std overridden to {args.log_std} "
f"(std≈{2.71828 ** args.log_std:.2f})")
if args.learning_rate is not None:
print(f"[train] learning_rate overridden to {args.learning_rate}")
else:
model = PPO(
cfg["policy"], venv,
learning_rate=cfg["learning_rate"],
n_steps=cfg["n_steps"],
batch_size=cfg["batch_size"],
n_epochs=cfg["n_epochs"],
gamma=cfg["gamma"],
gae_lambda=cfg["gae_lambda"],
clip_range=cfg["clip_range"],
ent_coef=cfg["ent_coef"],
vf_coef=cfg["vf_coef"],
max_grad_norm=cfg["max_grad_norm"],
target_kl=cfg.get("target_kl"),
policy_kwargs=policy_kwargs,
tensorboard_log=str(out / "tb"),
seed=args.seed,
device=args.device,
verbose=1,
)
# --- Callbacks ---
ckpt_cb = CheckpointCallback(
save_freq=max(1, cfg["checkpoint_freq"] // n_envs),
save_path=str(out / "checkpoints"), name_prefix="ppo",
save_vecnormalize=True,
)
eval_cb = EvalCallback(
eval_venv,
best_model_save_path=str(out / "best"),
log_path=str(out / "evals"),
eval_freq=max(1, cfg["eval_freq"] // n_envs),
n_eval_episodes=cfg["n_eval_episodes"],
deterministic=True,
)
callbacks = [ckpt_cb, eval_cb]
if not args.no_curriculum and "curriculum" in cfg and cfg["curriculum"]:
callbacks.append(CurriculumCallback(
cfg["curriculum"], [venv, eval_venv], verbose=1,
))
elif args.no_curriculum:
print("[train] curriculum disabled — env knobs left at their current values.")
# --- Train ---
model.learn(total_timesteps=total_timesteps, callback=callbacks,
progress_bar=True)
# --- Save final model + VecNormalize stats ---
model.save(out / "final.zip")
venv.save(str(out / "vecnormalize.pkl"))
# The EvalCallback already wrote best_model.zip into out/best/ — drop the
# VecNormalize stats next to it for the controller to pick up.
venv.save(str(out / "best" / "vecnormalize.pkl"))
print(f"[train] done. saved to {out}")
if __name__ == "__main__":
main()