Clean state
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from src.utils.config import load_config
|
||||
|
||||
__all__ = ["load_config"]
|
||||
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
# Resolves the extends chain first, then overlays shared.json underneath so
|
||||
# experiment-level keys always win over shared defaults.
|
||||
def load_config(config_path: str, shared_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
config_path = Path(config_path)
|
||||
cfg = _load_extends(config_path)
|
||||
|
||||
if shared_path is None:
|
||||
shared_path = config_path.parent.parent / "shared.json"
|
||||
else:
|
||||
shared_path = Path(shared_path)
|
||||
|
||||
if shared_path.exists():
|
||||
with open(shared_path) as f:
|
||||
shared_cfg = json.load(f)
|
||||
cfg = _deep_merge(shared_cfg, cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
# Pops the "extends" key and recursively merges the parent config underneath;
|
||||
# the seen set catches circular inheritance before it recurses infinitely.
|
||||
def _load_extends(config_path: Path, seen: Optional[set[Path]] = None) -> Dict[str, Any]:
|
||||
if seen is None:
|
||||
seen = set()
|
||||
resolved_path = config_path.resolve()
|
||||
if resolved_path in seen:
|
||||
chain = " -> ".join(str(p) for p in [*seen, resolved_path])
|
||||
raise ValueError(f"Circular config inheritance detected: {chain}")
|
||||
seen.add(resolved_path)
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
base_ref = cfg.pop("extends", None)
|
||||
if not base_ref:
|
||||
seen.remove(resolved_path)
|
||||
return cfg
|
||||
|
||||
base_path = (config_path.parent / base_ref).resolve()
|
||||
base_cfg = _load_extends(base_path, seen=seen)
|
||||
seen.remove(resolved_path)
|
||||
return _deep_merge(base_cfg, cfg)
|
||||
|
||||
|
||||
# Override always wins; nested dicts are merged recursively rather than replaced.
|
||||
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = base.copy()
|
||||
for key, value in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
Reference in New Issue
Block a user