Files
TIR_PROJ/tests/test_env.py
T
Johnny Fernandes a01a5c9cef Checkpoint 7
2026-05-11 12:21:51 +01:00

109 lines
3.5 KiB
Python

"""Gymnasium env: contract, determinism, reward components."""
import math
import numpy as np
import pytest
from herding.world.geometry import MAX_SHEEP, PEN_ENTRY
from herding.perception.obs import OBS_DIM
from herding.control.strombom import compute_action as strombom_action
from training.herding_env import HerdingEnv
def test_env_obs_action_shapes_single_frame():
env = HerdingEnv(n_sheep=3, seed=0, use_lidar=False)
obs, info = env.reset()
assert obs.shape == (OBS_DIM,)
assert obs.dtype == np.float32
obs, reward, term, trunc, info = env.step(
np.array([0.5, 0.0], dtype=np.float32))
assert obs.shape == (OBS_DIM,)
assert isinstance(reward, float)
assert isinstance(term, bool) and isinstance(trunc, bool)
def test_env_observation_space_matches_frame_stack():
env = HerdingEnv(n_sheep=2, seed=0, use_lidar=False, frame_stack=4)
obs, _ = env.reset()
assert obs.shape == (OBS_DIM * 4,)
assert env.observation_space.shape == (OBS_DIM * 4,)
def test_env_reset_determinism_same_seed():
a = HerdingEnv(n_sheep=3, seed=42, use_lidar=False)
b = HerdingEnv(n_sheep=3, seed=42, use_lidar=False)
obs_a, _ = a.reset(seed=42)
obs_b, _ = b.reset(seed=42)
assert np.allclose(obs_a, obs_b)
def test_env_curriculum_samples_full_range():
env = HerdingEnv(seed=0, use_lidar=False)
sizes = set()
for _ in range(40):
_, info = env.reset()
sizes.add(info["n_sheep"])
assert 1 in sizes
assert max(sizes) <= MAX_SHEEP
def test_env_step_returns_finite_values():
env = HerdingEnv(n_sheep=2, max_steps=200, seed=1, use_lidar=False)
obs, _ = env.reset()
for _ in range(200):
action = np.array([0.5, 0.5], dtype=np.float32)
obs, reward, term, trunc, _ = env.step(action)
assert np.isfinite(obs).all()
assert math.isfinite(reward)
if term or trunc:
break
def test_env_options_n_sheep_overrides_curriculum():
env = HerdingEnv(seed=0, use_lidar=False)
_, info = env.reset(options={"n_sheep": 7})
assert info["n_sheep"] == 7
def test_env_perceived_positions_lidar_vs_privileged():
env_priv = HerdingEnv(n_sheep=3, seed=0, use_lidar=False)
env_priv.reset(seed=0)
pos_priv = env_priv.perceived_positions()
assert len(pos_priv) == 3
env_lidar = HerdingEnv(n_sheep=3, seed=0, use_lidar=True)
env_lidar.reset(seed=0)
pos_lidar = env_lidar.perceived_positions()
# LiDAR mode returns whatever the tracker has — may be fewer than 3
# if sheep are out of FOV / range, but never more.
assert len(pos_lidar) <= 3
def test_env_set_time_weight_affects_reward():
env = HerdingEnv(n_sheep=1, seed=0, use_lidar=False)
env.reset(seed=0)
_, r_default, *_ = env.step(np.array([0.0, 0.0], dtype=np.float32))
env.set_time_weight(-1.0)
env.reset(seed=0)
_, r_penalised, *_ = env.step(np.array([0.0, 0.0], dtype=np.float32))
assert r_penalised < r_default
def test_env_strombom_rollout_moves_dog():
env = HerdingEnv(n_sheep=2, max_steps=400, seed=1, use_lidar=False)
env.reset()
start = (env.dog_x, env.dog_y)
for _ in range(400):
positions = env.perceived_positions()
if not positions:
break
vx, vy, _ = strombom_action(
(env.dog_x, env.dog_y), positions, PEN_ENTRY)
obs, _r, term, trunc, _ = env.step(
np.array([vx, vy], dtype=np.float32))
if term or trunc:
break
displacement = math.hypot(env.dog_x - start[0], env.dog_y - start[1])
assert displacement > 0.05