Clean state
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Full suite
|
||||
# python -m unittest discover -s classifier/tests -p "test_*.py" -t classifier
|
||||
|
||||
# Allow `from src...` imports when tests are run from repo root.
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Tests for config loading: shared.json inheritance and extends merging.
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from src.utils.config import load_config
|
||||
|
||||
|
||||
class ConfigMergeTests(unittest.TestCase):
|
||||
def test_shared_and_extends_merge(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
cfg_dir = root / "configs" / "phaseX"
|
||||
cfg_dir.mkdir(parents=True)
|
||||
|
||||
(root / "configs" / "shared.json").write_text(json.dumps({
|
||||
"batch_size": 32,
|
||||
"lr": 1e-4,
|
||||
"augment": {"hflip_p": 0.5, "blur_p": 0.2},
|
||||
}))
|
||||
(cfg_dir / "base.json").write_text(json.dumps({
|
||||
"epochs": 10,
|
||||
"augment": {"hflip_p": 0.1},
|
||||
}))
|
||||
(cfg_dir / "exp.json").write_text(json.dumps({
|
||||
"extends": "base.json",
|
||||
"epochs": 15,
|
||||
"augment": {"blur_p": 0.0},
|
||||
}))
|
||||
|
||||
cfg = load_config(cfg_dir / "exp.json")
|
||||
self.assertEqual(cfg["batch_size"], 32)
|
||||
self.assertEqual(cfg["epochs"], 15)
|
||||
self.assertEqual(cfg["augment"]["hflip_p"], 0.1)
|
||||
self.assertEqual(cfg["augment"]["blur_p"], 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Tests for binary_metrics edge cases: single-class inputs return null AUC/F1.
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from src.evaluation.metrics import binary_metrics
|
||||
|
||||
|
||||
class OneClassMetricTests(unittest.TestCase):
|
||||
def test_one_class_returns_none_for_auc_and_f1(self):
|
||||
logits = torch.tensor([0.1, -0.2, 0.3], dtype=torch.float32)
|
||||
labels = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32)
|
||||
metrics = binary_metrics(logits, labels)
|
||||
self.assertIsNone(metrics["auc_roc"])
|
||||
self.assertIsNone(metrics["f1"])
|
||||
self.assertIn("accuracy", metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Tests for CV split integrity: group leakage and subsample consistency.
|
||||
"""
|
||||
import unittest
|
||||
|
||||
from src.data.splits import apply_subsample, get_splits
|
||||
|
||||
|
||||
class _DummyDataset:
|
||||
def __init__(self, samples):
|
||||
self.samples = list(samples)
|
||||
|
||||
|
||||
def _mk_samples():
|
||||
samples = []
|
||||
sources = ["wiki", "inpainting", "text2img", "insight"]
|
||||
for person in ["a", "b", "c", "d", "e"]:
|
||||
for source in sources:
|
||||
label = 0 if source == "wiki" else 1
|
||||
samples.append((f"/data/{source}/id/{person}.jpg", label))
|
||||
return samples
|
||||
|
||||
|
||||
class SplitGroupingTests(unittest.TestCase):
|
||||
def test_group_leakage_is_blocked_across_folds(self):
|
||||
ds = _DummyDataset(_mk_samples())
|
||||
cfg = {"cv_folds": 5, "seed": 42}
|
||||
for train_idx, val_idx, test_idx in get_splits(ds, cfg):
|
||||
train_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in train_idx}
|
||||
val_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in val_idx}
|
||||
test_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in test_idx}
|
||||
self.assertTrue(train_bases.isdisjoint(val_bases))
|
||||
self.assertTrue(train_bases.isdisjoint(test_bases))
|
||||
self.assertTrue(val_bases.isdisjoint(test_bases))
|
||||
|
||||
def test_apply_subsample_keeps_full_identity_groups(self):
|
||||
ds = _DummyDataset(_mk_samples())
|
||||
sampled, total = apply_subsample(ds, {"subsample": 0.4, "seed": 7})
|
||||
self.assertEqual(total, 20)
|
||||
self.assertGreater(sampled, 0)
|
||||
by_basename = {}
|
||||
for path, _ in ds.samples:
|
||||
base = path.split("/")[-1]
|
||||
by_basename.setdefault(base, 0)
|
||||
by_basename[base] += 1
|
||||
# Each kept basename should include all 4 source variants.
|
||||
self.assertTrue(all(count == 4 for count in by_basename.values()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Tests for preprocessing transforms: eval pipeline is deterministic and test-safe.
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from src.preprocessing.pipeline import get_transforms
|
||||
|
||||
|
||||
class TransformTests(unittest.TestCase):
|
||||
def test_eval_transform_is_deterministic(self):
|
||||
rng = np.random.RandomState(0)
|
||||
arr = (rng.rand(128, 128, 3) * 255).astype(np.uint8)
|
||||
img = Image.fromarray(arr, mode="RGB")
|
||||
tfm = get_transforms(train=False, image_size=64)
|
||||
a = tfm(img)
|
||||
b = tfm(img)
|
||||
self.assertEqual(tuple(a.shape), (3, 64, 64))
|
||||
self.assertTrue(np.allclose(a.numpy(), b.numpy()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user