Checkpoint 4

This commit is contained in:
Johnny Fernandes
2026-05-11 00:42:52 +01:00
parent 2a6db038df
commit 6688325d89
26 changed files with 2018 additions and 503 deletions
+34 -8
View File
@@ -21,21 +21,47 @@ from pathlib import Path
class PolicyHandle:
"""Wrap a loaded PPO policy + VecNormalize so the controller can call
``predict(obs)`` without thinking about either."""
``predict(obs)`` without thinking about either.
Frame stacking is auto-detected from the policy's expected obs dim:
if it's a multiple of the single-frame ``OBS_DIM``, the handle keeps
a deque of the last K frames and concatenates them on each predict.
"""
def __init__(self, model, vecnorm):
self.model = model
self.vecnorm = vecnorm
# Lazy import to avoid forcing herding/* into the import path
# when SB3 isn't being used.
from herding.obs import OBS_DIM
policy_dim = int(model.observation_space.shape[0])
if policy_dim % OBS_DIM == 0 and policy_dim // OBS_DIM >= 1:
self.frame_stack = policy_dim // OBS_DIM
else:
self.frame_stack = 1
self._buffer: list = []
self._single_dim = OBS_DIM
def predict(self, obs):
# VecNormalize expects a batched obs of shape (n_envs, obs_dim).
if self.vecnorm is not None:
import numpy as np
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
obs_b = self.vecnorm.normalize_obs(obs_b)
import numpy as np
single = np.asarray(obs, dtype=np.float32).reshape(-1)
if single.shape[0] != self._single_dim:
# Caller already passed a stacked obs — use as-is.
stacked = single
elif self.frame_stack > 1:
if not self._buffer:
self._buffer = [single.copy() for _ in range(self.frame_stack)]
else:
self._buffer.append(single)
if len(self._buffer) > self.frame_stack:
self._buffer = self._buffer[-self.frame_stack:]
stacked = np.concatenate(self._buffer, axis=0)
else:
import numpy as np
obs_b = np.asarray(obs, dtype=np.float32).reshape(1, -1)
stacked = single
obs_b = stacked.reshape(1, -1)
if self.vecnorm is not None:
obs_b = self.vecnorm.normalize_obs(obs_b)
action, _ = self.model.predict(obs_b, deterministic=True)
return action[0]