Checkpoint 4
This commit is contained in:
@@ -21,21 +21,47 @@ from pathlib import Path
|
||||
|
||||
class PolicyHandle:
|
||||
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
|
||||
``predict(obs)`` without thinking about either."""
|
||||
``predict(obs)`` without thinking about either.
|
||||
|
||||
Frame stacking is auto-detected from the policy's expected obs dim:
|
||||
if it's a multiple of the single-frame ``OBS_DIM``, the handle keeps
|
||||
a deque of the last K frames and concatenates them on each predict.
|
||||
"""
|
||||
|
||||
def __init__(self, model, vecnorm):
|
||||
self.model = model
|
||||
self.vecnorm = vecnorm
|
||||
# Lazy import to avoid forcing herding/* into the import path
|
||||
# when SB3 isn't being used.
|
||||
from herding.obs import OBS_DIM
|
||||
policy_dim = int(model.observation_space.shape[0])
|
||||
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
||||
self.frame_stack = policy_dim // OBS_DIM
|
||||
else:
|
||||
self.frame_stack = 1
|
||||
self._buffer: list = []
|
||||
self._single_dim = OBS_DIM
|
||||
|
||||
def predict(self, obs):
|
||||
# VecNormalize expects a batched obs of shape (n_envs, obs_dim).
|
||||
if self.vecnorm is not None:
|
||||
import numpy as np
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
obs_b = self.vecnorm.normalize_obs(obs_b)
|
||||
import numpy as np
|
||||
single = np.asarray(obs, dtype=np.float32).reshape(-1)
|
||||
if single.shape[0] != self._single_dim:
|
||||
# Caller already passed a stacked obs — use as-is.
|
||||
stacked = single
|
||||
elif self.frame_stack > 1:
|
||||
if not self._buffer:
|
||||
self._buffer = [single.copy() for _ in range(self.frame_stack)]
|
||||
else:
|
||||
self._buffer.append(single)
|
||||
if len(self._buffer) > self.frame_stack:
|
||||
self._buffer = self._buffer[-self.frame_stack:]
|
||||
stacked = np.concatenate(self._buffer, axis=0)
|
||||
else:
|
||||
import numpy as np
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
stacked = single
|
||||
|
||||
obs_b = stacked.reshape(1, -1)
|
||||
if self.vecnorm is not None:
|
||||
obs_b = self.vecnorm.normalize_obs(obs_b)
|
||||
action, _ = self.model.predict(obs_b, deterministic=True)
|
||||
return action[0]
|
||||
|
||||
|
||||
@@ -4,11 +4,42 @@ Mode is selected by ``HERDING_MODE`` (env var, or via the
|
||||
``herding_runtime.cfg`` file the launcher writes since Webots strips
|
||||
env vars on some setups):
|
||||
|
||||
rl → load a BC-trained SB3 policy from HERDING_POLICY_DIR
|
||||
and use its (vx, vy) action each step.
|
||||
strombom → canonical Strömbom collect/drive heuristic.
|
||||
sequential → single-target "pin and push" — drives the sheep
|
||||
closest to the pen.
|
||||
bc → behaviour-cloned MLP, trained on Strömbom demos via
|
||||
sim. Default policy directory: training/runs/bc_v3.
|
||||
rl → KL-regularised PPO fine-tune of the BC policy. Same
|
||||
obs/action space as bc; refines time-to-pen via
|
||||
environment reward while staying anchored to bc.
|
||||
Default policy directory: training/runs/rl_v1.
|
||||
dagger → DAgger data collection. Reads sheep ground-truth
|
||||
via the receiver, computes the active-scan teacher's
|
||||
recommended action at every step, drives with either
|
||||
the teacher (HERDING_DAGGER_DRIVER=teacher, default)
|
||||
or the loaded student (=student), and logs each
|
||||
(lidar_stacked_obs, teacher_action) pair. On exit
|
||||
dumps to ``training/dagger/dagger_<ts>.npz`` for
|
||||
``tools.dagger_merge_train`` to consume.
|
||||
|
||||
Sheep perception
|
||||
----------------
|
||||
The dog now perceives sheep through its **front-mounted 140° LiDAR**
|
||||
(``protos/ShepherdDog.proto``: 180 rays, 12 m max range). Each step
|
||||
the controller:
|
||||
|
||||
1. Reads ``lidar.getRangeImage()``.
|
||||
2. Runs ``herding.lidar_perception.detections_from_scan`` to cluster
|
||||
returns into world-frame ``(x, y)`` sheep estimates.
|
||||
3. Folds those into a ``herding.sheep_tracker.SheepTracker`` which
|
||||
maintains last-seen positions for sheep currently out of the
|
||||
FOV and latches "penned" once a track disappears near the gate.
|
||||
|
||||
The output of step 3 is a ``{name: (x, y)}`` dict shaped exactly like
|
||||
the receiver-based one we used to consume — so Strömbom, Sequential
|
||||
and the BC obs builder run unchanged. The sheep→dog Emitter/Receiver
|
||||
link is still up (kept passively for compatibility) but its messages
|
||||
are *not* used for control.
|
||||
|
||||
All modes share the same low-level differential-drive controller
|
||||
(``herding.diffdrive.velocity_to_wheels`` with cos(err)-clamped forward
|
||||
@@ -33,14 +64,19 @@ if _PROJECT_ROOT not in sys.path:
|
||||
|
||||
from controller import Robot
|
||||
|
||||
from herding.active_scan import ActiveScanTeacher
|
||||
from herding.control import modulate_speed_near_sheep
|
||||
from herding.diffdrive import velocity_to_wheels
|
||||
from herding.geometry import (
|
||||
DOG_MAX_LINEAR, DOG_MAX_WHEEL_OMEGA,
|
||||
DOG_SOUTH_LIMIT, DOG_WHEEL_RADIUS,
|
||||
PEN_ENTRY,
|
||||
PEN_ENTRY, is_penned_position,
|
||||
)
|
||||
from herding.obs import build_obs
|
||||
from herding.lidar_perception import detections_from_scan
|
||||
from herding.obs import OBS_DIM, build_obs
|
||||
from herding.sequential import compute_action_debug as sequential_action_debug
|
||||
from herding.sheep_tracker import SheepTracker
|
||||
from herding.strombom import compute_action as strombom_action
|
||||
from herding.strombom import compute_action_debug as strombom_action_debug
|
||||
|
||||
|
||||
@@ -76,60 +112,82 @@ def _load_runtime_config():
|
||||
_runtime_cfg = _load_runtime_config()
|
||||
MODE = (os.environ.get("HERDING_MODE")
|
||||
or _runtime_cfg.get("HERDING_MODE")
|
||||
or "rl").lower()
|
||||
or "bc").lower()
|
||||
|
||||
|
||||
def _resolve_policy_dir() -> str:
|
||||
"""Where to look for the trained policy.
|
||||
def _resolve_policy_dir(mode: str) -> str:
|
||||
"""Where to look for the trained policy for the given mode.
|
||||
|
||||
Priority:
|
||||
1. HERDING_POLICY_DIR env var or runtime-cfg entry, if it points
|
||||
to a real directory.
|
||||
2. ``training/runs/bc_flock`` — flock-style BC (current default;
|
||||
requires the tight-cohesion sheep regime).
|
||||
3. ``training/runs/bc_solo`` — single-target BC (1-by-1 style;
|
||||
only works if ``herding/flocking_sim.py`` is reverted to the
|
||||
loose-cohesion regime).
|
||||
2. Mode-specific default:
|
||||
bc → training/runs/bc_v3 (Strömbom-imitated MLP)
|
||||
rl → training/runs/rl_v1 (KL-PPO fine-tune of bc_v3)
|
||||
3. Fall back to bc_v3.
|
||||
All checkpoints are frame-stacked K = 4; ``policy_loader`` reads
|
||||
the stacking factor from the policy's observation space.
|
||||
"""
|
||||
env_dir = (os.environ.get("HERDING_POLICY_DIR")
|
||||
or _runtime_cfg.get("HERDING_POLICY_DIR"))
|
||||
if env_dir and os.path.isdir(env_dir):
|
||||
return env_dir
|
||||
candidates = [
|
||||
os.path.join(_PROJECT_ROOT, "training", "runs", "bc_flock"),
|
||||
os.path.join(_PROJECT_ROOT, "training", "runs", "bc_solo"),
|
||||
]
|
||||
for c in candidates:
|
||||
if os.path.isdir(c):
|
||||
return c
|
||||
# Last resort — return env var anyway so error message is informative.
|
||||
return env_dir or candidates[0]
|
||||
mode_default = {
|
||||
"bc": os.path.join(_PROJECT_ROOT, "training", "runs", "bc_v3"),
|
||||
"rl": os.path.join(_PROJECT_ROOT, "training", "runs", "rl_v1"),
|
||||
"dagger": os.path.join(_PROJECT_ROOT, "training", "runs", "bc_v3"),
|
||||
}
|
||||
primary = mode_default.get(mode, mode_default["bc"])
|
||||
if os.path.isdir(primary):
|
||||
return primary
|
||||
# Fall back to BC if the requested checkpoint isn't there yet
|
||||
# (e.g., user asked for `rl` before training the fine-tune).
|
||||
fallback = mode_default["bc"]
|
||||
if os.path.isdir(fallback):
|
||||
return fallback
|
||||
return env_dir or primary
|
||||
|
||||
|
||||
_VALID_MODES = ("rl", "strombom", "sequential")
|
||||
_VALID_MODES = ("bc", "rl", "strombom", "sequential", "dagger", "diag")
|
||||
# Back-compat: an old config saying HERDING_MODE=rl meant "the BC policy".
|
||||
# We now use `rl` strictly for the KL-PPO fine-tune. If the rl_v1
|
||||
# directory isn't present, _resolve_policy_dir below silently falls
|
||||
# back to bc_v3, preserving the old behaviour.
|
||||
if MODE not in _VALID_MODES:
|
||||
print(f"[dog] unknown HERDING_MODE={MODE!r}; defaulting to strombom.")
|
||||
MODE = "strombom"
|
||||
|
||||
POLICY_DIR = _resolve_policy_dir()
|
||||
DAGGER_DRIVER = (os.environ.get("HERDING_DAGGER_DRIVER")
|
||||
or _runtime_cfg.get("HERDING_DAGGER_DRIVER")
|
||||
or "teacher").lower()
|
||||
if DAGGER_DRIVER not in ("teacher", "student"):
|
||||
DAGGER_DRIVER = "teacher"
|
||||
|
||||
POLICY_DIR = _resolve_policy_dir(MODE)
|
||||
policy_handle = None
|
||||
if MODE == "rl":
|
||||
if MODE in ("bc", "rl", "dagger"):
|
||||
print(f"[dog] resolved POLICY_DIR={POLICY_DIR} exists={os.path.isdir(POLICY_DIR)}")
|
||||
try:
|
||||
from policy_loader import load as _load_policy
|
||||
policy_handle = _load_policy(POLICY_DIR)
|
||||
print(f"[dog] RL policy loaded from {POLICY_DIR}")
|
||||
print(f"[dog] policy loaded from {POLICY_DIR}")
|
||||
except Exception as exc:
|
||||
print(f"[dog] RL policy load failed ({exc!r}); falling back to strombom.")
|
||||
MODE = "strombom"
|
||||
print(f"[dog] running in mode={MODE}")
|
||||
if MODE in ("bc", "rl"):
|
||||
print(f"[dog] policy load failed ({exc!r}); falling back to strombom.")
|
||||
MODE = "strombom"
|
||||
else:
|
||||
# In dagger mode, no policy is fine if driver=teacher.
|
||||
print(f"[dog] policy load failed ({exc!r}); dagger driver forced to teacher.")
|
||||
policy_handle = None
|
||||
print(f"[dog] running in mode={MODE}"
|
||||
+ (f" driver={DAGGER_DRIVER}" if MODE == "dagger" else ""))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Action smoothing + safety supervisor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ACTION_SMOOTH = 0.35
|
||||
ACTION_SMOOTH = 0.55 # was 0.35; bumped for less frame-to-frame action jitter
|
||||
prev_action = (0.0, 0.0)
|
||||
|
||||
|
||||
@@ -185,6 +243,12 @@ gps = robot.getDevice("gps"); gps.enable(timestep)
|
||||
compass = robot.getDevice("compass"); compass.enable(timestep)
|
||||
receiver = robot.getDevice("receiver"); receiver.enable(timestep)
|
||||
emitter = robot.getDevice("emitter")
|
||||
lidar = robot.getDevice("lidar"); lidar.enable(timestep)
|
||||
|
||||
# The receiver channel from sheep is no longer consumed for perception
|
||||
# (kept enabled in case any peripheral tooling reads it). Sheep
|
||||
# positions come exclusively from the LiDAR + tracker pipeline below.
|
||||
tracker = SheepTracker()
|
||||
|
||||
# Cosmetic ear motors — ignored by control logic but keep them animated.
|
||||
left_ear = robot.getDevice("left ear motor")
|
||||
@@ -202,53 +266,197 @@ EAR_RATE = 8.0
|
||||
# Main loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# {name: (x, y)} — kept across all sheep ever heard from. Sheep that drift
|
||||
# into the pen are tracked by ``penned`` so observations and Strömbom
|
||||
# agree on which ones still need herding.
|
||||
sheep_positions: dict = {}
|
||||
penned_set: set = set()
|
||||
# Active sheep positions come from the LiDAR-fed tracker each step;
|
||||
# penned_set is the tracker's ``get_penned_set()`` call. We drain the
|
||||
# receiver queue without consuming it, so the small backlog of sheep
|
||||
# pings can't grow unbounded.
|
||||
step_count = 0
|
||||
|
||||
from herding.geometry import is_penned_position
|
||||
import atexit
|
||||
import time
|
||||
import numpy as _np
|
||||
|
||||
# DAgger state ----------------------------------------------------------
|
||||
# Logged each step in dagger mode: (stacked_lidar_obs, teacher_action).
|
||||
DAGGER_LOG_OBS: list = []
|
||||
DAGGER_LOG_ACT: list = []
|
||||
# Diagnostic mode buffer (one dict per step).
|
||||
DIAG_BUF: list = []
|
||||
# Frame stack buffer the controller maintains itself when dagger mode is
|
||||
# active — the stacked obs we log must match what the policy sees so the
|
||||
# downstream BC consumes (stacked_obs, teacher_action) pairs cleanly.
|
||||
_FRAME_STACK = (policy_handle.frame_stack if policy_handle is not None else 4)
|
||||
_dagger_buffer: list = []
|
||||
# Active-scan teacher operates on GT (read from receiver).
|
||||
_dagger_teacher = ActiveScanTeacher(strombom_action) if MODE == "dagger" else None
|
||||
# GT positions accumulated from the receiver (sheep emit their xy each step).
|
||||
_gt_sheep: dict = {}
|
||||
|
||||
|
||||
_DAGGER_RUN_TS = int(time.time()) # one file per controller run
|
||||
_DAGGER_DUMPED = False
|
||||
# Sentinel that the auto-collection script polls — empty file written
|
||||
# when this controller decides the run is "done" (all sheep penned, by
|
||||
# GT). The launcher then kills Webots and moves on without waiting out
|
||||
# its timeout. Honoured only in dagger mode.
|
||||
_DAGGER_DONE_FILE = os.path.join(_PROJECT_ROOT, "training", "dagger", ".DONE")
|
||||
|
||||
|
||||
def _dump_dagger_log():
|
||||
"""Save accumulated (obs, teacher_action) pairs to disk on exit.
|
||||
|
||||
Webots may SIGKILL the controller, so the loop also calls this every
|
||||
DAGGER_FLUSH_STEPS so we lose at most a few seconds of data per run.
|
||||
Idempotent — repeated calls overwrite the same file with the latest
|
||||
accumulated buffer.
|
||||
"""
|
||||
global _DAGGER_DUMPED
|
||||
if MODE != "dagger" or not DAGGER_LOG_OBS:
|
||||
return
|
||||
out_dir = os.path.join(_PROJECT_ROOT, "training", "dagger")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
out_path = os.path.join(out_dir, f"dagger_{_DAGGER_RUN_TS}.npz")
|
||||
obs_arr = _np.stack(DAGGER_LOG_OBS).astype(_np.float32)
|
||||
act_arr = _np.stack(DAGGER_LOG_ACT).astype(_np.float32)
|
||||
_np.savez(out_path, obs=obs_arr, actions=act_arr)
|
||||
if not _DAGGER_DUMPED:
|
||||
print(f"[dog dagger] wrote {len(DAGGER_LOG_OBS)} pairs → {out_path}")
|
||||
_DAGGER_DUMPED = True
|
||||
|
||||
|
||||
DAGGER_FLUSH_STEPS = 500
|
||||
|
||||
|
||||
atexit.register(_dump_dagger_log)
|
||||
|
||||
|
||||
while robot.step(timestep) != -1:
|
||||
step_count += 1
|
||||
|
||||
# Drain receiver. In every mode we capture GT for the diagnostic
|
||||
# log line — perception still comes from LiDAR, the GT is read-only.
|
||||
while receiver.getQueueLength() > 0:
|
||||
msg = receiver.getString()
|
||||
receiver.nextPacket()
|
||||
parts = msg.split(":")
|
||||
if len(parts) == 4 and parts[0] == "sheep":
|
||||
try:
|
||||
x, y = float(parts[2]), float(parts[3])
|
||||
_gt_sheep[parts[1]] = (float(parts[2]), float(parts[3]))
|
||||
except ValueError:
|
||||
continue
|
||||
sheep_positions[parts[1]] = (x, y)
|
||||
if parts[1] not in penned_set and is_penned_position(x, y):
|
||||
penned_set.add(parts[1])
|
||||
pass
|
||||
|
||||
pos = gps.getValues()
|
||||
dog_xy = (pos[0], pos[1])
|
||||
n = compass.getValues()
|
||||
dog_heading = math.atan2(n[0], n[1])
|
||||
|
||||
# ---- Action selection ----
|
||||
if MODE == "rl" and policy_handle is not None:
|
||||
sheep_xy_list = list(sheep_positions.values())
|
||||
sheep_names = list(sheep_positions.keys())
|
||||
sheep_penned_list = [s in penned_set for s in sheep_names]
|
||||
obs = build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list)
|
||||
action = policy_handle.predict(obs)
|
||||
vx, vy = float(action[0]), float(action[1])
|
||||
elif MODE == "sequential":
|
||||
vx, vy, _mode_str, _dbg = sequential_action_debug(
|
||||
dog_xy, sheep_positions, PEN_ENTRY,
|
||||
)
|
||||
# ---- LiDAR perception → tracker → sheep_positions dict ----
|
||||
ranges = _np.asarray(lidar.getRangeImage(), dtype=_np.float32)
|
||||
detections = detections_from_scan(ranges, dog_xy[0], dog_xy[1], dog_heading)
|
||||
sheep_positions = tracker.update(detections)
|
||||
penned_set = tracker.get_penned_set()
|
||||
|
||||
# ---- Diagnostic mode: dump the first DIAG_STEPS scans + GT to disk.
|
||||
if MODE == "diag":
|
||||
DIAG_STEPS = 80
|
||||
if step_count <= DIAG_STEPS:
|
||||
DIAG_BUF.append(dict(
|
||||
step=step_count,
|
||||
ranges=ranges.copy(),
|
||||
dog_x=dog_xy[0], dog_y=dog_xy[1], dog_h=dog_heading,
|
||||
gt_sheep=dict(_gt_sheep),
|
||||
detections=list(detections),
|
||||
))
|
||||
if step_count == DIAG_STEPS:
|
||||
_diag_path = os.path.join(_PROJECT_ROOT, "training", "dagger",
|
||||
f"diag_{int(time.time())}.npz")
|
||||
os.makedirs(os.path.dirname(_diag_path), exist_ok=True)
|
||||
_np.savez(
|
||||
_diag_path,
|
||||
ranges=_np.stack([d["ranges"] for d in DIAG_BUF]),
|
||||
dog_xy=_np.array([[d["dog_x"], d["dog_y"]] for d in DIAG_BUF],
|
||||
dtype=_np.float32),
|
||||
dog_h=_np.array([d["dog_h"] for d in DIAG_BUF], dtype=_np.float32),
|
||||
# Per-step GT serialised: max-pad to 10 sheep.
|
||||
gt_xy=_np.array([
|
||||
[list(d["gt_sheep"].get(f"sheep{i}", (1e9, 1e9)))
|
||||
for i in range(1, 11)]
|
||||
for d in DIAG_BUF
|
||||
], dtype=_np.float32),
|
||||
detections=_np.array([
|
||||
len(d["detections"]) for d in DIAG_BUF
|
||||
], dtype=_np.int32),
|
||||
)
|
||||
print(f"[dog diag] wrote {DIAG_STEPS} scans → {_diag_path}")
|
||||
|
||||
# Build the single-frame LiDAR obs (matches what the env produces).
|
||||
sheep_xy_list = list(sheep_positions.values())
|
||||
sheep_penned_list = [False] * len(sheep_xy_list)
|
||||
single_obs = build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list)
|
||||
# Maintain our own frame stack so logged obs == what policy sees.
|
||||
if not _dagger_buffer:
|
||||
_dagger_buffer = [single_obs.copy() for _ in range(_FRAME_STACK)]
|
||||
else:
|
||||
# Strömbom (canonical baseline).
|
||||
vx, vy, _mode_str, _dbg = strombom_action_debug(
|
||||
dog_xy, sheep_positions, PEN_ENTRY,
|
||||
_dagger_buffer.append(single_obs)
|
||||
if len(_dagger_buffer) > _FRAME_STACK:
|
||||
_dagger_buffer = _dagger_buffer[-_FRAME_STACK:]
|
||||
stacked_obs = _np.concatenate(_dagger_buffer, axis=0).astype(_np.float32)
|
||||
|
||||
# ---- Action selection ----
|
||||
if MODE == "diag":
|
||||
# Diagnostic mode: rotate in place so the captured scans cover
|
||||
# all 360° of view from one position. Target = heading + π →
|
||||
# cos(err) clamps forward to ~0, the dog spins.
|
||||
_t = dog_heading + math.pi
|
||||
vx, vy = math.cos(_t), math.sin(_t)
|
||||
elif MODE == "dagger":
|
||||
# Teacher: active-scan + Strömbom on GT (active sheep only).
|
||||
gt_active = {name: xy for name, xy in _gt_sheep.items()
|
||||
if not is_penned_position(xy[0], xy[1])}
|
||||
t_vx, t_vy, _mode_str = _dagger_teacher(
|
||||
dog_xy, dog_heading, gt_active, PEN_ENTRY,
|
||||
)
|
||||
# Student (if a policy is loaded).
|
||||
s_vx, s_vy = None, None
|
||||
if policy_handle is not None:
|
||||
action = policy_handle.predict(stacked_obs)
|
||||
s_vx, s_vy = float(action[0]), float(action[1])
|
||||
# Drive selection.
|
||||
if DAGGER_DRIVER == "student" and policy_handle is not None:
|
||||
vx, vy = s_vx, s_vy
|
||||
else:
|
||||
vx, vy = t_vx, t_vy
|
||||
# Always log the teacher action (this is the supervision signal).
|
||||
DAGGER_LOG_OBS.append(stacked_obs.copy())
|
||||
DAGGER_LOG_ACT.append(_np.array([t_vx, t_vy], dtype=_np.float32))
|
||||
elif MODE in ("bc", "rl") and policy_handle is not None:
|
||||
# Pass the single-frame obs; the policy_loader maintains its own
|
||||
# frame stack internally. Both bc and rl use the same control
|
||||
# interface — the only difference is which checkpoint loaded.
|
||||
action = policy_handle.predict(single_obs)
|
||||
vx, vy = float(action[0]), float(action[1])
|
||||
elif MODE in ("strombom", "sequential"):
|
||||
# Wrap the analytic teacher in ActiveScanTeacher so the dog
|
||||
# rotates / walks-to-centre when the tracker briefly empties,
|
||||
# instead of going idle. Without this wrapper, the first 2 s
|
||||
# of LiDAR-blind operation kills the run because Strömbom and
|
||||
# Sequential both return (0, 0) when there are no positions.
|
||||
if "_analytic_teacher" not in globals():
|
||||
from herding.sequential import compute_action as sequential_action
|
||||
_analytic_teacher = ActiveScanTeacher(
|
||||
strombom_action if MODE == "strombom" else sequential_action
|
||||
)
|
||||
vx, vy, _mode_str = _analytic_teacher(
|
||||
dog_xy, dog_heading, sheep_positions, PEN_ENTRY,
|
||||
)
|
||||
|
||||
# Shared post-process: speed modulation near sheep. Applies to bc,
|
||||
# rl, strombom, sequential — every mode where the action source is
|
||||
# nominally unit-magnitude. In dagger mode the active-scan teacher
|
||||
# has already modulated, and the diag mode action is hand-built for
|
||||
# rotation; both skip.
|
||||
if MODE not in ("dagger", "diag"):
|
||||
vx, vy = modulate_speed_near_sheep(vx, vy, dog_xy, sheep_positions)
|
||||
|
||||
# EMA smoothing — reduces oscillation from policy or Strömbom flips.
|
||||
vx = ACTION_SMOOTH * prev_action[0] + (1.0 - ACTION_SMOOTH) * vx
|
||||
@@ -269,7 +477,31 @@ while robot.step(timestep) != -1:
|
||||
left_ear.setPosition(ear_pos)
|
||||
right_ear.setPosition(-ear_pos)
|
||||
|
||||
# --- DAgger: early-stop when all GT sheep are penned ---
|
||||
if MODE == "dagger" and _gt_sheep:
|
||||
gt_active_count = sum(1 for x, y in _gt_sheep.values()
|
||||
if not is_penned_position(x, y))
|
||||
if gt_active_count == 0 and not os.path.exists(_DAGGER_DONE_FILE):
|
||||
_dump_dagger_log()
|
||||
open(_DAGGER_DONE_FILE, "w").close()
|
||||
print(f"[dog dagger] all {len(_gt_sheep)} sheep penned — "
|
||||
f"wrote {_DAGGER_DONE_FILE}, exiting early")
|
||||
|
||||
if MODE == "dagger" and step_count % DAGGER_FLUSH_STEPS == 0 and DAGGER_LOG_OBS:
|
||||
_dump_dagger_log()
|
||||
|
||||
if step_count % 200 == 0:
|
||||
n_active = sum(1 for s in sheep_positions if s not in penned_set)
|
||||
print(f"[dog mode={MODE}] step={step_count} known={len(sheep_positions)} "
|
||||
f"penned={len(penned_set)} active={n_active} action=({vx:+.2f}, {vy:+.2f})")
|
||||
gt_penned = sum(1 for x, y in _gt_sheep.values()
|
||||
if is_penned_position(x, y))
|
||||
gt_total = len(_gt_sheep)
|
||||
extra = ""
|
||||
if MODE == "dagger":
|
||||
extra = f" logged={len(DAGGER_LOG_OBS)}"
|
||||
print(f"[dog mode={MODE}] step={step_count} "
|
||||
f"GT_penned={gt_penned}/{gt_total} "
|
||||
f"tracks_active={tracker.n_active()} "
|
||||
f"tracks_penned={tracker.n_penned()} "
|
||||
f"detections={len(detections)} action=({vx:+.2f}, {vy:+.2f}){extra}")
|
||||
|
||||
# Loop ended (Webots told us to quit). Flush any remaining DAgger log.
|
||||
_dump_dagger_log()
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Backwards-compat shim — Strömbom logic now lives in ``herding.strombom``."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..", ".."))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
from herding.strombom import ( # noqa: F401
|
||||
F_FACTOR, DELTA_COLLECT, DELTA_DRIVE,
|
||||
compute_action, compute_action_debug,
|
||||
)
|
||||
from herding.geometry import ( # noqa: F401
|
||||
PEN_X, PEN_Y, PEN_CENTER, PEN_ENTRY,
|
||||
in_pen,
|
||||
)
|
||||
Reference in New Issue
Block a user