Checkpoint 4
This commit is contained in:
+16
-4
@@ -46,9 +46,11 @@ def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
|
||||
|
||||
def make_analytic_predictor(action_fn):
|
||||
def _predict(env, _obs):
|
||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||
for i in range(env.n_sheep)
|
||||
if not env.sheep_penned[i]}
|
||||
# Use whatever perception the env exposes — tracker output in
|
||||
# LiDAR mode, ground truth in privileged mode. This makes
|
||||
# evaluation honest: the analytic teacher sees what the
|
||||
# deployed controller would see.
|
||||
positions = env.perceived_positions()
|
||||
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
||||
return np.array([vx, vy], dtype=np.float32)
|
||||
return _predict
|
||||
@@ -82,6 +84,7 @@ def main():
|
||||
parser.add_argument("--difficulty", type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
frame_stack = 1 # default; analytic predictors don't use stacked obs
|
||||
if args.policy == "strombom":
|
||||
predict = make_analytic_predictor(strombom_action)
|
||||
elif args.policy == "sequential":
|
||||
@@ -103,6 +106,14 @@ def main():
|
||||
f"policy.zip, final.zip)"
|
||||
)
|
||||
model = PPO.load(str(zip_path), device="auto")
|
||||
# Auto-detect frame stacking from the policy's expected obs dim,
|
||||
# so eval runs with whatever stacking the policy was trained on.
|
||||
from herding.obs import OBS_DIM as _SINGLE
|
||||
policy_obs_dim = int(model.observation_space.shape[0])
|
||||
if policy_obs_dim % _SINGLE == 0 and policy_obs_dim // _SINGLE >= 1:
|
||||
frame_stack = policy_obs_dim // _SINGLE
|
||||
if frame_stack > 1:
|
||||
print(f"[eval] policy expects frame_stack={frame_stack}")
|
||||
vecnorm = None
|
||||
vn_path = run / "vecnormalize.pkl"
|
||||
if not vn_path.exists() and run.parent.name != "best":
|
||||
@@ -121,7 +132,8 @@ def main():
|
||||
successes, steps, penned = [], [], []
|
||||
for seed in range(args.n_seeds):
|
||||
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
|
||||
difficulty=args.difficulty, seed=seed)
|
||||
difficulty=args.difficulty, seed=seed,
|
||||
frame_stack=frame_stack)
|
||||
r = rollout(env, predict, args.max_steps)
|
||||
successes.append(int(r["success"]))
|
||||
steps.append(r["steps"])
|
||||
|
||||
Reference in New Issue
Block a user