"""Multi-target tracker for LiDAR-detected sheep. Greedy nearest-neighbour data association across frames, with a wider re-acquisition gate for stale tracks (sheep flee during occlusion and reappear off-position), plus memory of last-seen positions for sheep out of FOV. Output is ``{name: (x, y)}`` — Strömbom / Sequential consume it directly. A track is marked penned once its estimated position crosses the gate plane south (``is_penned_position``). Penned tracks are excluded from ``get_positions`` and kept indefinitely. """ from __future__ import annotations import math from herding.world.geometry import MAX_SHEEP, in_pen, is_penned_position GATE_M = 2.5 # m — primary NN gate (recently observed tracks) REACQUIRE_GATE_M = 4.5 # m — wider gate for re-binding stale tracks REACQUIRE_MIN_AGE = 20 # steps — track must be this stale to use the wider gate PENNED_GATE_M = 4.0 # m — gate for matching detections to existing penned tracks FORGET_STEPS = 200 # ~3.2 s — delete stale active tracks (penned ones kept forever) MAX_ACTIVE_TRACKS = MAX_SHEEP class SheepTracker: """Online tracker with NN association and forgetful memory. Each track stores ``(x, y, last_seen_step, penned)``. """ def __init__(self, gate: float = GATE_M): self.gate = gate self._tracks: dict[int, tuple[float, float, int, bool]] = {} self._next_id = 0 self.step = 0 def reset(self) -> None: self._tracks.clear() self._next_id = 0 self.step = 0 def update(self, detections: list[tuple[float, float]]) -> dict[str, tuple[float, float]]: """Fold a new set of detections in and return active positions.""" self.step += 1 det_used: set[int] = set() updated_tids: set[int] = set() # Pass 1 — match active tracks within the primary gate. Oldest- # seen tracks bind first so a re-emerging long-lost sheep keeps # its old ID instead of being grabbed by a fresh neighbour. active_tids = [tid for tid, t in self._tracks.items() if not t[3]] active_tids.sort(key=lambda tid: self._tracks[tid][2]) for tid in active_tids: tx, ty, _, _ = self._tracks[tid] best_j, best_d = -1, self.gate for j, (dx, dy) in enumerate(detections): if j in det_used: continue d = math.hypot(dx - tx, dy - ty) if d < best_d: best_d = d best_j = j if best_j >= 0: dx, dy = detections[best_j] self._tracks[tid] = (dx, dy, self.step, False) det_used.add(best_j) updated_tids.add(tid) # Pass 1b — re-acquisition. Sheep flee at ~0.6 m/s, so over a # 1–2 s occlusion the same sheep may reappear outside the primary # gate. Allow rebinding within a wider gate for stale-enough # tracks; otherwise phantom tracks accumulate and corrupt CoM. for tid in active_tids: if tid in updated_tids: continue tx, ty, last, _ = self._tracks[tid] if (self.step - last) < REACQUIRE_MIN_AGE: continue best_j, best_d = -1, REACQUIRE_GATE_M for j, (dx, dy) in enumerate(detections): if j in det_used: continue d = math.hypot(dx - tx, dy - ty) if d < best_d: best_d = d best_j = j if best_j >= 0: dx, dy = detections[best_j] self._tracks[tid] = (dx, dy, self.step, False) det_used.add(best_j) updated_tids.add(tid) # Pass 2 — match remaining detections to penned tracks. penned_tids = [tid for tid, t in self._tracks.items() if t[3]] for tid in penned_tids: tx, ty, _, _ = self._tracks[tid] best_j, best_d = -1, PENNED_GATE_M for j, (dx, dy) in enumerate(detections): if j in det_used: continue d = math.hypot(dx - tx, dy - ty) if d < best_d: best_d = d best_j = j if best_j >= 0: dx, dy = detections[best_j] self._tracks[tid] = (dx, dy, self.step, True) det_used.add(best_j) # Spawn new tracks for unmatched detections. Born "penned" if # the detection already sits inside the pen geometry. for j, (dx, dy) in enumerate(detections): if j in det_used: continue penned = in_pen(dx, dy) or is_penned_position(dx, dy) self._tracks[self._next_id] = (dx, dy, self.step, penned) self._next_id += 1 # Promote active tracks whose current estimate crosses the gate. for tid, (tx, ty, last, penned) in list(self._tracks.items()): if penned: continue if is_penned_position(tx, ty): self._tracks[tid] = (tx, ty, last, True) # Forget stale active tracks; penned tracks live forever. for tid, (tx, ty, last, penned) in list(self._tracks.items()): if penned: continue if (self.step - last) > FORGET_STEPS: del self._tracks[tid] # Hard cap on the active set — drop the oldest-seen overflow. active = [(tid, last) for tid, (_, _, last, p) in self._tracks.items() if not p] if len(active) > MAX_ACTIVE_TRACKS: active.sort(key=lambda kv: kv[1]) for tid, _ in active[: len(active) - MAX_ACTIVE_TRACKS]: del self._tracks[tid] return self.get_positions() def get_positions(self) -> dict[str, tuple[float, float]]: """Active (not-penned) tracks as a ``{name: (x, y)}`` dict.""" return {f"t{tid}": (x, y) for tid, (x, y, _, penned) in self._tracks.items() if not penned} def get_penned_set(self) -> set[str]: return {f"t{tid}" for tid, (_, _, _, penned) in self._tracks.items() if penned} def n_active(self) -> int: return sum(1 for _, _, _, penned in self._tracks.values() if not penned) def n_penned(self) -> int: return sum(1 for _, _, _, penned in self._tracks.values() if penned)