Checkpoint 2
This commit is contained in:
@@ -0,0 +1,136 @@
|
||||
"""Evaluate a trained PPO policy (or the Strömbom baseline) on the env.
|
||||
|
||||
Reports success rate and time-to-pen across a fixed seed grid for each
|
||||
flock size 1..MAX_SHEEP. Used to produce the M5 quantitative comparison
|
||||
table mentioned in plan.md.
|
||||
|
||||
Usage::
|
||||
|
||||
python -m training.eval --policy training/runs/latest/best
|
||||
python -m training.eval --policy strombom
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from statistics import mean, stdev
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from herding.geometry import MAX_SHEEP, PEN_ENTRY
|
||||
from herding.strombom import compute_action as strombom_action
|
||||
from herding.sequential import compute_action as sequential_action
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
|
||||
obs, _ = env.reset()
|
||||
success = False
|
||||
for t in range(max_steps):
|
||||
action = predict_fn(env, obs)
|
||||
obs, _r, terminated, truncated, info = env.step(action)
|
||||
if terminated or truncated:
|
||||
success = bool(info.get("is_success", False))
|
||||
return {"success": success, "steps": info.get("steps", t + 1),
|
||||
"n_penned": info.get("n_penned", 0)}
|
||||
return {"success": False, "steps": max_steps, "n_penned": int(env.sheep_penned.sum())}
|
||||
|
||||
|
||||
def make_analytic_predictor(action_fn):
|
||||
def _predict(env, _obs):
|
||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||
for i in range(env.n_sheep)
|
||||
if not env.sheep_penned[i]}
|
||||
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
||||
return np.array([vx, vy], dtype=np.float32)
|
||||
return _predict
|
||||
|
||||
|
||||
# Backwards-compat alias.
|
||||
def make_strombom_predictor():
|
||||
return make_analytic_predictor(strombom_action)
|
||||
|
||||
|
||||
def make_policy_predictor(model, vecnorm):
|
||||
def _predict(_env, obs):
|
||||
if vecnorm is not None:
|
||||
obs_b = vecnorm.normalize_obs(np.asarray(obs, dtype=np.float32).reshape(1, -1))
|
||||
else:
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
action, _ = model.predict(obs_b, deterministic=True)
|
||||
return action[0]
|
||||
return _predict
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--policy", required=True,
|
||||
help="Either 'strombom' or path to an SB3 run directory.")
|
||||
parser.add_argument("--n-seeds", type=int, default=10)
|
||||
parser.add_argument("--max-steps", type=int, default=5000)
|
||||
parser.add_argument("--max-flock", type=int, default=MAX_SHEEP)
|
||||
# 1.0 = deployment distribution (sheep anywhere in field).
|
||||
# Lower values use the training-curriculum spawn band (sheep near gate).
|
||||
parser.add_argument("--difficulty", type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.policy == "strombom":
|
||||
predict = make_analytic_predictor(strombom_action)
|
||||
elif args.policy == "sequential":
|
||||
predict = make_analytic_predictor(sequential_action)
|
||||
else:
|
||||
from stable_baselines3 import PPO
|
||||
run = Path(args.policy)
|
||||
# Resolve to a zip: directory of checkpoints, or a direct zip path.
|
||||
if run.is_file():
|
||||
zip_path = run
|
||||
else:
|
||||
for name in ("best_model.zip", "policy.zip", "final.zip"):
|
||||
if (run / name).exists():
|
||||
zip_path = run / name
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"No checkpoint found in {run} (tried best_model.zip, "
|
||||
f"policy.zip, final.zip)"
|
||||
)
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
vecnorm = None
|
||||
vn_path = run / "vecnormalize.pkl"
|
||||
if not vn_path.exists() and run.parent.name != "best":
|
||||
vn_path = run.parent / "vecnormalize.pkl"
|
||||
if vn_path.exists():
|
||||
import pickle
|
||||
with open(vn_path, "rb") as f:
|
||||
vecnorm = pickle.load(f)
|
||||
vecnorm.training = False
|
||||
vecnorm.norm_reward = False
|
||||
predict = make_policy_predictor(model, vecnorm)
|
||||
|
||||
print(f"{'n_sheep':>8} {'success%':>10} {'mean_steps':>12} {'mean_penned':>12}")
|
||||
print("-" * 46)
|
||||
for n in range(1, args.max_flock + 1):
|
||||
successes, steps, penned = [], [], []
|
||||
for seed in range(args.n_seeds):
|
||||
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
|
||||
difficulty=args.difficulty, seed=seed)
|
||||
r = rollout(env, predict, args.max_steps)
|
||||
successes.append(int(r["success"]))
|
||||
steps.append(r["steps"])
|
||||
penned.append(r["n_penned"])
|
||||
sr = 100.0 * mean(successes)
|
||||
ms = mean(steps)
|
||||
mp = mean(penned)
|
||||
print(f"{n:>8d} {sr:>9.1f}% {ms:>12.0f} {mp:>12.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user