Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+62
View File
@@ -0,0 +1,62 @@
#!/usr/bin/env python3
"""
Check expected classifier result artifacts.
Usage:
python tools/artifact_chk.py
python tools/artifact_chk.py --output-root outputs
"""
import argparse
import json
from pathlib import Path
def parse_args():
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--output-root", default="outputs", help="Root with logs/models folders")
return p.parse_args()
def iter_config_paths(config_root: Path):
for sub in ("phase1", "phase2"):
yield from sorted((config_root / sub).glob("*.json"))
def main():
args = parse_args()
root = Path(__file__).resolve().parent.parent
config_root = root / "configs"
logs_dir = root / args.output_root / "logs"
models_dir = root / args.output_root / "models"
expected = []
for cfg_path in iter_config_paths(config_root):
with open(cfg_path) as f:
cfg = json.load(f)
run_name = cfg.get("run_name")
if run_name:
expected.append((run_name, cfg_path))
missing_logs = []
missing_models = []
for run_name, cfg_path in expected:
if not (logs_dir / f"{run_name}.json").exists():
missing_logs.append((run_name, cfg_path))
if not any(models_dir.glob(f"{run_name}_fold*_best.pt")):
missing_models.append((run_name, cfg_path))
print(f"Expected runs from configs: {len(expected)}")
print(f"Missing logs: {len(missing_logs)}")
for run_name, cfg_path in missing_logs:
print(f" - {run_name} ({cfg_path.relative_to(root)})")
print(f"Missing checkpoints: {len(missing_models)}")
for run_name, cfg_path in missing_models:
print(f" - {run_name} ({cfg_path.relative_to(root)})")
if missing_logs or missing_models:
raise SystemExit(1)
print("All expected artifacts found.")
if __name__ == "__main__":
main()
+262
View File
@@ -0,0 +1,262 @@
#!/usr/bin/env python3
"""
Pre-crop face images using MTCNN and save to a new directory.
Runs face detection once over the dataset and saves cropped images to disk.
Training configs can then point at the pre-cropped directory — no per-epoch
MTCNN overhead during training.
The output mirrors the source structure exactly:
data/wiki/14/37591914.jpg -> cropped/classifier/wiki/14/37591914.jpg
Resumable: already-cropped images are skipped by default.
Usage:
python tools/facecrop.py
python tools/facecrop.py --data-dir data --output-dir cropped/classifier
python tools/facecrop.py --sources wiki inpainting --device cpu
python tools/facecrop.py --no-skip-existing # reprocess everything
"""
import argparse
import sys
import warnings
from pathlib import Path
# Suppress facenet_pytorch's torch.load FutureWarning — not fixable externally.
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
SOURCES = ["wiki", "inpainting", "text2img", "insight"]
_DETECTORS: dict[tuple[str, str], object] = {}
def parse_args():
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
p.add_argument("--output-dir", default="cropped/classifier", help="Output root (default: cropped/classifier)")
p.add_argument("--margin", type=float, default=0.6, help="Face box margin as fraction of box size (default: 0.6)")
p.add_argument("--size", type=int, default=224, help="Output image size in px, square (default: 224)")
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE",
help=f"Only process these sources. Default: all ({', '.join(SOURCES)})")
p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True,
help="Skip images already present in output-dir (default: on, resumable)")
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
help="Re-process all images even if already cropped")
return p.parse_args()
# ── crop helpers ──────────────────────────────────────────────────────────────
def _crop_face(img, box, margin: float, size: int):
from PIL import Image as PILImage
x1, y1, x2, y2 = [float(v) for v in box]
bw, bh = x2 - x1, y2 - y1
mx, my = bw * margin / 2, bh * margin / 2
x1 -= mx; y1 -= my; x2 += mx; y2 += my
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
side = max(x2 - x1, y2 - y1)
x1, y1 = cx - side / 2, cy - side / 2
x2, y2 = x1 + side, y1 + side
w, h = img.size
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
return img.crop((int(x1), int(y1), int(x2), int(y2))).resize((size, size), PILImage.BILINEAR)
def _center_crop(img, size: int):
from PIL import Image as PILImage
w, h = img.size
side = min(w, h)
left, top = (w - side) // 2, (h - side) // 2
return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR)
def _get_detectors(device: str):
key_std = ("std", device)
key_relaxed = ("relaxed", device)
if key_std in _DETECTORS and key_relaxed in _DETECTORS:
return _DETECTORS[key_std], _DETECTORS[key_relaxed]
from facenet_pytorch import MTCNN
detector = MTCNN(
keep_all=False, select_largest=True,
min_face_size=15,
device=device, post_process=False,
)
detector_relaxed = MTCNN(
keep_all=False, select_largest=True,
min_face_size=10,
thresholds=[0.5, 0.6, 0.6],
device=device, post_process=False,
)
_DETECTORS[key_std] = detector
_DETECTORS[key_relaxed] = detector_relaxed
return detector, detector_relaxed
class FaceCropper:
"""Reusable face cropper for notebooks/tools (not training pipeline)."""
def __init__(self, margin: float = 0.6, size: int = 224, device: str | None = None):
import torch
self.margin = margin
self.size = size
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
def __call__(self, img):
from PIL import Image as PILImage
detector, detector_relaxed = _get_detectors(self.device)
boxes, _ = detector.detect(img)
if boxes is not None and len(boxes) > 0:
return _crop_face(img, boxes[0], self.margin, self.size)
w, h = img.size
img2x = img.resize((w * 2, h * 2), PILImage.BILINEAR)
boxes2, _ = detector_relaxed.detect(img2x)
if boxes2 is not None and len(boxes2) > 0:
box_orig = [v / 2 for v in boxes2[0]]
return _crop_face(img, box_orig, self.margin, self.size)
return _center_crop(img, self.size)
# ── main ──────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
import torch
from PIL import Image
from tqdm import tqdm
data_dir = Path(args.data_dir)
output_dir = Path(args.output_dir)
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
sources = args.sources or SOURCES
if not data_dir.exists():
print(f"Error: data directory not found: {data_dir}")
sys.exit(1)
# Validate requested sources
for src in sources:
if not (data_dir / src).exists():
print(f"Error: source directory not found: {data_dir / src}")
sys.exit(1)
try:
import facenet_pytorch # noqa: F401
except ImportError:
print("Error: facenet_pytorch not installed.")
print(" Run: pip install facenet-pytorch")
sys.exit(1)
print(f"Data dir: {data_dir.resolve()}")
print(f"Output dir: {output_dir.resolve()}")
print(f"Sources: {', '.join(sources)}")
print(f"Device: {device}")
print(f"Margin: {args.margin} | Size: {args.size}px")
print(f"Skip exist: {args.skip_existing}")
detector, detector_relaxed = _get_detectors(device)
# Collect all image paths, grouped by source for per-source stats
all_paths: list[Path] = []
for src in sources:
for subdir in sorted((data_dir / src).iterdir()):
if subdir.is_dir():
all_paths.extend(sorted(subdir.glob("*.jpg")))
print(f"\nTotal images: {len(all_paths):,}\n")
n_processed = n_skipped = n_error = 0
# track per-source: detected / retry_detected / fallback
src_stats: dict[str, dict] = {s: {"detected": 0, "retry": 0, "fallback": 0} for s in sources}
for img_path in tqdm(all_paths, desc="Pre-cropping", unit="img"):
rel = img_path.relative_to(data_dir)
out_path = output_dir / rel
src_name = img_path.parent.parent.name # data/wiki/14/file.jpg -> wiki
if args.skip_existing and out_path.exists():
n_skipped += 1
continue
out_path.parent.mkdir(parents=True, exist_ok=True)
try:
img = Image.open(img_path).convert("RGB")
except Exception as exc:
tqdm.write(f"[WARN] Cannot open {img_path.name}: {exc}")
n_error += 1
continue
cropped = None
try:
# Pass 1: detect on original image
boxes, _ = detector.detect(img)
if boxes is not None and len(boxes) > 0:
cropped = _crop_face(img, boxes[0], args.margin, args.size)
src_stats[src_name]["detected"] += 1
else:
# Pass 2: upscale 2x and retry with relaxed thresholds
w, h = img.size
img2x = img.resize((w * 2, h * 2), Image.BILINEAR)
boxes2, _ = detector_relaxed.detect(img2x)
if boxes2 is not None and len(boxes2) > 0:
# boxes are in upscaled coords — divide by 2 to get original coords
box_orig = [v / 2 for v in boxes2[0]]
cropped = _crop_face(img, box_orig, args.margin, args.size)
src_stats[src_name]["retry"] += 1
else:
cropped = _center_crop(img, args.size)
src_stats[src_name]["fallback"] += 1
except Exception as exc:
tqdm.write(f"[WARN] Detection failed for {img_path.name}: {exc}")
cropped = _center_crop(img, args.size)
src_stats[src_name]["fallback"] += 1
cropped.save(out_path, quality=95)
n_processed += 1
total = n_processed + n_skipped
n_detected = sum(s["detected"] for s in src_stats.values())
n_retry = sum(s["retry"] for s in src_stats.values())
n_fallback = sum(s["fallback"] for s in src_stats.values())
denom = max(n_processed, 1)
print(f"\n{'' * 55}")
print(f" Total images : {total:>8,}")
print(f" Processed : {n_processed:>8,}")
print(f" Skipped (existed) : {n_skipped:>8,}")
print(f" Errors : {n_error:>8,}")
print(f" Pass-1 detected : {n_detected:>8,} ({n_detected / denom:.1%})")
print(f" Pass-2 detected : {n_retry:>8,} ({n_retry / denom:.1%}) ← 2x upscale retry")
print(f" Centre fallback : {n_fallback:>8,} ({n_fallback / denom:.1%})")
print()
print(f" {'Source':<12} {'pass-1':>8} {'pass-2':>8} {'fallback':>8} {'fallback%':>10}")
print(f" {''*12} {''*8} {''*8} {''*8} {''*10}")
for src in sources:
s = src_stats[src]
total_src = s["detected"] + s["retry"] + s["fallback"]
fb_pct = s["fallback"] / max(total_src, 1)
print(f" {src:<12} {s['detected']:>8,} {s['retry']:>8,} {s['fallback']:>8,} {fb_pct:>9.1%}")
print(f"{'' * 55}")
print(f" Output: {output_dir.resolve()}")
print()
print("Next step — update your config:")
print(f' "data_dir": "{output_dir}"')
print(f' remove "face_crop": true (images are already cropped)')
if __name__ == "__main__":
main()
+56
View File
@@ -0,0 +1,56 @@
"""
Download the DeepFakeFace dataset from HuggingFace and extract it.
Usage:
python tools/download_data.py
python tools/download_data.py --data-dir /mnt/data/DFF
"""
import argparse
import zipfile
from pathlib import Path
from huggingface_hub import snapshot_download
SOURCES = ["wiki", "inpainting", "text2img", "insight"]
def download(data_dir: Path) -> None:
print(f"Downloading dataset from HuggingFace into {data_dir}...")
snapshot_download(
repo_id="OpenRL/DeepFakeFace",
repo_type="dataset",
local_dir=data_dir,
)
for source in SOURCES:
zip_path = data_dir / f"{source}.zip"
target_dir = data_dir / source
if target_dir.exists():
print(f" {source}/ already extracted, skipping")
continue
if not zip_path.exists():
print(f" WARNING: {zip_path} not found, skipping")
continue
print(f" Extracting {zip_path.name}...")
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(data_dir)
print(f" Done -> {target_dir}")
print("\nVerifying...")
for source in SOURCES:
d = data_dir / source
count = sum(1 for _ in d.rglob("*.jpg")) if d.exists() else 0
print(f" {source}: {count} images")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--data-dir", default="data",
help="Directory to download into. Default: data",
)
args = parser.parse_args()
download(Path(args.data_dir))
+104
View File
@@ -0,0 +1,104 @@
"""
Grad-CAM visualization for trained classifiers.
Generates heatmaps showing which image regions the model focused on,
overlaid on the original image. Targets the last Conv2d layer automatically.
Usage (from notebook or script):
from tools.gradcam import save_overlays
from src.evaluation.evaluate import predict_rows
records = predict_rows(model, test_dataset, raw_ds, test_idx, batch_size=32, device="cpu")
save_overlays(model, records, cfg, output_dir=Path("outputs/gradcam"), device="cpu")
Output: one PNG per selected sample, named 01_<stem>.png, 02_<stem>.png, ...
top_k//2 false positives (real predicted fake) and top_k//2 false negatives.
"""
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from src.preprocessing import get_transforms
# Pick the last convolution layer as the Grad-CAM target
def find_conv(model):
for module in reversed(list(model.modules())):
if isinstance(module, nn.Conv2d):
return module
raise ValueError("Could not find a Conv2d layer for Grad-CAM.")
# Build a normalized Grad-CAM heatmap for one image tensor (shape: 1xCxHxW)
def gradcam_map(model, image_tensor, device):
activations = []
gradients = []
target = find_conv(model)
def on_fwd(_, __, output):
activations.append(output.detach())
def on_bwd(_, __, grad_output):
gradients.append(grad_output[0].detach())
h_fwd = target.register_forward_hook(on_fwd)
h_bwd = target.register_full_backward_hook(on_bwd)
model.zero_grad(set_to_none=True)
logits = model(image_tensor.to(device)).squeeze()
logits.backward()
h_fwd.remove()
h_bwd.remove()
grads = gradients[0][0]
acts = activations[0][0]
weights = grads.mean(dim=(1, 2), keepdim=True)
cam = torch.relu((weights * acts).sum(dim=0))
cam = cam - cam.min()
cam = cam / cam.max().clamp(min=1e-8)
return cam.cpu().numpy()
# Save side-by-side input and Grad-CAM overlays for top-confidence errors
def save_overlays(model, records, cfg, output_dir, device, *, top_k=8):
output_dir.mkdir(parents=True, exist_ok=True)
transform = get_transforms(train=False, image_size=cfg["image_size"])
false_pos = sorted(
(r for r in records if r["label"] == 0 and r["pred"] == 1),
key=lambda r: r["prob_fake"],
reverse=True,
)[: top_k // 2]
false_neg = sorted(
(r for r in records if r["label"] == 1 and r["pred"] == 0),
key=lambda r: r["prob_fake"],
)[: top_k // 2]
selected = [*false_pos, *false_neg]
total = len(selected)
for idx, record in enumerate(selected, start=1):
print(f"Grad-CAM: rendering {idx}/{total} for {Path(record['path']).name}")
image = Image.open(record["path"]).convert("RGB")
image_tensor = transform(image).unsqueeze(0)
heatmap = gradcam_map(model, image_tensor, device)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(np.asarray(image))
axes[0].set_title("Input")
axes[0].axis("off")
axes[1].imshow(np.asarray(image))
axes[1].imshow(heatmap, cmap="jet", alpha=0.4, extent=(0, image.width, image.height, 0))
axes[1].set_title(
f"{Path(record['path']).name}\ntrue={record['label']} pred={record['pred']} p={record['prob_fake']:.3f}"
)
axes[1].axis("off")
fig.tight_layout()
fig.savefig(output_dir / f"{idx:02d}_{Path(record['path']).stem}.png", dpi=160)
plt.close(fig)
+98
View File
@@ -0,0 +1,98 @@
"""
Run inference on a single image using a trained classifier.
Usage:
python tools/inference.py <image_path> <config.json>
python tools/inference.py <image_path> <config.json> --checkpoint <path>
"""
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from PIL import Image
from src.models import get_model, load_checkpoint
from src.preprocessing import get_transforms
# Defaults checkpoint to outputs/models/{run_name}_best.pt when not supplied
def predict(image_path, config_path, checkpoint_path=None):
image_path = Path(image_path)
config_path = Path(config_path)
if not image_path.exists():
print(f"Error: Image not found: {image_path}")
sys.exit(1)
if not config_path.exists():
print(f"Error: Config not found: {config_path}")
sys.exit(1)
try:
with open(config_path) as f:
cfg = json.load(f)
except json.JSONDecodeError as e:
print(f"Error: Invalid JSON in config: {e}")
sys.exit(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# pretrained=False — we're loading a saved checkpoint, not ImageNet weights
model = get_model({**cfg, "pretrained": False})
except Exception as e:
print(f"Error: Failed to build model: {e}")
sys.exit(1)
if checkpoint_path is None:
checkpoint_path = ROOT / "outputs" / "models" / f"{cfg['run_name']}_best.pt"
else:
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
print(f"Error: Checkpoint not found: {checkpoint_path}")
sys.exit(1)
try:
load_checkpoint(model, checkpoint_path, device)
except Exception as e:
print(f"Error: Failed to load checkpoint: {e}")
sys.exit(1)
model.eval().to(device)
try:
transform = get_transforms(train=False, image_size=cfg["image_size"])
image = Image.open(image_path).convert("RGB")
tensor = transform(image).unsqueeze(0).to(device)
except Exception as e:
print(f"Error: Failed to load/preprocess image: {e}")
sys.exit(1)
with torch.no_grad():
logit = model(tensor).squeeze()
prob = torch.sigmoid(logit).item()
label = "FAKE" if prob >= 0.5 else "REAL"
confidence = prob if prob >= 0.5 else 1 - prob
print(f"Image : {image_path}")
print(f"Model : {cfg['run_name']} ({cfg['backbone']})")
print(f"Device: {device}")
print(f"Result: {label} (confidence: {confidence:.1%})")
print(f"P(fake): {prob:.4f} P(real): {1-prob:.4f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("image_path", help="Path to the input image")
parser.add_argument("config_path", help="Path to the model config JSON")
parser.add_argument("--checkpoint", help="Optional path to model checkpoint")
args = parser.parse_args()
predict(args.image_path, args.config_path, args.checkpoint)
+288
View File
@@ -0,0 +1,288 @@
"""
Re-evaluate existing trained checkpoints with per-source metrics.
Loads each config, rebuilds CV splits (deterministic), loads the _best.pt
checkpoint per fold, runs predict_rows, and writes updated log files
with aggregate + per-source + pairwise metrics.
Usage:
python tools/reevaluate.py # re-evaluate all experiments
python tools/reevaluate.py p1_resnet18_baseline # specific experiments
python tools/reevaluate.py --data-dir /mnt/data/DFF
python tools/reevaluate.py --use-gpu
"""
import argparse
import json
import sys
import warnings
from pathlib import Path
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
# Ensure classifier/ is on sys.path so `src.*` imports work
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
# Config paths (data_dir) are relative to the project root
PROJECT_ROOT = ROOT.parent
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"run_names", nargs="*",
help="Run names to re-evaluate (matches log filenames). Default: all.",
)
parser.add_argument(
"--data-dir", default=None,
help="Override cfg['data_dir'] for this run.",
)
parser.add_argument(
"--output-root", default="outputs",
help="Directory where models/logs live. Default: outputs",
)
parser.add_argument(
"--use-gpu", action="store_true",
help="Use GPU for evaluation.",
)
return parser.parse_args()
# Map run_name -> config path (relative to classifier/)
CONFIG_MAP = {
# Phase 1
"p1_resnet18_baseline": "configs/phase1/p1_resnet18_baseline.json",
"p1_simplecnn_baseline": "configs/phase1/p1_simplecnn_baseline.json",
# Phase 2a shortcut / holdout
"p2a_t1_original": "configs/phase2/p2a_t1_original.json",
"p2a_t2_real_norm": "configs/phase2/p2a_t2_real_norm.json",
"p2a_t3_holdout_text2img": "configs/phase2/p2a_t3_holdout_text2img.json",
"p2a_t3_holdout_inpainting": "configs/phase2/p2a_t3_holdout_inpainting.json",
"p2a_t3_holdout_insight": "configs/phase2/p2a_t3_holdout_insight.json",
# Phase 2b resolution
"p2b_resnet18_224": "configs/phase2/p2b_resnet18_224.json",
"p2b_simplecnn_224": "configs/phase2/p2b_simplecnn_224.json",
# Phase 2c face crop
"p2c_resnet18_facecrop": "configs/phase2/p2c_resnet18_facecrop.json",
"p2c_simplecnn_facecrop": "configs/phase2/p2c_simplecnn_facecrop.json",
# Phase 2d augmentation
"p2d_resnet18_aug": "configs/phase2/p2d_resnet18_aug.json",
"p2d_simplecnn_aug": "configs/phase2/p2d_simplecnn_aug.json",
# Phase 2e face crop + aug
"p2e_resnet18_facecrop_aug": "configs/phase2/p2e_resnet18_facecrop_aug.json",
"p2e_simplecnn_facecrop_aug": "configs/phase2/p2e_simplecnn_facecrop_aug.json",
}
def main():
args = parse_args()
import numpy as np
import torch
from src.data import DFFDataset, apply_subsample, build_transforms, get_splits
from src.evaluation.evaluate import predict_rows
from src.evaluation.metrics import calc_metrics, source_metrics, pair_metrics
from src.models import get_model
from src.utils import load_config
from src.utils.cross_validation import aggregate_fold_metrics
output_root = Path(args.output_root)
logs_dir = output_root / "logs"
models_dir = output_root / "models"
# Determine which experiments to re-evaluate
if args.run_names:
run_names = args.run_names
else:
run_names = sorted(
p.stem for p in logs_dir.glob("*.json")
if p.stem in CONFIG_MAP
)
device = "cuda" if args.use_gpu and torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
for run_name in run_names:
config_rel = CONFIG_MAP.get(run_name)
if config_rel is None:
print(f"\nSkipping {run_name}: no config mapping")
continue
config_path = ROOT / config_rel
if not config_path.exists():
print(f"\nSkipping {run_name}: config not found ({config_rel})")
continue
# Check that at least one checkpoint exists
checkpoints = sorted(models_dir.glob(f"{run_name}_fold*_best.pt"))
if not checkpoints:
print(f"\nSkipping {run_name}: no checkpoints found")
continue
print(f"\n{'='*60}")
print(f"Re-evaluating: {run_name}")
print(f" Config: {config_rel}")
print(f" Checkpoints: {len(checkpoints)}")
print(f"{'='*60}")
# Load config
cfg = load_config(config_path)
seed = cfg.get("seed", 42)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
deterministic = cfg.get("deterministic", False)
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
data_dir = args.data_dir or cfg.get("data_dir", "data")
# Config paths are relative to the project root, not classifier/
if not Path(data_dir).is_absolute():
data_dir = str(PROJECT_ROOT / data_dir)
# Build dataset
raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources"))
sampled = apply_subsample(raw_ds, cfg)
if sampled is not None:
n_samples, total = sampled
print(f" Subsampled to {n_samples}/{total} samples")
# Build CV splits and transforms (deterministic same as training)
splits = get_splits(raw_ds, cfg)
transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment"))
n_folds = len(splits)
fold_results = []
for fold_idx in range(n_folds):
train_idx, val_idx, test_idx = splits[fold_idx]
checkpoint_path = models_dir / f"{run_name}_fold{fold_idx}_best.pt"
if not checkpoint_path.exists():
print(f" Fold {fold_idx}: checkpoint missing, skipping")
continue
# Rebuild model and load checkpoint
model = get_model(cfg)
model.load_state_dict(
torch.load(checkpoint_path, map_location=device, weights_only=True)
)
model.to(device).eval()
# Build test dataset for this fold
if cfg.get("normalization") == "real_norm":
from src.preprocessing.pipeline import compute_real_stats
norm_mean, norm_std = compute_real_stats(raw_ds, train_idx)
else:
norm_mean = norm_std = None
test_dataset = transform_builder(
test_idx, train=False,
normalize_mean=norm_mean, normalize_std=norm_std,
)
records = predict_rows(
model, test_dataset, raw_ds, test_idx,
cfg["batch_size"], device, num_workers=4,
)
# Compute metrics
test_metrics = calc_metrics(records)
src_metrics = source_metrics(records)
pairwise = pair_metrics(records)
fold_result = {
"fold": fold_idx,
"train_size": len(train_idx),
"val_size": len(val_idx),
"test_size": len(test_idx),
"test_metrics": test_metrics,
"source_metrics": src_metrics,
"pair_metrics": pairwise,
}
fold_results.append(fold_result)
print(f" Fold {fold_idx}: auc={test_metrics.get('auc_roc', '?'):.4f} "
f"acc={test_metrics.get('accuracy', '?'):.4f} "
f"f1={test_metrics.get('f1', '?'):.4f}")
for source, sm in sorted(src_metrics.items()):
pa = sm.get("pairwise_auc")
dr = sm.get("detection_rate")
label = (f"pairwise_auc={pa:.4f}" if pa is not None
else f"detection_rate={dr:.4f}" if dr is not None else "")
print(f" {source}: {label}")
if not fold_results:
print(f" No folds evaluated for {run_name}")
continue
# Aggregate across folds
test_metrics_list = [f["test_metrics"] for f in fold_results]
aggregated = aggregate_fold_metrics(test_metrics_list)
# Aggregate per-source metrics
all_sources = sorted({s for f in fold_results for s in f["source_metrics"]})
aggregated_per_source = {}
for source in all_sources:
source_fold_metrics = []
for f in fold_results:
sm = f["source_metrics"].get(source)
if sm:
source_fold_metrics.append({
k: v for k, v in sm.items()
if isinstance(v, (int, float)) and k != "fold"
})
if source_fold_metrics:
aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics)
# Aggregate pairwise metrics
all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]})
aggregated_pairwise = {}
for pair in all_pairs:
pair_fold_metrics = []
for f in fold_results:
pm = f["pair_metrics"].get(pair)
if pm:
pair_fold_metrics.append({
k: v for k, v in pm.items()
if isinstance(v, (int, float)) and k not in ("fold", "n")
})
if pair_fold_metrics:
aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics)
# Load existing log to preserve training history
log_path = logs_dir / f"{run_name}.json"
if log_path.exists():
with open(log_path) as f:
existing = json.load(f)
# Keep the training history from the original log
for fr_new in fold_results:
for fr_old in existing.get("fold_results", []):
if fr_old["fold"] == fr_new["fold"]:
fr_new["history"] = fr_old.get("history")
break
else:
existing = {}
results = {
"run_name": run_name,
"n_folds": n_folds,
"fold_results": fold_results,
"aggregated_metrics": aggregated,
"aggregated_per_source": aggregated_per_source,
"aggregated_pairwise": aggregated_pairwise,
"config": existing.get("config", cfg),
}
with open(log_path, "w") as f:
json.dump(results, f, indent=2)
print(f" Saved: {log_path}")
print("\nDone.")
if __name__ == "__main__":
main()