43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
"""
|
|
Tests for config loading: shared.json inheritance and extends merging.
|
|
"""
|
|
import json
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
from src.utils.config import load_config
|
|
|
|
|
|
class ConfigMergeTests(unittest.TestCase):
|
|
def test_shared_and_extends_merge(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
root = Path(td)
|
|
cfg_dir = root / "configs" / "phaseX"
|
|
cfg_dir.mkdir(parents=True)
|
|
|
|
(root / "configs" / "shared.json").write_text(json.dumps({
|
|
"batch_size": 32,
|
|
"lr": 1e-4,
|
|
"augment": {"hflip_p": 0.5, "blur_p": 0.2},
|
|
}))
|
|
(cfg_dir / "base.json").write_text(json.dumps({
|
|
"epochs": 10,
|
|
"augment": {"hflip_p": 0.1},
|
|
}))
|
|
(cfg_dir / "exp.json").write_text(json.dumps({
|
|
"extends": "base.json",
|
|
"epochs": 15,
|
|
"augment": {"blur_p": 0.0},
|
|
}))
|
|
|
|
cfg = load_config(cfg_dir / "exp.json")
|
|
self.assertEqual(cfg["batch_size"], 32)
|
|
self.assertEqual(cfg["epochs"], 15)
|
|
self.assertEqual(cfg["augment"]["hflip_p"], 0.1)
|
|
self.assertEqual(cfg["augment"]["blur_p"], 0.0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|