""" Tests for CV split integrity: group leakage and subsample consistency. """ import unittest from src.data.splits import apply_subsample, get_splits class _DummyDataset: def __init__(self, samples): self.samples = list(samples) def _mk_samples(): samples = [] sources = ["wiki", "inpainting", "text2img", "insight"] for person in ["a", "b", "c", "d", "e"]: for source in sources: label = 0 if source == "wiki" else 1 samples.append((f"/data/{source}/id/{person}.jpg", label)) return samples class SplitGroupingTests(unittest.TestCase): def test_group_leakage_is_blocked_across_folds(self): ds = _DummyDataset(_mk_samples()) cfg = {"cv_folds": 5, "seed": 42} for train_idx, val_idx, test_idx in get_splits(ds, cfg): train_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in train_idx} val_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in val_idx} test_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in test_idx} self.assertTrue(train_bases.isdisjoint(val_bases)) self.assertTrue(train_bases.isdisjoint(test_bases)) self.assertTrue(val_bases.isdisjoint(test_bases)) def test_apply_subsample_keeps_full_identity_groups(self): ds = _DummyDataset(_mk_samples()) sampled, total = apply_subsample(ds, {"subsample": 0.4, "seed": 7}) self.assertEqual(total, 20) self.assertGreater(sampled, 0) by_basename = {} for path, _ in ds.samples: base = path.split("/")[-1] by_basename.setdefault(base, 0) by_basename[base] += 1 # Each kept basename should include all 4 source variants. self.assertTrue(all(count == 4 for count in by_basename.values())) if __name__ == "__main__": unittest.main()