# Training pipeline for the shepherd-dog herding project.
# Stages chain via output files in training/.
#
# Usage:
#   make            # full pipeline: bc_demos -> bc -> rl -> eval
#   make bc_demos   # generate sim demos
#   make bc         # behaviour clone (rebuilds bc_demos if missing)
#   make rl         # KL-PPO fine-tune (rebuilds bc if missing)
#   make eval       # 10-seed env eval of rl
#   make test       # pytest suite
#   make webots N=10 MODE=rl   # launch Webots in the chosen mode
#   WEBOTS_HEADLESS=1 make webots   # no 3D view, fast mode (still needs DISPLAY or xvfb-run)
#   make clean      # delete bc_demos and run artefacts
#   make clean_all  # delete artefacts for all combinations
#   make help       # print the target table
#
# Override any hyperparameter on the command line, for example:
#   make rl PPO_STEPS=2000000 KL=0.02
#   make eval EVAL_SEEDS=20
#
# Drive mode selects the locomotion model:
#   make DRIVE=differential       2-wheel diff-drive (default)
#   make DRIVE=mecanum             4-wheel omnidirectional
#
# World shape:
#   make WORLD=field              rectangular (default)
#   make WORLD=field_round        circular fence
#
# To train all 4 combinations:
#   make train_all


PY               := python

# Drive mode and world shape — each combination gets its own artefacts.
DRIVE            ?= differential
WORLD            ?= field

# Derived tag and paths.
TAG               = $(DRIVE)_$(WORLD)
BC_DEMOS          = training/bc/demos_$(TAG).npz
BC_DIR            = training/runs/bc_$(TAG)
RL_DIR            = training/runs/rl_$(TAG)
# Stage-2 "speed pass": continue PPO from RL_DIR with TIME_W < 0 so the
# policy keeps Stage-1's success rate but cuts time-to-pen.  Output is a
# separate run dir so Stage-1 stays comparable.
RL_FAST_DIR       = training/runs/rl_fast_$(TAG)
BC_POLICY         = $(BC_DIR)/policy.zip
RL_POLICY         = $(RL_DIR)/policy.zip
RL_FAST_POLICY    = $(RL_FAST_DIR)/policy.zip

# --- Demo collection ---
TEACHER          ?= universal
# Mecanum has more complex dynamics and a weaker teacher imitation signal
# (val_cos ≈ 0.70 vs ≥ 0.88 for differential).  Give it more demos and
# longer BC training to compensate.
ifeq ($(DRIVE),mecanum)
ifeq ($(WORLD),field_round)
SEEDS_PER_N      ?= 80
else
SEEDS_PER_N      ?= 50
endif
else
# Round field is harder; more demos give BC a fair shot at 60%+.
ifeq ($(WORLD),field_round)
SEEDS_PER_N      ?= 60
else
SEEDS_PER_N      ?= 25
endif
endif
SUBSAMPLE        ?= 3
FRAME_STACK      ?= 4
DEMO_MAX_STEPS   ?= 100000

# --- Behaviour cloning ---
ifeq ($(DRIVE),mecanum)
ifeq ($(WORLD),field_round)
BC_EPOCHS        ?= 200
else
BC_EPOCHS        ?= 100
endif
else
ifeq ($(WORLD),field_round)
BC_EPOCHS        ?= 150
else
BC_EPOCHS        ?= 60
endif
endif
BC_NET_ARCH      ?= 512,512

# --- Domain randomisation (used by bc_demos and rl targets) ---
# FP_RATE: mean false-positive detections injected per step (Poisson λ).
# ACTION_SMOOTH_TRAIN: EMA on actions to match Webots controller (0.55).
# WHEEL_SLIP_STD: Gaussian wheel-speed noise for mecanum dynamics gap.
#
# FP_RATE is used consistently in BC demos *and* RL: BC collection runs
# in PRIVILEGED mode (teacher sees GT; student obs sees the FP-injected
# tracker output), so the policy learns to denoise to the GT signal.
# Mismatched FP_RATE between BC/RL was the root cause of an earlier
# regression (BC=0, RL=2 → PPO stalled at 0% success).
FP_RATE          ?= 0.0
ACTION_SMOOTH_TRAIN ?= 0.55
WHEEL_SLIP_STD   ?= 0.05

