Checkpoint 7
This commit is contained in:
@@ -0,0 +1,71 @@
|
||||
"""Observation builder — shape, normalisation, order invariance."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from herding.perception.obs import OBS_DIM, build_obs
|
||||
|
||||
|
||||
def test_obs_shape_and_dtype():
|
||||
obs = build_obs((0.0, 0.0), 0.0, [(5.0, 5.0)], [False])
|
||||
assert obs.shape == (OBS_DIM,)
|
||||
assert obs.dtype == np.float32
|
||||
|
||||
|
||||
def test_obs_no_active_sheep_terminal():
|
||||
# All sheep penned → flock-summary fields zero, count zero.
|
||||
obs = build_obs((0.0, 0.0), 0.0, [(1.0, 1.0), (2.0, 2.0)], [True, True])
|
||||
assert obs[19] == 0.0
|
||||
# Aggregate fields (CoM, radius, std, vectors) should all be zero.
|
||||
assert np.allclose(obs[4:12], 0.0)
|
||||
|
||||
|
||||
def test_obs_dog_pose_normalised():
|
||||
obs = build_obs((15.0, -15.0), math.pi / 2, [(0.0, 0.0)], [False])
|
||||
assert math.isclose(obs[0], 1.0)
|
||||
assert math.isclose(obs[1], -1.0)
|
||||
assert math.isclose(obs[2], math.cos(math.pi / 2), abs_tol=1e-6)
|
||||
assert math.isclose(obs[3], math.sin(math.pi / 2), abs_tol=1e-6)
|
||||
|
||||
|
||||
def test_obs_order_invariance():
|
||||
"""Sheep order in the input list must not affect the observation."""
|
||||
sheep = [(3.0, 2.0), (-5.0, 1.0), (0.0, 8.0)]
|
||||
p = [False] * 3
|
||||
a = build_obs((0.0, 0.0), 0.0, sheep, p)
|
||||
b = build_obs((0.0, 0.0), 0.0, list(reversed(sheep)), list(reversed(p)))
|
||||
assert np.allclose(a, b)
|
||||
|
||||
|
||||
def test_obs_count_field_normalised_by_n_max():
|
||||
sheep = [(1.0, 1.0)] * 5
|
||||
p = [False] * 5
|
||||
obs = build_obs((0.0, 0.0), 0.0, sheep, p, n_max=10)
|
||||
assert math.isclose(obs[19], 0.5)
|
||||
|
||||
|
||||
def test_obs_polar_histogram_sums_to_one():
|
||||
sheep = [(1.0, 0.0), (-1.0, 0.0), (0.0, 1.0), (0.0, -1.0)]
|
||||
obs = build_obs((0.0, 0.0), 0.0, sheep, [False] * 4)
|
||||
assert math.isclose(float(obs[20:28].sum()), 1.0, abs_tol=1e-6)
|
||||
|
||||
|
||||
def test_obs_named_channels_closest_rearmost():
|
||||
# Channels 28..29 = (closest_to_pen - dog) / 15
|
||||
# Channels 30..31 = (rearmost - dog) / 15
|
||||
pen_x, pen_y = 11.5, -15.0
|
||||
near = (pen_x + 1.0, pen_y + 1.0)
|
||||
far = (-10.0, 10.0)
|
||||
obs = build_obs((0.0, 0.0), 0.0, [near, far], [False, False])
|
||||
tol = 1e-5
|
||||
assert math.isclose(obs[28], near[0] / 15.0, abs_tol=tol)
|
||||
assert math.isclose(obs[29], near[1] / 15.0, abs_tol=tol)
|
||||
assert math.isclose(obs[30], far[0] / 15.0, abs_tol=tol)
|
||||
assert math.isclose(obs[31], far[1] / 15.0, abs_tol=tol)
|
||||
|
||||
|
||||
def test_obs_pen_vector_zero_at_pen_entry():
|
||||
obs = build_obs((11.5, -15.0), 0.0, [(0.0, 0.0)], [False])
|
||||
assert math.isclose(obs[14], 0.0) # distance to pen
|
||||
Reference in New Issue
Block a user