Files
TIR_PROJ/tests/parity_test.py
T
Johnny Fernandes fce0e0c786 Checkpoint 6
2026-05-11 10:35:48 +01:00

97 lines
3.1 KiB
Python

"""Parity smoke-test for the herding env.
Verifies (a) all imports resolve, (b) the env's reset/step contract is
correct, (c) deterministic seeds give deterministic trajectories, and
(d) the Strömbom baseline can drive the env without crashing.
Run::
python -m training.parity_test
"""
from __future__ import annotations
import os
import sys
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
from herding.world.geometry import MAX_SHEEP, PEN_ENTRY
from herding.obs import OBS_DIM
from herding.control.strombom import compute_action
from training.herding_env import HerdingEnv
def test_obs_action_shapes():
env = HerdingEnv(n_sheep=3, seed=0)
obs, info = env.reset()
assert obs.shape == (OBS_DIM,), obs.shape
assert obs.dtype == np.float32
obs2, r, term, trunc, info = env.step(np.array([0.5, 0.0], dtype=np.float32))
assert obs2.shape == (OBS_DIM,)
assert isinstance(r, float)
assert isinstance(term, bool) and isinstance(trunc, bool)
print("[ok] shapes")
def test_reset_determinism():
"""Reset with the same seed should give the same initial observation.
We don't require step-determinism — PPO doesn't need it, and chasing
bit-exactness through the flocking jitter isn't worth the complexity.
"""
env_a = HerdingEnv(n_sheep=3, seed=42)
env_b = HerdingEnv(n_sheep=3, seed=42)
obs_a, _ = env_a.reset(seed=42)
obs_b, _ = env_b.reset(seed=42)
assert np.allclose(obs_a, obs_b), "Reset is non-deterministic for same seed"
print("[ok] reset determinism")
def test_curriculum_n_sheep_varies():
env = HerdingEnv(seed=0)
sizes = set()
for _ in range(40):
_, info = env.reset()
sizes.add(info["n_sheep"])
assert 1 in sizes
assert max(sizes) <= MAX_SHEEP
print(f"[ok] curriculum sampling — saw n_sheep in {sorted(sizes)}")
def test_strombom_drives_env():
"""Quick functional check that the analytic baseline can play the env
without exploding. Not a success-rate test — just no errors / NaNs."""
env = HerdingEnv(n_sheep=2, max_steps=400, seed=1)
obs, _ = env.reset()
for t in range(400):
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
for i in range(env.n_sheep)
if not env.sheep_penned[i]}
if not positions:
break
vx, vy, _mode = compute_action((env.dog_x, env.dog_y), positions, PEN_ENTRY)
obs, r, term, trunc, info = env.step(np.array([vx, vy], dtype=np.float32))
assert np.isfinite(obs).all(), f"NaN/Inf in obs at step {t}"
assert np.isfinite(r), f"NaN reward at step {t}"
if term or trunc:
break
print(f"[ok] strombom rollout — final n_penned={int(env.sheep_penned.sum())}/{env.n_sheep} after {env.steps} steps")
def main():
test_obs_action_shapes()
test_reset_determinism()
test_curriculum_n_sheep_varies()
test_strombom_drives_env()
print("\nAll parity checks passed.")
if __name__ == "__main__":
main()