52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
"""
|
|
Tests for CV split integrity: group leakage and subsample consistency.
|
|
"""
|
|
import unittest
|
|
|
|
from src.data.splits import apply_subsample, get_splits
|
|
|
|
|
|
class _DummyDataset:
|
|
def __init__(self, samples):
|
|
self.samples = list(samples)
|
|
|
|
|
|
def _mk_samples():
|
|
samples = []
|
|
sources = ["wiki", "inpainting", "text2img", "insight"]
|
|
for person in ["a", "b", "c", "d", "e"]:
|
|
for source in sources:
|
|
label = 0 if source == "wiki" else 1
|
|
samples.append((f"/data/{source}/id/{person}.jpg", label))
|
|
return samples
|
|
|
|
|
|
class SplitGroupingTests(unittest.TestCase):
|
|
def test_group_leakage_is_blocked_across_folds(self):
|
|
ds = _DummyDataset(_mk_samples())
|
|
cfg = {"cv_folds": 5, "seed": 42}
|
|
for train_idx, val_idx, test_idx in get_splits(ds, cfg):
|
|
train_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in train_idx}
|
|
val_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in val_idx}
|
|
test_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in test_idx}
|
|
self.assertTrue(train_bases.isdisjoint(val_bases))
|
|
self.assertTrue(train_bases.isdisjoint(test_bases))
|
|
self.assertTrue(val_bases.isdisjoint(test_bases))
|
|
|
|
def test_apply_subsample_keeps_full_identity_groups(self):
|
|
ds = _DummyDataset(_mk_samples())
|
|
sampled, total = apply_subsample(ds, {"subsample": 0.4, "seed": 7})
|
|
self.assertEqual(total, 20)
|
|
self.assertGreater(sampled, 0)
|
|
by_basename = {}
|
|
for path, _ in ds.samples:
|
|
base = path.split("/")[-1]
|
|
by_basename.setdefault(base, 0)
|
|
by_basename[base] += 1
|
|
# Each kept basename should include all 4 source variants.
|
|
self.assertTrue(all(count == 4 for count in by_basename.values()))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|