# --- KL-PPO fine-tune ---
# Round field: longer training, looser KL, no time penalty (success
# must be learned before speed is rewarded).
ifeq ($(WORLD),field_round)
PPO_STEPS        ?= 4000000
KL               ?= 0.02
else
PPO_STEPS        ?= 2000000
KL               ?= 0.05
endif
# Time penalty is 0 until success rate is high. Earlier runs showed
# TIME_W=-0.05 traded ~10 pts of success for speed on hard combos —
# learn to succeed first, optimize speed in a later pass.
TIME_W           ?= 0.0
IMITATE          ?= 0.0
# PPO rollouts at full difficulty so the training distribution matches
# eval (deployment).  Anything lower causes a train/eval mismatch that
# can make RL eval worse than BC.
DIFFICULTY       ?= 1.0

# --- Stage-2 "speed pass" (rl_fast) ---
# Continues from RL_DIR with a negative TIME_W. Tighter KL keeps the
# policy near the Stage-1 success rate while step-count drops.
# Differential and mecanum respond differently: mecanum needs a stronger
# time penalty to achieve speed gains; differential only needs a light
# touch (-0.02) — stronger penalties trade success for speed without gain.
RL_FAST_STEPS    ?= 1000000
RL_FAST_KL       ?= 0.05
ifeq ($(DRIVE),mecanum)
RL_FAST_TIME_W   ?= -0.05
else
RL_FAST_TIME_W   ?= -0.02
endif

# --- Evaluation ---
EVAL_SEEDS       ?= 10
EVAL_MAX_STEPS   ?= 15000

# --- Webots launcher ---
N                ?= 10
MODE             ?= rl


.PHONY: all bc_demos bc rl rl_fast eval eval_fast eval_all eval_all_fast \
        test webots webots_sweep clean clean_all help \
        train_all train_diff_rect train_diff_round \
        train_mec_rect train_mec_round \
        train_all_fast train_diff_rect_fast train_diff_round_fast \
        train_mec_rect_fast train_mec_round_fast \
        remote_full

all: eval

# Export HERDING_WORLD so that geometry.py picks it up at import time.
export HERDING_WORLD = $(WORLD)
# Force Python stdout/stderr unbuffered so progress is visible live when
# the build is run under tee / nohup / tmux pipes.
export PYTHONUNBUFFERED = 1

# Mecanum needs --use-webots-preset so collect/rl pick up
# HERDING_MEC_WEBOTS — the gym mecanum kinematics get the strafe
# efficiency and forward-bleed match against the physical-roller
# Webots proto. Without this flag the policy trains on textbook
# X-pattern mecanum and fails on deployment.
ifeq ($(DRIVE),mecanum)
WEBOTS_PRESET_FLAG = --use-webots-preset
else
WEBOTS_PRESET_FLAG =
endif

bc_demos: $(BC_DEMOS)
$(BC_DEMOS):
	$(PY) -m training.bc.collect \
		--teacher $(TEACHER) --out $(BC_DEMOS) \
		--seeds-per-n $(SEEDS_PER_N) --subsample $(SUBSAMPLE) \
		--frame-stack $(FRAME_STACK) --drive-mode $(DRIVE) \
		--world $(WORLD) \
		--max-steps $(DEMO_MAX_STEPS) \
		--fp-rate $(FP_RATE) \
		--action-smooth $(ACTION_SMOOTH_TRAIN) \
		--wheel-slip-std $(WHEEL_SLIP_STD) \
		$(WEBOTS_PRESET_FLAG)

bc: $(BC_POLICY)
$(BC_POLICY): $(BC_DEMOS)
	$(PY) -m training.bc.pretrain \
		--demos $(BC_DEMOS) --out $(BC_DIR) \
		--epochs $(BC_EPOCHS) --net-arch $(BC_NET_ARCH)

