Clean state
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from src.data.dataset import DFFDataset, PathDataset, SOURCES, get_source_name
|
||||
from src.data.splits import TransformSubset, apply_subsample, build_transforms, get_splits
|
||||
|
||||
__all__ = ["DFFDataset", "PathDataset", "SOURCES", "TransformSubset", "apply_subsample", "build_transforms", "get_source_name", "get_splits"]
|
||||
@@ -0,0 +1,82 @@
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# One real source (wiki) and three fake sources; 0 = real, 1 = fake
|
||||
# The same identity basename appears in every source - splitting must happen
|
||||
# at the identity level to prevent leakage (see splits.py), e.g:
|
||||
# data_dir/SOURCE/identity/BASENAME.jpg -> same BASENAME for all SOURCEs go into same split
|
||||
SOURCES = {
|
||||
"wiki": 0,
|
||||
"inpainting": 1,
|
||||
"text2img": 1,
|
||||
"insight": 1,
|
||||
}
|
||||
|
||||
|
||||
# Extracts source name from path assuming data_dir/source/identity/image.jpg layout
|
||||
def get_source_name(path: Path) -> str:
|
||||
return Path(path).parent.parent.name
|
||||
|
||||
|
||||
# Walks data_dir/source/identity/*.jpg and collects (path, label) pairs
|
||||
class DFFDataset(Dataset):
|
||||
def __init__(self, data_dir, sources=None, transform=None):
|
||||
self.transform = transform
|
||||
self.samples = []
|
||||
|
||||
data_dir = Path(data_dir)
|
||||
if not data_dir.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Dataset root not found: {data_dir}. Expected a directory containing "
|
||||
"wiki/, inpainting/, text2img/, and insight/."
|
||||
)
|
||||
if sources is None:
|
||||
sources = list(SOURCES.keys())
|
||||
|
||||
for source in sources:
|
||||
label = SOURCES[source]
|
||||
source_dir = data_dir / source
|
||||
if not source_dir.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Missing source directory: {source_dir}. Check `data_dir` in the config."
|
||||
)
|
||||
for subdir in sorted(source_dir.iterdir()):
|
||||
if subdir.is_dir():
|
||||
for img_path in sorted(subdir.glob("*.jpg")):
|
||||
self.samples.append((img_path, label))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path, label = self.samples[idx]
|
||||
image = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
return image, label
|
||||
|
||||
# Useful for quickly verifying class balance before training
|
||||
def label_counts(self):
|
||||
return Counter(label for _, label in self.samples)
|
||||
|
||||
|
||||
# Wraps a list of prediction records as a Dataset for re-scoring via DataLoader.
|
||||
class PathDataset(Dataset):
|
||||
def __init__(self, records, image_size, preprocess=None):
|
||||
from src.preprocessing import get_transforms
|
||||
self.records = records
|
||||
self.transform = get_transforms(train=False, image_size=image_size)
|
||||
self.preprocess = preprocess
|
||||
|
||||
def __len__(self):
|
||||
return len(self.records)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
record = self.records[idx]
|
||||
image = Image.open(record["path"]).convert("RGB")
|
||||
if self.preprocess is not None:
|
||||
image = self.preprocess(image)
|
||||
return self.transform(image), record["label"], idx
|
||||
@@ -0,0 +1,116 @@
|
||||
import random
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
from src.data.dataset import get_source_name
|
||||
from src.preprocessing import get_transforms
|
||||
from src.utils import create_group_kfold_splits, get_basename
|
||||
|
||||
|
||||
# Defined at module level so DataLoader workers (num_workers > 0) can serialize it safely
|
||||
class TransformSubset(Dataset):
|
||||
def __init__(self, subset, transform):
|
||||
self.subset, self.transform = subset, transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.subset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img, label = self.subset[idx]
|
||||
return self.transform(img), label
|
||||
|
||||
|
||||
# ── Splitting ──────────────────────────────────────────────────────────────
|
||||
|
||||
# Builds grouped CV fold indices from config; this is the only split strategy used.
|
||||
def get_splits(
|
||||
raw_dataset,
|
||||
cfg,
|
||||
) -> List[Tuple[List[int], List[int], List[int]]]:
|
||||
splits = create_group_kfold_splits(
|
||||
raw_dataset.samples,
|
||||
n_splits=cfg.get("cv_folds", 5),
|
||||
seed=cfg.get("seed", 42),
|
||||
)
|
||||
|
||||
# Optional source holdout: train/val from train_sources, test from eval_sources.
|
||||
train_sources = cfg.get("train_sources")
|
||||
eval_sources = cfg.get("eval_sources")
|
||||
if train_sources or eval_sources:
|
||||
all_sources = {get_source_name(path) for path, _ in raw_dataset.samples}
|
||||
ts = set(train_sources or all_sources)
|
||||
es = set(eval_sources or all_sources)
|
||||
unknown = (ts | es) - all_sources
|
||||
if unknown:
|
||||
raise ValueError(f"Unknown sources requested: {sorted(unknown)}")
|
||||
splits = [
|
||||
(
|
||||
[i for i in tr if get_source_name(raw_dataset.samples[i][0]) in ts],
|
||||
[i for i in val if get_source_name(raw_dataset.samples[i][0]) in ts],
|
||||
[i for i in te if get_source_name(raw_dataset.samples[i][0]) in es],
|
||||
)
|
||||
for tr, val, te in splits
|
||||
]
|
||||
return splits
|
||||
|
||||
|
||||
# Deterministic subsampling shared by training and reevaluation.
|
||||
def apply_subsample(raw_dataset, cfg) -> tuple[int, int] | None:
|
||||
subsample = cfg.get("subsample", 1.0)
|
||||
if subsample >= 1.0:
|
||||
return None
|
||||
|
||||
total = len(raw_dataset.samples)
|
||||
if total == 0:
|
||||
return 0, 0
|
||||
|
||||
# Subsample at basename-group level to preserve identity grouping guarantees.
|
||||
group_to_indices = {}
|
||||
for idx, (path, _) in enumerate(raw_dataset.samples):
|
||||
group = get_basename(str(path))
|
||||
group_to_indices.setdefault(group, []).append(idx)
|
||||
|
||||
groups = list(group_to_indices.keys())
|
||||
n_groups = len(groups)
|
||||
target_groups = max(1, int(n_groups * subsample))
|
||||
rng = random.Random(cfg.get("seed", 42))
|
||||
rng.shuffle(groups)
|
||||
keep_groups = set(groups[:target_groups])
|
||||
|
||||
keep_indices = [
|
||||
idx
|
||||
for group in keep_groups
|
||||
for idx in group_to_indices[group]
|
||||
]
|
||||
keep_indices.sort()
|
||||
raw_dataset.samples = [raw_dataset.samples[i] for i in keep_indices]
|
||||
return len(raw_dataset.samples), total
|
||||
|
||||
|
||||
# Controls stochastic augmentations (flip, jitter, etc.)
|
||||
# augment=False -> NO_AUGMENT preset (square-crop + resize + normalize still run)
|
||||
# augment=None, pipeline defaults, augment=dict -> override specific params
|
||||
# Face cropping is handled upstream via data_dir swap, not here
|
||||
def build_transforms(raw_dataset, cfg, augment=None) -> Callable:
|
||||
image_size = cfg["image_size"]
|
||||
|
||||
if augment is False:
|
||||
from src.preprocessing.pipeline import DFFImagePipeline
|
||||
|
||||
augment = DFFImagePipeline.NO_AUGMENT
|
||||
|
||||
def transform_builder(indices, train=True, normalize_mean=None, normalize_std=None):
|
||||
subset = Subset(raw_dataset, indices)
|
||||
return TransformSubset(
|
||||
subset,
|
||||
get_transforms(
|
||||
train=train,
|
||||
image_size=image_size,
|
||||
augment=augment,
|
||||
normalize_mean=normalize_mean,
|
||||
normalize_std=normalize_std,
|
||||
),
|
||||
)
|
||||
return transform_builder
|
||||
|
||||
Reference in New Issue
Block a user