Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+4
View File
@@ -0,0 +1,4 @@
from src.data.dataset import DFFDataset, PathDataset, SOURCES, get_source_name
from src.data.splits import TransformSubset, apply_subsample, build_transforms, get_splits
__all__ = ["DFFDataset", "PathDataset", "SOURCES", "TransformSubset", "apply_subsample", "build_transforms", "get_source_name", "get_splits"]
+82
View File
@@ -0,0 +1,82 @@
from collections import Counter
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
# One real source (wiki) and three fake sources; 0 = real, 1 = fake
# The same identity basename appears in every source - splitting must happen
# at the identity level to prevent leakage (see splits.py), e.g:
# data_dir/SOURCE/identity/BASENAME.jpg -> same BASENAME for all SOURCEs go into same split
SOURCES = {
"wiki": 0,
"inpainting": 1,
"text2img": 1,
"insight": 1,
}
# Extracts source name from path assuming data_dir/source/identity/image.jpg layout
def get_source_name(path: Path) -> str:
return Path(path).parent.parent.name
# Walks data_dir/source/identity/*.jpg and collects (path, label) pairs
class DFFDataset(Dataset):
def __init__(self, data_dir, sources=None, transform=None):
self.transform = transform
self.samples = []
data_dir = Path(data_dir)
if not data_dir.exists():
raise FileNotFoundError(
f"Dataset root not found: {data_dir}. Expected a directory containing "
"wiki/, inpainting/, text2img/, and insight/."
)
if sources is None:
sources = list(SOURCES.keys())
for source in sources:
label = SOURCES[source]
source_dir = data_dir / source
if not source_dir.exists():
raise FileNotFoundError(
f"Missing source directory: {source_dir}. Check `data_dir` in the config."
)
for subdir in sorted(source_dir.iterdir()):
if subdir.is_dir():
for img_path in sorted(subdir.glob("*.jpg")):
self.samples.append((img_path, label))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
path, label = self.samples[idx]
image = Image.open(path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
# Useful for quickly verifying class balance before training
def label_counts(self):
return Counter(label for _, label in self.samples)
# Wraps a list of prediction records as a Dataset for re-scoring via DataLoader.
class PathDataset(Dataset):
def __init__(self, records, image_size, preprocess=None):
from src.preprocessing import get_transforms
self.records = records
self.transform = get_transforms(train=False, image_size=image_size)
self.preprocess = preprocess
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
record = self.records[idx]
image = Image.open(record["path"]).convert("RGB")
if self.preprocess is not None:
image = self.preprocess(image)
return self.transform(image), record["label"], idx
+116
View File
@@ -0,0 +1,116 @@
import random
from typing import Callable, List, Tuple
from torch.utils.data import Dataset, Subset
from src.data.dataset import get_source_name
from src.preprocessing import get_transforms
from src.utils import create_group_kfold_splits, get_basename
# Defined at module level so DataLoader workers (num_workers > 0) can serialize it safely
class TransformSubset(Dataset):
def __init__(self, subset, transform):
self.subset, self.transform = subset, transform
def __len__(self):
return len(self.subset)
def __getitem__(self, idx):
img, label = self.subset[idx]
return self.transform(img), label
# ── Splitting ──────────────────────────────────────────────────────────────
# Builds grouped CV fold indices from config; this is the only split strategy used.
def get_splits(
raw_dataset,
cfg,
) -> List[Tuple[List[int], List[int], List[int]]]:
splits = create_group_kfold_splits(
raw_dataset.samples,
n_splits=cfg.get("cv_folds", 5),
seed=cfg.get("seed", 42),
)
# Optional source holdout: train/val from train_sources, test from eval_sources.
train_sources = cfg.get("train_sources")
eval_sources = cfg.get("eval_sources")
if train_sources or eval_sources:
all_sources = {get_source_name(path) for path, _ in raw_dataset.samples}
ts = set(train_sources or all_sources)
es = set(eval_sources or all_sources)
unknown = (ts | es) - all_sources
if unknown:
raise ValueError(f"Unknown sources requested: {sorted(unknown)}")
splits = [
(
[i for i in tr if get_source_name(raw_dataset.samples[i][0]) in ts],
[i for i in val if get_source_name(raw_dataset.samples[i][0]) in ts],
[i for i in te if get_source_name(raw_dataset.samples[i][0]) in es],
)
for tr, val, te in splits
]
return splits
# Deterministic subsampling shared by training and reevaluation.
def apply_subsample(raw_dataset, cfg) -> tuple[int, int] | None:
subsample = cfg.get("subsample", 1.0)
if subsample >= 1.0:
return None
total = len(raw_dataset.samples)
if total == 0:
return 0, 0
# Subsample at basename-group level to preserve identity grouping guarantees.
group_to_indices = {}
for idx, (path, _) in enumerate(raw_dataset.samples):
group = get_basename(str(path))
group_to_indices.setdefault(group, []).append(idx)
groups = list(group_to_indices.keys())
n_groups = len(groups)
target_groups = max(1, int(n_groups * subsample))
rng = random.Random(cfg.get("seed", 42))
rng.shuffle(groups)
keep_groups = set(groups[:target_groups])
keep_indices = [
idx
for group in keep_groups
for idx in group_to_indices[group]
]
keep_indices.sort()
raw_dataset.samples = [raw_dataset.samples[i] for i in keep_indices]
return len(raw_dataset.samples), total
# Controls stochastic augmentations (flip, jitter, etc.)
# augment=False -> NO_AUGMENT preset (square-crop + resize + normalize still run)
# augment=None, pipeline defaults, augment=dict -> override specific params
# Face cropping is handled upstream via data_dir swap, not here
def build_transforms(raw_dataset, cfg, augment=None) -> Callable:
image_size = cfg["image_size"]
if augment is False:
from src.preprocessing.pipeline import DFFImagePipeline
augment = DFFImagePipeline.NO_AUGMENT
def transform_builder(indices, train=True, normalize_mean=None, normalize_std=None):
subset = Subset(raw_dataset, indices)
return TransformSubset(
subset,
get_transforms(
train=train,
image_size=image_size,
augment=augment,
normalize_mean=normalize_mean,
normalize_std=normalize_std,
),
)
return transform_builder
+11
View File
@@ -0,0 +1,11 @@
from src.evaluation.evaluate import (
predict_rows,
rescore_rows,
save_errors,
save_hists,
save_preds,
save_summary,
)
from src.evaluation.metrics import binary_metrics, calc_metrics, pair_metrics, source_metrics
__all__ = ["binary_metrics", "calc_metrics", "pair_metrics", "predict_rows", "rescore_rows", "save_errors", "save_hists", "save_preds", "save_summary", "source_metrics"]
+135
View File
@@ -0,0 +1,135 @@
import csv
import json
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from src.data import PathDataset, get_source_name
# ── Inference ──────────────────────────────────────────────────────────────
# Run model inference and return one prediction record per sample
# raw_dataset and indices supply path/source/label metadata, decoupling this
# function from the wrapper layout (Subset/TransformSubset) of `dataset`
def predict_rows(model, dataset, raw_dataset, indices, batch_size, device, *, num_workers=4):
loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
records = []
model.eval().to(device)
with torch.no_grad():
offset = 0
for images, labels in loader:
logits = model(images.to(device)).squeeze(1).cpu()
probs = torch.sigmoid(logits)
preds = (probs >= 0.5).long()
for i in range(len(labels)):
sample_idx = indices[offset + i]
path, label = raw_dataset.samples[sample_idx]
records.append({
"path": str(path),
"basename": Path(path).name,
"source": get_source_name(path),
"label": int(label),
"pred": int(preds[i].item()),
"prob_fake": float(probs[i].item()),
"logit": float(logits[i].item()),
})
offset += len(labels)
return records
# Re-run model scoring for an existing list of records
def rescore_rows(
model, records, image_size, batch_size, device, preprocess=None, *, num_workers=4
):
dataset = PathDataset(records, image_size=image_size, preprocess=preprocess)
loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
outputs = []
model.eval().to(device)
with torch.no_grad():
for images, _, indices in loader:
logits = model(images.to(device)).squeeze(1).cpu()
probs = torch.sigmoid(logits)
preds = (probs >= 0.5).long()
for j in range(len(indices)):
base = dict(records[int(indices[j].item())])
base["prob_fake"] = float(probs[j].item())
base["pred"] = int(preds[j].item())
base["logit"] = float(logits[j].item())
outputs.append(base)
return outputs
# ── Export ─────────────────────────────────────────────────────────────────
# Writes all prediction records to CSV in a fixed column order
def save_preds(records, output_path):
output_path.parent.mkdir(parents=True, exist_ok=True)
fieldnames = ["path", "basename", "source", "label", "pred", "prob_fake", "logit"]
with open(output_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(records)
# Saves the top-k highest-confidence false positives and false negatives for error analysis
def save_errors(records, output_path, top_k=32):
false_positives = sorted(
(r for r in records if r["label"] == 0 and r["pred"] == 1),
key=lambda r: r["prob_fake"],
reverse=True,
)[:top_k]
false_negatives = sorted(
(r for r in records if r["label"] == 1 and r["pred"] == 0),
key=lambda r: r["prob_fake"],
)[:top_k]
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump({
"top_false_positives": false_positives,
"top_false_negatives": false_negatives,
}, f, indent=2)
# Saves P(fake) histograms: one overall class comparison and one per source
def save_hists(records, output_dir):
output_dir.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(8, 5))
real_probs = [r["prob_fake"] for r in records if r["label"] == 0]
fake_probs = [r["prob_fake"] for r in records if r["label"] == 1]
ax.hist(real_probs, bins=30, alpha=0.6, label="real", density=True)
ax.hist(fake_probs, bins=30, alpha=0.6, label="fake", density=True)
ax.set_xlabel("Predicted P(fake)")
ax.set_ylabel("Density")
ax.set_title("Confidence by class")
ax.legend()
fig.tight_layout()
fig.savefig(output_dir / "confidence_by_class.png", dpi=160)
plt.close(fig)
for source in sorted({r["source"] for r in records}):
fig, ax = plt.subplots(figsize=(8, 5))
source_probs = [r["prob_fake"] for r in records if r["source"] == source]
ax.hist(source_probs, bins=30, alpha=0.8)
ax.set_xlabel("Predicted P(fake)")
ax.set_ylabel("Count")
ax.set_title(f"Confidence distribution: {source}")
fig.tight_layout()
fig.savefig(output_dir / f"confidence_{source}.png", dpi=160)
plt.close(fig)
# Saves the aggregated results dict as a formatted JSON file
def save_summary(summary, output_path):
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(summary, f, indent=2)
+70
View File
@@ -0,0 +1,70 @@
import torch
import numpy as np
from sklearn.metrics import (
accuracy_score,
roc_auc_score,
f1_score,
confusion_matrix,
)
# AUC and F1 are undefined when only one class is present in the batch;
# returns None for those fields rather than raising
def binary_metrics(logits: torch.Tensor, labels: torch.Tensor) -> dict:
probs = torch.sigmoid(logits).numpy()
preds = (probs >= 0.5).astype(int)
y = labels.numpy().astype(int)
has_both_classes = len(np.unique(y)) > 1
auc_roc = float(roc_auc_score(y, probs)) if has_both_classes else None
f1 = float(f1_score(y, preds, zero_division=0)) if has_both_classes else None
return {
"accuracy": float(accuracy_score(y, preds)),
"auc_roc": auc_roc,
"f1": f1,
"confusion_matrix": confusion_matrix(y, preds, labels=[0, 1]),
}
# Converts per-sample records to tensors and delegates to binary_metrics
def calc_metrics(records):
logits = torch.tensor([r["logit"] for r in records], dtype=torch.float32)
labels = torch.tensor([r["label"] for r in records], dtype=torch.float32)
metrics = binary_metrics(logits, labels)
metrics["confusion_matrix"] = metrics["confusion_matrix"].tolist()
return metrics
# Per-source summaries; fake sources get detection_rate + pairwise_auc since AUC is undefined single-class
def source_metrics(records, real_source="wiki"):
wiki_records = [r for r in records if r["source"] == real_source]
by_source = {}
for source in sorted({r["source"] for r in records}):
source_records = [r for r in records if r["source"] == source]
metrics = calc_metrics(source_records)
metrics["n"] = len(source_records)
labels = [r["label"] for r in source_records]
if len(set(labels)) == 1:
if labels[0] == 1: # all fake
metrics["detection_rate"] = metrics["accuracy"]
if wiki_records:
pair_m = calc_metrics(wiki_records + source_records)
metrics["pairwise_auc"] = pair_m["auc_roc"]
metrics["pairwise_f1"] = pair_m["f1"]
else: # all real (wiki)
metrics["false_alarm_rate"] = 1.0 - metrics["accuracy"]
by_source[source] = metrics
return by_source
# Real-vs-one-fake AUC/F1 per fake source - more interpretable than global AUC when class ratios vary
def pair_metrics(records, real_source="wiki"):
fake_sources = sorted({r["source"] for r in records if r["source"] != real_source})
pairwise = {}
for fake_source in fake_sources:
subset = [r for r in records if r["source"] in {real_source, fake_source}]
if subset:
pairwise[f"{real_source}_vs_{fake_source}"] = {
"sources": [real_source, fake_source],
"n": len(subset),
**calc_metrics(subset),
}
return pairwise
+35
View File
@@ -0,0 +1,35 @@
from pathlib import Path
from typing import Callable, Union
import torch
import torch.nn as nn
# Maps backbone name -> builder function; populated by each model module at import time
_REGISTRY: dict[str, Callable[[dict], nn.Module]] = {}
# Called by each model module to advertise its backbone(s) to get_model
def register(name: str, builder: Callable[[dict], nn.Module]) -> None:
_REGISTRY[name] = builder
# Instantiates the backbone requested in cfg["backbone"]
def get_model(cfg: dict) -> nn.Module:
backbone = cfg.get("backbone", "simple_cnn")
builder = _REGISTRY.get(backbone)
if builder is None:
available = ", ".join(sorted(_REGISTRY))
raise ValueError(f"Unknown backbone: {backbone!r}. Available: {available}")
return builder(cfg)
# Loads a saved state-dict into model in-place and returns it
def load_checkpoint(model: nn.Module, path: Union[Path, str], device) -> nn.Module:
model.load_state_dict(torch.load(path, map_location=device, weights_only=True))
return model
# Importing the modules triggers their register() calls
from src.models import simple_cnn, resnet, efficientnet # noqa: E402, F401
__all__ = ["get_model", "load_checkpoint", "register"]
+23
View File
@@ -0,0 +1,23 @@
import torch.nn as nn
from torchvision import models
from src.models import register
# EfficientNet's classification head is a Sequential; [-1] targets the final Linear
def build(cfg: dict) -> nn.Module:
backbone = cfg.get("backbone", "efficientnet_b0")
pretrained = cfg.get("pretrained", True)
if backbone == "efficientnet_b0":
weights = models.EfficientNet_B0_Weights.DEFAULT if pretrained else None
model = models.efficientnet_b0(weights=weights)
else:
raise ValueError(f"Unsupported EfficientNet backbone: {backbone!r}. Supported: efficientnet_b0")
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
register("efficientnet_b0", build)
+30
View File
@@ -0,0 +1,30 @@
import torch.nn as nn
from torchvision import models
from src.models import register
# Loads pretrained ResNet and replaces the 1000-class head with a single logit for binary detection
def build(cfg: dict) -> nn.Module:
backbone = cfg.get("backbone", "resnet18")
pretrained = cfg.get("pretrained", True)
if backbone == "resnet18":
weights = models.ResNet18_Weights.DEFAULT if pretrained else None
model = models.resnet18(weights=weights)
elif backbone == "resnet34":
weights = models.ResNet34_Weights.DEFAULT if pretrained else None
model = models.resnet34(weights=weights)
elif backbone == "resnet50":
weights = models.ResNet50_Weights.DEFAULT if pretrained else None
model = models.resnet50(weights=weights)
else:
raise ValueError(f"Unsupported backbone: {backbone!r}")
model.fc = nn.Linear(model.fc.in_features, 1)
return model
register("resnet18", build)
register("resnet34", build)
register("resnet50", build)
+49
View File
@@ -0,0 +1,49 @@
import torch.nn as nn
from src.models import register
# Named presets map cnn_preset config values to channel lists
CNN_PRESETS = {
"micro": [8, 16],
"small": [8, 16, 32],
"medium": [16, 32, 64, 64],
"large": [32, 64, 128, 256],
}
# Each entry in channels builds a Conv -> BN -> ReLU -> Pool block
# the last block pools to 1×1 so the head is resolution-independent
class SimpleCNN(nn.Module):
def __init__(self, channels=None, in_channels=3, dropout=0.0):
super().__init__()
if channels is None:
channels = CNN_PRESETS["medium"]
layers = []
prev = in_channels
for i, ch in enumerate(channels):
layers += [nn.Conv2d(prev, ch, 3, padding=1), nn.BatchNorm2d(ch), nn.ReLU()]
if i < len(channels) - 1:
layers.append(nn.MaxPool2d(2))
else:
layers.append(nn.AdaptiveAvgPool2d(1))
prev = ch
self.features = nn.Sequential(*layers)
head = []
if dropout > 0:
head.append(nn.Dropout(dropout))
head.append(nn.Linear(channels[-1], 1))
self.classifier = nn.Sequential(*head)
def forward(self, x):
return self.classifier(self.features(x).flatten(1))
# Resolves cnn_channels > cnn_preset > "medium" fallback
def build(cfg: dict) -> nn.Module:
channels = cfg.get("cnn_channels") or CNN_PRESETS.get(cfg.get("cnn_preset", "medium"), CNN_PRESETS["medium"])
return SimpleCNN(channels=channels, dropout=cfg.get("dropout", 0.0))
register("simple_cnn", build)
+3
View File
@@ -0,0 +1,3 @@
from src.preprocessing.pipeline import DFFImagePipeline, get_transforms
__all__ = ["DFFImagePipeline", "get_transforms"]
+231
View File
@@ -0,0 +1,231 @@
import io
import random
import numpy as np
import torch
from PIL import Image, ImageFilter
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
# Per-channel mean and std of the ImageNet training set (RGB order)
# Required when using torchvision pretrained weights — they were trained with
# this exact normalisation and expect it at inference time
_IMAGENET_MEAN = (0.485, 0.456, 0.406)
_IMAGENET_STD = (0.229, 0.224, 0.225)
# Computes per-channel mean and std from real (label=0) training samples only
# Used for the real-norm experiment to test whether the model relies on
# colour/brightness differences between real and fake rather than identity cues
def compute_real_stats(dataset, indices, max_samples=1000, seed=42):
real_indices = [i for i in indices if dataset.samples[i][1] == 0]
if not real_indices:
return _IMAGENET_MEAN, _IMAGENET_STD
if len(real_indices) > max_samples:
rng = np.random.RandomState(seed)
real_indices = rng.choice(real_indices, max_samples, replace=False).tolist()
means, vars_ = [], []
for i in real_indices:
path, _ = dataset.samples[i]
img = np.array(Image.open(path).convert("RGB"), dtype=np.float32) / 255.0
means.append(img.mean(axis=(0, 1)))
vars_.append(img.var(axis=(0, 1)))
mean = tuple(float(x) for x in np.mean(means, axis=0))
std = tuple(float(x) for x in np.sqrt(np.mean(vars_, axis=0)))
return mean, std
# Single-image preprocessing pipeline for training and evaluation
# Square-crops first to remove the real-rectangular / fake-square geometry cue,
# then resizes, augments (train only), and normalizes
# augment=None uses DEFAULTS; augment=dict overrides specific keys; NO_AUGMENT disables all stochastic ops
class DFFImagePipeline:
DEFAULTS = {
"crop_scale": [0.85, 1.0],
"center_jitter": 0.1,
"hflip_p": 0.5,
"rotation_degrees": 15,
"brightness": 0.4,
"contrast": 0.4,
"saturation": 0.3,
"hue": 0.05,
"grayscale_p": 0.2,
"blur_p": 0.3,
"blur_radius": [0.1, 1.5],
"jpeg_p": 0.3,
"jpeg_quality": [65, 95],
"erase_p": 0.3,
"erase_scale": [0.02, 0.15],
"noise_p": 0.2,
"noise_std": 0.05,
}
# Pass as augment=DFFImagePipeline.NO_AUGMENT to keep crop+resize+normalize but skip all randomness
NO_AUGMENT: dict = {
"crop_scale": [1.0, 1.0],
"center_jitter": 0.0,
"hflip_p": 0.0,
"rotation_degrees": 0,
"brightness": 0.0,
"contrast": 0.0,
"saturation": 0.0,
"hue": 0.0,
"grayscale_p": 0.0,
"blur_p": 0.0,
"jpeg_p": 0.0,
"erase_p": 0.0,
"noise_p": 0.0,
}
def __init__(self, *, image_size: int, train: bool, augment: dict | None = None,
normalize_mean=None, normalize_std=None):
self.image_size = image_size
self.train = train
self.normalize_mean = normalize_mean or _IMAGENET_MEAN
self.normalize_std = normalize_std or _IMAGENET_STD
cfg = {**self.DEFAULTS, **(augment or {})}
self.crop_scale = tuple(cfg["crop_scale"])
self.center_jitter = cfg["center_jitter"]
self.hflip_p = cfg["hflip_p"]
self.rotation_degrees = cfg["rotation_degrees"]
self.brightness = cfg["brightness"]
self.contrast = cfg["contrast"]
self.saturation = cfg["saturation"]
self.hue = cfg["hue"]
self.grayscale_p = cfg["grayscale_p"]
self.blur_p = cfg["blur_p"]
self.blur_radius = tuple(cfg["blur_radius"])
self.jpeg_p = cfg["jpeg_p"]
self.jpeg_quality = tuple(cfg["jpeg_quality"])
self.erase_p = cfg["erase_p"]
self.erase_scale = tuple(cfg["erase_scale"])
self.noise_p = cfg["noise_p"]
self.noise_std = cfg["noise_std"]
# Geometry transforms
def _crop_square(self, img: Image.Image) -> Image.Image:
width, height = img.size
short_side = min(width, height)
center_top = max((height - short_side) // 2, 0)
center_left = max((width - short_side) // 2, 0)
if self.train:
min_scale, max_scale = self.crop_scale
scale = random.uniform(min_scale, max_scale)
crop_size = max(1, int(short_side * scale))
top = max((height - crop_size) // 2, 0)
left = max((width - crop_size) // 2, 0)
jitter_y = int((height - crop_size) * self.center_jitter)
jitter_x = int((width - crop_size) * self.center_jitter)
if jitter_y > 0:
top += random.randint(-jitter_y, jitter_y)
if jitter_x > 0:
left += random.randint(-jitter_x, jitter_x)
top = max(0, min(top, height - crop_size))
left = max(0, min(left, width - crop_size))
else:
crop_size = short_side
top = center_top
left = center_left
return F.crop(img, top=top, left=left, height=crop_size, width=crop_size)
def _maybe_flip(self, img: Image.Image) -> Image.Image:
if self.train and random.random() < self.hflip_p:
return F.hflip(img)
return img
def _maybe_rotate(self, img: Image.Image) -> Image.Image:
if self.train and self.rotation_degrees > 0:
angle = random.uniform(-self.rotation_degrees, self.rotation_degrees)
return F.rotate(img, angle, interpolation=InterpolationMode.BILINEAR, fill=0)
return img
# Photometric transforms
def _jitter_factor(self, amount: float) -> float:
return random.uniform(max(0.0, 1.0 - amount), 1.0 + amount)
def _maybe_color_jitter(self, img: Image.Image) -> Image.Image:
if not self.train:
return img
img = F.adjust_brightness(img, self._jitter_factor(self.brightness))
img = F.adjust_contrast(img, self._jitter_factor(self.contrast))
img = F.adjust_saturation(img, self._jitter_factor(self.saturation))
img = F.adjust_hue(img, random.uniform(-self.hue, self.hue))
return img
def _maybe_grayscale(self, img: Image.Image) -> Image.Image:
if self.train and random.random() < self.grayscale_p:
return F.to_grayscale(img, num_output_channels=3)
return img
def _maybe_blur(self, img: Image.Image) -> Image.Image:
if self.train and random.random() < self.blur_p:
radius = random.uniform(*self.blur_radius)
return img.filter(ImageFilter.GaussianBlur(radius=radius))
return img
# Bias-reduction transforms
# JPEG recompression removes high-frequency GAN artifacts that survive other augmentations
def _maybe_jpeg(self, img: Image.Image) -> Image.Image:
if self.train and random.random() < self.jpeg_p:
quality = random.randint(*self.jpeg_quality)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=quality)
buf.seek(0)
img = Image.open(buf).convert("RGB")
img.load() # decode pixels while buf is still in scope
return img
# Random erasing forces the model to use multiple regions rather than a single discriminative patch
def _maybe_erase(self, tensor: torch.Tensor) -> torch.Tensor:
if self.train and random.random() < self.erase_p:
c, h, w = tensor.shape
area = h * w
min_scale, max_scale = self.erase_scale
erase_area = area * random.uniform(min_scale, max_scale)
aspect = random.uniform(0.3, 1.0 / 0.3)
eh = int(round((erase_area * aspect) ** 0.5))
ew = int(round((erase_area / aspect) ** 0.5))
eh, ew = min(eh, h), min(ew, w)
top = random.randint(0, h - eh)
left = random.randint(0, w - ew)
tensor = tensor.clone()
tensor[:, top:top + eh, left:left + ew] = torch.rand(c, eh, ew)
return tensor
# Gaussian noise improves robustness to sensor noise vs GAN noise patterns
def _maybe_noise(self, tensor: torch.Tensor) -> torch.Tensor:
if self.train and random.random() < self.noise_p:
noise = torch.randn_like(tensor) * self.noise_std
tensor = tensor + noise
return tensor
# Pipeline entrypoint
def __call__(self, img: Image.Image) -> torch.Tensor:
img = self._crop_square(img)
img = F.resize(img, [self.image_size, self.image_size], interpolation=InterpolationMode.BILINEAR)
img = self._maybe_flip(img)
img = self._maybe_rotate(img)
img = self._maybe_color_jitter(img)
img = self._maybe_grayscale(img)
img = self._maybe_blur(img)
img = self._maybe_jpeg(img)
tensor = F.to_tensor(img)
tensor = self._maybe_erase(tensor)
tensor = self._maybe_noise(tensor)
return F.normalize(tensor, self.normalize_mean, self.normalize_std)
# Convenience wrapper used by splits.py and evaluate.py
def get_transforms(train=True, image_size=224, augment=None,
normalize_mean=None, normalize_std=None):
return DFFImagePipeline(image_size=image_size, train=train, augment=augment,
normalize_mean=normalize_mean, normalize_std=normalize_std)
+3
View File
@@ -0,0 +1,3 @@
from src.training.trainer import train_classifier, train_classifier_cv
__all__ = ["train_classifier", "train_classifier_cv"]
+374
View File
@@ -0,0 +1,374 @@
import json
from collections import Counter
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from src.evaluation.metrics import binary_metrics
# ── AMP compatibility shim ─────────────────────────────────────────────────
# torch.amp.GradScaler / autocast moved from torch.cuda.amp in PyTorch 2.3+
if hasattr(torch.amp, "GradScaler"):
_GradScaler = torch.amp.GradScaler
_autocast = torch.amp.autocast
else:
from torch.cuda.amp import GradScaler as _OldGradScaler, autocast as _OldAutocast
_GradScaler = lambda device="", enabled=True, **kw: _OldGradScaler(enabled=enabled, **kw)
_autocast = lambda device_type="", enabled=True, **kw: _OldAutocast(enabled=enabled, **kw)
# ── Single-fold training ───────────────────────────────────────────────────
# Trains one fold; saves best checkpoint by val AUC-ROC and final checkpoint.
# pos_weight is passed through to BCEWithLogitsLoss to handle class imbalance.
def train_classifier(
model,
train_dataset,
val_dataset,
*,
epochs=10,
batch_size=16,
lr=1e-4,
weight_decay=1e-4,
device="cuda",
save_dir="outputs/models",
run_name="classifier",
early_stopping_patience=0,
num_workers=4,
grad_clip_norm=1.0,
T_max=None,
pos_weight=None,
):
device = torch.device(device)
if device.type == "cuda":
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
print("Using CPU")
use_amp = device.type == "cuda"
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {n_params:,}")
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
)
pw = torch.tensor([pos_weight], device=device) if pos_weight is not None else None
criterion = nn.BCEWithLogitsLoss(pos_weight=pw)
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=lr, weight_decay=weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max or epochs)
scaler = _GradScaler("cuda", enabled=use_amp)
print(f"Device: {device} AMP: {use_amp}")
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
history = {
"train_loss": [], "train_acc": [], "train_auc": [], "train_f1": [],
"val_loss": [], "val_acc": [], "val_auc": [], "val_f1": [],
}
best_auc = 0.0
patience_counter = 0
for epoch in range(1, epochs + 1):
# ── train ──
model.train()
total_loss = 0.0
train_logits, train_labels = [], []
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [train]", leave=False):
images = images.to(device)
labels = labels.float().to(device)
optimizer.zero_grad()
with _autocast("cuda", enabled=use_amp):
logits = model(images).squeeze(1)
loss = criterion(logits, labels)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
scaler.step(optimizer)
scaler.update()
total_loss += loss.item() * len(images)
train_logits.append(logits.detach().cpu())
train_labels.append(labels.detach().cpu())
train_loss = total_loss / sum(len(logit) for logit in train_logits)
train_m = binary_metrics(torch.cat(train_logits), torch.cat(train_labels))
scheduler.step()
# ── validate ──
model.eval()
val_loss = 0.0
all_logits, all_labels = [], []
with torch.no_grad():
for images, labels in tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} [val]", leave=False):
images = images.to(device)
labels = labels.float().to(device)
with _autocast("cuda", enabled=use_amp):
logits = model(images).squeeze(1)
batch_loss = criterion(logits, labels)
if not (torch.isnan(batch_loss) or torch.isinf(batch_loss)):
val_loss += batch_loss.item() * len(images)
all_logits.append(logits.cpu())
all_labels.append(labels.cpu())
val_loss /= len(val_dataset)
val_m = binary_metrics(torch.cat(all_logits), torch.cat(all_labels))
# ── record ──
history["train_loss"].append(train_loss)
history["train_acc"].append(train_m["accuracy"])
history["train_auc"].append(train_m["auc_roc"])
history["train_f1"].append(train_m["f1"])
history["val_loss"].append(val_loss)
history["val_acc"].append(val_m["accuracy"])
history["val_auc"].append(val_m["auc_roc"])
history["val_f1"].append(val_m["f1"])
gap_loss = train_loss - val_loss
gap_acc = train_m["accuracy"] - val_m["accuracy"]
print(
f"[{epoch:03d}/{epochs}] "
f"loss: {train_loss:.4f}/{val_loss:.4f} (gap {gap_loss:+.4f}) "
f"acc: {train_m['accuracy']:.4f}/{val_m['accuracy']:.4f} (gap {gap_acc:+.4f}) "
f"auc: {train_m['auc_roc']:.4f}/{val_m['auc_roc']:.4f} "
f"f1: {train_m['f1']:.4f}/{val_m['f1']:.4f}"
)
# ── checkpoint ──
if val_m["auc_roc"] is not None and val_m["auc_roc"] > best_auc:
best_auc = val_m["auc_roc"]
torch.save(model.state_dict(), save_dir / f"{run_name}_best.pt")
patience_counter = 0
else:
patience_counter += 1
if early_stopping_patience > 0 and patience_counter >= early_stopping_patience:
print(f"Early stopping at epoch {epoch} (no improvement for {early_stopping_patience} epochs)")
break
torch.save(model.state_dict(), save_dir / f"{run_name}_final.pt")
return history
# ── CV training ────────────────────────────────────────────────────────────
# Iterates over pre-built splits, trains one model per fold, evaluates on the
# held-out test fold, then aggregates metrics across folds with mean ± std
def train_classifier_cv(
model_fn,
raw_dataset,
splits,
*,
epochs=10,
batch_size=16,
lr=1e-4,
weight_decay=1e-4,
device="cuda",
save_dir="outputs/models",
run_name="classifier_cv",
early_stopping_patience=0,
num_workers=4,
transform_builder=None,
grad_clip_norm=1.0,
T_max=None,
normalization=None,
logs_dir=None,
):
from src.evaluation.evaluate import (
predict_rows, save_errors, save_hists, save_preds, save_summary,
)
from src.evaluation.metrics import binary_metrics, calc_metrics, source_metrics, pair_metrics
from src.utils.cross_validation import aggregate_fold_metrics
device = torch.device(device if torch.cuda.is_available() else "cpu")
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
eval_dir = Path(logs_dir) / run_name if logs_dir is not None else None
if eval_dir is not None:
eval_dir.mkdir(parents=True, exist_ok=True)
fold_results = []
all_records = []
for fold_idx, (train_idx, val_idx, test_idx) in enumerate(splits):
print(f"\n{'='*60}")
print(f"Fold {fold_idx + 1}/{len(splits)}")
print(f" Train: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}")
print(f"{'='*60}")
model = model_fn().to(device)
norm_mean = norm_std = None
if normalization == "real_norm":
from src.preprocessing.pipeline import compute_real_stats
norm_mean, norm_std = compute_real_stats(raw_dataset, train_idx)
print(f" Real-norm stats: mean={norm_mean}, std={norm_std}")
if transform_builder is not None:
train_dataset = transform_builder(train_idx, train=True,
normalize_mean=norm_mean, normalize_std=norm_std)
val_dataset = transform_builder(val_idx, train=False,
normalize_mean=norm_mean, normalize_std=norm_std)
test_dataset = transform_builder(test_idx, train=False,
normalize_mean=norm_mean, normalize_std=norm_std)
else:
train_dataset = Subset(raw_dataset, train_idx)
val_dataset = Subset(raw_dataset, val_idx)
test_dataset = Subset(raw_dataset, test_idx)
# Compute pos_weight = n_real / n_fake for BCEWithLogitsLoss class balancing
train_labels = [raw_dataset.samples[i][1] for i in train_idx]
class_counts = Counter(train_labels)
pos_weight = class_counts[0] / class_counts[1] if class_counts[1] > 0 else 1.0
fold_run_name = f"{run_name}_fold{fold_idx}"
history = train_classifier(
model,
train_dataset,
val_dataset,
epochs=epochs,
batch_size=batch_size,
lr=lr,
weight_decay=weight_decay,
device=device,
save_dir=save_dir,
run_name=fold_run_name,
early_stopping_patience=early_stopping_patience,
num_workers=num_workers,
grad_clip_norm=grad_clip_norm,
T_max=T_max,
pos_weight=pos_weight,
)
# Load best checkpoint and evaluate on test set
checkpoint_path = save_dir / f"{fold_run_name}_best.pt"
if checkpoint_path.exists():
model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))
model.eval()
records = predict_rows(
model, test_dataset, raw_dataset, test_idx,
batch_size, device, num_workers=num_workers,
)
# Compute aggregate and per-source test metrics
test_metrics = calc_metrics(records)
src_metrics = source_metrics(records)
pairwise = pair_metrics(records)
fold_result = {
"fold": fold_idx,
"train_size": len(train_idx),
"val_size": len(val_idx),
"test_size": len(test_idx),
"history": history,
"test_metrics": test_metrics,
"source_metrics": src_metrics,
"pair_metrics": pairwise,
}
fold_results.append(fold_result)
all_records.extend(records)
if eval_dir is not None:
fold_dir = eval_dir / f"fold{fold_idx}"
save_preds(records, fold_dir / "preds.csv")
save_errors(records, fold_dir / "errors.json")
print(f"\nFold {fold_idx + 1} Test Metrics:")
for key, value in test_metrics.items():
if key != "confusion_matrix":
print(f" {key}: {value}")
print(f" Per-source AUC:")
for source, sm in sorted(src_metrics.items()):
pa = sm.get("pairwise_auc")
dr = sm.get("detection_rate")
label = f"pairwise_auc={pa:.4f}" if pa is not None else f"detection_rate={dr:.4f}" if dr is not None else ""
print(f" {source}: {label}")
# Aggregate metrics across folds
test_metrics_list = [f["test_metrics"] for f in fold_results]
aggregated = aggregate_fold_metrics(test_metrics_list)
# Aggregate per-source metrics across folds
all_sources = sorted({s for f in fold_results for s in f["source_metrics"]})
aggregated_per_source = {}
for source in all_sources:
source_fold_metrics = []
for f in fold_results:
sm = f["source_metrics"].get(source)
if sm:
# Only keep scalar numeric fields for aggregation
source_fold_metrics.append({
k: v for k, v in sm.items()
if isinstance(v, (int, float)) and k != "fold"
})
if source_fold_metrics:
aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics)
# Aggregate pairwise source metrics across folds
all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]})
aggregated_pairwise = {}
for pair in all_pairs:
pair_fold_metrics = []
for f in fold_results:
pm = f["pair_metrics"].get(pair)
if pm:
pair_fold_metrics.append({
k: v for k, v in pm.items()
if isinstance(v, (int, float)) and k not in ("fold", "n")
})
if pair_fold_metrics:
aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics)
results = {
"run_name": run_name,
"n_folds": len(splits),
"fold_results": fold_results,
"aggregated_metrics": aggregated,
"aggregated_per_source": aggregated_per_source,
"aggregated_pairwise": aggregated_pairwise,
}
if eval_dir is not None:
save_hists(all_records, eval_dir / "hists")
save_summary(results, eval_dir / "summary.json")
print(f"\n{'='*60}")
print("Cross-Validation Results (Aggregated)")
print(f"{'='*60}")
for key, value in aggregated.items():
print(f" {key}:")
print(f" mean: {value['mean']:.4f}")
print(f" std: {value['std']:.4f}")
print(f" 95% CI: ±{value['ci_95']:.4f}")
if aggregated_per_source:
print(f"\nPer-Source Pairwise AUC (wiki vs. fake source):")
for source in sorted(aggregated_per_source):
ps = aggregated_per_source[source]
pa = ps.get("pairwise_auc", {})
if pa:
print(f" {source}: {pa['mean']:.4f} ± {pa['std']:.4f}")
dr = ps.get("detection_rate")
if dr and not pa:
print(f" {source}: detection_rate={dr['mean']:.4f} ± {dr['std']:.4f}")
return results
+13
View File
@@ -0,0 +1,13 @@
from src.utils.config import load_config
from src.utils.cross_validation import (
aggregate_fold_metrics,
create_group_kfold_splits,
get_basename,
)
__all__ = [
"load_config",
"aggregate_fold_metrics",
"create_group_kfold_splits",
"get_basename",
]
+60
View File
@@ -0,0 +1,60 @@
import json
from pathlib import Path
from typing import Any, Dict, Optional
# ── Loading ────────────────────────────────────────────────────────────────
# Resolves the extends chain first, then overlays shared.json underneath so
# experiment-level keys always win over shared defaults
def load_config(config_path: str, shared_path: Optional[str] = None) -> Dict[str, Any]:
config_path = Path(config_path)
cfg = _load_extends(config_path)
if shared_path is None:
shared_path = config_path.parent.parent / "shared.json"
else:
shared_path = Path(shared_path)
if shared_path.exists():
with open(shared_path) as f:
shared_cfg = json.load(f)
cfg = _deep_merge(shared_cfg, cfg)
return cfg
# Pops the "extends" key and recursively merges the parent config underneath;
# the seen set catches circular inheritance before it recurses infinitely
def _load_extends(config_path: Path, seen: Optional[set[Path]] = None) -> Dict[str, Any]:
if seen is None:
seen = set()
resolved_path = config_path.resolve()
if resolved_path in seen:
chain = " -> ".join(str(p) for p in [*seen, resolved_path])
raise ValueError(f"Circular config inheritance detected: {chain}")
seen.add(resolved_path)
with open(config_path) as f:
cfg = json.load(f)
base_ref = cfg.pop("extends", None)
if not base_ref:
seen.remove(resolved_path)
return cfg
base_path = (config_path.parent / base_ref).resolve()
base_cfg = _load_extends(base_path, seen=seen)
seen.remove(resolved_path)
return _deep_merge(base_cfg, cfg)
# override always wins; nested dicts are merged recursively rather than replaced
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
+85
View File
@@ -0,0 +1,85 @@
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold
# Groups by filename stem so sibling images of the same identity (same name
# across wiki/inpainting/text2img/insight) always stay in the same fold
def get_basename(path: str) -> str:
return Path(path).stem
# ── Splits ─────────────────────────────────────────────────────────────────
# Outer fold: StratifiedGroupKFold holds out one fold as test.
# Inner val: 10% of remaining groups are held out randomly — no per-class
# stratification needed since every DFF basename is multi-source (mixed label).
def create_group_kfold_splits(
samples: List[Tuple[str, int]],
n_splits: int = 5,
seed: int = 42,
) -> List[Tuple[List[int], List[int], List[int]]]:
paths = [s[0] for s in samples]
labels = np.array([s[1] for s in samples])
groups = np.array([get_basename(p) for p in paths])
sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
splits = []
for fold_idx, (train_val_idx, test_idx) in enumerate(sgkf.split(paths, labels, groups)):
train_val_groups = groups[train_val_idx]
unique_groups = np.unique(train_val_groups)
n_val_groups = max(1, int(len(unique_groups) * 0.1))
rng = np.random.RandomState(seed + fold_idx)
val_groups = set(rng.choice(unique_groups, n_val_groups, replace=False))
train_idx = []
val_idx = []
for i, g in enumerate(train_val_groups):
if g in val_groups:
val_idx.append(train_val_idx[i])
else:
train_idx.append(train_val_idx[i])
splits.append((train_idx, val_idx, test_idx.tolist()))
return splits
# ── Aggregation ────────────────────────────────────────────────────────────
# Infers numeric keys from the first fold if metric_keys is not supplied;
# uses sample std (ddof=1) and normal-approximation 95% CI
def aggregate_fold_metrics(
fold_metrics: List[Dict[str, Any]],
metric_keys: List[str] = None,
) -> Dict[str, Any]:
if metric_keys is None:
metric_keys = [
k for k, v in fold_metrics[0].items()
if isinstance(v, (int, float)) and not isinstance(v, bool)
]
aggregated = {}
for key in metric_keys:
values = [fold[key] for fold in fold_metrics if key in fold]
if not values:
continue
values = np.array(values)
mean = np.mean(values)
std = np.std(values, ddof=1)
n = len(values)
ci_95 = 1.96 * std / np.sqrt(n) if n > 1 else 0.0
aggregated[key] = {
"mean": float(mean),
"std": float(std),
"ci_95": float(ci_95),
"values": values.tolist(),
}
return aggregated