Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+3
View File
@@ -0,0 +1,3 @@
from src.utils.config import load_config
__all__ = ["load_config"]
+58
View File
@@ -0,0 +1,58 @@
import json
from pathlib import Path
from typing import Any, Dict, Optional
# Resolves the extends chain first, then overlays shared.json underneath so
# experiment-level keys always win over shared defaults.
def load_config(config_path: str, shared_path: Optional[str] = None) -> Dict[str, Any]:
config_path = Path(config_path)
cfg = _load_extends(config_path)
if shared_path is None:
shared_path = config_path.parent.parent / "shared.json"
else:
shared_path = Path(shared_path)
if shared_path.exists():
with open(shared_path) as f:
shared_cfg = json.load(f)
cfg = _deep_merge(shared_cfg, cfg)
return cfg
# Pops the "extends" key and recursively merges the parent config underneath;
# the seen set catches circular inheritance before it recurses infinitely.
def _load_extends(config_path: Path, seen: Optional[set[Path]] = None) -> Dict[str, Any]:
if seen is None:
seen = set()
resolved_path = config_path.resolve()
if resolved_path in seen:
chain = " -> ".join(str(p) for p in [*seen, resolved_path])
raise ValueError(f"Circular config inheritance detected: {chain}")
seen.add(resolved_path)
with open(config_path) as f:
cfg = json.load(f)
base_ref = cfg.pop("extends", None)
if not base_ref:
seen.remove(resolved_path)
return cfg
base_path = (config_path.parent / base_ref).resolve()
base_cfg = _load_extends(base_path, seen=seen)
seen.remove(resolved_path)
return _deep_merge(base_cfg, cfg)
# Override always wins; nested dicts are merged recursively rather than replaced.
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result