Checkpoint 5 - incomplete

This commit is contained in:
Johnny Fernandes
2026-05-11 10:35:39 +01:00
parent 6688325d89
commit b457155538
13 changed files with 174 additions and 74 deletions
+32 -7
View File
@@ -10,7 +10,7 @@ per-step reward signal does the rest.
Pipeline
--------
1. Load ``bc_v3`` weights into both the trainable policy and a frozen
1. Load ``bc`` weights into both the trainable policy and a frozen
reference ``ref_policy``.
2. Initialise the policy's log_std to a small fixed value (≈ 1.5)
and disable its gradient — exploration noise stays small so PPO
@@ -19,14 +19,14 @@ Pipeline
each minibatch.
4. Train for ~13 M timesteps with a low LR (5e-5).
Output: ``runs/rl_v1/policy.zip`` — same SB3 format as bc_v3, loadable
Output: ``runs/rl/policy.zip`` — same SB3 format as bc, loadable
by the dog controller's ``HERDING_MODE=rl`` path.
Usage::
python -m training.train_ppo \\
--bc training/runs/bc_v3 \\
--out training/runs/rl_v1 \\
--bc training/runs/bc \\
--out training/runs/rl \\
--total-timesteps 2000000
"""
@@ -205,9 +205,9 @@ class KLPPO(PPO):
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--bc", default="training/runs/bc_v3",
parser.add_argument("--bc", default="training/runs/bc",
help="Directory containing the BC initialisation (policy.zip).")
parser.add_argument("--out", default="training/runs/rl_v1",
parser.add_argument("--out", default="training/runs/rl",
help="Where to save the fine-tuned policy.")
parser.add_argument("--total-timesteps", type=int, default=2_000_000)
parser.add_argument("--n-envs", type=int, default=8)
@@ -232,12 +232,23 @@ def main() -> None:
help="SB3's per-batch KL early stop; safety belt.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", default="cpu")
parser.add_argument("--imitate-weight", type=float, default=None,
help="Override env.W_IMITATE for this training "
"run. Set to 0.0 to drop the Strömbom "
"cosine-imitation reward — useful during "
"PPO refinement where you want reward, "
"not teacher imitation, to drive updates.")
parser.add_argument("--time-weight", type=float, default=None,
help="Override env.W_TIME. Default env value is "
"0.0; setting e.g. -0.1 adds a small per-"
"step penalty that explicitly rewards "
"fast time-to-pen.")
args = parser.parse_args()
bc_zip = Path(args.bc) / "policy.zip"
if not bc_zip.exists():
raise SystemExit(
f"BC checkpoint not found at {bc_zip}. Train bc_v3 first with "
f"BC checkpoint not found at {bc_zip}. Train bc first with "
f"`python -m training.bc_pretrain`."
)
@@ -259,6 +270,20 @@ def main() -> None:
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
# --- Apply reward-shaping overrides to every env instance ---
def _broadcast(method: str, value):
for v in (venv, eval_venv):
try:
v.env_method(method, value)
except AttributeError:
v.venv.env_method(method, value)
if args.imitate_weight is not None:
_broadcast("set_imitate_weight", args.imitate_weight)
print(f"[rl] W_IMITATE overridden to {args.imitate_weight}")
if args.time_weight is not None:
_broadcast("set_time_weight", args.time_weight)
print(f"[rl] W_TIME overridden to {args.time_weight}")
# --- Trainable policy: load BC weights, then bolt onto PPO ---
# Trick: instantiate a PPO with the right env (so the policy
# network is constructed at the correct obs/action shape), then