72 lines
2.4 KiB
Python
72 lines
2.4 KiB
Python
"""Observation builder — shape, normalisation, order invariance."""
|
|
|
|
import math
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from herding.perception.obs import OBS_DIM, build_obs
|
|
|
|
|
|
def test_obs_shape_and_dtype():
|
|
obs = build_obs((0.0, 0.0), 0.0, [(5.0, 5.0)], [False])
|
|
assert obs.shape == (OBS_DIM,)
|
|
assert obs.dtype == np.float32
|
|
|
|
|
|
def test_obs_no_active_sheep_terminal():
|
|
# All sheep penned → flock-summary fields zero, count zero.
|
|
obs = build_obs((0.0, 0.0), 0.0, [(1.0, 1.0), (2.0, 2.0)], [True, True])
|
|
assert obs[19] == 0.0
|
|
# Aggregate fields (CoM, radius, std, vectors) should all be zero.
|
|
assert np.allclose(obs[4:12], 0.0)
|
|
|
|
|
|
def test_obs_dog_pose_normalised():
|
|
obs = build_obs((15.0, -15.0), math.pi / 2, [(0.0, 0.0)], [False])
|
|
assert math.isclose(obs[0], 1.0)
|
|
assert math.isclose(obs[1], -1.0)
|
|
assert math.isclose(obs[2], math.cos(math.pi / 2), abs_tol=1e-6)
|
|
assert math.isclose(obs[3], math.sin(math.pi / 2), abs_tol=1e-6)
|
|
|
|
|
|
def test_obs_order_invariance():
|
|
"""Sheep order in the input list must not affect the observation."""
|
|
sheep = [(3.0, 2.0), (-5.0, 1.0), (0.0, 8.0)]
|
|
p = [False] * 3
|
|
a = build_obs((0.0, 0.0), 0.0, sheep, p)
|
|
b = build_obs((0.0, 0.0), 0.0, list(reversed(sheep)), list(reversed(p)))
|
|
assert np.allclose(a, b)
|
|
|
|
|
|
def test_obs_count_field_normalised_by_n_max():
|
|
sheep = [(1.0, 1.0)] * 5
|
|
p = [False] * 5
|
|
obs = build_obs((0.0, 0.0), 0.0, sheep, p, n_max=10)
|
|
assert math.isclose(obs[19], 0.5)
|
|
|
|
|
|
def test_obs_polar_histogram_sums_to_one():
|
|
sheep = [(1.0, 0.0), (-1.0, 0.0), (0.0, 1.0), (0.0, -1.0)]
|
|
obs = build_obs((0.0, 0.0), 0.0, sheep, [False] * 4)
|
|
assert math.isclose(float(obs[20:28].sum()), 1.0, abs_tol=1e-6)
|
|
|
|
|
|
def test_obs_named_channels_closest_rearmost():
|
|
# Channels 28..29 = (closest_to_pen - dog) / 15
|
|
# Channels 30..31 = (rearmost - dog) / 15
|
|
pen_x, pen_y = 11.5, -15.0
|
|
near = (pen_x + 1.0, pen_y + 1.0)
|
|
far = (-10.0, 10.0)
|
|
obs = build_obs((0.0, 0.0), 0.0, [near, far], [False, False])
|
|
tol = 1e-5
|
|
assert math.isclose(obs[28], near[0] / 15.0, abs_tol=tol)
|
|
assert math.isclose(obs[29], near[1] / 15.0, abs_tol=tol)
|
|
assert math.isclose(obs[30], far[0] / 15.0, abs_tol=tol)
|
|
assert math.isclose(obs[31], far[1] / 15.0, abs_tol=tol)
|
|
|
|
|
|
def test_obs_pen_vector_zero_at_pen_entry():
|
|
obs = build_obs((11.5, -15.0), 0.0, [(0.0, 0.0)], [False])
|
|
assert math.isclose(obs[14], 0.0) # distance to pen
|