Files
TIR_PROJ/herding/perception/sheep_tracker.py
T
Johnny Fernandes a01a5c9cef Checkpoint 7
2026-05-11 12:21:51 +01:00

162 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Multi-target tracker for LiDAR-detected sheep.
Greedy nearest-neighbour data association across frames, with a wider
re-acquisition gate for stale tracks (sheep flee during occlusion and
reappear off-position), plus memory of last-seen positions for sheep
out of FOV. Output is ``{name: (x, y)}`` — Strömbom / Sequential
consume it directly.
A track is marked penned once its estimated position crosses the gate
plane south (``is_penned_position``). Penned tracks are excluded from
``get_positions`` and kept indefinitely.
"""
from __future__ import annotations
import math
from herding.world.geometry import MAX_SHEEP, in_pen, is_penned_position
GATE_M = 2.5 # m — primary NN gate (recently observed tracks)
REACQUIRE_GATE_M = 4.5 # m — wider gate for re-binding stale tracks
REACQUIRE_MIN_AGE = 20 # steps — track must be this stale to use the wider gate
PENNED_GATE_M = 4.0 # m — gate for matching detections to existing penned tracks
FORGET_STEPS = 200 # ~3.2 s — delete stale active tracks (penned ones kept forever)
MAX_ACTIVE_TRACKS = MAX_SHEEP
class SheepTracker:
"""Online tracker with NN association and forgetful memory.
Each track stores ``(x, y, last_seen_step, penned)``.
"""
def __init__(self, gate: float = GATE_M):
self.gate = gate
self._tracks: dict[int, tuple[float, float, int, bool]] = {}
self._next_id = 0
self.step = 0
def reset(self) -> None:
self._tracks.clear()
self._next_id = 0
self.step = 0
def update(self, detections: list[tuple[float, float]]) -> dict[str, tuple[float, float]]:
"""Fold a new set of detections in and return active positions."""
self.step += 1
det_used: set[int] = set()
updated_tids: set[int] = set()
# Pass 1 — match active tracks within the primary gate. Oldest-
# seen tracks bind first so a re-emerging long-lost sheep keeps
# its old ID instead of being grabbed by a fresh neighbour.
active_tids = [tid for tid, t in self._tracks.items() if not t[3]]
active_tids.sort(key=lambda tid: self._tracks[tid][2])
for tid in active_tids:
tx, ty, _, _ = self._tracks[tid]
best_j, best_d = -1, self.gate
for j, (dx, dy) in enumerate(detections):
if j in det_used:
continue
d = math.hypot(dx - tx, dy - ty)
if d < best_d:
best_d = d
best_j = j
if best_j >= 0:
dx, dy = detections[best_j]
self._tracks[tid] = (dx, dy, self.step, False)
det_used.add(best_j)
updated_tids.add(tid)
# Pass 1b — re-acquisition. Sheep flee at ~0.6 m/s, so over a
# 12 s occlusion the same sheep may reappear outside the primary
# gate. Allow rebinding within a wider gate for stale-enough
# tracks; otherwise phantom tracks accumulate and corrupt CoM.
for tid in active_tids:
if tid in updated_tids:
continue
tx, ty, last, _ = self._tracks[tid]
if (self.step - last) < REACQUIRE_MIN_AGE:
continue
best_j, best_d = -1, REACQUIRE_GATE_M
for j, (dx, dy) in enumerate(detections):
if j in det_used:
continue
d = math.hypot(dx - tx, dy - ty)
if d < best_d:
best_d = d
best_j = j
if best_j >= 0:
dx, dy = detections[best_j]
self._tracks[tid] = (dx, dy, self.step, False)
det_used.add(best_j)
updated_tids.add(tid)
# Pass 2 — match remaining detections to penned tracks.
penned_tids = [tid for tid, t in self._tracks.items() if t[3]]
for tid in penned_tids:
tx, ty, _, _ = self._tracks[tid]
best_j, best_d = -1, PENNED_GATE_M
for j, (dx, dy) in enumerate(detections):
if j in det_used:
continue
d = math.hypot(dx - tx, dy - ty)
if d < best_d:
best_d = d
best_j = j
if best_j >= 0:
dx, dy = detections[best_j]
self._tracks[tid] = (dx, dy, self.step, True)
det_used.add(best_j)
# Spawn new tracks for unmatched detections. Born "penned" if
# the detection already sits inside the pen geometry.
for j, (dx, dy) in enumerate(detections):
if j in det_used:
continue
penned = in_pen(dx, dy) or is_penned_position(dx, dy)
self._tracks[self._next_id] = (dx, dy, self.step, penned)
self._next_id += 1
# Promote active tracks whose current estimate crosses the gate.
for tid, (tx, ty, last, penned) in list(self._tracks.items()):
if penned:
continue
if is_penned_position(tx, ty):
self._tracks[tid] = (tx, ty, last, True)
# Forget stale active tracks; penned tracks live forever.
for tid, (tx, ty, last, penned) in list(self._tracks.items()):
if penned:
continue
if (self.step - last) > FORGET_STEPS:
del self._tracks[tid]
# Hard cap on the active set — drop the oldest-seen overflow.
active = [(tid, last) for tid, (_, _, last, p) in self._tracks.items()
if not p]
if len(active) > MAX_ACTIVE_TRACKS:
active.sort(key=lambda kv: kv[1])
for tid, _ in active[: len(active) - MAX_ACTIVE_TRACKS]:
del self._tracks[tid]
return self.get_positions()
def get_positions(self) -> dict[str, tuple[float, float]]:
"""Active (not-penned) tracks as a ``{name: (x, y)}`` dict."""
return {f"t{tid}": (x, y)
for tid, (x, y, _, penned) in self._tracks.items()
if not penned}
def get_penned_set(self) -> set[str]:
return {f"t{tid}" for tid, (_, _, _, penned) in self._tracks.items() if penned}
def n_active(self) -> int:
return sum(1 for _, _, _, penned in self._tracks.values() if not penned)
def n_penned(self) -> int:
return sum(1 for _, _, _, penned in self._tracks.values() if penned)