Checkpoint 8
This commit is contained in:
+45
-6
@@ -12,11 +12,28 @@ Usage::
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Early CLI pre-parse for --world so geometry is configured before
|
||||
# other herding.* modules are imported.
|
||||
_pre_argv = [a for a in os.sys.argv[1:]]
|
||||
_pre_world = None
|
||||
for i, a in enumerate(_pre_argv):
|
||||
if a == "--world" and i + 1 < len(_pre_argv):
|
||||
_pre_world = _pre_argv[i + 1]
|
||||
break
|
||||
if a.startswith("--world="):
|
||||
_pre_world = a.split("=", 1)[1]
|
||||
break
|
||||
if _pre_world is not None:
|
||||
from herding.world.geometry import configure as _geo_configure
|
||||
_geo_configure(_pre_world)
|
||||
os.environ["HERDING_WORLD"] = _pre_world
|
||||
|
||||
from herding.world.geometry import MAX_SHEEP, PEN_ENTRY
|
||||
from herding.control.sequential import compute_action as sequential_action
|
||||
from herding.control.strombom import compute_action as strombom_action
|
||||
@@ -38,18 +55,20 @@ def rollout(env: HerdingEnv, predict_fn, max_steps: int) -> dict:
|
||||
"n_penned": int(env.sheep_penned.sum())}
|
||||
|
||||
|
||||
def make_analytic_predictor(action_fn):
|
||||
def make_analytic_predictor(action_fn, drive_mode: str = "differential"):
|
||||
"""Wrap an analytic teacher so it runs on the env's exposed
|
||||
perception (tracker in LiDAR mode, GT in privileged mode)."""
|
||||
def _predict(env, _obs):
|
||||
positions = env.perceived_positions()
|
||||
vx, vy, _mode = action_fn((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
||||
if drive_mode == "mecanum":
|
||||
return np.array([vx, vy, 0.0], dtype=np.float32)
|
||||
return np.array([vx, vy], dtype=np.float32)
|
||||
return _predict
|
||||
|
||||
|
||||
def make_strombom_predictor():
|
||||
return make_analytic_predictor(strombom_action)
|
||||
def make_strombom_predictor(drive_mode: str = "differential"):
|
||||
return make_analytic_predictor(strombom_action, drive_mode)
|
||||
|
||||
|
||||
def make_policy_predictor(model, vecnorm):
|
||||
@@ -73,13 +92,21 @@ def main():
|
||||
parser.add_argument("--difficulty", type=float, default=1.0,
|
||||
help="0 = sheep spawn near the gate (easy); "
|
||||
"1 = full field (deployment distribution).")
|
||||
parser.add_argument("--drive-mode", default="differential",
|
||||
choices=["differential", "mecanum"],
|
||||
help="Drive mode for the dog robot.")
|
||||
parser.add_argument("--world", default=None,
|
||||
choices=["field", "field_round"],
|
||||
help="World shape. If not set, uses HERDING_WORLD "
|
||||
"env var or defaults to 'field'.")
|
||||
args = parser.parse_args()
|
||||
|
||||
drive_mode = args.drive_mode
|
||||
frame_stack = 1
|
||||
if args.policy == "strombom":
|
||||
predict = make_analytic_predictor(strombom_action)
|
||||
predict = make_analytic_predictor(strombom_action, drive_mode)
|
||||
elif args.policy == "sequential":
|
||||
predict = make_analytic_predictor(sequential_action)
|
||||
predict = make_analytic_predictor(sequential_action, drive_mode)
|
||||
else:
|
||||
from stable_baselines3 import PPO
|
||||
run = Path(args.policy)
|
||||
@@ -114,6 +141,18 @@ def main():
|
||||
vecnorm.norm_reward = False
|
||||
predict = make_policy_predictor(model, vecnorm)
|
||||
|
||||
# Infer drive_mode from policy action dim if using a learned policy.
|
||||
if args.policy not in ("strombom", "sequential"):
|
||||
policy_action_dim = int(model.action_space.shape[0])
|
||||
if policy_action_dim == 2 and drive_mode == "mecanum":
|
||||
drive_mode = "differential"
|
||||
print(f"[eval] policy has 2D actions — overriding drive_mode "
|
||||
f"to differential")
|
||||
elif policy_action_dim == 3 and drive_mode == "differential":
|
||||
drive_mode = "mecanum"
|
||||
print(f"[eval] policy has 3D actions — overriding drive_mode "
|
||||
f"to mecanum")
|
||||
|
||||
print(f"{'n_sheep':>8} {'success%':>10} {'mean_steps':>12} {'mean_penned':>12}")
|
||||
print("-" * 46)
|
||||
for n in range(1, args.max_flock + 1):
|
||||
@@ -121,7 +160,7 @@ def main():
|
||||
for seed in range(args.n_seeds):
|
||||
env = HerdingEnv(n_sheep=n, max_steps=args.max_steps,
|
||||
difficulty=args.difficulty, seed=seed,
|
||||
frame_stack=frame_stack)
|
||||
frame_stack=frame_stack, drive_mode=drive_mode)
|
||||
r = rollout(env, predict, args.max_steps)
|
||||
successes.append(int(r["success"]))
|
||||
steps.append(r["steps"])
|
||||
|
||||
Reference in New Issue
Block a user