rl: $(RL_POLICY)
$(RL_POLICY): $(BC_POLICY)
	$(PY) -m training.rl.train \
		--bc $(BC_DIR) --out $(RL_DIR) \
		--total-timesteps $(PPO_STEPS) --kl-coef $(KL) \
		--imitate-weight $(IMITATE) --time-weight $(TIME_W) \
		--difficulty $(DIFFICULTY) \
		--drive-mode $(DRIVE) --world $(WORLD) \
		--fp-rate $(FP_RATE) \
		--action-smooth $(ACTION_SMOOTH_TRAIN) \
		--wheel-slip-std $(WHEEL_SLIP_STD)
	# (rl/train.py auto-applies HERDING_MEC_WEBOTS when drive=mecanum;
	# no --use-webots-preset flag needed.)

eval: $(RL_POLICY)
	$(PY) -m training.eval --policy $(RL_DIR) \
		--max-flock 10 --max-steps $(EVAL_MAX_STEPS) --n-seeds $(EVAL_SEEDS) \
		--drive-mode $(DRIVE) --world $(WORLD)

# --- Stage-2 speed pass ---
# Continues PPO from $(RL_DIR) with a per-step time penalty so the
# policy keeps Stage-1's success rate but cuts mean steps-to-pen. Use
# `make rl_fast` after Stage-1 RL has converged (success ≥ teacher).
rl_fast: $(RL_FAST_POLICY)
$(RL_FAST_POLICY): $(RL_POLICY)
	$(PY) -m training.rl.train \
		--bc $(RL_DIR) --out $(RL_FAST_DIR) \
		--total-timesteps $(RL_FAST_STEPS) --kl-coef $(RL_FAST_KL) \
		--imitate-weight $(IMITATE) --time-weight $(RL_FAST_TIME_W) \
		--difficulty $(DIFFICULTY) \
		--drive-mode $(DRIVE) --world $(WORLD) \
		--fp-rate $(FP_RATE) \
		--action-smooth $(ACTION_SMOOTH_TRAIN) \
		--wheel-slip-std $(WHEEL_SLIP_STD)

eval_fast: $(RL_FAST_POLICY)
	$(PY) -m training.eval --policy $(RL_FAST_DIR) \
		--max-flock 10 --max-steps $(EVAL_MAX_STEPS) --n-seeds $(EVAL_SEEDS) \
		--drive-mode $(DRIVE) --world $(WORLD)

test:
	$(PY) -m pytest tests/

webots:
	@bash tools/webots_menu.sh

# Headless sweep across all modes × worlds × flock sizes.
# Results are written to webots_sweep.log.
# Set USE_GT=1 to bypass LiDAR tracker (isolate perception from policy).
webots_sweep:
	env $(if $(USE_GT),HERDING_USE_GT=1,) \
	    PATH="$(CONDA_PREFIX)/bin:$(PATH)" \
	    bash tools/webots_sweep.sh webots_sweep.log

clean:
	rm -f $(BC_DEMOS)
	rm -rf $(BC_DIR) $(RL_DIR)

clean_all:
	rm -f training/bc/demos_*.npz
	rm -rf training/runs/bc_* training/runs/rl_*

# --- Train all 4 combinations ---
train_diff_rect:
	$(MAKE) DRIVE=differential WORLD=field

train_diff_round:
	$(MAKE) DRIVE=differential WORLD=field_round

train_mec_rect:
	$(MAKE) DRIVE=mecanum WORLD=field

train_mec_round:
	$(MAKE) DRIVE=mecanum WORLD=field_round

train_all: train_diff_rect train_diff_round train_mec_rect train_mec_round

# Gym eval sweep over all 4 combos. Use after train_all / train_all_fast.
eval_all:
	@for d in differential mecanum; do \
	  for w in field field_round; do \
	    echo ""; \
	    echo "=== BC  $$d / $$w ==="; \
	    $(PY) -m training.eval --policy training/runs/bc_$${d}_$${w} \
	      --max-flock 10 --max-steps $(EVAL_MAX_STEPS) --n-seeds $(EVAL_SEEDS) \
	      --drive-mode $$d --world $$w; \
	    echo ""; \
	    echo "=== RL  $$d / $$w ==="; \
	    $(PY) -m training.eval --policy training/runs/rl_$${d}_$${w} \
	      --max-flock 10 --max-steps $(EVAL_MAX_STEPS) --n-seeds $(EVAL_SEEDS) \
	      --drive-mode $$d --world $$w; \
	  done; \
	done

