Checkpoint 4
This commit is contained in:
+18
-6
@@ -43,13 +43,15 @@ from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def build_model(net_arch_pi, net_arch_vf, log_std_init: float):
|
||||
"""Build a fresh SB3 PPO with the same architecture as train_ppo.
|
||||
def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
|
||||
frame_stack: int = 1):
|
||||
"""Build a fresh SB3 PPO solely as a vehicle for the policy weights.
|
||||
|
||||
We only need the policy to load weights into; PPO's training-loop
|
||||
plumbing isn't used during BC.
|
||||
PPO's training-loop plumbing isn't used during BC. ``frame_stack``
|
||||
must match the demo file so the env's obs space agrees with the
|
||||
recorded obs shape.
|
||||
"""
|
||||
env = DummyVecEnv([lambda: HerdingEnv()])
|
||||
env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack)])
|
||||
model = PPO(
|
||||
"MlpPolicy", env,
|
||||
policy_kwargs=dict(
|
||||
@@ -139,7 +141,17 @@ def main():
|
||||
# --- Build model ---
|
||||
net_arch_pi = [int(x) for x in args.net_arch.split(",")]
|
||||
net_arch_vf = net_arch_pi[:]
|
||||
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init)
|
||||
# Auto-detect frame stacking from the demo file so a stacked-obs
|
||||
# demo trains a stacked-obs policy without an extra CLI flag.
|
||||
obs_dim = obs.shape[1]
|
||||
from herding.obs import OBS_DIM as _SINGLE
|
||||
if obs_dim % _SINGLE != 0:
|
||||
raise RuntimeError(f"demo obs dim {obs_dim} is not a multiple of {_SINGLE}")
|
||||
frame_stack = obs_dim // _SINGLE
|
||||
if frame_stack > 1:
|
||||
print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}")
|
||||
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init,
|
||||
frame_stack=frame_stack)
|
||||
policy = model.policy.to(args.device)
|
||||
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user