Checkpoint 4

This commit is contained in:
Johnny Fernandes
2026-05-11 00:42:52 +01:00
parent 2a6db038df
commit 6688325d89
26 changed files with 2018 additions and 503 deletions
+16 -4
View File
@@ -46,9 +46,11 @@ def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
def make_analytic_predictor(action_fn):
def _predict(env, _obs):
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
for i in range(env.n_sheep)
if not env.sheep_penned[i]}
# Use whatever perception the env exposes — tracker output in
# LiDAR mode, ground truth in privileged mode. This makes
# evaluation honest: the analytic teacher sees what the
# deployed controller would see.
positions = env.perceived_positions()
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
return np.array([vx, vy], dtype=np.float32)
return _predict
@@ -82,6 +84,7 @@ def main():
parser.add_argument("--difficulty", type=float, default=1.0)
args = parser.parse_args()
frame_stack = 1 # default; analytic predictors don't use stacked obs
if args.policy == "strombom":
predict = make_analytic_predictor(strombom_action)
elif args.policy == "sequential":
@@ -103,6 +106,14 @@ def main():
f"policy.zip, final.zip)"
)
model = PPO.load(str(zip_path), device="auto")
# Auto-detect frame stacking from the policy's expected obs dim,
# so eval runs with whatever stacking the policy was trained on.
from herding.obs import OBS_DIM as _SINGLE
policy_obs_dim = int(model.observation_space.shape[0])
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
frame_stack = policy_obs_dim // _SINGLE
if frame_stack > 1:
print(f"[eval] policy expects frame_stack={frame_stack}")
vecnorm = None
vn_path = run / "vecnormalize.pkl"
if not vn_path.exists() and run.parent.name != "best":
@@ -121,7 +132,8 @@ def main():
successes, steps, penned = [], [], []
for seed in range(args.n_seeds):
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
difficulty=args.difficulty, seed=seed)
difficulty=args.difficulty, seed=seed,
frame_stack=frame_stack)
r = rollout(env, predict, args.max_steps)
successes.append(int(r["success"]))
steps.append(r["steps"])