# One-shot remote runbook: clean → Stage-1 train → Stage-1 eval → Stage-2
# train → Stage-2 eval. Each step pipes to its own log file in the repo
# root so the run is fully unattended.
remote_full:
	$(MAKE) clean_all
	$(MAKE) train_all 2>&1 | tee stage1_train.log
	$(MAKE) eval_all 2>&1 | tee stage1_eval.log
	$(MAKE) train_all_fast 2>&1 | tee stage2_train.log
	$(MAKE) eval_all_fast 2>&1 | tee stage2_eval.log
	@echo ""
	@echo "===================================================="
	@echo "  Done. Logs: stage1_train.log stage1_eval.log"
	@echo "              stage2_train.log stage2_eval.log"
	@echo "===================================================="

eval_all_fast:
	@for d in differential mecanum; do \
	  for w in field field_round; do \
	    echo ""; \
	    echo "=== RL_FAST  $$d / $$w ==="; \
	    $(PY) -m training.eval --policy training/runs/rl_fast_$${d}_$${w} \
	      --max-flock 10 --max-steps $(EVAL_MAX_STEPS) --n-seeds $(EVAL_SEEDS) \
	      --drive-mode $$d --world $$w; \
	  done; \
	done

# --- Stage-2 sweep ---
train_diff_rect_fast:
	$(MAKE) DRIVE=differential WORLD=field rl_fast

train_diff_round_fast:
	$(MAKE) DRIVE=differential WORLD=field_round rl_fast

train_mec_rect_fast:
	$(MAKE) DRIVE=mecanum WORLD=field rl_fast

train_mec_round_fast:
	$(MAKE) DRIVE=mecanum WORLD=field_round rl_fast

train_all_fast: train_diff_rect_fast train_diff_round_fast \
                train_mec_rect_fast train_mec_round_fast

help:
	@echo "Targets:"
	@echo "  make              full pipeline (bc_demos -> bc -> rl -> eval)"
	@echo "  make bc_demos     sim demos via the '$(TEACHER)' teacher"
	@echo "  make bc           train BC (rebuilds bc_demos if missing)"
	@echo "  make rl           KL-PPO fine-tune (rebuilds bc if missing)"
	@echo "  make eval         $(EVAL_SEEDS)-seed env eval of rl"
	@echo "  make test         pytest suite"
	@echo "  make webots [N=$(N)] [MODE=$(MODE)] [DRIVE=$(DRIVE)] [WORLD=$(WORLD)]"
	@echo "                    launch Webots in the chosen mode"
	@echo "  WEBOTS_HEADLESS=1 make webots …   no 3D view + fast + --batch"
	@echo "  make clean        delete artefacts for current DRIVE+WORLD"
	@echo "  make clean_all    delete artefacts for all combinations"
	@echo ""
	@echo "Combinations:"
	@echo "  make DRIVE=differential WORLD=field       diff + rectangular (default)"
	@echo "  make DRIVE=differential WORLD=field_round  diff + circular"
	@echo "  make DRIVE=mecanum     WORLD=field         mecanum + rectangular"
	@echo "  make DRIVE=mecanum     WORLD=field_round   mecanum + circular"
	@echo "  make train_all                            all 4 in sequence"
	@echo ""
	@echo "Hyperparameter overrides (showing defaults):"
	@echo "  TEACHER=$(TEACHER) SEEDS_PER_N=$(SEEDS_PER_N) SUBSAMPLE=$(SUBSAMPLE) FRAME_STACK=$(FRAME_STACK) DEMO_MAX_STEPS=$(DEMO_MAX_STEPS)"
	@echo "  BC_EPOCHS=$(BC_EPOCHS) BC_NET_ARCH=$(BC_NET_ARCH)"
	@echo "  PPO_STEPS=$(PPO_STEPS) KL=$(KL) IMITATE=$(IMITATE) TIME_W=$(TIME_W)"
	@echo "  EVAL_SEEDS=$(EVAL_SEEDS) EVAL_MAX_STEPS=$(EVAL_MAX_STEPS)"
