Checkpoint 3
This commit is contained in:
+17
-4
@@ -27,11 +27,19 @@ if _PROJECT_ROOT not in sys.path:
|
||||
import numpy as np
|
||||
|
||||
from herding.geometry import PEN_ENTRY
|
||||
from herding.sequential import compute_action
|
||||
from herding.sequential import compute_action as sequential_action
|
||||
from herding.strombom import compute_action as strombom_action
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int):
|
||||
TEACHERS = {
|
||||
"sequential": sequential_action,
|
||||
"strombom": strombom_action,
|
||||
}
|
||||
|
||||
|
||||
def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int,
|
||||
teacher_fn):
|
||||
env = HerdingEnv(n_sheep=n_sheep, max_steps=max_steps,
|
||||
difficulty=1.0, seed=seed)
|
||||
obs, _ = env.reset(seed=seed)
|
||||
@@ -41,7 +49,7 @@ def collect_one(n_sheep: int, seed: int, max_steps: int, subsample: int):
|
||||
for i in range(env.n_sheep) if not env.sheep_penned[i]}
|
||||
if not positions:
|
||||
break
|
||||
vx, vy, _mode = compute_action(
|
||||
vx, vy, _mode = teacher_fn(
|
||||
(env.dog_x, env.dog_y), positions, PEN_ENTRY,
|
||||
)
|
||||
action = np.array([vx, vy], dtype=np.float32)
|
||||
@@ -70,7 +78,12 @@ def main():
|
||||
help="Keep every Nth (obs, action) pair.")
|
||||
parser.add_argument("--keep-failures", action="store_true",
|
||||
help="Include partial-success trajectories. Default off.")
|
||||
parser.add_argument("--teacher", default="sequential",
|
||||
choices=list(TEACHERS.keys()),
|
||||
help="Which analytic teacher to demonstrate.")
|
||||
args = parser.parse_args()
|
||||
teacher_fn = TEACHERS[args.teacher]
|
||||
print(f"[demos] teacher: {args.teacher}")
|
||||
|
||||
n_sheep_list = [int(x) for x in args.n_sheep_list.split(",")]
|
||||
print(f"[demos] grid: n_sheep={n_sheep_list}, seeds={args.seeds_per_n}, "
|
||||
@@ -83,7 +96,7 @@ def main():
|
||||
for n in n_sheep_list:
|
||||
for seed in range(args.seeds_per_n):
|
||||
obs, actions, success, total_steps = collect_one(
|
||||
n, seed, args.max_steps, args.subsample,
|
||||
n, seed, args.max_steps, args.subsample, teacher_fn,
|
||||
)
|
||||
n_total += 1
|
||||
if success:
|
||||
|
||||
Reference in New Issue
Block a user