97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
"""Parity smoke-test for the herding env.
|
|
|
|
Verifies (a) all imports resolve, (b) the env's reset/step contract is
|
|
correct, (c) deterministic seeds give deterministic trajectories, and
|
|
(d) the Strömbom baseline can drive the env without crashing.
|
|
|
|
Run::
|
|
|
|
python -m training.parity_test
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import sys
|
|
|
|
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
|
if _PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, _PROJECT_ROOT)
|
|
|
|
import numpy as np
|
|
|
|
from herding.geometry import MAX_SHEEP, PEN_ENTRY
|
|
from herding.obs import OBS_DIM
|
|
from herding.strombom import compute_action
|
|
from training.herding_env import HerdingEnv
|
|
|
|
|
|
def test_obs_action_shapes():
|
|
env = HerdingEnv(n_sheep=3, seed=0)
|
|
obs, info = env.reset()
|
|
assert obs.shape == (OBS_DIM,), obs.shape
|
|
assert obs.dtype == np.float32
|
|
obs2, r, term, trunc, info = env.step(np.array([0.5, 0.0], dtype=np.float32))
|
|
assert obs2.shape == (OBS_DIM,)
|
|
assert isinstance(r, float)
|
|
assert isinstance(term, bool) and isinstance(trunc, bool)
|
|
print("[ok] shapes")
|
|
|
|
|
|
def test_reset_determinism():
|
|
"""Reset with the same seed should give the same initial observation.
|
|
|
|
We don't require step-determinism — PPO doesn't need it, and chasing
|
|
bit-exactness through the flocking jitter isn't worth the complexity.
|
|
"""
|
|
env_a = HerdingEnv(n_sheep=3, seed=42)
|
|
env_b = HerdingEnv(n_sheep=3, seed=42)
|
|
obs_a, _ = env_a.reset(seed=42)
|
|
obs_b, _ = env_b.reset(seed=42)
|
|
assert np.allclose(obs_a, obs_b), "Reset is non-deterministic for same seed"
|
|
print("[ok] reset determinism")
|
|
|
|
|
|
def test_curriculum_n_sheep_varies():
|
|
env = HerdingEnv(seed=0)
|
|
sizes = set()
|
|
for _ in range(40):
|
|
_, info = env.reset()
|
|
sizes.add(info["n_sheep"])
|
|
assert 1 in sizes
|
|
assert max(sizes) <= MAX_SHEEP
|
|
print(f"[ok] curriculum sampling — saw n_sheep in {sorted(sizes)}")
|
|
|
|
|
|
def test_strombom_drives_env():
|
|
"""Quick functional check that the analytic baseline can play the env
|
|
without exploding. Not a success-rate test — just no errors / NaNs."""
|
|
env = HerdingEnv(n_sheep=2, max_steps=400, seed=1)
|
|
obs, _ = env.reset()
|
|
for t in range(400):
|
|
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
|
for i in range(env.n_sheep)
|
|
if not env.sheep_penned[i]}
|
|
if not positions:
|
|
break
|
|
vx, vy, _mode = compute_action((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
|
obs, r, term, trunc, info = env.step(np.array([vx, vy], dtype=np.float32))
|
|
assert np.isfinite(obs).all(), f"NaN/Inf in obs at step {t}"
|
|
assert np.isfinite(r), f"NaN reward at step {t}"
|
|
if term or trunc:
|
|
break
|
|
print(f"[ok] strombom rollout — final n_penned={int(env.sheep_penned.sum())}/{env.n_sheep} after {env.steps} steps")
|
|
|
|
|
|
def main():
|
|
test_obs_action_shapes()
|
|
test_reset_determinism()
|
|
test_curriculum_n_sheep_varies()
|
|
test_strombom_drives_env()
|
|
print("\nAll parity checks passed.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|