Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+4
View File
@@ -0,0 +1,4 @@
from src.data.dataset import DFFDataset, PathDataset, SOURCES, get_source_name
from src.data.splits import TransformSubset, apply_subsample, build_transforms, get_splits
__all__ = ["DFFDataset", "PathDataset", "SOURCES", "TransformSubset", "apply_subsample", "build_transforms", "get_source_name", "get_splits"]
+82
View File
@@ -0,0 +1,82 @@
from collections import Counter
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
# One real source (wiki) and three fake sources; 0 = real, 1 = fake
# The same identity basename appears in every source - splitting must happen
# at the identity level to prevent leakage (see splits.py), e.g:
# data_dir/SOURCE/identity/BASENAME.jpg -> same BASENAME for all SOURCEs go into same split
SOURCES = {
"wiki": 0,
"inpainting": 1,
"text2img": 1,
"insight": 1,
}
# Extracts source name from path assuming data_dir/source/identity/image.jpg layout
def get_source_name(path: Path) -> str:
return Path(path).parent.parent.name
# Walks data_dir/source/identity/*.jpg and collects (path, label) pairs
class DFFDataset(Dataset):
def __init__(self, data_dir, sources=None, transform=None):
self.transform = transform
self.samples = []
data_dir = Path(data_dir)
if not data_dir.exists():
raise FileNotFoundError(
f"Dataset root not found: {data_dir}. Expected a directory containing "
"wiki/, inpainting/, text2img/, and insight/."
)
if sources is None:
sources = list(SOURCES.keys())
for source in sources:
label = SOURCES[source]
source_dir = data_dir / source
if not source_dir.exists():
raise FileNotFoundError(
f"Missing source directory: {source_dir}. Check `data_dir` in the config."
)
for subdir in sorted(source_dir.iterdir()):
if subdir.is_dir():
for img_path in sorted(subdir.glob("*.jpg")):
self.samples.append((img_path, label))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
path, label = self.samples[idx]
image = Image.open(path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
# Useful for quickly verifying class balance before training
def label_counts(self):
return Counter(label for _, label in self.samples)
# Wraps a list of prediction records as a Dataset for re-scoring via DataLoader.
class PathDataset(Dataset):
def __init__(self, records, image_size, preprocess=None):
from src.preprocessing import get_transforms
self.records = records
self.transform = get_transforms(train=False, image_size=image_size)
self.preprocess = preprocess
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
record = self.records[idx]
image = Image.open(record["path"]).convert("RGB")
if self.preprocess is not None:
image = self.preprocess(image)
return self.transform(image), record["label"], idx
+116
View File
@@ -0,0 +1,116 @@
import random
from typing import Callable, List, Tuple
from torch.utils.data import Dataset, Subset
from src.data.dataset import get_source_name
from src.preprocessing import get_transforms
from src.utils import create_group_kfold_splits, get_basename
# Defined at module level so DataLoader workers (num_workers > 0) can serialize it safely
class TransformSubset(Dataset):
def __init__(self, subset, transform):
self.subset, self.transform = subset, transform
def __len__(self):
return len(self.subset)
def __getitem__(self, idx):
img, label = self.subset[idx]
return self.transform(img), label
# ── Splitting ──────────────────────────────────────────────────────────────
# Builds grouped CV fold indices from config; this is the only split strategy used.
def get_splits(
raw_dataset,
cfg,
) -> List[Tuple[List[int], List[int], List[int]]]:
splits = create_group_kfold_splits(
raw_dataset.samples,
n_splits=cfg.get("cv_folds", 5),
seed=cfg.get("seed", 42),
)
# Optional source holdout: train/val from train_sources, test from eval_sources.
train_sources = cfg.get("train_sources")
eval_sources = cfg.get("eval_sources")
if train_sources or eval_sources:
all_sources = {get_source_name(path) for path, _ in raw_dataset.samples}
ts = set(train_sources or all_sources)
es = set(eval_sources or all_sources)
unknown = (ts | es) - all_sources
if unknown:
raise ValueError(f"Unknown sources requested: {sorted(unknown)}")
splits = [
(
[i for i in tr if get_source_name(raw_dataset.samples[i][0]) in ts],
[i for i in val if get_source_name(raw_dataset.samples[i][0]) in ts],
[i for i in te if get_source_name(raw_dataset.samples[i][0]) in es],
)
for tr, val, te in splits
]
return splits
# Deterministic subsampling shared by training and reevaluation.
def apply_subsample(raw_dataset, cfg) -> tuple[int, int] | None:
subsample = cfg.get("subsample", 1.0)
if subsample >= 1.0:
return None
total = len(raw_dataset.samples)
if total == 0:
return 0, 0
# Subsample at basename-group level to preserve identity grouping guarantees.
group_to_indices = {}
for idx, (path, _) in enumerate(raw_dataset.samples):
group = get_basename(str(path))
group_to_indices.setdefault(group, []).append(idx)
groups = list(group_to_indices.keys())
n_groups = len(groups)
target_groups = max(1, int(n_groups * subsample))
rng = random.Random(cfg.get("seed", 42))
rng.shuffle(groups)
keep_groups = set(groups[:target_groups])
keep_indices = [
idx
for group in keep_groups
for idx in group_to_indices[group]
]
keep_indices.sort()
raw_dataset.samples = [raw_dataset.samples[i] for i in keep_indices]
return len(raw_dataset.samples), total
# Controls stochastic augmentations (flip, jitter, etc.)
# augment=False -> NO_AUGMENT preset (square-crop + resize + normalize still run)
# augment=None, pipeline defaults, augment=dict -> override specific params
# Face cropping is handled upstream via data_dir swap, not here
def build_transforms(raw_dataset, cfg, augment=None) -> Callable:
image_size = cfg["image_size"]
if augment is False:
from src.preprocessing.pipeline import DFFImagePipeline
augment = DFFImagePipeline.NO_AUGMENT
def transform_builder(indices, train=True, normalize_mean=None, normalize_std=None):
subset = Subset(raw_dataset, indices)
return TransformSubset(
subset,
get_transforms(
train=train,
image_size=image_size,
augment=augment,
normalize_mean=normalize_mean,
normalize_std=normalize_std,
),
)
return transform_builder