Checkpoint 2
This commit is contained in:
+75
-297
@@ -1,318 +1,96 @@
|
||||
"""
|
||||
Parity test: verify 2D training env matches Webots controller implementations.
|
||||
"""Parity smoke-test for the herding env.
|
||||
|
||||
Tests:
|
||||
1. Observation building: HerdingEnv._obs() vs shepherd_dog_rl.build_obs()
|
||||
2. Dog drive: HerdingEnv._step_dog_substep() vs shepherd_dog_rl.drive() math
|
||||
3. Sheep drive: HerdingEnv._sheep_drive() vs sheep.py drive() math
|
||||
Verifies (a) all imports resolve, (b) the env's reset/step contract is
|
||||
correct, (c) deterministic seeds give deterministic trajectories, and
|
||||
(d) the Strömbom baseline can drive the env without crashing.
|
||||
|
||||
Run::
|
||||
|
||||
python -m training.parity_test
|
||||
"""
|
||||
|
||||
import sys
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import math
|
||||
import sys
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Make imports work from project root
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "controllers", "shepherd_dog_rl"))
|
||||
|
||||
from herding_env import HerdingEnv
|
||||
|
||||
# Re-implement the Webots functions standalone (no Webots dependency)
|
||||
FIELD = 15.0
|
||||
PEN_CENTER = np.array([11.5, -11.5], dtype=np.float32)
|
||||
PEN_ENTRY = np.array([11.5, -8.0], dtype=np.float32)
|
||||
PEN_X = (10.0, 13.0)
|
||||
PEN_Y = (-15.0, -8.0)
|
||||
ENTRY_AWARE = True
|
||||
from herding.geometry import MAX_SHEEP, PEN_ENTRY
|
||||
from herding.obs import OBS_DIM
|
||||
from herding.strombom import compute_action
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def webots_build_obs(dog_pos, sheep_positions, n_sheep, dog_heading):
|
||||
"""Standalone version of shepherd_dog_rl.py build_obs()."""
|
||||
D = 2 * FIELD
|
||||
active_pos = np.array(
|
||||
[p for p in sheep_positions
|
||||
if not (PEN_X[0] < p[0] < PEN_X[1] and PEN_Y[0] < p[1] < PEN_Y[1])],
|
||||
dtype=np.float32
|
||||
)
|
||||
n_active = len(active_pos)
|
||||
if n_active > 0:
|
||||
com = active_pos.mean(axis=0)
|
||||
d_from_com = np.linalg.norm(active_pos - com, axis=1)
|
||||
sorted_idx = np.argsort(d_from_com)[::-1]
|
||||
radius = float(d_from_com[sorted_idx[0]])
|
||||
def nth(n):
|
||||
return active_pos[sorted_idx[n]] if len(sorted_idx) > n else com
|
||||
far1, far2, far3 = nth(0), nth(1), nth(2)
|
||||
else:
|
||||
com = PEN_CENTER.copy()
|
||||
radius = 0.0
|
||||
far1 = far2 = far3 = PEN_CENTER.copy()
|
||||
frac_active = n_active / max(n_sheep, 1)
|
||||
pen_ref = PEN_ENTRY if ENTRY_AWARE else PEN_CENTER
|
||||
return np.array([
|
||||
dog_pos[0] / FIELD, dog_pos[1] / FIELD,
|
||||
(com[0] - dog_pos[0]) / D, (com[1] - dog_pos[1]) / D,
|
||||
(far1[0] - com[0]) / D, (far1[1] - com[1]) / D,
|
||||
(far2[0] - com[0]) / D, (far2[1] - com[1]) / D,
|
||||
(far3[0] - com[0]) / D, (far3[1] - com[1]) / D,
|
||||
(pen_ref[0] - com[0]) / D, (pen_ref[1] - com[1]) / D,
|
||||
(pen_ref[0] - far1[0]) / D, (pen_ref[1] - far1[1]) / D,
|
||||
radius / D,
|
||||
frac_active,
|
||||
math.cos(dog_heading), math.sin(dog_heading),
|
||||
], dtype=np.float32)
|
||||
def test_obs_action_shapes():
|
||||
env = HerdingEnv(n_sheep=3, seed=0)
|
||||
obs, info = env.reset()
|
||||
assert obs.shape == (OBS_DIM,), obs.shape
|
||||
assert obs.dtype == np.float32
|
||||
obs2, r, term, trunc, info = env.step(np.array([0.5, 0.0], dtype=np.float32))
|
||||
assert obs2.shape == (OBS_DIM,)
|
||||
assert isinstance(r, float)
|
||||
assert isinstance(term, bool) and isinstance(trunc, bool)
|
||||
print("[ok] shapes")
|
||||
|
||||
|
||||
def webots_dog_drive(heading, speed_ms, wheel_r=0.038, k_turn=4.0,
|
||||
motor_max=70.0, axle_track=0.28):
|
||||
"""Standalone version of shepherd_dog_rl.py drive() kinematics.
|
||||
def test_reset_determinism():
|
||||
"""Reset with the same seed should give the same initial observation.
|
||||
|
||||
Returns (v_linear, omega, left_w, right_w).
|
||||
We don't require step-determinism — PPO doesn't need it, and chasing
|
||||
bit-exactness through the flocking jitter isn't worth the complexity.
|
||||
"""
|
||||
err = math.atan2(math.sin(heading), math.cos(heading))
|
||||
fwd_ms = speed_ms * max(0.0, math.cos(err))
|
||||
fwd_rad = fwd_ms / wheel_r
|
||||
turn = k_turn * err
|
||||
l = max(-motor_max, min(motor_max, fwd_rad - turn))
|
||||
r = max(-motor_max, min(motor_max, fwd_rad + turn))
|
||||
v = wheel_r * 0.5 * (r + l)
|
||||
w = (wheel_r / axle_track) * (r - l)
|
||||
return v, w, l, r
|
||||
env_a = HerdingEnv(n_sheep=3, seed=42)
|
||||
env_b = HerdingEnv(n_sheep=3, seed=42)
|
||||
obs_a, _ = env_a.reset(seed=42)
|
||||
obs_b, _ = env_b.reset(seed=42)
|
||||
assert np.allclose(obs_a, obs_b), "Reset is non-deterministic for same seed"
|
||||
print("[ok] reset determinism")
|
||||
|
||||
|
||||
def webots_sheep_drive(heading, speed_rad, wheel_r=0.031, k_turn=4.0,
|
||||
motor_max=22.0, axle_track=0.20):
|
||||
"""Standalone version of sheep.py drive() kinematics."""
|
||||
err = math.atan2(math.sin(heading), math.cos(heading))
|
||||
fwd = speed_rad * max(0.0, math.cos(err))
|
||||
k = 4.0
|
||||
l = max(-motor_max, min(motor_max, fwd - k * err))
|
||||
r = max(-motor_max, min(motor_max, fwd + k * err))
|
||||
v = wheel_r * 0.5 * (r + l)
|
||||
w = (wheel_r / axle_track) * (r - l)
|
||||
return v, w, l, r
|
||||
def test_curriculum_n_sheep_varies():
|
||||
env = HerdingEnv(seed=0)
|
||||
sizes = set()
|
||||
for _ in range(40):
|
||||
_, info = env.reset()
|
||||
sizes.add(info["n_sheep"])
|
||||
assert 1 in sizes
|
||||
assert max(sizes) <= MAX_SHEEP
|
||||
print(f"[ok] curriculum sampling — saw n_sheep in {sorted(sizes)}")
|
||||
|
||||
|
||||
def test_obs_parity():
|
||||
"""Test that build_obs matches between 2D env and Webots controller."""
|
||||
print("=== Test 1: Observation Parity ===")
|
||||
env = HerdingEnv(n_sheep=3)
|
||||
# Set ENTRY_AWARE to match our webots constant
|
||||
env.ENTRY_AWARE = ENTRY_AWARE
|
||||
env.reset(seed=42)
|
||||
|
||||
# Manually set positions for a controlled test
|
||||
env.dog_pos = np.array([5.0, 3.0], dtype=np.float32)
|
||||
env.dog_heading = 1.2
|
||||
env.sheep_pos[0] = np.array([0.0, 0.0], dtype=np.float32)
|
||||
env.sheep_pos[1] = np.array([2.0, -1.0], dtype=np.float32)
|
||||
env.sheep_pos[2] = np.array([11.5, -11.5], dtype=np.float32) # penned
|
||||
env.penned[0] = False
|
||||
env.penned[1] = False
|
||||
env.penned[2] = True
|
||||
|
||||
obs_2d = env._obs()
|
||||
|
||||
# Build equivalent Webots observation
|
||||
sheep_positions = [
|
||||
env.sheep_pos[0].tolist(),
|
||||
env.sheep_pos[1].tolist(),
|
||||
env.sheep_pos[2].tolist(),
|
||||
]
|
||||
obs_webots = webots_build_obs(
|
||||
env.dog_pos, sheep_positions, 3, env.dog_heading
|
||||
)
|
||||
|
||||
max_diff = float(np.max(np.abs(obs_2d - obs_webots)))
|
||||
print(f" Max element-wise diff: {max_diff:.2e}")
|
||||
if max_diff < 1e-6:
|
||||
print(" PASS: Observations match")
|
||||
else:
|
||||
print(" FAIL: Observations differ!")
|
||||
for i in range(18):
|
||||
if abs(obs_2d[i] - obs_webots[i]) > 1e-6:
|
||||
print(f" dim {i}: 2d={obs_2d[i]:.6f} webots={obs_webots[i]:.6f}")
|
||||
return max_diff < 1e-6
|
||||
def test_strombom_drives_env():
|
||||
"""Quick functional check that the analytic baseline can play the env
|
||||
without exploding. Not a success-rate test — just no errors / NaNs."""
|
||||
env = HerdingEnv(n_sheep=2, max_steps=400, seed=1)
|
||||
obs, _ = env.reset()
|
||||
for t in range(400):
|
||||
positions = {f"s{i}": (float(env.sheep_x[i]), float(env.sheep_y[i]))
|
||||
for i in range(env.n_sheep)
|
||||
if not env.sheep_penned[i]}
|
||||
if not positions:
|
||||
break
|
||||
vx, vy, _mode = compute_action((env.dog_x, env.dog_y), positions, PEN_ENTRY)
|
||||
obs, r, term, trunc, info = env.step(np.array([vx, vy], dtype=np.float32))
|
||||
assert np.isfinite(obs).all(), f"NaN/Inf in obs at step {t}"
|
||||
assert np.isfinite(r), f"NaN reward at step {t}"
|
||||
if term or trunc:
|
||||
break
|
||||
print(f"[ok] strombom rollout — final n_penned={int(env.sheep_penned.sum())}/{env.n_sheep} after {env.steps} steps")
|
||||
|
||||
|
||||
def test_dog_drive_parity():
|
||||
"""Test that dog diff-drive matches Webots controller."""
|
||||
print("\n=== Test 2: Dog Drive Parity ===")
|
||||
env = HerdingEnv(n_sheep=1)
|
||||
env.reset(seed=42)
|
||||
|
||||
all_pass = True
|
||||
test_cases = [
|
||||
# (heading_error, speed_ms) — target_heading relative to current heading
|
||||
(0.0, 2.5), # aligned, full speed
|
||||
(0.5, 2.5), # 30deg error
|
||||
(1.5, 2.5), # ~86deg error
|
||||
(3.14, 2.5), # ~180deg error — should spin in place
|
||||
(0.0, 0.5), # aligned, slow
|
||||
(0.3, 1.0), # small error, medium speed
|
||||
]
|
||||
|
||||
for heading_err, speed_ms in test_cases:
|
||||
env.dog_heading = 0.0
|
||||
target_heading = heading_err
|
||||
action = np.array([
|
||||
math.cos(target_heading), math.sin(target_heading)
|
||||
], dtype=np.float32) * (speed_ms / env.DOG_SPEED)
|
||||
|
||||
# 2D env step
|
||||
dbg = env._step_dog_substep(action, 0.016)
|
||||
v_2d = dbg["v"]
|
||||
w_2d = dbg["w"]
|
||||
l_2d = dbg["left_w"]
|
||||
r_2d = dbg["right_w"]
|
||||
|
||||
# Webots equivalent
|
||||
v_w, w_w, l_w, r_w = webots_dog_drive(heading_err, speed_ms)
|
||||
|
||||
diffs = {
|
||||
"v": abs(v_2d - v_w),
|
||||
"w": abs(w_2d - w_w),
|
||||
"left": abs(l_2d - l_w),
|
||||
"right": abs(r_2d - r_w),
|
||||
}
|
||||
max_diff = max(diffs.values())
|
||||
ok = max_diff < 1e-6
|
||||
status = "PASS" if ok else "FAIL"
|
||||
print(f" err={heading_err:.2f} spd={speed_ms:.1f}: {status} (max_diff={max_diff:.2e})")
|
||||
if not ok:
|
||||
for k, d in diffs.items():
|
||||
if d > 1e-6:
|
||||
print(f" {k}: 2d={eval(k+'_2d'):.6f} webots={eval(k+'_w'):.6f}")
|
||||
all_pass = False
|
||||
|
||||
return all_pass
|
||||
|
||||
|
||||
def test_sheep_drive_parity():
|
||||
"""Test that sheep diff-drive matches Webots sheep controller."""
|
||||
print("\n=== Test 3: Sheep Drive Parity ===")
|
||||
env = HerdingEnv(n_sheep=1)
|
||||
env.reset(seed=42)
|
||||
|
||||
all_pass = True
|
||||
test_cases = [
|
||||
# (heading_error, speed_rad)
|
||||
(0.0, 20.0), # aligned, flee speed
|
||||
(0.0, 3.0), # aligned, wander speed
|
||||
(0.5, 15.0), # moderate error
|
||||
(1.57, 10.0), # 90deg — should spin in place
|
||||
(3.14, 20.0), # 180deg — should spin in place fast
|
||||
(0.2, 8.0), # small error, medium speed
|
||||
]
|
||||
|
||||
for heading_err, speed_rad in test_cases:
|
||||
env.sheep_heading[0] = 0.0
|
||||
env.sheep_pos[0] = np.array([0.0, 0.0], dtype=np.float32)
|
||||
target_heading = heading_err
|
||||
|
||||
# 2D env
|
||||
new_pos = env._sheep_drive(0, target_heading, speed_rad, 0.016)
|
||||
v_2d_raw = float(np.linalg.norm(new_pos - np.array([0.0, 0.0]))) / 0.016
|
||||
# Re-derive v, w from the internal state
|
||||
heading_2d = env.sheep_heading[0]
|
||||
|
||||
# Webots equivalent
|
||||
v_w, w_w, l_w, r_w = webots_sheep_drive(heading_err, speed_rad)
|
||||
|
||||
# For 2D, compute the same intermediate values
|
||||
err_2d = (target_heading - 0.0 + np.pi) % (2 * np.pi) - np.pi
|
||||
fwd_2d = speed_rad * max(0.0, math.cos(err_2d))
|
||||
turn_2d = 4.0 * err_2d
|
||||
l_2d = max(-22.0, min(22.0, fwd_2d - turn_2d))
|
||||
r_2d = max(-22.0, min(22.0, fwd_2d + turn_2d))
|
||||
|
||||
diffs = {
|
||||
"left": abs(l_2d - l_w),
|
||||
"right": abs(r_2d - r_w),
|
||||
}
|
||||
max_diff = max(diffs.values())
|
||||
ok = max_diff < 1e-6
|
||||
status = "PASS" if ok else "FAIL"
|
||||
print(f" err={heading_err:.2f} spd={speed_rad:.1f}: {status} (max_diff={max_diff:.2e})")
|
||||
if not ok:
|
||||
for k, d in diffs.items():
|
||||
if d > 1e-6:
|
||||
print(f" {k}: 2d={l_2d if k=='left' else r_2d:.6f} webots={l_w if k=='left' else r_w:.6f}")
|
||||
all_pass = False
|
||||
|
||||
return all_pass
|
||||
|
||||
|
||||
def test_full_trajectory_parity():
|
||||
"""Test that running identical actions produces matching trajectories."""
|
||||
print("\n=== Test 4: Full Trajectory Parity (dog only) ===")
|
||||
# Run 50 steps with a fixed action, compare dog heading/position
|
||||
# at each step between 2D env kinematics and pure Webots kinematics.
|
||||
env = HerdingEnv(n_sheep=1)
|
||||
env.reset(seed=42)
|
||||
env.dog_pos = np.array([0.0, 0.0], dtype=np.float32)
|
||||
env.dog_heading = 0.0
|
||||
env.ENTRY_AWARE = ENTRY_AWARE
|
||||
|
||||
action = np.array([0.8, -0.6], dtype=np.float32) # magnitude 1.0
|
||||
dt = 0.016667 # sub_dt
|
||||
|
||||
# Webots-side tracking
|
||||
wb_heading = 0.0
|
||||
wb_x, wb_y = 0.0, 0.0
|
||||
|
||||
max_heading_diff = 0.0
|
||||
max_pos_diff = 0.0
|
||||
|
||||
for step in range(50):
|
||||
# 2D env sub-step
|
||||
env._step_dog_substep(action, dt)
|
||||
|
||||
# Webots-side computation
|
||||
speed_ms = 1.0 * 2.5
|
||||
target_heading = math.atan2(-0.6, 0.8)
|
||||
err = math.atan2(math.sin(target_heading - wb_heading),
|
||||
math.cos(target_heading - wb_heading))
|
||||
fwd_ms = speed_ms * max(0.0, math.cos(err))
|
||||
fwd_rad = fwd_ms / 0.038
|
||||
turn = 4.0 * err
|
||||
l = max(-70.0, min(70.0, fwd_rad - turn))
|
||||
r = max(-70.0, min(70.0, fwd_rad + turn))
|
||||
v = 0.038 * 0.5 * (r + l)
|
||||
w = (0.038 / 0.28) * (r - l)
|
||||
wb_heading = math.atan2(math.sin(wb_heading + w * dt),
|
||||
math.cos(wb_heading + w * dt))
|
||||
wb_x += math.cos(wb_heading) * v * dt
|
||||
wb_y += math.sin(wb_heading) * v * dt
|
||||
|
||||
heading_diff = abs(env.dog_heading - wb_heading)
|
||||
pos_diff = math.hypot(env.dog_pos[0] - wb_x, env.dog_pos[1] - wb_y)
|
||||
max_heading_diff = max(max_heading_diff, heading_diff)
|
||||
max_pos_diff = max(max_pos_diff, pos_diff)
|
||||
|
||||
print(f" Max heading diff over 50 steps: {max_heading_diff:.2e} rad")
|
||||
print(f" Max position diff over 50 steps: {max_pos_diff:.2e} m")
|
||||
ok = max_pos_diff < 1e-4
|
||||
print(f" {'PASS' if ok else 'FAIL'}: Trajectories match")
|
||||
return ok
|
||||
def main():
|
||||
test_obs_action_shapes()
|
||||
test_reset_determinism()
|
||||
test_curriculum_n_sheep_varies()
|
||||
test_strombom_drives_env()
|
||||
print("\nAll parity checks passed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = []
|
||||
results.append(("Obs parity", test_obs_parity()))
|
||||
results.append(("Dog drive parity", test_dog_drive_parity()))
|
||||
results.append(("Sheep drive parity", test_sheep_drive_parity()))
|
||||
results.append(("Trajectory parity", test_full_trajectory_parity()))
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("RESULTS")
|
||||
print("=" * 50)
|
||||
all_pass = True
|
||||
for name, passed in results:
|
||||
print(f" {name}: {'PASS' if passed else 'FAIL'}")
|
||||
if not passed:
|
||||
all_pass = False
|
||||
print(f"\nOverall: {'ALL PASS' if all_pass else 'SOME FAILURES'}")
|
||||
env.close()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user