Checkpoint 3

This commit is contained in:
Johnny Fernandes
2026-05-10 12:46:14 +01:00
parent 1bb9415414
commit 2a6db038df
16 changed files with 305 additions and 662 deletions
+17 -4
View File
@@ -27,11 +27,19 @@ if _PROJECT_ROOT not in sys.path:
import numpy as np
from herding.geometry import PEN_ENTRY
from herding.sequential import compute_action
from herding.sequential import compute_action as sequential_action
from herding.strombom import compute_action as strombom_action
from training.herding_env import HerdingEnv
def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int):
TEACHERS = {
"sequential": sequential_action,
"strombom": strombom_action,
}
def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int,
teacher_fn):
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
difficulty=1.0, seed=seed)
obs, _ = env.reset(seed=seed)
@@ -41,7 +49,7 @@ def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int):
for i in range(env.n_sheep) if not env.sheep_penned[i]}
if not positions:
break
vx, vy, _mode = compute_action(
vx, vy, _mode = teacher_fn(
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
)
action = np.array([vx, vy], dtype=np.float32)
@@ -70,7 +78,12 @@ def main():
help="Keep every Nth (obs, action) pair.")
parser.add_argument("--keep-failures", action="store_true",
help="Include partial-success trajectories. Default off.")
parser.add_argument("--teacher", default="sequential",
choices=list(TEACHERS.keys()),
help="Which analytic teacher to demonstrate.")
args = parser.parse_args()
teacher_fn = TEACHERS[args.teacher]
print(f"[demos] teacher: {args.teacher}")
n_sheep_list = [int(x) for x in args.n_sheep_list.split(",")]
print(f"[demos] grid: n_sheep={n_sheep_list}, seeds={args.seeds_per_n}, "
@@ -83,7 +96,7 @@ def main():
for n in n_sheep_list:
for seed in range(args.seeds_per_n):
obs, actions, success, total_steps = collect_one(
n, seed, args.max_steps, args.subsample,
n, seed, args.max_steps, args.subsample, teacher_fn,
)
n_total += 1
if success: