Checkpoint 4
This commit is contained in:
@@ -21,21 +21,47 @@ from pathlib import Path
|
||||
|
||||
class PolicyHandle:
|
||||
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
|
||||
``predict(obs)`` without thinking about either."""
|
||||
``predict(obs)`` without thinking about either.
|
||||
|
||||
Frame stacking is auto-detected from the policy's expected obs dim:
|
||||
if it's a multiple of the single-frame ``OBS_DIM``, the handle keeps
|
||||
a deque of the last K frames and concatenates them on each predict.
|
||||
"""
|
||||
|
||||
def __init__(self, model, vecnorm):
|
||||
self.model = model
|
||||
self.vecnorm = vecnorm
|
||||
# Lazy import to avoid forcing herding/* into the import path
|
||||
# when SB3 isn't being used.
|
||||
from herding.obs import OBS_DIM
|
||||
policy_dim = int(model.observation_space.shape[0])
|
||||
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
|
||||
self.frame_stack = policy_dim // OBS_DIM
|
||||
else:
|
||||
self.frame_stack = 1
|
||||
self._buffer: list = []
|
||||
self._single_dim = OBS_DIM
|
||||
|
||||
def predict(self, obs):
|
||||
# VecNormalize expects a batched obs of shape (n_envs, obs_dim).
|
||||
if self.vecnorm is not None:
|
||||
import numpy as np
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
obs_b = self.vecnorm.normalize_obs(obs_b)
|
||||
import numpy as np
|
||||
single = np.asarray(obs, dtype=np.float32).reshape(-1)
|
||||
if single.shape[0] != self._single_dim:
|
||||
# Caller already passed a stacked obs — use as-is.
|
||||
stacked = single
|
||||
elif self.frame_stack > 1:
|
||||
if not self._buffer:
|
||||
self._buffer = [single.copy() for _ in range(self.frame_stack)]
|
||||
else:
|
||||
self._buffer.append(single)
|
||||
if len(self._buffer) > self.frame_stack:
|
||||
self._buffer = self._buffer[-self.frame_stack:]
|
||||
stacked = np.concatenate(self._buffer, axis=0)
|
||||
else:
|
||||
import numpy as np
|
||||
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
|
||||
stacked = single
|
||||
|
||||
obs_b = stacked.reshape(1, -1)
|
||||
if self.vecnorm is not None:
|
||||
obs_b = self.vecnorm.normalize_obs(obs_b)
|
||||
action, _ = self.model.predict(obs_b, deterministic=True)
|
||||
return action[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user