Checkpoint 4

This commit is contained in:
Johnny Fernandes
2026-05-11 00:42:52 +01:00
parent 2a6db038df
commit 6688325d89
26 changed files with 2018 additions and 503 deletions
+18 -6
View File
@@ -43,13 +43,15 @@ from stable_baselines3.common.vec_env import DummyVecEnv
from training.herding_env import HerdingEnv
def build_model(net_arch_pi, net_arch_vf, log_std_init: float):
"""Build a fresh SB3 PPO with the same architecture as train_ppo.
def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
frame_stack: int = 1):
"""Build a fresh SB3 PPO solely as a vehicle for the policy weights.
We only need the policy to load weights into; PPO's training-loop
plumbing isn't used during BC.
PPO's training-loop plumbing isn't used during BC. ``frame_stack``
must match the demo file so the env's obs space agrees with the
recorded obs shape.
"""
env = DummyVecEnv([lambda: HerdingEnv()])
env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack)])
model = PPO(
"MlpPolicy", env,
policy_kwargs=dict(
@@ -139,7 +141,17 @@ def main():
# --- Build model ---
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
net_arch_vf = net_arch_pi[:]
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init)
# Auto-detect frame stacking from the demo file so a stacked-obs
# demo trains a stacked-obs policy without an extra CLI flag.
obs_dim = obs.shape[1]
from herding.obs import OBS_DIM as _SINGLE
if obs_dim % _SINGLE != 0:
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
frame_stack = obs_dim // _SINGLE
if frame_stack > 1:
print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}")
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init,
frame_stack=frame_stack)
policy = model.policy.to(args.device)
optimizer = optim.Adam(policy.parameters(), lr=args.lr)