Checkpoint 8
This commit is contained in:
+65
-4
@@ -20,8 +20,26 @@ Usage::
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Early CLI pre-parse for --world so geometry is configured before any
|
||||
# herding.* / training.* import binds geometry constants. Matches the
|
||||
# pattern used by training.bc.collect and training.eval.
|
||||
_pre_argv = [a for a in os.sys.argv[1:]]
|
||||
_pre_world = None
|
||||
for i, a in enumerate(_pre_argv):
|
||||
if a == "--world" and i + 1 < len(_pre_argv):
|
||||
_pre_world = _pre_argv[i + 1]
|
||||
break
|
||||
if a.startswith("--world="):
|
||||
_pre_world = a.split("=", 1)[1]
|
||||
break
|
||||
if _pre_world is not None:
|
||||
from herding.world.geometry import configure as _geo_configure
|
||||
_geo_configure(_pre_world)
|
||||
os.environ["HERDING_WORLD"] = _pre_world
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
@@ -38,9 +56,14 @@ from training.herding_env import HerdingEnv
|
||||
# Env factory
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
def _make_env(rank: int, seed: int, frame_stack: int):
|
||||
def _make_env(rank: int, seed: int, frame_stack: int,
|
||||
drive_mode: str = "differential",
|
||||
difficulty: float = 1.0,
|
||||
max_n_sheep: int = 10):
|
||||
def _thunk():
|
||||
env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack)
|
||||
env = HerdingEnv(seed=seed + rank, frame_stack=frame_stack,
|
||||
drive_mode=drive_mode, difficulty=difficulty,
|
||||
max_n_sheep=max_n_sheep)
|
||||
env = Monitor(env, info_keywords=("is_success", "n_sheep", "n_penned"))
|
||||
return env
|
||||
return _thunk
|
||||
@@ -198,13 +221,34 @@ def main() -> None:
|
||||
help="SB3 per-batch KL early-stop guard.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--device", default="cpu")
|
||||
parser.add_argument("--drive-mode", default=None,
|
||||
choices=["differential", "mecanum"],
|
||||
help="Drive mode. If not set, inferred from "
|
||||
"BC action dimension (2→differential, 3→mecanum).")
|
||||
parser.add_argument("--imitate-weight", type=float, default=None,
|
||||
help="Override env.W_IMITATE (e.g. 0.0 to drop "
|
||||
"Strömbom imitation during fine-tune).")
|
||||
parser.add_argument("--time-weight", type=float, default=None,
|
||||
help="Override env.W_TIME (e.g. -0.1 for a "
|
||||
"per-step time penalty).")
|
||||
parser.add_argument("--difficulty", type=float, default=1.0,
|
||||
help="HerdingEnv difficulty for PPO rollouts. "
|
||||
"Must match eval (1.0) to avoid train/eval "
|
||||
"distribution mismatch.")
|
||||
parser.add_argument("--max-n-sheep", type=int, default=10,
|
||||
help="Upper bound on flock size sampled each reset.")
|
||||
parser.add_argument("--world", default=None,
|
||||
choices=["field", "field_round"],
|
||||
help="World shape. If not set, uses HERDING_WORLD "
|
||||
"env var or defaults to 'field'.")
|
||||
args = parser.parse_args()
|
||||
# --world was already honoured in the early pre-parse above; here we
|
||||
# just sanity-check that the final argparse view agrees.
|
||||
if args.world is not None:
|
||||
from herding.world.geometry import FIELD_SHAPE as _CURRENT_SHAPE
|
||||
if args.world != _CURRENT_SHAPE:
|
||||
print(f"[rl] WARNING: --world={args.world} but geometry is "
|
||||
f"'{_CURRENT_SHAPE}'. File a bug.")
|
||||
|
||||
bc_zip = Path(args.bc) / "policy.zip"
|
||||
if not bc_zip.exists():
|
||||
@@ -226,9 +270,26 @@ def main() -> None:
|
||||
frame_stack = obs_dim // OBS_DIM
|
||||
print(f"[rl] BC obs dim {obs_dim} → frame_stack={frame_stack}")
|
||||
|
||||
env_fns = [_make_env(i, args.seed, frame_stack) for i in range(args.n_envs)]
|
||||
# Infer drive mode from BC action dim if not explicitly set.
|
||||
bc_action_dim = int(ref_only.action_space.shape[0])
|
||||
if args.drive_mode is not None:
|
||||
drive_mode = args.drive_mode
|
||||
elif bc_action_dim == 3:
|
||||
drive_mode = "mecanum"
|
||||
else:
|
||||
drive_mode = "differential"
|
||||
print(f"[rl] drive_mode={drive_mode} (BC action_dim={bc_action_dim})")
|
||||
|
||||
env_fns = [_make_env(i, args.seed, frame_stack, drive_mode,
|
||||
difficulty=args.difficulty,
|
||||
max_n_sheep=args.max_n_sheep)
|
||||
for i in range(args.n_envs)]
|
||||
venv = SubprocVecEnv(env_fns) if args.n_envs > 1 else DummyVecEnv(env_fns)
|
||||
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack)])
|
||||
eval_venv = DummyVecEnv([_make_env(99, args.seed + 999, frame_stack,
|
||||
drive_mode,
|
||||
difficulty=args.difficulty,
|
||||
max_n_sheep=args.max_n_sheep)])
|
||||
print(f"[rl] difficulty={args.difficulty} max_n_sheep={args.max_n_sheep}")
|
||||
|
||||
# Reward-shaping overrides (broadcast to every env instance).
|
||||
def _broadcast(method: str, value):
|
||||
|
||||
Reference in New Issue
Block a user