Checkpoint 8
This commit is contained in:
+19
-3
@@ -35,14 +35,15 @@ from training.herding_env import HerdingEnv
|
||||
|
||||
|
||||
def build_model(net_arch_pi, net_arch_vf, log_std_init: float,
|
||||
frame_stack: int = 1):
|
||||
frame_stack: int = 1, drive_mode: str = "differential"):
|
||||
"""Build a fresh SB3 PPO solely as a vehicle for the policy weights.
|
||||
|
||||
PPO's training-loop plumbing isn't used during BC. ``frame_stack``
|
||||
must match the demo file so the env's obs space agrees with the
|
||||
recorded obs shape.
|
||||
"""
|
||||
env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack)])
|
||||
env = DummyVecEnv([lambda: HerdingEnv(frame_stack=frame_stack,
|
||||
drive_mode=drive_mode)])
|
||||
model = PPO(
|
||||
"MlpPolicy", env,
|
||||
policy_kwargs=dict(
|
||||
@@ -83,6 +84,10 @@ def main():
|
||||
"term; balances against MSE.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--device", default="cpu")
|
||||
parser.add_argument("--drive-mode", default=None,
|
||||
choices=["differential", "mecanum"],
|
||||
help="Drive mode. If not set, inferred from "
|
||||
"demo action dimension (2→differential, 3→mecanum).")
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
@@ -130,8 +135,19 @@ def main():
|
||||
frame_stack = obs_dim // _SINGLE
|
||||
if frame_stack > 1:
|
||||
print(f"[bc] inferred frame_stack={frame_stack} from demo obs dim {obs_dim}")
|
||||
|
||||
# Infer drive mode from action dimension if not explicitly set.
|
||||
action_dim = actions.shape[1]
|
||||
if args.drive_mode is not None:
|
||||
drive_mode = args.drive_mode
|
||||
elif action_dim == 3:
|
||||
drive_mode = "mecanum"
|
||||
else:
|
||||
drive_mode = "differential"
|
||||
print(f"[bc] drive_mode={drive_mode} (action_dim={action_dim})")
|
||||
|
||||
model, _env = build_model(net_arch_pi, net_arch_vf, args.log_std_init,
|
||||
frame_stack=frame_stack)
|
||||
frame_stack=frame_stack, drive_mode=drive_mode)
|
||||
policy = model.policy.to(args.device)
|
||||
optimizer = optim.Adam(policy.parameters(), lr=args.lr)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user