Checkpoint 5 - incomplete
This commit is contained in:
+15
-12
@@ -7,12 +7,15 @@ policy that runs under LiDAR perception in Webots.
|
||||
sim demos (active-scan teacher on tracker output, K=4 frame stack)
|
||||
│
|
||||
▼
|
||||
bc_pretrain.py ──► runs/bc_v3 (deployed policy — beats Strömbom on n≥8)
|
||||
bc_pretrain.py ──► runs/bc (BC baseline)
|
||||
│
|
||||
▼ (optional: tools/auto_dagger.sh + tools/dagger_merge_train.py
|
||||
│ if sim-trained doesn't transfer cleanly to Webots)
|
||||
▼ KL-regularised PPO fine-tune (training/train_ppo.py)
|
||||
│
|
||||
runs/bc_dagger
|
||||
runs/rl (deployed `rl` mode)
|
||||
|
||||
# optional branch — kept for reference, not deployed:
|
||||
runs/bc_dagger (Webots-grounded DAgger refinement, useful if a
|
||||
modified world breaks sim-to-real transfer)
|
||||
```
|
||||
|
||||
## Files
|
||||
@@ -42,14 +45,14 @@ rollout collection, not gradient compute.
|
||||
# 1. Sim demos with the active-scan + Strömbom teacher under LiDAR
|
||||
# perception. K=4 frame stack so the MLP has temporal context.
|
||||
python -m tools.collect_demos --teacher strombom \
|
||||
--out demos_v3.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
||||
--out demos.npz --seeds-per-n 15 --subsample 3 --frame-stack 4
|
||||
|
||||
# 2. Behavior-clone.
|
||||
python -m training.bc_pretrain --demos demos_v3.npz \
|
||||
--out runs/bc_v3 --epochs 60 --net-arch 512,512
|
||||
python -m training.bc_pretrain --demos demos.npz \
|
||||
--out runs/bc --epochs 60 --net-arch 512,512
|
||||
|
||||
# 3. Evaluate.
|
||||
python -m training.eval --policy runs/bc_v3 \
|
||||
python -m training.eval --policy runs/bc \
|
||||
--max-flock 10 --max-steps 8000 --n-seeds 5
|
||||
```
|
||||
|
||||
@@ -78,7 +81,7 @@ seat:
|
||||
HERDING_POLICY_DIR=$PWD/training/runs/bc_dagger \
|
||||
HERDING_DAGGER_DRIVER=student \
|
||||
tools/auto_dagger.sh 3 60
|
||||
python -m tools.dagger_merge_train --out runs/bc_dagger_v2
|
||||
python -m tools.dagger_merge_train --out runs/bc_dagger
|
||||
```
|
||||
|
||||
## Available analytic teachers
|
||||
@@ -107,6 +110,6 @@ python -m training.eval --policy sequential --max-flock 10 --max-steps 8000 --n
|
||||
tools/run_webots.sh 10 rl
|
||||
```
|
||||
|
||||
The dog controller loads the highest-priority policy that exists
|
||||
(`bc_dagger_v2` → `bc_dagger` → `bc_v3`). Override with
|
||||
`HERDING_POLICY_DIR=…` if you want a specific checkpoint.
|
||||
The dog controller loads `runs/bc` for `bc` mode and `runs/rl` for
|
||||
`rl` mode. Override with `HERDING_POLICY_DIR=…` for a specific
|
||||
checkpoint.
|
||||
|
||||
@@ -15,7 +15,7 @@ Usage::
|
||||
|
||||
python -m training.bc_pretrain \\
|
||||
--demos training/demos.npz \\
|
||||
--out training/runs/bc_flock
|
||||
--out training/runs/bc
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -83,7 +83,7 @@ def policy_forward_mean(policy, obs_batch):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--demos", default="training/demos.npz")
|
||||
parser.add_argument("--out", default="training/runs/bc_solo")
|
||||
parser.add_argument("--out", default="training/runs/bc")
|
||||
parser.add_argument("--epochs", type=int, default=60)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--lr", type=float, default=1e-3)
|
||||
|
||||
@@ -204,6 +204,12 @@ class HerdingEnv(gym.Env):
|
||||
already mimics a stronger teacher (sequential)."""
|
||||
self.W_IMITATE = float(value)
|
||||
|
||||
def set_time_weight(self, value: float) -> None:
|
||||
"""Override W_TIME (instance-level). Default 0.0; a small
|
||||
negative value (e.g. -0.1) adds a per-step penalty that
|
||||
explicitly rewards fast time-to-pen during PPO fine-tune."""
|
||||
self.W_TIME = float(value)
|
||||
|
||||
# ---- gym API ----
|
||||
def reset(self, *, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
@@ -431,6 +437,9 @@ class HerdingEnv(gym.Env):
|
||||
|
||||
d_progress = max(-5.0, min(5.0, self.prev_d_pen - d_pen))
|
||||
r = self.W_PEN_DELTA * delta_pen + self.W_PROGRESS * d_progress
|
||||
# Per-step time penalty (0 by default). When negative, encourages
|
||||
# the policy to finish quickly — used during PPO fine-tune.
|
||||
r += self.W_TIME
|
||||
|
||||
if action is not None and self.W_IMITATE > 0.0:
|
||||
positions = self._perceived_positions()
|
||||
|
||||
Binary file not shown.
+32
-7
@@ -10,7 +10,7 @@ per-step reward signal does the rest.
|
||||
|
||||
Pipeline
|
||||
--------
|
||||
1. Load ``bc_v3`` weights into both the trainable policy and a frozen
|
||||
1. Load ``bc`` weights into both the trainable policy and a frozen
|
||||
reference ``ref_policy``.
|
||||
2. Initialise the policy's log_std to a small fixed value (≈ −1.5)
|
||||
and disable its gradient — exploration noise stays small so PPO
|
||||
@@ -19,14 +19,14 @@ Pipeline
|
||||
each minibatch.
|
||||
4. Train for ~1–3 M timesteps with a low LR (5e-5).
|
||||
|
||||
Output: ``runs/rl_v1/policy.zip`` — same SB3 format as bc_v3, loadable
|
||||
Output: ``runs/rl/policy.zip`` — same SB3 format as bc, loadable
|
||||
by the dog controller's ``HERDING_MODE=rl`` path.
|
||||
|
||||
Usage::
|
||||
|
||||
python -m training.train_ppo \\
|
||||
--bc training/runs/bc_v3 \\
|
||||
--out training/runs/rl_v1 \\
|
||||
--bc training/runs/bc \\
|
||||
--out training/runs/rl \\
|
||||
--total-timesteps 2000000
|
||||
"""
|
||||
|
||||
@@ -205,9 +205,9 @@ class KLPPO(PPO):
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--bc", default="training/runs/bc_v3",
|
||||
parser.add_argument("--bc", default="training/runs/bc",
|
||||
help="Directory containing the BC initialisation (policy.zip).")
|
||||
parser.add_argument("--out", default="training/runs/rl_v1",
|
||||
parser.add_argument("--out", default="training/runs/rl",
|
||||
help="Where to save the fine-tuned policy.")
|
||||
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
|
||||
parser.add_argument("--n-envs", type=int, default=8)
|
||||
@@ -232,12 +232,23 @@ def main() -> None:
|
||||
help="SB3's per-batch KL early stop; safety belt.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--device", default="cpu")
|
||||
parser.add_argument("--imitate-weight", type=float, default=None,
|
||||
help="Override env.W_IMITATE for this training "
|
||||
"run. Set to 0.0 to drop the Strömbom "
|
||||
"cosine-imitation reward — useful during "
|
||||
"PPO refinement where you want reward, "
|
||||
"not teacher imitation, to drive updates.")
|
||||
parser.add_argument("--time-weight", type=float, default=None,
|
||||
help="Override env.W_TIME. Default env value is "
|
||||
"0.0; setting e.g. -0.1 adds a small per-"
|
||||
"step penalty that explicitly rewards "
|
||||
"fast time-to-pen.")
|
||||
args = parser.parse_args()
|
||||
|
||||
bc_zip = Path(args.bc) / "policy.zip"
|
||||
if not bc_zip.exists():
|
||||
raise SystemExit(
|
||||
f"BC checkpoint not found at {bc_zip}. Train bc_v3 first with "
|
||||
f"BC checkpoint not found at {bc_zip}. Train bc first with "
|
||||
f"`python -m training.bc_pretrain`."
|
||||
)
|
||||
|
||||
@@ -259,6 +270,20 @@ def main() -> None:
|
||||
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
||||
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
|
||||
|
||||
# --- Apply reward-shaping overrides to every env instance ---
|
||||
def _broadcast(method: str, value):
|
||||
for v in (venv, eval_venv):
|
||||
try:
|
||||
v.env_method(method, value)
|
||||
except AttributeError:
|
||||
v.venv.env_method(method, value)
|
||||
if args.imitate_weight is not None:
|
||||
_broadcast("set_imitate_weight", args.imitate_weight)
|
||||
print(f"[rl] W_IMITATE overridden to {args.imitate_weight}")
|
||||
if args.time_weight is not None:
|
||||
_broadcast("set_time_weight", args.time_weight)
|
||||
print(f"[rl] W_TIME overridden to {args.time_weight}")
|
||||
|
||||
# --- Trainable policy: load BC weights, then bolt onto PPO ---
|
||||
# Trick: instantiate a PPO with the right env (so the policy
|
||||
# network is constructed at the correct obs/action shape), then
|
||||
|
||||
Reference in New Issue
Block a user