Files
DRL_PROJ/classifier/tests/test_splits.py
T
Johnny Fernandes bb3dfb92d5 Clean state
2026-04-30 01:25:39 +01:00

52 lines
1.9 KiB
Python

"""
Tests for CV split integrity: group leakage and subsample consistency.
"""
import unittest
from src.data.splits import apply_subsample, get_splits
class _DummyDataset:
def __init__(self, samples):
self.samples = list(samples)
def _mk_samples():
samples = []
sources = ["wiki", "inpainting", "text2img", "insight"]
for person in ["a", "b", "c", "d", "e"]:
for source in sources:
label = 0 if source == "wiki" else 1
samples.append((f"/data/{source}/id/{person}.jpg", label))
return samples
class SplitGroupingTests(unittest.TestCase):
def test_group_leakage_is_blocked_across_folds(self):
ds = _DummyDataset(_mk_samples())
cfg = {"cv_folds": 5, "seed": 42}
for train_idx, val_idx, test_idx in get_splits(ds, cfg):
train_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in train_idx}
val_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in val_idx}
test_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in test_idx}
self.assertTrue(train_bases.isdisjoint(val_bases))
self.assertTrue(train_bases.isdisjoint(test_bases))
self.assertTrue(val_bases.isdisjoint(test_bases))
def test_apply_subsample_keeps_full_identity_groups(self):
ds = _DummyDataset(_mk_samples())
sampled, total = apply_subsample(ds, {"subsample": 0.4, "seed": 7})
self.assertEqual(total, 20)
self.assertGreater(sampled, 0)
by_basename = {}
for path, _ in ds.samples:
base = path.split("/")[-1]
by_basename.setdefault(base, 0)
by_basename[base] += 1
# Each kept basename should include all 4 source variants.
self.assertTrue(all(count == 4 for count in by_basename.values()))
if __name__ == "__main__":
unittest.main()