Checkpoint 8

This commit is contained in:
Johnny Fernandes
2026-05-12 22:41:03 +01:00
parent a01a5c9cef
commit 5c2ee4bba5
31 changed files with 3189 additions and 380 deletions
+6 -12
View File
@@ -26,14 +26,15 @@ import math
import numpy as np
from herding.world.geometry import (
FIELD_X, FIELD_Y, PEN_ENTRY, MAX_SHEEP,
PEN_ENTRY, MAX_SHEEP, distance_to_wall,
)
OBS_DIM = 32
def build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list,
n_max: int = MAX_SHEEP) -> np.ndarray:
n_max: int = MAX_SHEEP,
n_expected: int | None = None) -> np.ndarray:
"""Assemble the dog policy's observation vector.
Parameters
@@ -43,6 +44,7 @@ def build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list,
sheep_xy_list : iterable of (x, y) for ALL known sheep
sheep_penned_list : parallel iterable of bool — True if sheep is penned
n_max : maximum supported flock size used for the count normaliser
n_expected : unused, kept for API compatibility.
"""
dog_x, dog_y = dog_xy
obs = np.zeros(OBS_DIM, dtype=np.float32)
@@ -89,16 +91,8 @@ def build_obs(dog_xy, dog_heading, sheep_xy_list, sheep_penned_list,
obs[15] = float(rel[far_idx, 0]) / 15.0
obs[16] = float(rel[far_idx, 1]) / 15.0
min_sheep_wall = min(
float(np.min(arr[:, 0] - FIELD_X[0])),
float(np.min(FIELD_X[1] - arr[:, 0])),
float(np.min(arr[:, 1] - FIELD_Y[0])),
float(np.min(FIELD_Y[1] - arr[:, 1])),
)
min_dog_wall = min(
dog_x - FIELD_X[0], FIELD_X[1] - dog_x,
dog_y - FIELD_Y[0], FIELD_Y[1] - dog_y,
)
min_sheep_wall = float(min(distance_to_wall(sx, sy) for sx, sy in active))
min_dog_wall = distance_to_wall(dog_x, dog_y)
obs[17] = min_sheep_wall / 15.0
obs[18] = float(min_dog_wall) / 15.0
obs[19] = n / n_max