Clean state
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from src.data.dataset import DFFDataset, PathDataset, SOURCES, get_source_name
|
||||
from src.data.splits import TransformSubset, apply_subsample, build_transforms, get_splits
|
||||
|
||||
__all__ = ["DFFDataset", "PathDataset", "SOURCES", "TransformSubset", "apply_subsample", "build_transforms", "get_source_name", "get_splits"]
|
||||
@@ -0,0 +1,82 @@
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# One real source (wiki) and three fake sources; 0 = real, 1 = fake
|
||||
# The same identity basename appears in every source - splitting must happen
|
||||
# at the identity level to prevent leakage (see splits.py), e.g:
|
||||
# data_dir/SOURCE/identity/BASENAME.jpg -> same BASENAME for all SOURCEs go into same split
|
||||
SOURCES = {
|
||||
"wiki": 0,
|
||||
"inpainting": 1,
|
||||
"text2img": 1,
|
||||
"insight": 1,
|
||||
}
|
||||
|
||||
|
||||
# Extracts source name from path assuming data_dir/source/identity/image.jpg layout
|
||||
def get_source_name(path: Path) -> str:
|
||||
return Path(path).parent.parent.name
|
||||
|
||||
|
||||
# Walks data_dir/source/identity/*.jpg and collects (path, label) pairs
|
||||
class DFFDataset(Dataset):
|
||||
def __init__(self, data_dir, sources=None, transform=None):
|
||||
self.transform = transform
|
||||
self.samples = []
|
||||
|
||||
data_dir = Path(data_dir)
|
||||
if not data_dir.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Dataset root not found: {data_dir}. Expected a directory containing "
|
||||
"wiki/, inpainting/, text2img/, and insight/."
|
||||
)
|
||||
if sources is None:
|
||||
sources = list(SOURCES.keys())
|
||||
|
||||
for source in sources:
|
||||
label = SOURCES[source]
|
||||
source_dir = data_dir / source
|
||||
if not source_dir.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Missing source directory: {source_dir}. Check `data_dir` in the config."
|
||||
)
|
||||
for subdir in sorted(source_dir.iterdir()):
|
||||
if subdir.is_dir():
|
||||
for img_path in sorted(subdir.glob("*.jpg")):
|
||||
self.samples.append((img_path, label))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path, label = self.samples[idx]
|
||||
image = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
return image, label
|
||||
|
||||
# Useful for quickly verifying class balance before training
|
||||
def label_counts(self):
|
||||
return Counter(label for _, label in self.samples)
|
||||
|
||||
|
||||
# Wraps a list of prediction records as a Dataset for re-scoring via DataLoader.
|
||||
class PathDataset(Dataset):
|
||||
def __init__(self, records, image_size, preprocess=None):
|
||||
from src.preprocessing import get_transforms
|
||||
self.records = records
|
||||
self.transform = get_transforms(train=False, image_size=image_size)
|
||||
self.preprocess = preprocess
|
||||
|
||||
def __len__(self):
|
||||
return len(self.records)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
record = self.records[idx]
|
||||
image = Image.open(record["path"]).convert("RGB")
|
||||
if self.preprocess is not None:
|
||||
image = self.preprocess(image)
|
||||
return self.transform(image), record["label"], idx
|
||||
@@ -0,0 +1,116 @@
|
||||
import random
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
from src.data.dataset import get_source_name
|
||||
from src.preprocessing import get_transforms
|
||||
from src.utils import create_group_kfold_splits, get_basename
|
||||
|
||||
|
||||
# Defined at module level so DataLoader workers (num_workers > 0) can serialize it safely
|
||||
class TransformSubset(Dataset):
|
||||
def __init__(self, subset, transform):
|
||||
self.subset, self.transform = subset, transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.subset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img, label = self.subset[idx]
|
||||
return self.transform(img), label
|
||||
|
||||
|
||||
# ── Splitting ──────────────────────────────────────────────────────────────
|
||||
|
||||
# Builds grouped CV fold indices from config; this is the only split strategy used.
|
||||
def get_splits(
|
||||
raw_dataset,
|
||||
cfg,
|
||||
) -> List[Tuple[List[int], List[int], List[int]]]:
|
||||
splits = create_group_kfold_splits(
|
||||
raw_dataset.samples,
|
||||
n_splits=cfg.get("cv_folds", 5),
|
||||
seed=cfg.get("seed", 42),
|
||||
)
|
||||
|
||||
# Optional source holdout: train/val from train_sources, test from eval_sources.
|
||||
train_sources = cfg.get("train_sources")
|
||||
eval_sources = cfg.get("eval_sources")
|
||||
if train_sources or eval_sources:
|
||||
all_sources = {get_source_name(path) for path, _ in raw_dataset.samples}
|
||||
ts = set(train_sources or all_sources)
|
||||
es = set(eval_sources or all_sources)
|
||||
unknown = (ts | es) - all_sources
|
||||
if unknown:
|
||||
raise ValueError(f"Unknown sources requested: {sorted(unknown)}")
|
||||
splits = [
|
||||
(
|
||||
[i for i in tr if get_source_name(raw_dataset.samples[i][0]) in ts],
|
||||
[i for i in val if get_source_name(raw_dataset.samples[i][0]) in ts],
|
||||
[i for i in te if get_source_name(raw_dataset.samples[i][0]) in es],
|
||||
)
|
||||
for tr, val, te in splits
|
||||
]
|
||||
return splits
|
||||
|
||||
|
||||
# Deterministic subsampling shared by training and reevaluation.
|
||||
def apply_subsample(raw_dataset, cfg) -> tuple[int, int] | None:
|
||||
subsample = cfg.get("subsample", 1.0)
|
||||
if subsample >= 1.0:
|
||||
return None
|
||||
|
||||
total = len(raw_dataset.samples)
|
||||
if total == 0:
|
||||
return 0, 0
|
||||
|
||||
# Subsample at basename-group level to preserve identity grouping guarantees.
|
||||
group_to_indices = {}
|
||||
for idx, (path, _) in enumerate(raw_dataset.samples):
|
||||
group = get_basename(str(path))
|
||||
group_to_indices.setdefault(group, []).append(idx)
|
||||
|
||||
groups = list(group_to_indices.keys())
|
||||
n_groups = len(groups)
|
||||
target_groups = max(1, int(n_groups * subsample))
|
||||
rng = random.Random(cfg.get("seed", 42))
|
||||
rng.shuffle(groups)
|
||||
keep_groups = set(groups[:target_groups])
|
||||
|
||||
keep_indices = [
|
||||
idx
|
||||
for group in keep_groups
|
||||
for idx in group_to_indices[group]
|
||||
]
|
||||
keep_indices.sort()
|
||||
raw_dataset.samples = [raw_dataset.samples[i] for i in keep_indices]
|
||||
return len(raw_dataset.samples), total
|
||||
|
||||
|
||||
# Controls stochastic augmentations (flip, jitter, etc.)
|
||||
# augment=False -> NO_AUGMENT preset (square-crop + resize + normalize still run)
|
||||
# augment=None, pipeline defaults, augment=dict -> override specific params
|
||||
# Face cropping is handled upstream via data_dir swap, not here
|
||||
def build_transforms(raw_dataset, cfg, augment=None) -> Callable:
|
||||
image_size = cfg["image_size"]
|
||||
|
||||
if augment is False:
|
||||
from src.preprocessing.pipeline import DFFImagePipeline
|
||||
|
||||
augment = DFFImagePipeline.NO_AUGMENT
|
||||
|
||||
def transform_builder(indices, train=True, normalize_mean=None, normalize_std=None):
|
||||
subset = Subset(raw_dataset, indices)
|
||||
return TransformSubset(
|
||||
subset,
|
||||
get_transforms(
|
||||
train=train,
|
||||
image_size=image_size,
|
||||
augment=augment,
|
||||
normalize_mean=normalize_mean,
|
||||
normalize_std=normalize_std,
|
||||
),
|
||||
)
|
||||
return transform_builder
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from src.evaluation.evaluate import (
|
||||
predict_rows,
|
||||
rescore_rows,
|
||||
save_errors,
|
||||
save_hists,
|
||||
save_preds,
|
||||
save_summary,
|
||||
)
|
||||
from src.evaluation.metrics import binary_metrics, calc_metrics, pair_metrics, source_metrics
|
||||
|
||||
__all__ = ["binary_metrics", "calc_metrics", "pair_metrics", "predict_rows", "rescore_rows", "save_errors", "save_hists", "save_preds", "save_summary", "source_metrics"]
|
||||
@@ -0,0 +1,135 @@
|
||||
import csv
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.data import PathDataset, get_source_name
|
||||
|
||||
|
||||
# ── Inference ──────────────────────────────────────────────────────────────
|
||||
|
||||
# Run model inference and return one prediction record per sample
|
||||
# raw_dataset and indices supply path/source/label metadata, decoupling this
|
||||
# function from the wrapper layout (Subset/TransformSubset) of `dataset`
|
||||
def predict_rows(model, dataset, raw_dataset, indices, batch_size, device, *, num_workers=4):
|
||||
loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
|
||||
records = []
|
||||
|
||||
model.eval().to(device)
|
||||
with torch.no_grad():
|
||||
offset = 0
|
||||
for images, labels in loader:
|
||||
logits = model(images.to(device)).squeeze(1).cpu()
|
||||
probs = torch.sigmoid(logits)
|
||||
preds = (probs >= 0.5).long()
|
||||
|
||||
for i in range(len(labels)):
|
||||
sample_idx = indices[offset + i]
|
||||
path, label = raw_dataset.samples[sample_idx]
|
||||
records.append({
|
||||
"path": str(path),
|
||||
"basename": Path(path).name,
|
||||
"source": get_source_name(path),
|
||||
"label": int(label),
|
||||
"pred": int(preds[i].item()),
|
||||
"prob_fake": float(probs[i].item()),
|
||||
"logit": float(logits[i].item()),
|
||||
})
|
||||
offset += len(labels)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
# Re-run model scoring for an existing list of records
|
||||
def rescore_rows(
|
||||
model, records, image_size, batch_size, device, preprocess=None, *, num_workers=4
|
||||
):
|
||||
dataset = PathDataset(records, image_size=image_size, preprocess=preprocess)
|
||||
loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
|
||||
outputs = []
|
||||
|
||||
model.eval().to(device)
|
||||
with torch.no_grad():
|
||||
for images, _, indices in loader:
|
||||
logits = model(images.to(device)).squeeze(1).cpu()
|
||||
probs = torch.sigmoid(logits)
|
||||
preds = (probs >= 0.5).long()
|
||||
|
||||
for j in range(len(indices)):
|
||||
base = dict(records[int(indices[j].item())])
|
||||
base["prob_fake"] = float(probs[j].item())
|
||||
base["pred"] = int(preds[j].item())
|
||||
base["logit"] = float(logits[j].item())
|
||||
outputs.append(base)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# ── Export ─────────────────────────────────────────────────────────────────
|
||||
|
||||
# Writes all prediction records to CSV in a fixed column order
|
||||
def save_preds(records, output_path):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fieldnames = ["path", "basename", "source", "label", "pred", "prob_fake", "logit"]
|
||||
with open(output_path, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(records)
|
||||
|
||||
|
||||
# Saves the top-k highest-confidence false positives and false negatives for error analysis
|
||||
def save_errors(records, output_path, top_k=32):
|
||||
false_positives = sorted(
|
||||
(r for r in records if r["label"] == 0 and r["pred"] == 1),
|
||||
key=lambda r: r["prob_fake"],
|
||||
reverse=True,
|
||||
)[:top_k]
|
||||
false_negatives = sorted(
|
||||
(r for r in records if r["label"] == 1 and r["pred"] == 0),
|
||||
key=lambda r: r["prob_fake"],
|
||||
)[:top_k]
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump({
|
||||
"top_false_positives": false_positives,
|
||||
"top_false_negatives": false_negatives,
|
||||
}, f, indent=2)
|
||||
|
||||
|
||||
# Saves P(fake) histograms: one overall class comparison and one per source
|
||||
def save_hists(records, output_dir):
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5))
|
||||
real_probs = [r["prob_fake"] for r in records if r["label"] == 0]
|
||||
fake_probs = [r["prob_fake"] for r in records if r["label"] == 1]
|
||||
ax.hist(real_probs, bins=30, alpha=0.6, label="real", density=True)
|
||||
ax.hist(fake_probs, bins=30, alpha=0.6, label="fake", density=True)
|
||||
ax.set_xlabel("Predicted P(fake)")
|
||||
ax.set_ylabel("Density")
|
||||
ax.set_title("Confidence by class")
|
||||
ax.legend()
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_dir / "confidence_by_class.png", dpi=160)
|
||||
plt.close(fig)
|
||||
|
||||
for source in sorted({r["source"] for r in records}):
|
||||
fig, ax = plt.subplots(figsize=(8, 5))
|
||||
source_probs = [r["prob_fake"] for r in records if r["source"] == source]
|
||||
ax.hist(source_probs, bins=30, alpha=0.8)
|
||||
ax.set_xlabel("Predicted P(fake)")
|
||||
ax.set_ylabel("Count")
|
||||
ax.set_title(f"Confidence distribution: {source}")
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_dir / f"confidence_{source}.png", dpi=160)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# Saves the aggregated results dict as a formatted JSON file
|
||||
def save_summary(summary, output_path):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
@@ -0,0 +1,70 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
roc_auc_score,
|
||||
f1_score,
|
||||
confusion_matrix,
|
||||
)
|
||||
|
||||
# AUC and F1 are undefined when only one class is present in the batch;
|
||||
# returns None for those fields rather than raising
|
||||
def binary_metrics(logits: torch.Tensor, labels: torch.Tensor) -> dict:
|
||||
probs = torch.sigmoid(logits).numpy()
|
||||
preds = (probs >= 0.5).astype(int)
|
||||
y = labels.numpy().astype(int)
|
||||
has_both_classes = len(np.unique(y)) > 1
|
||||
auc_roc = float(roc_auc_score(y, probs)) if has_both_classes else None
|
||||
f1 = float(f1_score(y, preds, zero_division=0)) if has_both_classes else None
|
||||
return {
|
||||
"accuracy": float(accuracy_score(y, preds)),
|
||||
"auc_roc": auc_roc,
|
||||
"f1": f1,
|
||||
"confusion_matrix": confusion_matrix(y, preds, labels=[0, 1]),
|
||||
}
|
||||
|
||||
|
||||
# Converts per-sample records to tensors and delegates to binary_metrics
|
||||
def calc_metrics(records):
|
||||
logits = torch.tensor([r["logit"] for r in records], dtype=torch.float32)
|
||||
labels = torch.tensor([r["label"] for r in records], dtype=torch.float32)
|
||||
metrics = binary_metrics(logits, labels)
|
||||
metrics["confusion_matrix"] = metrics["confusion_matrix"].tolist()
|
||||
return metrics
|
||||
|
||||
|
||||
# Per-source summaries; fake sources get detection_rate + pairwise_auc since AUC is undefined single-class
|
||||
def source_metrics(records, real_source="wiki"):
|
||||
wiki_records = [r for r in records if r["source"] == real_source]
|
||||
by_source = {}
|
||||
for source in sorted({r["source"] for r in records}):
|
||||
source_records = [r for r in records if r["source"] == source]
|
||||
metrics = calc_metrics(source_records)
|
||||
metrics["n"] = len(source_records)
|
||||
labels = [r["label"] for r in source_records]
|
||||
if len(set(labels)) == 1:
|
||||
if labels[0] == 1: # all fake
|
||||
metrics["detection_rate"] = metrics["accuracy"]
|
||||
if wiki_records:
|
||||
pair_m = calc_metrics(wiki_records + source_records)
|
||||
metrics["pairwise_auc"] = pair_m["auc_roc"]
|
||||
metrics["pairwise_f1"] = pair_m["f1"]
|
||||
else: # all real (wiki)
|
||||
metrics["false_alarm_rate"] = 1.0 - metrics["accuracy"]
|
||||
by_source[source] = metrics
|
||||
return by_source
|
||||
|
||||
|
||||
# Real-vs-one-fake AUC/F1 per fake source - more interpretable than global AUC when class ratios vary
|
||||
def pair_metrics(records, real_source="wiki"):
|
||||
fake_sources = sorted({r["source"] for r in records if r["source"] != real_source})
|
||||
pairwise = {}
|
||||
for fake_source in fake_sources:
|
||||
subset = [r for r in records if r["source"] in {real_source, fake_source}]
|
||||
if subset:
|
||||
pairwise[f"{real_source}_vs_{fake_source}"] = {
|
||||
"sources": [real_source, fake_source],
|
||||
"n": len(subset),
|
||||
**calc_metrics(subset),
|
||||
}
|
||||
return pairwise
|
||||
@@ -0,0 +1,35 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Maps backbone name -> builder function; populated by each model module at import time
|
||||
_REGISTRY: dict[str, Callable[[dict], nn.Module]] = {}
|
||||
|
||||
|
||||
# Called by each model module to advertise its backbone(s) to get_model
|
||||
def register(name: str, builder: Callable[[dict], nn.Module]) -> None:
|
||||
_REGISTRY[name] = builder
|
||||
|
||||
|
||||
# Instantiates the backbone requested in cfg["backbone"]
|
||||
def get_model(cfg: dict) -> nn.Module:
|
||||
backbone = cfg.get("backbone", "simple_cnn")
|
||||
builder = _REGISTRY.get(backbone)
|
||||
if builder is None:
|
||||
available = ", ".join(sorted(_REGISTRY))
|
||||
raise ValueError(f"Unknown backbone: {backbone!r}. Available: {available}")
|
||||
return builder(cfg)
|
||||
|
||||
|
||||
# Loads a saved state-dict into model in-place and returns it
|
||||
def load_checkpoint(model: nn.Module, path: Union[Path, str], device) -> nn.Module:
|
||||
model.load_state_dict(torch.load(path, map_location=device, weights_only=True))
|
||||
return model
|
||||
|
||||
|
||||
# Importing the modules triggers their register() calls
|
||||
from src.models import simple_cnn, resnet, efficientnet # noqa: E402, F401
|
||||
|
||||
__all__ = ["get_model", "load_checkpoint", "register"]
|
||||
@@ -0,0 +1,23 @@
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
from src.models import register
|
||||
|
||||
|
||||
# EfficientNet's classification head is a Sequential; [-1] targets the final Linear
|
||||
def build(cfg: dict) -> nn.Module:
|
||||
backbone = cfg.get("backbone", "efficientnet_b0")
|
||||
pretrained = cfg.get("pretrained", True)
|
||||
|
||||
if backbone == "efficientnet_b0":
|
||||
weights = models.EfficientNet_B0_Weights.DEFAULT if pretrained else None
|
||||
model = models.efficientnet_b0(weights=weights)
|
||||
else:
|
||||
raise ValueError(f"Unsupported EfficientNet backbone: {backbone!r}. Supported: efficientnet_b0")
|
||||
|
||||
in_features = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(in_features, 1)
|
||||
return model
|
||||
|
||||
|
||||
register("efficientnet_b0", build)
|
||||
@@ -0,0 +1,30 @@
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
from src.models import register
|
||||
|
||||
|
||||
# Loads pretrained ResNet and replaces the 1000-class head with a single logit for binary detection
|
||||
def build(cfg: dict) -> nn.Module:
|
||||
backbone = cfg.get("backbone", "resnet18")
|
||||
pretrained = cfg.get("pretrained", True)
|
||||
|
||||
if backbone == "resnet18":
|
||||
weights = models.ResNet18_Weights.DEFAULT if pretrained else None
|
||||
model = models.resnet18(weights=weights)
|
||||
elif backbone == "resnet34":
|
||||
weights = models.ResNet34_Weights.DEFAULT if pretrained else None
|
||||
model = models.resnet34(weights=weights)
|
||||
elif backbone == "resnet50":
|
||||
weights = models.ResNet50_Weights.DEFAULT if pretrained else None
|
||||
model = models.resnet50(weights=weights)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backbone: {backbone!r}")
|
||||
|
||||
model.fc = nn.Linear(model.fc.in_features, 1)
|
||||
return model
|
||||
|
||||
|
||||
register("resnet18", build)
|
||||
register("resnet34", build)
|
||||
register("resnet50", build)
|
||||
@@ -0,0 +1,49 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from src.models import register
|
||||
|
||||
# Named presets map cnn_preset config values to channel lists
|
||||
CNN_PRESETS = {
|
||||
"micro": [8, 16],
|
||||
"small": [8, 16, 32],
|
||||
"medium": [16, 32, 64, 64],
|
||||
"large": [32, 64, 128, 256],
|
||||
}
|
||||
|
||||
|
||||
# Each entry in channels builds a Conv -> BN -> ReLU -> Pool block
|
||||
# the last block pools to 1×1 so the head is resolution-independent
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self, channels=None, in_channels=3, dropout=0.0):
|
||||
super().__init__()
|
||||
if channels is None:
|
||||
channels = CNN_PRESETS["medium"]
|
||||
|
||||
layers = []
|
||||
prev = in_channels
|
||||
for i, ch in enumerate(channels):
|
||||
layers += [nn.Conv2d(prev, ch, 3, padding=1), nn.BatchNorm2d(ch), nn.ReLU()]
|
||||
if i < len(channels) - 1:
|
||||
layers.append(nn.MaxPool2d(2))
|
||||
else:
|
||||
layers.append(nn.AdaptiveAvgPool2d(1))
|
||||
prev = ch
|
||||
self.features = nn.Sequential(*layers)
|
||||
|
||||
head = []
|
||||
if dropout > 0:
|
||||
head.append(nn.Dropout(dropout))
|
||||
head.append(nn.Linear(channels[-1], 1))
|
||||
self.classifier = nn.Sequential(*head)
|
||||
|
||||
def forward(self, x):
|
||||
return self.classifier(self.features(x).flatten(1))
|
||||
|
||||
|
||||
# Resolves cnn_channels > cnn_preset > "medium" fallback
|
||||
def build(cfg: dict) -> nn.Module:
|
||||
channels = cfg.get("cnn_channels") or CNN_PRESETS.get(cfg.get("cnn_preset", "medium"), CNN_PRESETS["medium"])
|
||||
return SimpleCNN(channels=channels, dropout=cfg.get("dropout", 0.0))
|
||||
|
||||
|
||||
register("simple_cnn", build)
|
||||
@@ -0,0 +1,3 @@
|
||||
from src.preprocessing.pipeline import DFFImagePipeline, get_transforms
|
||||
|
||||
__all__ = ["DFFImagePipeline", "get_transforms"]
|
||||
@@ -0,0 +1,231 @@
|
||||
import io
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
# Per-channel mean and std of the ImageNet training set (RGB order)
|
||||
# Required when using torchvision pretrained weights — they were trained with
|
||||
# this exact normalisation and expect it at inference time
|
||||
_IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
_IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
|
||||
# Computes per-channel mean and std from real (label=0) training samples only
|
||||
# Used for the real-norm experiment to test whether the model relies on
|
||||
# colour/brightness differences between real and fake rather than identity cues
|
||||
def compute_real_stats(dataset, indices, max_samples=1000, seed=42):
|
||||
real_indices = [i for i in indices if dataset.samples[i][1] == 0]
|
||||
if not real_indices:
|
||||
return _IMAGENET_MEAN, _IMAGENET_STD
|
||||
|
||||
if len(real_indices) > max_samples:
|
||||
rng = np.random.RandomState(seed)
|
||||
real_indices = rng.choice(real_indices, max_samples, replace=False).tolist()
|
||||
|
||||
means, vars_ = [], []
|
||||
for i in real_indices:
|
||||
path, _ = dataset.samples[i]
|
||||
img = np.array(Image.open(path).convert("RGB"), dtype=np.float32) / 255.0
|
||||
means.append(img.mean(axis=(0, 1)))
|
||||
vars_.append(img.var(axis=(0, 1)))
|
||||
|
||||
mean = tuple(float(x) for x in np.mean(means, axis=0))
|
||||
std = tuple(float(x) for x in np.sqrt(np.mean(vars_, axis=0)))
|
||||
return mean, std
|
||||
|
||||
|
||||
# Single-image preprocessing pipeline for training and evaluation
|
||||
# Square-crops first to remove the real-rectangular / fake-square geometry cue,
|
||||
# then resizes, augments (train only), and normalizes
|
||||
# augment=None uses DEFAULTS; augment=dict overrides specific keys; NO_AUGMENT disables all stochastic ops
|
||||
class DFFImagePipeline:
|
||||
|
||||
DEFAULTS = {
|
||||
"crop_scale": [0.85, 1.0],
|
||||
"center_jitter": 0.1,
|
||||
"hflip_p": 0.5,
|
||||
"rotation_degrees": 15,
|
||||
"brightness": 0.4,
|
||||
"contrast": 0.4,
|
||||
"saturation": 0.3,
|
||||
"hue": 0.05,
|
||||
"grayscale_p": 0.2,
|
||||
"blur_p": 0.3,
|
||||
"blur_radius": [0.1, 1.5],
|
||||
"jpeg_p": 0.3,
|
||||
"jpeg_quality": [65, 95],
|
||||
"erase_p": 0.3,
|
||||
"erase_scale": [0.02, 0.15],
|
||||
"noise_p": 0.2,
|
||||
"noise_std": 0.05,
|
||||
}
|
||||
|
||||
# Pass as augment=DFFImagePipeline.NO_AUGMENT to keep crop+resize+normalize but skip all randomness
|
||||
NO_AUGMENT: dict = {
|
||||
"crop_scale": [1.0, 1.0],
|
||||
"center_jitter": 0.0,
|
||||
"hflip_p": 0.0,
|
||||
"rotation_degrees": 0,
|
||||
"brightness": 0.0,
|
||||
"contrast": 0.0,
|
||||
"saturation": 0.0,
|
||||
"hue": 0.0,
|
||||
"grayscale_p": 0.0,
|
||||
"blur_p": 0.0,
|
||||
"jpeg_p": 0.0,
|
||||
"erase_p": 0.0,
|
||||
"noise_p": 0.0,
|
||||
}
|
||||
|
||||
def __init__(self, *, image_size: int, train: bool, augment: dict | None = None,
|
||||
normalize_mean=None, normalize_std=None):
|
||||
self.image_size = image_size
|
||||
self.train = train
|
||||
self.normalize_mean = normalize_mean or _IMAGENET_MEAN
|
||||
self.normalize_std = normalize_std or _IMAGENET_STD
|
||||
|
||||
cfg = {**self.DEFAULTS, **(augment or {})}
|
||||
self.crop_scale = tuple(cfg["crop_scale"])
|
||||
self.center_jitter = cfg["center_jitter"]
|
||||
self.hflip_p = cfg["hflip_p"]
|
||||
self.rotation_degrees = cfg["rotation_degrees"]
|
||||
self.brightness = cfg["brightness"]
|
||||
self.contrast = cfg["contrast"]
|
||||
self.saturation = cfg["saturation"]
|
||||
self.hue = cfg["hue"]
|
||||
self.grayscale_p = cfg["grayscale_p"]
|
||||
self.blur_p = cfg["blur_p"]
|
||||
self.blur_radius = tuple(cfg["blur_radius"])
|
||||
self.jpeg_p = cfg["jpeg_p"]
|
||||
self.jpeg_quality = tuple(cfg["jpeg_quality"])
|
||||
self.erase_p = cfg["erase_p"]
|
||||
self.erase_scale = tuple(cfg["erase_scale"])
|
||||
self.noise_p = cfg["noise_p"]
|
||||
self.noise_std = cfg["noise_std"]
|
||||
|
||||
# Geometry transforms
|
||||
def _crop_square(self, img: Image.Image) -> Image.Image:
|
||||
width, height = img.size
|
||||
short_side = min(width, height)
|
||||
center_top = max((height - short_side) // 2, 0)
|
||||
center_left = max((width - short_side) // 2, 0)
|
||||
|
||||
if self.train:
|
||||
min_scale, max_scale = self.crop_scale
|
||||
scale = random.uniform(min_scale, max_scale)
|
||||
crop_size = max(1, int(short_side * scale))
|
||||
top = max((height - crop_size) // 2, 0)
|
||||
left = max((width - crop_size) // 2, 0)
|
||||
|
||||
jitter_y = int((height - crop_size) * self.center_jitter)
|
||||
jitter_x = int((width - crop_size) * self.center_jitter)
|
||||
if jitter_y > 0:
|
||||
top += random.randint(-jitter_y, jitter_y)
|
||||
if jitter_x > 0:
|
||||
left += random.randint(-jitter_x, jitter_x)
|
||||
|
||||
top = max(0, min(top, height - crop_size))
|
||||
left = max(0, min(left, width - crop_size))
|
||||
else:
|
||||
crop_size = short_side
|
||||
top = center_top
|
||||
left = center_left
|
||||
|
||||
return F.crop(img, top=top, left=left, height=crop_size, width=crop_size)
|
||||
|
||||
def _maybe_flip(self, img: Image.Image) -> Image.Image:
|
||||
if self.train and random.random() < self.hflip_p:
|
||||
return F.hflip(img)
|
||||
return img
|
||||
|
||||
def _maybe_rotate(self, img: Image.Image) -> Image.Image:
|
||||
if self.train and self.rotation_degrees > 0:
|
||||
angle = random.uniform(-self.rotation_degrees, self.rotation_degrees)
|
||||
return F.rotate(img, angle, interpolation=InterpolationMode.BILINEAR, fill=0)
|
||||
return img
|
||||
|
||||
# Photometric transforms
|
||||
def _jitter_factor(self, amount: float) -> float:
|
||||
return random.uniform(max(0.0, 1.0 - amount), 1.0 + amount)
|
||||
|
||||
def _maybe_color_jitter(self, img: Image.Image) -> Image.Image:
|
||||
if not self.train:
|
||||
return img
|
||||
img = F.adjust_brightness(img, self._jitter_factor(self.brightness))
|
||||
img = F.adjust_contrast(img, self._jitter_factor(self.contrast))
|
||||
img = F.adjust_saturation(img, self._jitter_factor(self.saturation))
|
||||
img = F.adjust_hue(img, random.uniform(-self.hue, self.hue))
|
||||
return img
|
||||
|
||||
def _maybe_grayscale(self, img: Image.Image) -> Image.Image:
|
||||
if self.train and random.random() < self.grayscale_p:
|
||||
return F.to_grayscale(img, num_output_channels=3)
|
||||
return img
|
||||
|
||||
def _maybe_blur(self, img: Image.Image) -> Image.Image:
|
||||
if self.train and random.random() < self.blur_p:
|
||||
radius = random.uniform(*self.blur_radius)
|
||||
return img.filter(ImageFilter.GaussianBlur(radius=radius))
|
||||
return img
|
||||
|
||||
# Bias-reduction transforms
|
||||
# JPEG recompression removes high-frequency GAN artifacts that survive other augmentations
|
||||
def _maybe_jpeg(self, img: Image.Image) -> Image.Image:
|
||||
if self.train and random.random() < self.jpeg_p:
|
||||
quality = random.randint(*self.jpeg_quality)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=quality)
|
||||
buf.seek(0)
|
||||
img = Image.open(buf).convert("RGB")
|
||||
img.load() # decode pixels while buf is still in scope
|
||||
return img
|
||||
|
||||
# Random erasing forces the model to use multiple regions rather than a single discriminative patch
|
||||
def _maybe_erase(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
if self.train and random.random() < self.erase_p:
|
||||
c, h, w = tensor.shape
|
||||
area = h * w
|
||||
min_scale, max_scale = self.erase_scale
|
||||
erase_area = area * random.uniform(min_scale, max_scale)
|
||||
aspect = random.uniform(0.3, 1.0 / 0.3)
|
||||
eh = int(round((erase_area * aspect) ** 0.5))
|
||||
ew = int(round((erase_area / aspect) ** 0.5))
|
||||
eh, ew = min(eh, h), min(ew, w)
|
||||
top = random.randint(0, h - eh)
|
||||
left = random.randint(0, w - ew)
|
||||
tensor = tensor.clone()
|
||||
tensor[:, top:top + eh, left:left + ew] = torch.rand(c, eh, ew)
|
||||
return tensor
|
||||
|
||||
# Gaussian noise improves robustness to sensor noise vs GAN noise patterns
|
||||
def _maybe_noise(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
if self.train and random.random() < self.noise_p:
|
||||
noise = torch.randn_like(tensor) * self.noise_std
|
||||
tensor = tensor + noise
|
||||
return tensor
|
||||
|
||||
# Pipeline entrypoint
|
||||
def __call__(self, img: Image.Image) -> torch.Tensor:
|
||||
img = self._crop_square(img)
|
||||
img = F.resize(img, [self.image_size, self.image_size], interpolation=InterpolationMode.BILINEAR)
|
||||
img = self._maybe_flip(img)
|
||||
img = self._maybe_rotate(img)
|
||||
img = self._maybe_color_jitter(img)
|
||||
img = self._maybe_grayscale(img)
|
||||
img = self._maybe_blur(img)
|
||||
img = self._maybe_jpeg(img)
|
||||
tensor = F.to_tensor(img)
|
||||
tensor = self._maybe_erase(tensor)
|
||||
tensor = self._maybe_noise(tensor)
|
||||
return F.normalize(tensor, self.normalize_mean, self.normalize_std)
|
||||
|
||||
|
||||
# Convenience wrapper used by splits.py and evaluate.py
|
||||
def get_transforms(train=True, image_size=224, augment=None,
|
||||
normalize_mean=None, normalize_std=None):
|
||||
return DFFImagePipeline(image_size=image_size, train=train, augment=augment,
|
||||
normalize_mean=normalize_mean, normalize_std=normalize_std)
|
||||
@@ -0,0 +1,3 @@
|
||||
from src.training.trainer import train_classifier, train_classifier_cv
|
||||
|
||||
__all__ = ["train_classifier", "train_classifier_cv"]
|
||||
@@ -0,0 +1,374 @@
|
||||
import json
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.evaluation.metrics import binary_metrics
|
||||
|
||||
# ── AMP compatibility shim ─────────────────────────────────────────────────
|
||||
|
||||
# torch.amp.GradScaler / autocast moved from torch.cuda.amp in PyTorch 2.3+
|
||||
if hasattr(torch.amp, "GradScaler"):
|
||||
_GradScaler = torch.amp.GradScaler
|
||||
_autocast = torch.amp.autocast
|
||||
else:
|
||||
from torch.cuda.amp import GradScaler as _OldGradScaler, autocast as _OldAutocast
|
||||
_GradScaler = lambda device="", enabled=True, **kw: _OldGradScaler(enabled=enabled, **kw)
|
||||
_autocast = lambda device_type="", enabled=True, **kw: _OldAutocast(enabled=enabled, **kw)
|
||||
|
||||
|
||||
# ── Single-fold training ───────────────────────────────────────────────────
|
||||
|
||||
# Trains one fold; saves best checkpoint by val AUC-ROC and final checkpoint.
|
||||
# pos_weight is passed through to BCEWithLogitsLoss to handle class imbalance.
|
||||
def train_classifier(
|
||||
model,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
*,
|
||||
epochs=10,
|
||||
batch_size=16,
|
||||
lr=1e-4,
|
||||
weight_decay=1e-4,
|
||||
device="cuda",
|
||||
save_dir="outputs/models",
|
||||
run_name="classifier",
|
||||
early_stopping_patience=0,
|
||||
num_workers=4,
|
||||
grad_clip_norm=1.0,
|
||||
T_max=None,
|
||||
pos_weight=None,
|
||||
):
|
||||
device = torch.device(device)
|
||||
if device.type == "cuda":
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
print("Using CPU")
|
||||
|
||||
use_amp = device.type == "cuda"
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"Trainable parameters: {n_params:,}")
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=num_workers, pin_memory=True
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
|
||||
)
|
||||
|
||||
pw = torch.tensor([pos_weight], device=device) if pos_weight is not None else None
|
||||
criterion = nn.BCEWithLogitsLoss(pos_weight=pw)
|
||||
optimizer = torch.optim.AdamW(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=lr, weight_decay=weight_decay,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max or epochs)
|
||||
scaler = _GradScaler("cuda", enabled=use_amp)
|
||||
print(f"Device: {device} AMP: {use_amp}")
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
history = {
|
||||
"train_loss": [], "train_acc": [], "train_auc": [], "train_f1": [],
|
||||
"val_loss": [], "val_acc": [], "val_auc": [], "val_f1": [],
|
||||
}
|
||||
best_auc = 0.0
|
||||
patience_counter = 0
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
# ── train ──
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
train_logits, train_labels = [], []
|
||||
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [train]", leave=False):
|
||||
images = images.to(device)
|
||||
labels = labels.float().to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
logits = model(images).squeeze(1)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
total_loss += loss.item() * len(images)
|
||||
train_logits.append(logits.detach().cpu())
|
||||
train_labels.append(labels.detach().cpu())
|
||||
|
||||
train_loss = total_loss / sum(len(logit) for logit in train_logits)
|
||||
train_m = binary_metrics(torch.cat(train_logits), torch.cat(train_labels))
|
||||
scheduler.step()
|
||||
|
||||
# ── validate ──
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
all_logits, all_labels = [], []
|
||||
with torch.no_grad():
|
||||
for images, labels in tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} [val]", leave=False):
|
||||
images = images.to(device)
|
||||
labels = labels.float().to(device)
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
logits = model(images).squeeze(1)
|
||||
batch_loss = criterion(logits, labels)
|
||||
if not (torch.isnan(batch_loss) or torch.isinf(batch_loss)):
|
||||
val_loss += batch_loss.item() * len(images)
|
||||
all_logits.append(logits.cpu())
|
||||
all_labels.append(labels.cpu())
|
||||
|
||||
val_loss /= len(val_dataset)
|
||||
val_m = binary_metrics(torch.cat(all_logits), torch.cat(all_labels))
|
||||
|
||||
# ── record ──
|
||||
history["train_loss"].append(train_loss)
|
||||
history["train_acc"].append(train_m["accuracy"])
|
||||
history["train_auc"].append(train_m["auc_roc"])
|
||||
history["train_f1"].append(train_m["f1"])
|
||||
history["val_loss"].append(val_loss)
|
||||
history["val_acc"].append(val_m["accuracy"])
|
||||
history["val_auc"].append(val_m["auc_roc"])
|
||||
history["val_f1"].append(val_m["f1"])
|
||||
|
||||
gap_loss = train_loss - val_loss
|
||||
gap_acc = train_m["accuracy"] - val_m["accuracy"]
|
||||
print(
|
||||
f"[{epoch:03d}/{epochs}] "
|
||||
f"loss: {train_loss:.4f}/{val_loss:.4f} (gap {gap_loss:+.4f}) "
|
||||
f"acc: {train_m['accuracy']:.4f}/{val_m['accuracy']:.4f} (gap {gap_acc:+.4f}) "
|
||||
f"auc: {train_m['auc_roc']:.4f}/{val_m['auc_roc']:.4f} "
|
||||
f"f1: {train_m['f1']:.4f}/{val_m['f1']:.4f}"
|
||||
)
|
||||
|
||||
# ── checkpoint ──
|
||||
if val_m["auc_roc"] is not None and val_m["auc_roc"] > best_auc:
|
||||
best_auc = val_m["auc_roc"]
|
||||
torch.save(model.state_dict(), save_dir / f"{run_name}_best.pt")
|
||||
patience_counter = 0
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
if early_stopping_patience > 0 and patience_counter >= early_stopping_patience:
|
||||
print(f"Early stopping at epoch {epoch} (no improvement for {early_stopping_patience} epochs)")
|
||||
break
|
||||
|
||||
torch.save(model.state_dict(), save_dir / f"{run_name}_final.pt")
|
||||
return history
|
||||
|
||||
|
||||
# ── CV training ────────────────────────────────────────────────────────────
|
||||
|
||||
# Iterates over pre-built splits, trains one model per fold, evaluates on the
|
||||
# held-out test fold, then aggregates metrics across folds with mean ± std
|
||||
def train_classifier_cv(
|
||||
model_fn,
|
||||
raw_dataset,
|
||||
splits,
|
||||
*,
|
||||
epochs=10,
|
||||
batch_size=16,
|
||||
lr=1e-4,
|
||||
weight_decay=1e-4,
|
||||
device="cuda",
|
||||
save_dir="outputs/models",
|
||||
run_name="classifier_cv",
|
||||
early_stopping_patience=0,
|
||||
num_workers=4,
|
||||
transform_builder=None,
|
||||
grad_clip_norm=1.0,
|
||||
T_max=None,
|
||||
normalization=None,
|
||||
logs_dir=None,
|
||||
):
|
||||
from src.evaluation.evaluate import (
|
||||
predict_rows, save_errors, save_hists, save_preds, save_summary,
|
||||
)
|
||||
from src.evaluation.metrics import binary_metrics, calc_metrics, source_metrics, pair_metrics
|
||||
from src.utils.cross_validation import aggregate_fold_metrics
|
||||
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
eval_dir = Path(logs_dir) / run_name if logs_dir is not None else None
|
||||
if eval_dir is not None:
|
||||
eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fold_results = []
|
||||
all_records = []
|
||||
|
||||
for fold_idx, (train_idx, val_idx, test_idx) in enumerate(splits):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Fold {fold_idx + 1}/{len(splits)}")
|
||||
print(f" Train: {len(train_idx)} Val: {len(val_idx)} Test: {len(test_idx)}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
model = model_fn().to(device)
|
||||
|
||||
norm_mean = norm_std = None
|
||||
if normalization == "real_norm":
|
||||
from src.preprocessing.pipeline import compute_real_stats
|
||||
norm_mean, norm_std = compute_real_stats(raw_dataset, train_idx)
|
||||
print(f" Real-norm stats: mean={norm_mean}, std={norm_std}")
|
||||
|
||||
if transform_builder is not None:
|
||||
train_dataset = transform_builder(train_idx, train=True,
|
||||
normalize_mean=norm_mean, normalize_std=norm_std)
|
||||
val_dataset = transform_builder(val_idx, train=False,
|
||||
normalize_mean=norm_mean, normalize_std=norm_std)
|
||||
test_dataset = transform_builder(test_idx, train=False,
|
||||
normalize_mean=norm_mean, normalize_std=norm_std)
|
||||
else:
|
||||
train_dataset = Subset(raw_dataset, train_idx)
|
||||
val_dataset = Subset(raw_dataset, val_idx)
|
||||
test_dataset = Subset(raw_dataset, test_idx)
|
||||
|
||||
# Compute pos_weight = n_real / n_fake for BCEWithLogitsLoss class balancing
|
||||
train_labels = [raw_dataset.samples[i][1] for i in train_idx]
|
||||
class_counts = Counter(train_labels)
|
||||
pos_weight = class_counts[0] / class_counts[1] if class_counts[1] > 0 else 1.0
|
||||
|
||||
fold_run_name = f"{run_name}_fold{fold_idx}"
|
||||
history = train_classifier(
|
||||
model,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
device=device,
|
||||
save_dir=save_dir,
|
||||
run_name=fold_run_name,
|
||||
early_stopping_patience=early_stopping_patience,
|
||||
num_workers=num_workers,
|
||||
grad_clip_norm=grad_clip_norm,
|
||||
T_max=T_max,
|
||||
pos_weight=pos_weight,
|
||||
)
|
||||
|
||||
# Load best checkpoint and evaluate on test set
|
||||
checkpoint_path = save_dir / f"{fold_run_name}_best.pt"
|
||||
if checkpoint_path.exists():
|
||||
model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))
|
||||
|
||||
model.eval()
|
||||
|
||||
records = predict_rows(
|
||||
model, test_dataset, raw_dataset, test_idx,
|
||||
batch_size, device, num_workers=num_workers,
|
||||
)
|
||||
|
||||
# Compute aggregate and per-source test metrics
|
||||
test_metrics = calc_metrics(records)
|
||||
src_metrics = source_metrics(records)
|
||||
pairwise = pair_metrics(records)
|
||||
|
||||
fold_result = {
|
||||
"fold": fold_idx,
|
||||
"train_size": len(train_idx),
|
||||
"val_size": len(val_idx),
|
||||
"test_size": len(test_idx),
|
||||
"history": history,
|
||||
"test_metrics": test_metrics,
|
||||
"source_metrics": src_metrics,
|
||||
"pair_metrics": pairwise,
|
||||
}
|
||||
fold_results.append(fold_result)
|
||||
all_records.extend(records)
|
||||
|
||||
if eval_dir is not None:
|
||||
fold_dir = eval_dir / f"fold{fold_idx}"
|
||||
save_preds(records, fold_dir / "preds.csv")
|
||||
save_errors(records, fold_dir / "errors.json")
|
||||
|
||||
print(f"\nFold {fold_idx + 1} Test Metrics:")
|
||||
for key, value in test_metrics.items():
|
||||
if key != "confusion_matrix":
|
||||
print(f" {key}: {value}")
|
||||
print(f" Per-source AUC:")
|
||||
for source, sm in sorted(src_metrics.items()):
|
||||
pa = sm.get("pairwise_auc")
|
||||
dr = sm.get("detection_rate")
|
||||
label = f"pairwise_auc={pa:.4f}" if pa is not None else f"detection_rate={dr:.4f}" if dr is not None else ""
|
||||
print(f" {source}: {label}")
|
||||
|
||||
# Aggregate metrics across folds
|
||||
test_metrics_list = [f["test_metrics"] for f in fold_results]
|
||||
aggregated = aggregate_fold_metrics(test_metrics_list)
|
||||
|
||||
# Aggregate per-source metrics across folds
|
||||
all_sources = sorted({s for f in fold_results for s in f["source_metrics"]})
|
||||
aggregated_per_source = {}
|
||||
for source in all_sources:
|
||||
source_fold_metrics = []
|
||||
for f in fold_results:
|
||||
sm = f["source_metrics"].get(source)
|
||||
if sm:
|
||||
# Only keep scalar numeric fields for aggregation
|
||||
source_fold_metrics.append({
|
||||
k: v for k, v in sm.items()
|
||||
if isinstance(v, (int, float)) and k != "fold"
|
||||
})
|
||||
if source_fold_metrics:
|
||||
aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics)
|
||||
|
||||
# Aggregate pairwise source metrics across folds
|
||||
all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]})
|
||||
aggregated_pairwise = {}
|
||||
for pair in all_pairs:
|
||||
pair_fold_metrics = []
|
||||
for f in fold_results:
|
||||
pm = f["pair_metrics"].get(pair)
|
||||
if pm:
|
||||
pair_fold_metrics.append({
|
||||
k: v for k, v in pm.items()
|
||||
if isinstance(v, (int, float)) and k not in ("fold", "n")
|
||||
})
|
||||
if pair_fold_metrics:
|
||||
aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics)
|
||||
|
||||
results = {
|
||||
"run_name": run_name,
|
||||
"n_folds": len(splits),
|
||||
"fold_results": fold_results,
|
||||
"aggregated_metrics": aggregated,
|
||||
"aggregated_per_source": aggregated_per_source,
|
||||
"aggregated_pairwise": aggregated_pairwise,
|
||||
}
|
||||
|
||||
if eval_dir is not None:
|
||||
save_hists(all_records, eval_dir / "hists")
|
||||
save_summary(results, eval_dir / "summary.json")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Cross-Validation Results (Aggregated)")
|
||||
print(f"{'='*60}")
|
||||
for key, value in aggregated.items():
|
||||
print(f" {key}:")
|
||||
print(f" mean: {value['mean']:.4f}")
|
||||
print(f" std: {value['std']:.4f}")
|
||||
print(f" 95% CI: ±{value['ci_95']:.4f}")
|
||||
|
||||
if aggregated_per_source:
|
||||
print(f"\nPer-Source Pairwise AUC (wiki vs. fake source):")
|
||||
for source in sorted(aggregated_per_source):
|
||||
ps = aggregated_per_source[source]
|
||||
pa = ps.get("pairwise_auc", {})
|
||||
if pa:
|
||||
print(f" {source}: {pa['mean']:.4f} ± {pa['std']:.4f}")
|
||||
dr = ps.get("detection_rate")
|
||||
if dr and not pa:
|
||||
print(f" {source}: detection_rate={dr['mean']:.4f} ± {dr['std']:.4f}")
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,13 @@
|
||||
from src.utils.config import load_config
|
||||
from src.utils.cross_validation import (
|
||||
aggregate_fold_metrics,
|
||||
create_group_kfold_splits,
|
||||
get_basename,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"load_config",
|
||||
"aggregate_fold_metrics",
|
||||
"create_group_kfold_splits",
|
||||
"get_basename",
|
||||
]
|
||||
@@ -0,0 +1,60 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
# ── Loading ────────────────────────────────────────────────────────────────
|
||||
|
||||
# Resolves the extends chain first, then overlays shared.json underneath so
|
||||
# experiment-level keys always win over shared defaults
|
||||
def load_config(config_path: str, shared_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
config_path = Path(config_path)
|
||||
cfg = _load_extends(config_path)
|
||||
|
||||
if shared_path is None:
|
||||
shared_path = config_path.parent.parent / "shared.json"
|
||||
else:
|
||||
shared_path = Path(shared_path)
|
||||
|
||||
if shared_path.exists():
|
||||
with open(shared_path) as f:
|
||||
shared_cfg = json.load(f)
|
||||
cfg = _deep_merge(shared_cfg, cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
# Pops the "extends" key and recursively merges the parent config underneath;
|
||||
# the seen set catches circular inheritance before it recurses infinitely
|
||||
def _load_extends(config_path: Path, seen: Optional[set[Path]] = None) -> Dict[str, Any]:
|
||||
if seen is None:
|
||||
seen = set()
|
||||
resolved_path = config_path.resolve()
|
||||
if resolved_path in seen:
|
||||
chain = " -> ".join(str(p) for p in [*seen, resolved_path])
|
||||
raise ValueError(f"Circular config inheritance detected: {chain}")
|
||||
seen.add(resolved_path)
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
base_ref = cfg.pop("extends", None)
|
||||
if not base_ref:
|
||||
seen.remove(resolved_path)
|
||||
return cfg
|
||||
|
||||
base_path = (config_path.parent / base_ref).resolve()
|
||||
base_cfg = _load_extends(base_path, seen=seen)
|
||||
seen.remove(resolved_path)
|
||||
return _deep_merge(base_cfg, cfg)
|
||||
|
||||
|
||||
# override always wins; nested dicts are merged recursively rather than replaced
|
||||
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = base.copy()
|
||||
for key, value in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
@@ -0,0 +1,85 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from sklearn.model_selection import StratifiedGroupKFold
|
||||
|
||||
|
||||
# Groups by filename stem so sibling images of the same identity (same name
|
||||
# across wiki/inpainting/text2img/insight) always stay in the same fold
|
||||
def get_basename(path: str) -> str:
|
||||
return Path(path).stem
|
||||
|
||||
|
||||
# ── Splits ─────────────────────────────────────────────────────────────────
|
||||
|
||||
# Outer fold: StratifiedGroupKFold holds out one fold as test.
|
||||
# Inner val: 10% of remaining groups are held out randomly — no per-class
|
||||
# stratification needed since every DFF basename is multi-source (mixed label).
|
||||
def create_group_kfold_splits(
|
||||
samples: List[Tuple[str, int]],
|
||||
n_splits: int = 5,
|
||||
seed: int = 42,
|
||||
) -> List[Tuple[List[int], List[int], List[int]]]:
|
||||
paths = [s[0] for s in samples]
|
||||
labels = np.array([s[1] for s in samples])
|
||||
groups = np.array([get_basename(p) for p in paths])
|
||||
|
||||
sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
|
||||
splits = []
|
||||
|
||||
for fold_idx, (train_val_idx, test_idx) in enumerate(sgkf.split(paths, labels, groups)):
|
||||
train_val_groups = groups[train_val_idx]
|
||||
unique_groups = np.unique(train_val_groups)
|
||||
n_val_groups = max(1, int(len(unique_groups) * 0.1))
|
||||
|
||||
rng = np.random.RandomState(seed + fold_idx)
|
||||
val_groups = set(rng.choice(unique_groups, n_val_groups, replace=False))
|
||||
|
||||
train_idx = []
|
||||
val_idx = []
|
||||
for i, g in enumerate(train_val_groups):
|
||||
if g in val_groups:
|
||||
val_idx.append(train_val_idx[i])
|
||||
else:
|
||||
train_idx.append(train_val_idx[i])
|
||||
|
||||
splits.append((train_idx, val_idx, test_idx.tolist()))
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
# ── Aggregation ────────────────────────────────────────────────────────────
|
||||
|
||||
# Infers numeric keys from the first fold if metric_keys is not supplied;
|
||||
# uses sample std (ddof=1) and normal-approximation 95% CI
|
||||
def aggregate_fold_metrics(
|
||||
fold_metrics: List[Dict[str, Any]],
|
||||
metric_keys: List[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if metric_keys is None:
|
||||
metric_keys = [
|
||||
k for k, v in fold_metrics[0].items()
|
||||
if isinstance(v, (int, float)) and not isinstance(v, bool)
|
||||
]
|
||||
|
||||
aggregated = {}
|
||||
for key in metric_keys:
|
||||
values = [fold[key] for fold in fold_metrics if key in fold]
|
||||
if not values:
|
||||
continue
|
||||
values = np.array(values)
|
||||
mean = np.mean(values)
|
||||
std = np.std(values, ddof=1)
|
||||
n = len(values)
|
||||
ci_95 = 1.96 * std / np.sqrt(n) if n > 1 else 0.0
|
||||
aggregated[key] = {
|
||||
"mean": float(mean),
|
||||
"std": float(std),
|
||||
"ci_95": float(ci_95),
|
||||
"values": values.tolist(),
|
||||
}
|
||||
|
||||
return aggregated
|
||||
|
||||
|
||||
Reference in New Issue
Block a user