"""Multi-target tracker for LiDAR-detected sheep. Greedy nearest-neighbour data association across frames, with a wider re-acquisition gate for stale tracks (sheep flee during occlusion and reappear off-position), plus memory of last-seen positions for sheep out of FOV. Output is ``{name: (x, y)}`` — Strömbom / Sequential consume it directly. When **predictive mode** is enabled (the default), tracks carry a constant-velocity state ``(vx, vy)`` estimated from the last two observations. While a track is occluded its position is extrapolated using this velocity for up to ``PREDICT_STEPS`` frames, keeping the teacher's CoM estimate stable during brief losses. After prediction expires, the track falls back to its last-seen position (static memory) until ``FORGET_STEPS`` deletes it entirely. A track is marked penned once its estimated position crosses the gate plane south (``is_penned_position``). Penned tracks are excluded from ``get_positions`` and kept indefinitely. """ from __future__ import annotations import math from herding.world.geometry import MAX_SHEEP, in_pen, is_penned_position GATE_M = 2.5 # m — primary NN gate (recently observed tracks) REACQUIRE_GATE_M = 4.5 # m — wider gate for re-binding stale tracks REACQUIRE_MIN_AGE = 20 # steps — track must be this stale to use the wider gate PENNED_GATE_M = 4.0 # m — gate for matching detections to existing penned tracks FORGET_STEPS = 200 # ~3.2 s — delete stale active tracks (penned ones kept forever) MAX_ACTIVE_TRACKS = MAX_SHEEP # Predictive tracking constants. PREDICT_STEPS = 120 # ~1.9 s — extrapolate velocity this many frames VELOCITY_CLAMP = 1.0 # m/s — max predicted speed (sheep max is ~0.78 m/s) class Track: """Single track with position, velocity, and age.""" __slots__ = ("x", "y", "vx", "vy", "last_seen", "penned") def __init__(self, x: float, y: float, step: int, penned: bool = False): self.x = x self.y = y self.vx = 0.0 self.vy = 0.0 self.last_seen = step self.penned = penned @property def age(self) -> int: """Not-a-property in the hot loop — callers pass current step.""" raise NotImplementedError def predicted_position(self, current_step: int) -> tuple[float, float]: """Extrapolated position using constant velocity, clamped.""" dt = current_step - self.last_seen if dt <= 0 or dt > PREDICT_STEPS: return self.x, self.y speed = math.hypot(self.vx, self.vy) if speed < 1e-4: return self.x, self.y # Clamp extrapolation distance. max_d = VELOCITY_CLAMP * dt * 0.016 # steps → seconds d = min(speed * dt * 0.016, max_d) return ( self.x + d * (self.vx / speed), self.y + d * (self.vy / speed), ) def update(self, x: float, y: float, step: int) -> None: """Absorb a new detection and re-estimate velocity.""" dt = step - self.last_seen if dt > 0: dt_s = dt * 0.016 # steps → seconds new_vx = (x - self.x) / dt_s new_vy = (y - self.y) / dt_s # Exponential smoothing on velocity. alpha = 0.6 self.vx = alpha * new_vx + (1.0 - alpha) * self.vx self.vy = alpha * new_vy + (1.0 - alpha) * self.vy self.x = x self.y = y self.last_seen = step class SheepTracker: """Online tracker with NN association, prediction, and forgetful memory. Each track is a :class:`Track` with position, velocity estimate, last-seen step, and penned flag. """ def __init__(self, gate: float = GATE_M): self.gate = gate self._tracks: dict[int, Track] = {} self._next_id = 0 self.step = 0 def reset(self) -> None: self._tracks.clear() self._next_id = 0 self.step = 0 def update(self, detections: list[tuple[float, float]]) -> dict[str, tuple[float, float]]: """Fold a new set of detections in and return active positions.""" self.step += 1 det_used: set[int] = set() updated_tids: set[int] = set() # Pass 1 — match active tracks within the primary gate. # Use predicted positions for matching, oldest-first. active_tids = [tid for tid, t in self._tracks.items() if not t.penned] active_tids.sort(key=lambda tid: self._tracks[tid].last_seen) for tid in active_tids: track = self._tracks[tid] # Use predicted position for matching. tx, ty = track.predicted_position(self.step) best_j, best_d = -1, self.gate for j, (dx, dy) in enumerate(detections): if j in det_used: continue d = math.hypot(dx - tx, dy - ty) if d < best_d: best_d = d best_j = j if best_j >= 0: dx, dy = detections[best_j] track.update(dx, dy, self.step) det_used.add(best_j) updated_tids.add(tid) # Pass 1b — re-acquisition with wider gate for stale tracks. for tid in active_tids: if tid in updated_tids: continue track = self._tracks[tid] if (self.step - track.last_seen) < REACQUIRE_MIN_AGE: continue tx, ty = track.predicted_position(self.step) best_j, best_d = -1, REACQUIRE_GATE_M for j, (dx, dy) in enumerate(detections): if j in det_used: continue d = math.hypot(dx - tx, dy - ty) if d < best_d: best_d = d best_j = j if best_j >= 0: dx, dy = detections[best_j] track.update(dx, dy, self.step) det_used.add(best_j) updated_tids.add(tid) # Pass 2 — match remaining detections to penned tracks. penned_tids = [tid for tid, t in self._tracks.items() if t.penned] for tid in penned_tids: track = self._tracks[tid] best_j, best_d = -1, PENNED_GATE_M for j, (dx, dy) in enumerate(detections): if j in det_used: continue d = math.hypot(dx - track.x, dy - track.y) if d < best_d: best_d = d best_j = j if best_j >= 0: dx, dy = detections[best_j] track.update(dx, dy, self.step) det_used.add(best_j) # Spawn new tracks for unmatched detections. for j, (dx, dy) in enumerate(detections): if j in det_used: continue penned = in_pen(dx, dy) or is_penned_position(dx, dy) self._tracks[self._next_id] = Track(dx, dy, self.step, penned) self._next_id += 1 # Promote active tracks whose current estimate crosses the gate. for track in self._tracks.values(): if track.penned: continue px, py = track.predicted_position(self.step) if is_penned_position(px, py): track.penned = True # Forget stale active tracks; penned tracks live forever. stale = [tid for tid, t in self._tracks.items() if not t.penned and (self.step - t.last_seen) > FORGET_STEPS] for tid in stale: del self._tracks[tid] # Hard cap on the active set — drop the oldest-seen overflow. active = [(tid, t.last_seen) for tid, t in self._tracks.items() if not t.penned] if len(active) > MAX_ACTIVE_TRACKS: active.sort(key=lambda kv: kv[1]) for tid, _ in active[: len(active) - MAX_ACTIVE_TRACKS]: del self._tracks[tid] return self.get_positions() def get_positions(self) -> dict[str, tuple[float, float]]: """Active (not-penned) tracks as a ``{name: (x, y)}`` dict. For tracks currently being predicted (occluded but within PREDICT_STEPS), returns the extrapolated position so the teacher sees a smooth estimate. """ result = {} for tid, track in self._tracks.items(): if track.penned: continue px, py = track.predicted_position(self.step) result[f"t{tid}"] = (px, py) return result def get_penned_set(self) -> set[str]: return {f"t{tid}" for tid, t in self._tracks.items() if t.penned} def n_active(self) -> int: return sum(1 for t in self._tracks.values() if not t.penned) def n_penned(self) -> int: return sum(1 for t in self._tracks.values() if t.penned) def n_predicted(self) -> int: """Number of active tracks currently being extrapolated (not directly observed).""" return sum(1 for t in self._tracks.values() if not t.penned and (self.step - t.last_seen) > 0 and (self.step - t.last_seen) <= PREDICT_STEPS)