Files
TIR_PROJ/training/train.py
T
2026-04-22 23:34:58 +01:00

212 lines
7.3 KiB
Python

"""
PPO training script for the herding task.
Usage examples
--------------
# Start fresh with curriculum (1 → 5 sheep):
python train.py --curriculum
# Resume from checkpoint, skip directly to 3 sheep:
python train.py --resume runs/ppo_herding/ckpt_200000_steps.zip --n-sheep 3
# Quick smoke-test (no curriculum, single env):
python train.py --n-envs 1 --total-steps 50000
"""
import argparse
import os
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import (
BaseCallback,
CallbackList,
CheckpointCallback,
EvalCallback,
)
from stable_baselines3.common.vec_env import SubprocVecEnv, VecNormalize
from herding_env import HerdingEnv
# ---------------------------------------------------------------------------
# Curriculum callback
# ---------------------------------------------------------------------------
class CurriculumCallback(BaseCallback):
"""
Advances the curriculum (number of active sheep) when the rolling mean
episode success rate exceeds a threshold.
Success = episode terminated (all sheep penned) rather than truncated.
"""
THRESHOLD = 0.75 # success rate to graduate
WINDOW = 100 # episodes to average over
MIN_EPISODES = 50 # don't graduate before seeing this many episodes
def __init__(self, start_sheep: int, max_sheep: int, verbose: int = 1):
super().__init__(verbose)
self.max_sheep = max_sheep
self._successes = []
self._cur_sheep = start_sheep
def _on_step(self) -> bool:
for info, done in zip(self.locals["infos"], self.locals["dones"]):
if done:
truncated = info.get("TimeLimit.truncated", False)
self._successes.append(0 if truncated else 1)
if len(self._successes) > self.WINDOW:
self._successes.pop(0)
if (self._cur_sheep < self.max_sheep
and len(self._successes) >= self.MIN_EPISODES
and np.mean(self._successes) >= self.THRESHOLD):
self._cur_sheep += 1
self.training_env.env_method("set_n_sheep", self._cur_sheep)
self._successes.clear()
if self.verbose:
print(f"\n[Curriculum] Advanced to {self._cur_sheep} sheep "
f"at step {self.num_timesteps}\n")
return True
# ---------------------------------------------------------------------------
# Environment factory
# ---------------------------------------------------------------------------
def make_env(n_sheep: int, seed: int, max_steps: int):
def _init():
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps)
env.reset(seed=seed)
return env
return _init
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--n-sheep", type=int, default=1,
help="Starting number of sheep (or fixed count if no curriculum)")
p.add_argument("--max-sheep", type=int, default=5,
help="Maximum sheep for curriculum (ignored without --curriculum)")
p.add_argument("--n-envs", type=int, default=8,
help="Number of parallel environments")
p.add_argument("--total-steps", type=int, default=5_000_000,
help="Total environment steps to train for")
p.add_argument("--max-steps", type=int, default=2000,
help="Episode step limit inside each env")
p.add_argument("--curriculum", action="store_true",
help="Enable automatic curriculum advancement")
p.add_argument("--resume", type=str, default=None,
help="Path to a .zip checkpoint to resume training from")
p.add_argument("--run-dir", type=str, default="runs/ppo_herding",
help="Output directory for checkpoints and logs")
p.add_argument("--save-freq", type=int, default=100_000,
help="Checkpoint every N steps (per-env, not total)")
p.add_argument("--eval-freq", type=int, default=50_000,
help="Evaluate every N steps")
p.add_argument("--eval-eps", type=int, default=20,
help="Episodes per evaluation run")
return p.parse_args()
def main():
args = parse_args()
os.makedirs(args.run_dir, exist_ok=True)
ckpt_dir = os.path.join(args.run_dir, "checkpoints")
best_dir = os.path.join(args.run_dir, "best_model")
norm_path = os.path.join(args.run_dir, "vecnorm.pkl")
os.makedirs(ckpt_dir, exist_ok=True)
# Training envs
train_env = SubprocVecEnv([
make_env(args.n_sheep, seed=i, max_steps=args.max_steps)
for i in range(args.n_envs)
])
if args.resume and os.path.exists(norm_path):
train_env = VecNormalize.load(norm_path, train_env)
train_env.training = True
train_env.norm_reward = True
else:
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True,
clip_obs=10.0)
# Eval env (no reward normalisation, deterministic)
eval_env = SubprocVecEnv([
make_env(args.n_sheep, seed=1000 + i, max_steps=args.max_steps)
for i in range(2)
])
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False,
clip_obs=10.0, training=False)
# Callbacks
checkpoint_cb = CheckpointCallback(
save_freq=max(args.save_freq // args.n_envs, 1),
save_path=ckpt_dir,
name_prefix="ckpt",
save_vecnormalize=True,
)
eval_cb = EvalCallback(
eval_env,
best_model_save_path=best_dir,
log_path=args.run_dir,
eval_freq=max(args.eval_freq // args.n_envs, 1),
n_eval_episodes=args.eval_eps,
deterministic=True,
verbose=1,
)
callbacks = [checkpoint_cb, eval_cb]
if args.curriculum:
callbacks.append(CurriculumCallback(start_sheep=args.n_sheep,
max_sheep=args.max_sheep))
callback_list = CallbackList(callbacks)
# Model
ppo_kwargs = dict(
policy = "MlpPolicy",
env = train_env,
learning_rate = 3e-4,
n_steps = 2048,
batch_size = 256,
n_epochs = 10,
gamma = 0.995,
gae_lambda = 0.95,
clip_range = 0.2,
ent_coef = 0.005,
vf_coef = 0.5,
max_grad_norm = 0.5,
policy_kwargs = dict(net_arch=[256, 256]),
tensorboard_log = args.run_dir,
verbose = 1,
)
if args.resume:
print(f"Resuming from {args.resume}")
model = PPO.load(args.resume, env=train_env, **{
k: v for k, v in ppo_kwargs.items()
if k not in ("policy", "env")
})
else:
model = PPO(**ppo_kwargs)
model.learn(
total_timesteps=args.total_steps,
callback=callback_list,
reset_num_timesteps=args.resume is None,
tb_log_name="ppo",
)
# Save final artefacts
model.save(os.path.join(args.run_dir, "final_model"))
train_env.save(norm_path)
print(f"\nTraining complete. Artefacts saved to {args.run_dir}/")
if __name__ == "__main__":
main()