"""Parity smoke-test for the herding env. Verifies (a) all imports resolve, (b) the env's reset/step contract is correct, (c) deterministic seeds give deterministic trajectories, and (d) the Strömbom baseline can drive the env without crashing. Run:: python -m training.parity_test """ from __future__ import annotations import os import sys _HERE = os.path.dirname(os.path.abspath(__file__)) _PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..")) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) import numpy as np from herding.world.geometry import MAX_SHEEP, PEN_ENTRY from herding.obs import OBS_DIM from herding.control.strombom import compute_action from training.herding_env import HerdingEnv def test_obs_action_shapes(): env = HerdingEnv(n_sheep=3, seed=0) obs, info = env.reset() assert obs.shape == (OBS_DIM,), obs.shape assert obs.dtype == np.float32 obs2, r, term, trunc, info = env.step(np.array([0.5, 0.0], dtype=np.float32)) assert obs2.shape == (OBS_DIM,) assert isinstance(r, float) assert isinstance(term, bool) and isinstance(trunc, bool) print("[ok] shapes") def test_reset_determinism(): """Reset with the same seed should give the same initial observation. We don't require step-determinism — PPO doesn't need it, and chasing bit-exactness through the flocking jitter isn't worth the complexity. """ env_a = HerdingEnv(n_sheep=3, seed=42) env_b = HerdingEnv(n_sheep=3, seed=42) obs_a, _ = env_a.reset(seed=42) obs_b, _ = env_b.reset(seed=42) assert np.allclose(obs_a, obs_b), "Reset is non-deterministic for same seed" print("[ok] reset determinism") def test_curriculum_n_sheep_varies(): env = HerdingEnv(seed=0) sizes = set() for _ in range(40): _, info = env.reset() sizes.add(info["n_sheep"]) assert 1 in sizes assert max(sizes) <= MAX_SHEEP print(f"[ok] curriculum sampling — saw n_sheep in {sorted(sizes)}") def test_strombom_drives_env(): """Quick functional check that the analytic baseline can play the env without exploding. Not a success-rate test — just no errors / NaNs.""" env = HerdingEnv(n_sheep=2, max_steps=400, seed=1) obs, _ = env.reset() for t in range(400): positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i])) for i in range(env.n_sheep) if not env.sheep_penned[i]} if not positions: break vx, vy, _mode = compute_action((env.dog_x, env.dog_y), positions, PEN_ENTRY) obs, r, term, trunc, info = env.step(np.array([vx, vy], dtype=np.float32)) assert np.isfinite(obs).all(), f"NaN/Inf in obs at step {t}" assert np.isfinite(r), f"NaN reward at step {t}" if term or trunc: break print(f"[ok] strombom rollout — final n_penned={int(env.sheep_penned.sum())}/{env.n_sheep} after {env.steps} steps") def main(): test_obs_action_shapes() test_reset_determinism() test_curriculum_n_sheep_varies() test_strombom_drives_env() print("\nAll parity checks passed.") if __name__ == "__main__": main()