Files
TIR_PROJ/tests/test_obs.py
T
Johnny Fernandes a01a5c9cef Checkpoint 7
2026-05-11 12:21:51 +01:00

72 lines
2.4 KiB
Python

"""Observation builder — shape, normalisation, order invariance."""
import math
import numpy as np
import pytest
from herding.perception.obs import OBS_DIM, build_obs
def test_obs_shape_and_dtype():
obs = build_obs((0.0, 0.0), 0.0, [(5.0, 5.0)], [False])
assert obs.shape == (OBS_DIM,)
assert obs.dtype == np.float32
def test_obs_no_active_sheep_terminal():
# All sheep penned → flock-summary fields zero, count zero.
obs = build_obs((0.0, 0.0), 0.0, [(1.0, 1.0), (2.0, 2.0)], [True, True])
assert obs[19] == 0.0
# Aggregate fields (CoM, radius, std, vectors) should all be zero.
assert np.allclose(obs[4:12], 0.0)
def test_obs_dog_pose_normalised():
obs = build_obs((15.0, -15.0), math.pi / 2, [(0.0, 0.0)], [False])
assert math.isclose(obs[0], 1.0)
assert math.isclose(obs[1], -1.0)
assert math.isclose(obs[2], math.cos(math.pi / 2), abs_tol=1e-6)
assert math.isclose(obs[3], math.sin(math.pi / 2), abs_tol=1e-6)
def test_obs_order_invariance():
"""Sheep order in the input list must not affect the observation."""
sheep = [(3.0, 2.0), (-5.0, 1.0), (0.0, 8.0)]
p = [False] * 3
a = build_obs((0.0, 0.0), 0.0, sheep, p)
b = build_obs((0.0, 0.0), 0.0, list(reversed(sheep)), list(reversed(p)))
assert np.allclose(a, b)
def test_obs_count_field_normalised_by_n_max():
sheep = [(1.0, 1.0)] * 5
p = [False] * 5
obs = build_obs((0.0, 0.0), 0.0, sheep, p, n_max=10)
assert math.isclose(obs[19], 0.5)
def test_obs_polar_histogram_sums_to_one():
sheep = [(1.0, 0.0), (-1.0, 0.0), (0.0, 1.0), (0.0, -1.0)]
obs = build_obs((0.0, 0.0), 0.0, sheep, [False] * 4)
assert math.isclose(float(obs[20:28].sum()), 1.0, abs_tol=1e-6)
def test_obs_named_channels_closest_rearmost():
# Channels 28..29 = (closest_to_pen - dog) / 15
# Channels 30..31 = (rearmost - dog) / 15
pen_x, pen_y = 11.5, -15.0
near = (pen_x + 1.0, pen_y + 1.0)
far = (-10.0, 10.0)
obs = build_obs((0.0, 0.0), 0.0, [near, far], [False, False])
tol = 1e-5
assert math.isclose(obs[28], near[0] / 15.0, abs_tol=tol)
assert math.isclose(obs[29], near[1] / 15.0, abs_tol=tol)
assert math.isclose(obs[30], far[0] / 15.0, abs_tol=tol)
assert math.isclose(obs[31], far[1] / 15.0, abs_tol=tol)
def test_obs_pen_vector_zero_at_pen_entry():
obs = build_obs((11.5, -15.0), 0.0, [(0.0, 0.0)], [False])
assert math.isclose(obs[14], 0.0) # distance to pen