Checkpoint 7

This commit is contained in:
Johnny Fernandes
2026-05-11 12:21:51 +01:00
parent fce0e0c786
commit a01a5c9cef
34 changed files with 1266 additions and 1038 deletions
+26 -38
View File
@@ -1,27 +1,19 @@
"""Evaluate a trained PPO policy (or the Strömbom baseline) on the env.
"""Env-side evaluation of analytic or learned policies.
Reports success rate and time-to-pen across a fixed seed grid for each
flock size 1..MAX_SHEEP. Used to produce the M5 quantitative comparison
table mentioned in plan.md.
Reports success rate, mean steps and mean penned per flock size for
``n_sheep ∈ 1..max_flock`` across ``--n-seeds`` seeds each.
Usage::
python -m training.eval --policy training/runs/latest/best
python -m training.eval --policy training/runs/rl --n-seeds 10
python -m training.eval --policy strombom
"""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
from statistics import mean, stdev
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from statistics import mean
import numpy as np
@@ -33,40 +25,38 @@ from training.herding_env import HerdingEnv
def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
obs, _ = env.reset()
success = False
for t in range(max_steps):
action = predict_fn(env, obs)
obs, _r, terminated, truncated, info = env.step(action)
if terminated or truncated:
success = bool(info.get("is_success", False))
return {"success": success, "steps": info.get("steps", t + 1),
"n_penned": info.get("n_penned", 0)}
return {"success": False, "steps": max_steps, "n_penned": int(env.sheep_penned.sum())}
return {
"success": bool(info.get("is_success", False)),
"steps": info.get("steps", t + 1),
"n_penned": info.get("n_penned", 0),
}
return {"success": False, "steps": max_steps,
"n_penned": int(env.sheep_penned.sum())}
def make_analytic_predictor(action_fn):
"""Wrap an analytic teacher so it runs on the env's exposed
perception (tracker in LiDAR mode, GT in privileged mode)."""
def _predict(env, _obs):
# Use whatever perception the env exposes — tracker output in
# LiDAR mode, ground truth in privileged mode. This makes
# evaluation honest: the analytic teacher sees what the
# deployed controller would see.
positions = env.perceived_positions()
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
return np.array([vx, vy], dtype=np.float32)
return _predict
# Backwards-compat alias.
def make_strombom_predictor():
return make_analytic_predictor(strombom_action)
def make_policy_predictor(model, vecnorm):
def _predict(_env, obs):
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
if vecnorm is not None:
obs_b = vecnorm.normalize_obs(np.asarray(obs, dtype=np.float32).reshape(1, -1))
else:
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
obs_b = vecnorm.normalize_obs(obs_b)
action, _ = model.predict(obs_b, deterministic=True)
return action[0]
return _predict
@@ -75,16 +65,17 @@ def make_policy_predictor(model, vecnorm):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--policy", required=True,
help="Either 'strombom' or path to an SB3 run directory.")
help="'strombom', 'sequential', or path to a "
"policy directory / zip.")
parser.add_argument("--n-seeds", type=int, default=10)
parser.add_argument("--max-steps", type=int, default=5000)
parser.add_argument("--max-flock", type=int, default=MAX_SHEEP)
# 1.0 = deployment distribution (sheep anywhere in field).
# Lower values use the training-curriculum spawn band (sheep near gate).
parser.add_argument("--difficulty", type=float, default=1.0)
parser.add_argument("--difficulty", type=float, default=1.0,
help="0 = sheep spawn near the gate (easy); "
"1 = full field (deployment distribution).")
args = parser.parse_args()
frame_stack = 1 # default; analytic predictors don't use stacked obs
frame_stack = 1
if args.policy == "strombom":
predict = make_analytic_predictor(strombom_action)
elif args.policy == "sequential":
@@ -92,23 +83,20 @@ def main():
else:
from stable_baselines3 import PPO
run = Path(args.policy)
# Resolve to a zip: directory of checkpoints, or a direct zip path.
if run.is_file():
zip_path = run
else:
for name in ("best_model.zip", "policy.zip", "final.zip"):
for name in ("policy.zip", "final.zip"):
if (run / name).exists():
zip_path = run / name
break
else:
raise FileNotFoundError(
f"No checkpoint found in {run} (tried best_model.zip, "
f"policy.zip, final.zip)"
f"No checkpoint found in {run} "
f"(tried policy.zip, final.zip)"
)
model = PPO.load(str(zip_path), device="auto")
# Auto-detect frame stacking from the policy's expected obs dim,
# so eval runs with whatever stacking the policy was trained on.
from herding.obs import OBS_DIM as _SINGLE
from herding.perception.obs import OBS_DIM as _SINGLE
policy_obs_dim = int(model.observation_space.shape[0])
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
frame_stack = policy_obs_dim // _SINGLE