Checkpoint 8

This commit is contained in:
Johnny Fernandes
2026-05-12 22:41:03 +01:00
parent a01a5c9cef
commit 5c2ee4bba5
31 changed files with 3189 additions and 380 deletions
+45 -6
View File
@@ -12,11 +12,28 @@ Usage::
from __future__ import annotations
import argparse
import os
from pathlib import Path
from statistics import mean
import numpy as np
# Early CLI pre-parse for --world so geometry is configured before
# other herding.* modules are imported.
_pre_argv = [a for a in os.sys.argv[1:]]
_pre_world = None
for i, a in enumerate(_pre_argv):
if a == "--world" and i + 1 < len(_pre_argv):
_pre_world = _pre_argv[i + 1]
break
if a.startswith("--world="):
_pre_world = a.split("=", 1)[1]
break
if _pre_world is not None:
from herding.world.geometry import configure as _geo_configure
_geo_configure(_pre_world)
os.environ["HERDING_WORLD"] = _pre_world
from herding.world.geometry import MAX_SHEEP, PEN_ENTRY
from herding.control.sequential import compute_action as sequential_action
from herding.control.strombom import compute_action as strombom_action
@@ -38,18 +55,20 @@ def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
"n_penned": int(env.sheep_penned.sum())}
def make_analytic_predictor(action_fn):
def make_analytic_predictor(action_fn, drive_mode: str = "differential"):
"""Wrap an analytic teacher so it runs on the env's exposed
perception (tracker in LiDAR mode, GT in privileged mode)."""
def _predict(env, _obs):
positions = env.perceived_positions()
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
if drive_mode == "mecanum":
return np.array([vx, vy, 0.0], dtype=np.float32)
return np.array([vx, vy], dtype=np.float32)
return _predict
def make_strombom_predictor():
return make_analytic_predictor(strombom_action)
def make_strombom_predictor(drive_mode: str = "differential"):
return make_analytic_predictor(strombom_action, drive_mode)
def make_policy_predictor(model, vecnorm):
@@ -73,13 +92,21 @@ def main():
parser.add_argument("--difficulty", type=float, default=1.0,
help="0 = sheep spawn near the gate (easy); "
"1 = full field (deployment distribution).")
parser.add_argument("--drive-mode", default="differential",
choices=["differential", "mecanum"],
help="Drive mode for the dog robot.")
parser.add_argument("--world", default=None,
choices=["field", "field_round"],
help="World shape. If not set, uses HERDING_WORLD "
"env var or defaults to 'field'.")
args = parser.parse_args()
drive_mode = args.drive_mode
frame_stack = 1
if args.policy == "strombom":
predict = make_analytic_predictor(strombom_action)
predict = make_analytic_predictor(strombom_action, drive_mode)
elif args.policy == "sequential":
predict = make_analytic_predictor(sequential_action)
predict = make_analytic_predictor(sequential_action, drive_mode)
else:
from stable_baselines3 import PPO
run = Path(args.policy)
@@ -114,6 +141,18 @@ def main():
vecnorm.norm_reward = False
predict = make_policy_predictor(model, vecnorm)
# Infer drive_mode from policy action dim if using a learned policy.
if args.policy not in ("strombom", "sequential"):
policy_action_dim = int(model.action_space.shape[0])
if policy_action_dim == 2 and drive_mode == "mecanum":
drive_mode = "differential"
print(f"[eval] policy has 2D actions — overriding drive_mode "
f"to differential")
elif policy_action_dim == 3 and drive_mode == "differential":
drive_mode = "mecanum"
print(f"[eval] policy has 3D actions — overriding drive_mode "
f"to mecanum")
print(f"{'n_sheep':>8} {'success%':>10} {'mean_steps':>12} {'mean_penned':>12}")
print("-" * 46)
for n in range(1, args.max_flock + 1):
@@ -121,7 +160,7 @@ def main():
for seed in range(args.n_seeds):
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
difficulty=args.difficulty, seed=seed,
frame_stack=frame_stack)
frame_stack=frame_stack, drive_mode=drive_mode)
r = rollout(env, predict, args.max_steps)
successes.append(int(r["success"]))
steps.append(r["steps"])