Clean state
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Tests for CV split integrity: group leakage and subsample consistency.
|
||||
"""
|
||||
import unittest
|
||||
|
||||
from src.data.splits import apply_subsample, get_splits
|
||||
|
||||
|
||||
class _DummyDataset:
|
||||
def __init__(self, samples):
|
||||
self.samples = list(samples)
|
||||
|
||||
|
||||
def _mk_samples():
|
||||
samples = []
|
||||
sources = ["wiki", "inpainting", "text2img", "insight"]
|
||||
for person in ["a", "b", "c", "d", "e"]:
|
||||
for source in sources:
|
||||
label = 0 if source == "wiki" else 1
|
||||
samples.append((f"/data/{source}/id/{person}.jpg", label))
|
||||
return samples
|
||||
|
||||
|
||||
class SplitGroupingTests(unittest.TestCase):
|
||||
def test_group_leakage_is_blocked_across_folds(self):
|
||||
ds = _DummyDataset(_mk_samples())
|
||||
cfg = {"cv_folds": 5, "seed": 42}
|
||||
for train_idx, val_idx, test_idx in get_splits(ds, cfg):
|
||||
train_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in train_idx}
|
||||
val_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in val_idx}
|
||||
test_bases = {ds.samples[i][0].split("/")[-1].split(".")[0] for i in test_idx}
|
||||
self.assertTrue(train_bases.isdisjoint(val_bases))
|
||||
self.assertTrue(train_bases.isdisjoint(test_bases))
|
||||
self.assertTrue(val_bases.isdisjoint(test_bases))
|
||||
|
||||
def test_apply_subsample_keeps_full_identity_groups(self):
|
||||
ds = _DummyDataset(_mk_samples())
|
||||
sampled, total = apply_subsample(ds, {"subsample": 0.4, "seed": 7})
|
||||
self.assertEqual(total, 20)
|
||||
self.assertGreater(sampled, 0)
|
||||
by_basename = {}
|
||||
for path, _ in ds.samples:
|
||||
base = path.split("/")[-1]
|
||||
by_basename.setdefault(base, 0)
|
||||
by_basename[base] += 1
|
||||
# Each kept basename should include all 4 source variants.
|
||||
self.assertTrue(all(count == 4 for count in by_basename.values()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user