Files
TIR_PROJ/herding/perception/sheep_tracker.py
T
Johnny Fernandes fce0e0c786 Checkpoint 6
2026-05-11 10:35:48 +01:00

198 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Multi-target tracker for LiDAR-detected sheep.
Greedy nearest-neighbour data association (with a distance gate) across
frames, plus a memory of last-seen positions for tracks that fall out
of the dog's FOV. Output is a ``{name: (x, y)}`` dict shaped exactly
like the receiver-based ``sheep_positions`` used previously by the
Webots controller and by the env, so Strömbom and Sequential can
consume it unchanged.
Penned-detection heuristic
--------------------------
Two ways a track is marked penned:
1. Its current estimated position is south of the gate plane and
within the gate column (the ``is_penned_position`` test the env
already uses on ground truth).
2. It hasn't been observed for ``STALE_STEPS`` and its last-seen
position was inside the gate-approach band — the dog's LiDAR can
only see ~2 m into the pen through the open gate, so a sheep
that disappeared near the entry has almost certainly entered.
Tracks marked penned are excluded from ``get_positions()`` (which is
what Strömbom consumes), matching the prior receiver-based behaviour.
"""
from __future__ import annotations
import math
from herding.world.geometry import MAX_SHEEP, in_pen, is_penned_position
GATE_M = 2.5 # m — primary NN gate (recent tracks)
REACQUIRE_GATE_M = 4.5 # m — wider gate for re-acquiring stale tracks (sheep moved during occlusion)
REACQUIRE_MIN_AGE = 20 # steps — only rebind via the wide gate if the track has been stale for this long
PENNED_GATE_M = 4.0 # m — wide gate for matching against already-penned tracks; the pen is small (3×7 m) so duplicates are easy without it
FORGET_STEPS = 200 # ~3.2 s — delete stale active tracks; tighter than 5 s to limit phantoms but long enough to bridge typical FOV gaps
MAX_ACTIVE_TRACKS = MAX_SHEEP # hard cap to the worst-case real flock size
# Penned tracks are never forgotten: sheep don't leave the pen, and
# losing the track makes the counter oscillate as the same sheep gets
# re-detected and counted multiple times.
class SheepTracker:
"""Online tracker with NN association and a forgetful memory.
Each track stores ``(x, y, last_seen_step, penned)``.
"""
def __init__(self, gate: float = GATE_M):
self.gate = gate
# tid → (x, y, last_seen_step, penned)
self._tracks: dict[int, tuple[float, float, int, bool]] = {}
self._next_id = 0
self.step = 0
def reset(self) -> None:
self._tracks.clear()
self._next_id = 0
self.step = 0
# ------------------------------------------------------------------
# Update
# ------------------------------------------------------------------
def update(self, detections: list[tuple[float, float]]) -> dict[str, tuple[float, float]]:
"""Fold a new set of detections in and return active positions."""
self.step += 1
det_used: set[int] = set()
updated_tids: set[int] = set()
# Pass 1: match against ACTIVE tracks first (oldest-seen-first so
# a re-emerging long-lost sheep grabs its old ID before a fresh
# neighbour does).
active_tids = [tid for tid, t in self._tracks.items() if not t[3]]
active_tids.sort(key=lambda tid: self._tracks[tid][2])
for tid in active_tids:
tx, ty, _, _ = self._tracks[tid]
best_j, best_d = -1, self.gate
for j, (dx, dy) in enumerate(detections):
if j in det_used:
continue
d = math.hypot(dx - tx, dy - ty)
if d < best_d:
best_d = d
best_j = j
if best_j >= 0:
dx, dy = detections[best_j]
self._tracks[tid] = (dx, dy, self.step, False)
det_used.add(best_j)
updated_tids.add(tid)
# Pass 1b: re-acquisition with a wider gate for tracks that have
# been stale for ≥ REACQUIRE_MIN_AGE steps. Sheep flee at
# ~0.6 m/s; over a 12 s occlusion (dog rotating or driving)
# they move enough that a fresh detection lies outside the
# primary GATE_M but is still clearly the same sheep. Without
# this, phantom tracks accumulate and corrupt the CoM.
for tid in active_tids:
if tid in updated_tids:
continue
tx, ty, last, _ = self._tracks[tid]
if (self.step - last) < REACQUIRE_MIN_AGE:
continue
best_j, best_d = -1, REACQUIRE_GATE_M
for j, (dx, dy) in enumerate(detections):
if j in det_used:
continue
d = math.hypot(dx - tx, dy - ty)
if d < best_d:
best_d = d
best_j = j
if best_j >= 0:
dx, dy = detections[best_j]
self._tracks[tid] = (dx, dy, self.step, False)
det_used.add(best_j)
updated_tids.add(tid)
# Pass 2: match remaining detections against PENNED tracks with
# a tighter gate. Without this, every frame near the gate spawns
# a fresh penned track for the same sheep, which under a long
# Webots run leads to thousands of phantom penned tracks.
penned_tids = [tid for tid, t in self._tracks.items() if t[3]]
for tid in penned_tids:
tx, ty, _, _ = self._tracks[tid]
best_j, best_d = -1, PENNED_GATE_M
for j, (dx, dy) in enumerate(detections):
if j in det_used:
continue
d = math.hypot(dx - tx, dy - ty)
if d < best_d:
best_d = d
best_j = j
if best_j >= 0:
dx, dy = detections[best_j]
self._tracks[tid] = (dx, dy, self.step, True)
det_used.add(best_j)
# Unmatched detections → new tracks. A detection that is already
# inside the pen is born "penned" so we don't accumulate active
# tracks for sheep that arrived in the pen during occlusion.
for j, (dx, dy) in enumerate(detections):
if j in det_used:
continue
penned = in_pen(dx, dy) or is_penned_position(dx, dy)
self._tracks[self._next_id] = (dx, dy, self.step, penned)
self._next_id += 1
# Promote active tracks to penned ONLY by geometric position
# (sheep is in the pen column south of the gate). The previous
# "stale + near gate" heuristic was firing on ordinary occlusion
# near the gate and creating phantom penned tracks.
for tid, (tx, ty, last, penned) in list(self._tracks.items()):
if penned:
continue
if is_penned_position(tx, ty):
self._tracks[tid] = (tx, ty, last, True)
# Forget stale ACTIVE tracks after FORGET_STEPS. Penned tracks
# are kept indefinitely — sheep can't escape the pen, so once a
# track is marked penned, that sheep is permanently penned.
for tid, (tx, ty, last, penned) in list(self._tracks.items()):
if penned:
continue
if (self.step - last) > FORGET_STEPS:
del self._tracks[tid]
# Hard cap on the active set. If we somehow have more than
# MAX_ACTIVE_TRACKS active tracks, drop the oldest-seen ones
# first — they are most likely false positives from world
# geometry (walls, gate posts) the env's raycaster doesn't
# model, and a bloated active set wrecks the downstream CoM.
active = [(tid, last) for tid, (_, _, last, p) in self._tracks.items()
if not p]
if len(active) > MAX_ACTIVE_TRACKS:
active.sort(key=lambda kv: kv[1]) # oldest-seen first
for tid, _ in active[: len(active) - MAX_ACTIVE_TRACKS]:
del self._tracks[tid]
return self.get_positions()
# ------------------------------------------------------------------
# Outputs
# ------------------------------------------------------------------
def get_positions(self) -> dict[str, tuple[float, float]]:
"""Active (not-yet-penned) tracks. Same shape as receiver dict."""
return {f"t{tid}": (x, y)
for tid, (x, y, _, penned) in self._tracks.items()
if not penned}
def get_penned_set(self) -> set[str]:
return {f"t{tid}" for tid, (_, _, _, penned) in self._tracks.items() if penned}
def n_active(self) -> int:
return sum(1 for _, _, _, penned in self._tracks.values() if not penned)
def n_penned(self) -> int:
return sum(1 for _, _, _, penned in self._tracks.values() if penned)