1c197e0ff7
Two changes that together raise diff/round gym success ~52%→88% (BC)
and ~68%→88% (RL) without retraining; diff/field stays at 100%.
* TrackerConfig.consensus_k default 1 → 3 (radius 0.5 m, max_age 15
frames). The same candidate-promotion mechanism that closed the
Webots LiDAR gap also filters gym tracker phantoms — they show up
on the round field where sheep run further between detection
cycles than GATE_M, so each new position spawns a fresh track
while the stale one persists in memory. SheepTracker() called with
no tracker_cfg keeps the legacy pass-through behaviour for
backwards compatibility.
* Strömbom + universal teachers now detect when the natural
"behind the flock" drive target leaves the curved boundary and
fall back to pushing the flock radially inward toward the centre.
Breaks the wall-circling pattern that previously trapped both the
analytical baselines and the trained policies.
A/B numbers (n_sheep ∈ {1,2,3,5,10}, 5 seeds each, max_steps=15000):
diff/field bc: baseline 100% consensus 100%
diff/field rl: baseline 100% consensus 100%
diff/round bc: baseline 52% consensus 88%
diff/round rl: baseline 68% consensus 88%
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
268 lines
9.8 KiB
Python
268 lines
9.8 KiB
Python
"""Tests for herding/config.py — dataclass construction, defaults, overrides."""
|
|
|
|
import math
|
|
import pytest
|
|
|
|
from herding.config import (
|
|
DetectionConfig,
|
|
DomainRandomConfig,
|
|
HerdingConfig,
|
|
HERDING_DEFAULT,
|
|
HERDING_WEBOTS,
|
|
LidarConfig,
|
|
LIDAR_FULL,
|
|
LIDAR_WEBOTS,
|
|
RobotConfig,
|
|
TrackerConfig,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LidarConfig
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestLidarConfig:
|
|
def test_defaults_match_full_circle_preset(self):
|
|
assert LidarConfig() == LIDAR_FULL
|
|
|
|
def test_webots_preset(self):
|
|
assert LIDAR_WEBOTS.n_rays == 180
|
|
assert abs(LIDAR_WEBOTS.fov_rad - math.radians(140.0)) < 1e-9
|
|
|
|
def test_frozen(self):
|
|
cfg = LidarConfig()
|
|
with pytest.raises((AttributeError, TypeError)):
|
|
cfg.n_rays = 42 # type: ignore[misc]
|
|
|
|
def test_invalid_n_rays(self):
|
|
with pytest.raises(ValueError):
|
|
LidarConfig(n_rays=0)
|
|
|
|
def test_invalid_fov(self):
|
|
with pytest.raises(ValueError):
|
|
LidarConfig(fov_rad=0.0)
|
|
with pytest.raises(ValueError):
|
|
LidarConfig(fov_rad=math.pi * 3)
|
|
|
|
def test_invalid_max_range(self):
|
|
with pytest.raises(ValueError):
|
|
LidarConfig(max_range=-1.0)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TrackerConfig
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestTrackerConfig:
|
|
def test_defaults(self):
|
|
cfg = TrackerConfig()
|
|
assert cfg.forget_steps == 200
|
|
assert cfg.max_new_tracks_per_step == 10
|
|
|
|
def test_webots_preset_tighter(self):
|
|
cfg = HERDING_WEBOTS.tracker
|
|
# forget_steps was extended so confirmed sheep tracks survive
|
|
# sparse 140° FOV re-sightings; consensus blocks phantoms from
|
|
# reaching this lifetime.
|
|
assert cfg.forget_steps >= 200
|
|
assert cfg.max_new_tracks_per_step == 1
|
|
assert cfg.pen_latch_depth == 2.0
|
|
|
|
def test_default_consensus_enabled(self):
|
|
# Consensus is on by default — it filters tracker phantoms that
|
|
# confused the policy on the round field (52% → 88%) at no cost
|
|
# on the rectangular field (100% → 100%). Pass-through (k=1) is
|
|
# still available by explicitly constructing TrackerConfig(consensus_k=1).
|
|
cfg = TrackerConfig()
|
|
assert cfg.consensus_k >= 2
|
|
assert cfg.consensus_radius_m > 0.0
|
|
assert cfg.consensus_max_age > cfg.consensus_k
|
|
|
|
def test_webots_preset_enables_consensus(self):
|
|
cfg = HERDING_WEBOTS.tracker
|
|
assert cfg.consensus_k > 1
|
|
assert cfg.consensus_radius_m > 0.0
|
|
assert cfg.consensus_max_age >= cfg.consensus_k
|
|
|
|
def test_invalid_forget_steps(self):
|
|
with pytest.raises(ValueError):
|
|
TrackerConfig(forget_steps=0)
|
|
|
|
def test_invalid_max_new_tracks(self):
|
|
with pytest.raises(ValueError):
|
|
TrackerConfig(max_new_tracks_per_step=0)
|
|
|
|
def test_invalid_consensus_params(self):
|
|
with pytest.raises(ValueError):
|
|
TrackerConfig(consensus_k=0)
|
|
with pytest.raises(ValueError):
|
|
TrackerConfig(consensus_radius_m=0.0)
|
|
with pytest.raises(ValueError):
|
|
TrackerConfig(consensus_max_age=0)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DetectionConfig
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestDetectionConfig:
|
|
def test_defaults(self):
|
|
cfg = DetectionConfig()
|
|
assert cfg.wall_reject == 0.5
|
|
|
|
def test_webots_preset_wall_reject(self):
|
|
# wall_reject stays at 0.5 m — 1.0 m was too aggressive near the south gate
|
|
cfg = HERDING_WEBOTS.detection
|
|
assert cfg.wall_reject == 0.5
|
|
|
|
def test_invalid_wall_reject(self):
|
|
with pytest.raises(ValueError):
|
|
DetectionConfig(wall_reject=-0.1)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RobotConfig
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRobotConfig:
|
|
def test_max_linear_derived(self):
|
|
cfg = RobotConfig()
|
|
assert abs(cfg.max_linear - cfg.wheel_radius * cfg.max_wheel_omega) < 1e-9
|
|
|
|
def test_default_action_smooth_zero(self):
|
|
assert RobotConfig().action_smooth == 0.0
|
|
|
|
def test_webots_action_smooth(self):
|
|
assert HERDING_WEBOTS.robot.action_smooth == 0.55
|
|
|
|
def test_invalid_action_smooth(self):
|
|
with pytest.raises(ValueError):
|
|
RobotConfig(action_smooth=1.0)
|
|
with pytest.raises(ValueError):
|
|
RobotConfig(action_smooth=-0.1)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DomainRandomConfig
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestDomainRandomConfig:
|
|
def test_all_zeros_by_default(self):
|
|
cfg = DomainRandomConfig()
|
|
assert cfg.fp_rate == 0.0
|
|
assert cfg.wheel_slip_std == 0.0
|
|
assert cfg.compass_noise_std == 0.0
|
|
|
|
def test_invalid_fp_rate(self):
|
|
with pytest.raises(ValueError):
|
|
DomainRandomConfig(fp_rate=-1.0)
|
|
|
|
def test_invalid_slip_std(self):
|
|
with pytest.raises(ValueError):
|
|
DomainRandomConfig(wheel_slip_std=-0.01)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HerdingConfig
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestHerdingConfig:
|
|
def test_default_is_herding_default(self):
|
|
assert HerdingConfig() == HERDING_DEFAULT
|
|
|
|
def test_replace_sub_config(self):
|
|
new_cfg = HERDING_WEBOTS.replace(
|
|
domain_random=DomainRandomConfig(fp_rate=2.0)
|
|
)
|
|
assert new_cfg.domain_random.fp_rate == 2.0
|
|
# Other sub-configs unchanged
|
|
assert new_cfg.tracker == HERDING_WEBOTS.tracker
|
|
assert new_cfg.lidar == HERDING_WEBOTS.lidar
|
|
|
|
def test_herding_default_matches_original_module_constants(self):
|
|
"""Verify the default config reproduces the original hardcoded values."""
|
|
from herding.perception.lidar_sim import (
|
|
LIDAR_N_RAYS, LIDAR_FOV, LIDAR_MAX_RANGE, LIDAR_NOISE,
|
|
SHEEP_RADIUS, POST_RADIUS,
|
|
)
|
|
from herding.perception.lidar_perception import (
|
|
GAP_THRESHOLD, MAX_CLUSTER_SPAN, RANGE_HIT_EPS,
|
|
SPLIT_RANGE_GAP, WALL_REJECT, STATIC_REJECT,
|
|
)
|
|
from herding.perception.sheep_tracker import (
|
|
GATE_M, REACQUIRE_GATE_M, REACQUIRE_MIN_AGE, PENNED_GATE_M,
|
|
FORGET_STEPS, PREDICT_STEPS, VELOCITY_CLAMP,
|
|
)
|
|
cfg = HERDING_DEFAULT
|
|
assert cfg.lidar.n_rays == LIDAR_N_RAYS
|
|
assert cfg.lidar.fov_rad == LIDAR_FOV
|
|
assert cfg.lidar.max_range == LIDAR_MAX_RANGE
|
|
assert cfg.lidar.noise_std == LIDAR_NOISE
|
|
assert cfg.lidar.sheep_radius == SHEEP_RADIUS
|
|
assert cfg.lidar.post_radius == POST_RADIUS
|
|
assert cfg.detection.gap_threshold == GAP_THRESHOLD
|
|
assert cfg.detection.max_cluster_span == MAX_CLUSTER_SPAN
|
|
assert cfg.detection.range_hit_eps == RANGE_HIT_EPS
|
|
assert cfg.detection.split_range_gap == SPLIT_RANGE_GAP
|
|
assert cfg.detection.wall_reject == WALL_REJECT
|
|
assert cfg.detection.static_reject == STATIC_REJECT
|
|
assert cfg.tracker.gate_m == GATE_M
|
|
assert cfg.tracker.reacquire_gate_m == REACQUIRE_GATE_M
|
|
assert cfg.tracker.reacquire_min_age == REACQUIRE_MIN_AGE
|
|
assert cfg.tracker.penned_gate_m == PENNED_GATE_M
|
|
assert cfg.tracker.forget_steps == FORGET_STEPS
|
|
assert cfg.tracker.predict_steps == PREDICT_STEPS
|
|
assert cfg.tracker.velocity_clamp == VELOCITY_CLAMP
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration: HerdingEnv honours the config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestHerdingEnvConfig:
|
|
def test_default_env_unchanged(self):
|
|
"""HerdingEnv() still works with no config — zero behaviour change."""
|
|
from training.herding_env import HerdingEnv
|
|
env = HerdingEnv(n_sheep=1, max_steps=5, difficulty=1.0, seed=0)
|
|
obs, info = env.reset()
|
|
assert obs.shape == (32,)
|
|
obs2, *_ = env.step(env.action_space.sample())
|
|
assert obs2.shape == (32,)
|
|
|
|
def test_webots_config_propagates_action_smooth(self):
|
|
from training.herding_env import HerdingEnv
|
|
env = HerdingEnv(herding_cfg=HERDING_WEBOTS)
|
|
assert env.ACTION_SMOOTH == 0.55
|
|
|
|
def test_webots_config_runs(self):
|
|
from training.herding_env import HerdingEnv
|
|
env = HerdingEnv(
|
|
n_sheep=2, max_steps=10, difficulty=1.0, seed=42,
|
|
herding_cfg=HERDING_WEBOTS,
|
|
)
|
|
obs, _ = env.reset()
|
|
for _ in range(5):
|
|
obs, _, terminated, truncated, _ = env.step(env.action_space.sample())
|
|
assert obs.shape == (32,)
|
|
|
|
def test_domain_random_fp_runs(self):
|
|
from training.herding_env import HerdingEnv
|
|
cfg = HERDING_WEBOTS.replace(
|
|
domain_random=DomainRandomConfig(fp_rate=3.0, fp_std_pos=0.2)
|
|
)
|
|
env = HerdingEnv(n_sheep=2, max_steps=10, difficulty=1.0, seed=7, herding_cfg=cfg)
|
|
env.reset()
|
|
for _ in range(5):
|
|
env.step(env.action_space.sample())
|
|
|
|
def test_domain_random_slip_runs(self):
|
|
from training.herding_env import HerdingEnv
|
|
cfg = HERDING_WEBOTS.replace(
|
|
domain_random=DomainRandomConfig(wheel_slip_std=0.05, compass_noise_std=0.02)
|
|
)
|
|
env = HerdingEnv(n_sheep=1, max_steps=10, difficulty=1.0, seed=3,
|
|
drive_mode="mecanum", herding_cfg=cfg)
|
|
env.reset()
|
|
for _ in range(5):
|
|
env.step(env.action_space.sample())
|