Clean state
This commit is contained in:
@@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Check expected classifier result artifacts.
|
||||
|
||||
Usage:
|
||||
python tools/artifact_chk.py
|
||||
python tools/artifact_chk.py --output-root outputs
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument("--output-root", default="outputs", help="Root with logs/models folders")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def iter_config_paths(config_root: Path):
|
||||
for sub in ("phase1", "phase2"):
|
||||
yield from sorted((config_root / sub).glob("*.json"))
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
root = Path(__file__).resolve().parent.parent
|
||||
config_root = root / "configs"
|
||||
logs_dir = root / args.output_root / "logs"
|
||||
models_dir = root / args.output_root / "models"
|
||||
|
||||
expected = []
|
||||
for cfg_path in iter_config_paths(config_root):
|
||||
with open(cfg_path) as f:
|
||||
cfg = json.load(f)
|
||||
run_name = cfg.get("run_name")
|
||||
if run_name:
|
||||
expected.append((run_name, cfg_path))
|
||||
|
||||
missing_logs = []
|
||||
missing_models = []
|
||||
for run_name, cfg_path in expected:
|
||||
if not (logs_dir / f"{run_name}.json").exists():
|
||||
missing_logs.append((run_name, cfg_path))
|
||||
if not any(models_dir.glob(f"{run_name}_fold*_best.pt")):
|
||||
missing_models.append((run_name, cfg_path))
|
||||
|
||||
print(f"Expected runs from configs: {len(expected)}")
|
||||
print(f"Missing logs: {len(missing_logs)}")
|
||||
for run_name, cfg_path in missing_logs:
|
||||
print(f" - {run_name} ({cfg_path.relative_to(root)})")
|
||||
print(f"Missing checkpoints: {len(missing_models)}")
|
||||
for run_name, cfg_path in missing_models:
|
||||
print(f" - {run_name} ({cfg_path.relative_to(root)})")
|
||||
|
||||
if missing_logs or missing_models:
|
||||
raise SystemExit(1)
|
||||
print("All expected artifacts found.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pre-crop face images using MTCNN and save to a new directory.
|
||||
|
||||
Runs face detection once over the dataset and saves cropped images to disk.
|
||||
Training configs can then point at the pre-cropped directory — no per-epoch
|
||||
MTCNN overhead during training.
|
||||
|
||||
The output mirrors the source structure exactly:
|
||||
data/wiki/14/37591914.jpg -> cropped/classifier/wiki/14/37591914.jpg
|
||||
|
||||
Resumable: already-cropped images are skipped by default.
|
||||
|
||||
Usage:
|
||||
python tools/facecrop.py
|
||||
python tools/facecrop.py --data-dir data --output-dir cropped/classifier
|
||||
python tools/facecrop.py --sources wiki inpainting --device cpu
|
||||
python tools/facecrop.py --no-skip-existing # reprocess everything
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
# Suppress facenet_pytorch's torch.load FutureWarning — not fixable externally.
|
||||
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
SOURCES = ["wiki", "inpainting", "text2img", "insight"]
|
||||
_DETECTORS: dict[tuple[str, str], object] = {}
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
|
||||
p.add_argument("--output-dir", default="cropped/classifier", help="Output root (default: cropped/classifier)")
|
||||
p.add_argument("--margin", type=float, default=0.6, help="Face box margin as fraction of box size (default: 0.6)")
|
||||
p.add_argument("--size", type=int, default=224, help="Output image size in px, square (default: 224)")
|
||||
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
|
||||
p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE",
|
||||
help=f"Only process these sources. Default: all ({', '.join(SOURCES)})")
|
||||
p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True,
|
||||
help="Skip images already present in output-dir (default: on, resumable)")
|
||||
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
|
||||
help="Re-process all images even if already cropped")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
# ── crop helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _crop_face(img, box, margin: float, size: int):
|
||||
from PIL import Image as PILImage
|
||||
x1, y1, x2, y2 = [float(v) for v in box]
|
||||
bw, bh = x2 - x1, y2 - y1
|
||||
mx, my = bw * margin / 2, bh * margin / 2
|
||||
x1 -= mx; y1 -= my; x2 += mx; y2 += my
|
||||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||
side = max(x2 - x1, y2 - y1)
|
||||
x1, y1 = cx - side / 2, cy - side / 2
|
||||
x2, y2 = x1 + side, y1 + side
|
||||
w, h = img.size
|
||||
x1, y1 = max(0, x1), max(0, y1)
|
||||
x2, y2 = min(w, x2), min(h, y2)
|
||||
return img.crop((int(x1), int(y1), int(x2), int(y2))).resize((size, size), PILImage.BILINEAR)
|
||||
|
||||
|
||||
def _center_crop(img, size: int):
|
||||
from PIL import Image as PILImage
|
||||
w, h = img.size
|
||||
side = min(w, h)
|
||||
left, top = (w - side) // 2, (h - side) // 2
|
||||
return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR)
|
||||
|
||||
|
||||
def _get_detectors(device: str):
|
||||
key_std = ("std", device)
|
||||
key_relaxed = ("relaxed", device)
|
||||
if key_std in _DETECTORS and key_relaxed in _DETECTORS:
|
||||
return _DETECTORS[key_std], _DETECTORS[key_relaxed]
|
||||
|
||||
from facenet_pytorch import MTCNN
|
||||
|
||||
detector = MTCNN(
|
||||
keep_all=False, select_largest=True,
|
||||
min_face_size=15,
|
||||
device=device, post_process=False,
|
||||
)
|
||||
detector_relaxed = MTCNN(
|
||||
keep_all=False, select_largest=True,
|
||||
min_face_size=10,
|
||||
thresholds=[0.5, 0.6, 0.6],
|
||||
device=device, post_process=False,
|
||||
)
|
||||
_DETECTORS[key_std] = detector
|
||||
_DETECTORS[key_relaxed] = detector_relaxed
|
||||
return detector, detector_relaxed
|
||||
|
||||
|
||||
class FaceCropper:
|
||||
"""Reusable face cropper for notebooks/tools (not training pipeline)."""
|
||||
|
||||
def __init__(self, margin: float = 0.6, size: int = 224, device: str | None = None):
|
||||
import torch
|
||||
|
||||
self.margin = margin
|
||||
self.size = size
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def __call__(self, img):
|
||||
from PIL import Image as PILImage
|
||||
|
||||
detector, detector_relaxed = _get_detectors(self.device)
|
||||
|
||||
boxes, _ = detector.detect(img)
|
||||
if boxes is not None and len(boxes) > 0:
|
||||
return _crop_face(img, boxes[0], self.margin, self.size)
|
||||
|
||||
w, h = img.size
|
||||
img2x = img.resize((w * 2, h * 2), PILImage.BILINEAR)
|
||||
boxes2, _ = detector_relaxed.detect(img2x)
|
||||
if boxes2 is not None and len(boxes2) > 0:
|
||||
box_orig = [v / 2 for v in boxes2[0]]
|
||||
return _crop_face(img, box_orig, self.margin, self.size)
|
||||
return _center_crop(img, self.size)
|
||||
|
||||
|
||||
# ── main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
data_dir = Path(args.data_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
sources = args.sources or SOURCES
|
||||
|
||||
if not data_dir.exists():
|
||||
print(f"Error: data directory not found: {data_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate requested sources
|
||||
for src in sources:
|
||||
if not (data_dir / src).exists():
|
||||
print(f"Error: source directory not found: {data_dir / src}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import facenet_pytorch # noqa: F401
|
||||
except ImportError:
|
||||
print("Error: facenet_pytorch not installed.")
|
||||
print(" Run: pip install facenet-pytorch")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Data dir: {data_dir.resolve()}")
|
||||
print(f"Output dir: {output_dir.resolve()}")
|
||||
print(f"Sources: {', '.join(sources)}")
|
||||
print(f"Device: {device}")
|
||||
print(f"Margin: {args.margin} | Size: {args.size}px")
|
||||
print(f"Skip exist: {args.skip_existing}")
|
||||
|
||||
detector, detector_relaxed = _get_detectors(device)
|
||||
|
||||
# Collect all image paths, grouped by source for per-source stats
|
||||
all_paths: list[Path] = []
|
||||
for src in sources:
|
||||
for subdir in sorted((data_dir / src).iterdir()):
|
||||
if subdir.is_dir():
|
||||
all_paths.extend(sorted(subdir.glob("*.jpg")))
|
||||
|
||||
print(f"\nTotal images: {len(all_paths):,}\n")
|
||||
|
||||
n_processed = n_skipped = n_error = 0
|
||||
# track per-source: detected / retry_detected / fallback
|
||||
src_stats: dict[str, dict] = {s: {"detected": 0, "retry": 0, "fallback": 0} for s in sources}
|
||||
|
||||
for img_path in tqdm(all_paths, desc="Pre-cropping", unit="img"):
|
||||
rel = img_path.relative_to(data_dir)
|
||||
out_path = output_dir / rel
|
||||
src_name = img_path.parent.parent.name # data/wiki/14/file.jpg -> wiki
|
||||
|
||||
if args.skip_existing and out_path.exists():
|
||||
n_skipped += 1
|
||||
continue
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
except Exception as exc:
|
||||
tqdm.write(f"[WARN] Cannot open {img_path.name}: {exc}")
|
||||
n_error += 1
|
||||
continue
|
||||
|
||||
cropped = None
|
||||
try:
|
||||
# Pass 1: detect on original image
|
||||
boxes, _ = detector.detect(img)
|
||||
if boxes is not None and len(boxes) > 0:
|
||||
cropped = _crop_face(img, boxes[0], args.margin, args.size)
|
||||
src_stats[src_name]["detected"] += 1
|
||||
else:
|
||||
# Pass 2: upscale 2x and retry with relaxed thresholds
|
||||
w, h = img.size
|
||||
img2x = img.resize((w * 2, h * 2), Image.BILINEAR)
|
||||
boxes2, _ = detector_relaxed.detect(img2x)
|
||||
if boxes2 is not None and len(boxes2) > 0:
|
||||
# boxes are in upscaled coords — divide by 2 to get original coords
|
||||
box_orig = [v / 2 for v in boxes2[0]]
|
||||
cropped = _crop_face(img, box_orig, args.margin, args.size)
|
||||
src_stats[src_name]["retry"] += 1
|
||||
else:
|
||||
cropped = _center_crop(img, args.size)
|
||||
src_stats[src_name]["fallback"] += 1
|
||||
except Exception as exc:
|
||||
tqdm.write(f"[WARN] Detection failed for {img_path.name}: {exc}")
|
||||
cropped = _center_crop(img, args.size)
|
||||
src_stats[src_name]["fallback"] += 1
|
||||
|
||||
cropped.save(out_path, quality=95)
|
||||
n_processed += 1
|
||||
|
||||
total = n_processed + n_skipped
|
||||
n_detected = sum(s["detected"] for s in src_stats.values())
|
||||
n_retry = sum(s["retry"] for s in src_stats.values())
|
||||
n_fallback = sum(s["fallback"] for s in src_stats.values())
|
||||
denom = max(n_processed, 1)
|
||||
|
||||
print(f"\n{'─' * 55}")
|
||||
print(f" Total images : {total:>8,}")
|
||||
print(f" Processed : {n_processed:>8,}")
|
||||
print(f" Skipped (existed) : {n_skipped:>8,}")
|
||||
print(f" Errors : {n_error:>8,}")
|
||||
print(f" Pass-1 detected : {n_detected:>8,} ({n_detected / denom:.1%})")
|
||||
print(f" Pass-2 detected : {n_retry:>8,} ({n_retry / denom:.1%}) ← 2x upscale retry")
|
||||
print(f" Centre fallback : {n_fallback:>8,} ({n_fallback / denom:.1%})")
|
||||
print()
|
||||
print(f" {'Source':<12} {'pass-1':>8} {'pass-2':>8} {'fallback':>8} {'fallback%':>10}")
|
||||
print(f" {'─'*12} {'─'*8} {'─'*8} {'─'*8} {'─'*10}")
|
||||
for src in sources:
|
||||
s = src_stats[src]
|
||||
total_src = s["detected"] + s["retry"] + s["fallback"]
|
||||
fb_pct = s["fallback"] / max(total_src, 1)
|
||||
print(f" {src:<12} {s['detected']:>8,} {s['retry']:>8,} {s['fallback']:>8,} {fb_pct:>9.1%}")
|
||||
print(f"{'─' * 55}")
|
||||
print(f" Output: {output_dir.resolve()}")
|
||||
print()
|
||||
print("Next step — update your config:")
|
||||
print(f' "data_dir": "{output_dir}"')
|
||||
print(f' remove "face_crop": true (images are already cropped)')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Download the DeepFakeFace dataset from HuggingFace and extract it.
|
||||
|
||||
Usage:
|
||||
python tools/download_data.py
|
||||
python tools/download_data.py --data-dir /mnt/data/DFF
|
||||
"""
|
||||
import argparse
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
SOURCES = ["wiki", "inpainting", "text2img", "insight"]
|
||||
|
||||
|
||||
def download(data_dir: Path) -> None:
|
||||
print(f"Downloading dataset from HuggingFace into {data_dir}...")
|
||||
snapshot_download(
|
||||
repo_id="OpenRL/DeepFakeFace",
|
||||
repo_type="dataset",
|
||||
local_dir=data_dir,
|
||||
)
|
||||
|
||||
for source in SOURCES:
|
||||
zip_path = data_dir / f"{source}.zip"
|
||||
target_dir = data_dir / source
|
||||
|
||||
if target_dir.exists():
|
||||
print(f" {source}/ already extracted, skipping")
|
||||
continue
|
||||
|
||||
if not zip_path.exists():
|
||||
print(f" WARNING: {zip_path} not found, skipping")
|
||||
continue
|
||||
|
||||
print(f" Extracting {zip_path.name}...")
|
||||
with zipfile.ZipFile(zip_path, "r") as z:
|
||||
z.extractall(data_dir)
|
||||
print(f" Done -> {target_dir}")
|
||||
|
||||
print("\nVerifying...")
|
||||
for source in SOURCES:
|
||||
d = data_dir / source
|
||||
count = sum(1 for _ in d.rglob("*.jpg")) if d.exists() else 0
|
||||
print(f" {source}: {count} images")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--data-dir", default="data",
|
||||
help="Directory to download into. Default: data",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
download(Path(args.data_dir))
|
||||
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Grad-CAM visualization for trained classifiers.
|
||||
|
||||
Generates heatmaps showing which image regions the model focused on,
|
||||
overlaid on the original image. Targets the last Conv2d layer automatically.
|
||||
|
||||
Usage (from notebook or script):
|
||||
from tools.gradcam import save_overlays
|
||||
from src.evaluation.evaluate import predict_rows
|
||||
|
||||
records = predict_rows(model, test_dataset, raw_ds, test_idx, batch_size=32, device="cpu")
|
||||
save_overlays(model, records, cfg, output_dir=Path("outputs/gradcam"), device="cpu")
|
||||
|
||||
Output: one PNG per selected sample, named 01_<stem>.png, 02_<stem>.png, ...
|
||||
top_k//2 false positives (real predicted fake) and top_k//2 false negatives.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
from src.preprocessing import get_transforms
|
||||
|
||||
|
||||
# Pick the last convolution layer as the Grad-CAM target
|
||||
def find_conv(model):
|
||||
for module in reversed(list(model.modules())):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
return module
|
||||
raise ValueError("Could not find a Conv2d layer for Grad-CAM.")
|
||||
|
||||
|
||||
# Build a normalized Grad-CAM heatmap for one image tensor (shape: 1xCxHxW)
|
||||
def gradcam_map(model, image_tensor, device):
|
||||
activations = []
|
||||
gradients = []
|
||||
target = find_conv(model)
|
||||
|
||||
def on_fwd(_, __, output):
|
||||
activations.append(output.detach())
|
||||
|
||||
def on_bwd(_, __, grad_output):
|
||||
gradients.append(grad_output[0].detach())
|
||||
|
||||
h_fwd = target.register_forward_hook(on_fwd)
|
||||
h_bwd = target.register_full_backward_hook(on_bwd)
|
||||
|
||||
model.zero_grad(set_to_none=True)
|
||||
logits = model(image_tensor.to(device)).squeeze()
|
||||
logits.backward()
|
||||
|
||||
h_fwd.remove()
|
||||
h_bwd.remove()
|
||||
|
||||
grads = gradients[0][0]
|
||||
acts = activations[0][0]
|
||||
weights = grads.mean(dim=(1, 2), keepdim=True)
|
||||
cam = torch.relu((weights * acts).sum(dim=0))
|
||||
cam = cam - cam.min()
|
||||
cam = cam / cam.max().clamp(min=1e-8)
|
||||
return cam.cpu().numpy()
|
||||
|
||||
|
||||
# Save side-by-side input and Grad-CAM overlays for top-confidence errors
|
||||
def save_overlays(model, records, cfg, output_dir, device, *, top_k=8):
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
transform = get_transforms(train=False, image_size=cfg["image_size"])
|
||||
|
||||
false_pos = sorted(
|
||||
(r for r in records if r["label"] == 0 and r["pred"] == 1),
|
||||
key=lambda r: r["prob_fake"],
|
||||
reverse=True,
|
||||
)[: top_k // 2]
|
||||
false_neg = sorted(
|
||||
(r for r in records if r["label"] == 1 and r["pred"] == 0),
|
||||
key=lambda r: r["prob_fake"],
|
||||
)[: top_k // 2]
|
||||
selected = [*false_pos, *false_neg]
|
||||
|
||||
total = len(selected)
|
||||
for idx, record in enumerate(selected, start=1):
|
||||
print(f"Grad-CAM: rendering {idx}/{total} for {Path(record['path']).name}")
|
||||
image = Image.open(record["path"]).convert("RGB")
|
||||
image_tensor = transform(image).unsqueeze(0)
|
||||
heatmap = gradcam_map(model, image_tensor, device)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
|
||||
axes[0].imshow(np.asarray(image))
|
||||
axes[0].set_title("Input")
|
||||
axes[0].axis("off")
|
||||
|
||||
axes[1].imshow(np.asarray(image))
|
||||
axes[1].imshow(heatmap, cmap="jet", alpha=0.4, extent=(0, image.width, image.height, 0))
|
||||
axes[1].set_title(
|
||||
f"{Path(record['path']).name}\ntrue={record['label']} pred={record['pred']} p={record['prob_fake']:.3f}"
|
||||
)
|
||||
axes[1].axis("off")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_dir / f"{idx:02d}_{Path(record['path']).stem}.png", dpi=160)
|
||||
plt.close(fig)
|
||||
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Run inference on a single image using a trained classifier.
|
||||
|
||||
Usage:
|
||||
python tools/inference.py <image_path> <config.json>
|
||||
python tools/inference.py <image_path> <config.json> --checkpoint <path>
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from src.models import get_model, load_checkpoint
|
||||
from src.preprocessing import get_transforms
|
||||
|
||||
|
||||
# Defaults checkpoint to outputs/models/{run_name}_best.pt when not supplied
|
||||
def predict(image_path, config_path, checkpoint_path=None):
|
||||
image_path = Path(image_path)
|
||||
config_path = Path(config_path)
|
||||
|
||||
if not image_path.exists():
|
||||
print(f"Error: Image not found: {image_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not config_path.exists():
|
||||
print(f"Error: Config not found: {config_path}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error: Invalid JSON in config: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
try:
|
||||
# pretrained=False — we're loading a saved checkpoint, not ImageNet weights
|
||||
model = get_model({**cfg, "pretrained": False})
|
||||
except Exception as e:
|
||||
print(f"Error: Failed to build model: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if checkpoint_path is None:
|
||||
checkpoint_path = ROOT / "outputs" / "models" / f"{cfg['run_name']}_best.pt"
|
||||
else:
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
print(f"Error: Checkpoint not found: {checkpoint_path}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
load_checkpoint(model, checkpoint_path, device)
|
||||
except Exception as e:
|
||||
print(f"Error: Failed to load checkpoint: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
model.eval().to(device)
|
||||
|
||||
try:
|
||||
transform = get_transforms(train=False, image_size=cfg["image_size"])
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
tensor = transform(image).unsqueeze(0).to(device)
|
||||
except Exception as e:
|
||||
print(f"Error: Failed to load/preprocess image: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
with torch.no_grad():
|
||||
logit = model(tensor).squeeze()
|
||||
prob = torch.sigmoid(logit).item()
|
||||
|
||||
label = "FAKE" if prob >= 0.5 else "REAL"
|
||||
confidence = prob if prob >= 0.5 else 1 - prob
|
||||
|
||||
print(f"Image : {image_path}")
|
||||
print(f"Model : {cfg['run_name']} ({cfg['backbone']})")
|
||||
print(f"Device: {device}")
|
||||
print(f"Result: {label} (confidence: {confidence:.1%})")
|
||||
print(f"P(fake): {prob:.4f} P(real): {1-prob:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("image_path", help="Path to the input image")
|
||||
parser.add_argument("config_path", help="Path to the model config JSON")
|
||||
parser.add_argument("--checkpoint", help="Optional path to model checkpoint")
|
||||
args = parser.parse_args()
|
||||
predict(args.image_path, args.config_path, args.checkpoint)
|
||||
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Re-evaluate existing trained checkpoints with per-source metrics.
|
||||
|
||||
Loads each config, rebuilds CV splits (deterministic), loads the _best.pt
|
||||
checkpoint per fold, runs predict_rows, and writes updated log files
|
||||
with aggregate + per-source + pairwise metrics.
|
||||
|
||||
Usage:
|
||||
python tools/reevaluate.py # re-evaluate all experiments
|
||||
python tools/reevaluate.py p1_resnet18_baseline # specific experiments
|
||||
python tools/reevaluate.py --data-dir /mnt/data/DFF
|
||||
python tools/reevaluate.py --use-gpu
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
|
||||
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
|
||||
|
||||
# Ensure classifier/ is on sys.path so `src.*` imports work
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
# Config paths (data_dir) are relative to the project root
|
||||
PROJECT_ROOT = ROOT.parent
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"run_names", nargs="*",
|
||||
help="Run names to re-evaluate (matches log filenames). Default: all.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir", default=None,
|
||||
help="Override cfg['data_dir'] for this run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-root", default="outputs",
|
||||
help="Directory where models/logs live. Default: outputs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gpu", action="store_true",
|
||||
help="Use GPU for evaluation.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# Map run_name -> config path (relative to classifier/)
|
||||
CONFIG_MAP = {
|
||||
# Phase 1
|
||||
"p1_resnet18_baseline": "configs/phase1/p1_resnet18_baseline.json",
|
||||
"p1_simplecnn_baseline": "configs/phase1/p1_simplecnn_baseline.json",
|
||||
# Phase 2a – shortcut / holdout
|
||||
"p2a_t1_original": "configs/phase2/p2a_t1_original.json",
|
||||
"p2a_t2_real_norm": "configs/phase2/p2a_t2_real_norm.json",
|
||||
"p2a_t3_holdout_text2img": "configs/phase2/p2a_t3_holdout_text2img.json",
|
||||
"p2a_t3_holdout_inpainting": "configs/phase2/p2a_t3_holdout_inpainting.json",
|
||||
"p2a_t3_holdout_insight": "configs/phase2/p2a_t3_holdout_insight.json",
|
||||
# Phase 2b – resolution
|
||||
"p2b_resnet18_224": "configs/phase2/p2b_resnet18_224.json",
|
||||
"p2b_simplecnn_224": "configs/phase2/p2b_simplecnn_224.json",
|
||||
# Phase 2c – face crop
|
||||
"p2c_resnet18_facecrop": "configs/phase2/p2c_resnet18_facecrop.json",
|
||||
"p2c_simplecnn_facecrop": "configs/phase2/p2c_simplecnn_facecrop.json",
|
||||
# Phase 2d – augmentation
|
||||
"p2d_resnet18_aug": "configs/phase2/p2d_resnet18_aug.json",
|
||||
"p2d_simplecnn_aug": "configs/phase2/p2d_simplecnn_aug.json",
|
||||
# Phase 2e – face crop + aug
|
||||
"p2e_resnet18_facecrop_aug": "configs/phase2/p2e_resnet18_facecrop_aug.json",
|
||||
"p2e_simplecnn_facecrop_aug": "configs/phase2/p2e_simplecnn_facecrop_aug.json",
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
import numpy as np
|
||||
import torch
|
||||
from src.data import DFFDataset, apply_subsample, build_transforms, get_splits
|
||||
from src.evaluation.evaluate import predict_rows
|
||||
from src.evaluation.metrics import calc_metrics, source_metrics, pair_metrics
|
||||
from src.models import get_model
|
||||
from src.utils import load_config
|
||||
from src.utils.cross_validation import aggregate_fold_metrics
|
||||
|
||||
output_root = Path(args.output_root)
|
||||
logs_dir = output_root / "logs"
|
||||
models_dir = output_root / "models"
|
||||
|
||||
# Determine which experiments to re-evaluate
|
||||
if args.run_names:
|
||||
run_names = args.run_names
|
||||
else:
|
||||
run_names = sorted(
|
||||
p.stem for p in logs_dir.glob("*.json")
|
||||
if p.stem in CONFIG_MAP
|
||||
)
|
||||
|
||||
device = "cuda" if args.use_gpu and torch.cuda.is_available() else "cpu"
|
||||
print(f"Device: {device}")
|
||||
|
||||
for run_name in run_names:
|
||||
config_rel = CONFIG_MAP.get(run_name)
|
||||
if config_rel is None:
|
||||
print(f"\nSkipping {run_name}: no config mapping")
|
||||
continue
|
||||
|
||||
config_path = ROOT / config_rel
|
||||
if not config_path.exists():
|
||||
print(f"\nSkipping {run_name}: config not found ({config_rel})")
|
||||
continue
|
||||
|
||||
# Check that at least one checkpoint exists
|
||||
checkpoints = sorted(models_dir.glob(f"{run_name}_fold*_best.pt"))
|
||||
if not checkpoints:
|
||||
print(f"\nSkipping {run_name}: no checkpoints found")
|
||||
continue
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Re-evaluating: {run_name}")
|
||||
print(f" Config: {config_rel}")
|
||||
print(f" Checkpoints: {len(checkpoints)}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Load config
|
||||
cfg = load_config(config_path)
|
||||
seed = cfg.get("seed", 42)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
deterministic = cfg.get("deterministic", False)
|
||||
torch.backends.cudnn.deterministic = deterministic
|
||||
torch.backends.cudnn.benchmark = not deterministic
|
||||
|
||||
data_dir = args.data_dir or cfg.get("data_dir", "data")
|
||||
# Config paths are relative to the project root, not classifier/
|
||||
if not Path(data_dir).is_absolute():
|
||||
data_dir = str(PROJECT_ROOT / data_dir)
|
||||
|
||||
# Build dataset
|
||||
raw_ds = DFFDataset(data_dir, sources=cfg.get("dataset_sources"))
|
||||
|
||||
sampled = apply_subsample(raw_ds, cfg)
|
||||
if sampled is not None:
|
||||
n_samples, total = sampled
|
||||
print(f" Subsampled to {n_samples}/{total} samples")
|
||||
|
||||
# Build CV splits and transforms (deterministic – same as training)
|
||||
splits = get_splits(raw_ds, cfg)
|
||||
transform_builder = build_transforms(raw_ds, cfg, augment=cfg.get("augment"))
|
||||
|
||||
n_folds = len(splits)
|
||||
fold_results = []
|
||||
|
||||
for fold_idx in range(n_folds):
|
||||
train_idx, val_idx, test_idx = splits[fold_idx]
|
||||
|
||||
checkpoint_path = models_dir / f"{run_name}_fold{fold_idx}_best.pt"
|
||||
if not checkpoint_path.exists():
|
||||
print(f" Fold {fold_idx}: checkpoint missing, skipping")
|
||||
continue
|
||||
|
||||
# Rebuild model and load checkpoint
|
||||
model = get_model(cfg)
|
||||
model.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location=device, weights_only=True)
|
||||
)
|
||||
model.to(device).eval()
|
||||
|
||||
# Build test dataset for this fold
|
||||
if cfg.get("normalization") == "real_norm":
|
||||
from src.preprocessing.pipeline import compute_real_stats
|
||||
norm_mean, norm_std = compute_real_stats(raw_ds, train_idx)
|
||||
else:
|
||||
norm_mean = norm_std = None
|
||||
|
||||
test_dataset = transform_builder(
|
||||
test_idx, train=False,
|
||||
normalize_mean=norm_mean, normalize_std=norm_std,
|
||||
)
|
||||
|
||||
records = predict_rows(
|
||||
model, test_dataset, raw_ds, test_idx,
|
||||
cfg["batch_size"], device, num_workers=4,
|
||||
)
|
||||
|
||||
# Compute metrics
|
||||
test_metrics = calc_metrics(records)
|
||||
src_metrics = source_metrics(records)
|
||||
pairwise = pair_metrics(records)
|
||||
|
||||
fold_result = {
|
||||
"fold": fold_idx,
|
||||
"train_size": len(train_idx),
|
||||
"val_size": len(val_idx),
|
||||
"test_size": len(test_idx),
|
||||
"test_metrics": test_metrics,
|
||||
"source_metrics": src_metrics,
|
||||
"pair_metrics": pairwise,
|
||||
}
|
||||
fold_results.append(fold_result)
|
||||
|
||||
print(f" Fold {fold_idx}: auc={test_metrics.get('auc_roc', '?'):.4f} "
|
||||
f"acc={test_metrics.get('accuracy', '?'):.4f} "
|
||||
f"f1={test_metrics.get('f1', '?'):.4f}")
|
||||
for source, sm in sorted(src_metrics.items()):
|
||||
pa = sm.get("pairwise_auc")
|
||||
dr = sm.get("detection_rate")
|
||||
label = (f"pairwise_auc={pa:.4f}" if pa is not None
|
||||
else f"detection_rate={dr:.4f}" if dr is not None else "")
|
||||
print(f" {source}: {label}")
|
||||
|
||||
if not fold_results:
|
||||
print(f" No folds evaluated for {run_name}")
|
||||
continue
|
||||
|
||||
# Aggregate across folds
|
||||
test_metrics_list = [f["test_metrics"] for f in fold_results]
|
||||
aggregated = aggregate_fold_metrics(test_metrics_list)
|
||||
|
||||
# Aggregate per-source metrics
|
||||
all_sources = sorted({s for f in fold_results for s in f["source_metrics"]})
|
||||
aggregated_per_source = {}
|
||||
for source in all_sources:
|
||||
source_fold_metrics = []
|
||||
for f in fold_results:
|
||||
sm = f["source_metrics"].get(source)
|
||||
if sm:
|
||||
source_fold_metrics.append({
|
||||
k: v for k, v in sm.items()
|
||||
if isinstance(v, (int, float)) and k != "fold"
|
||||
})
|
||||
if source_fold_metrics:
|
||||
aggregated_per_source[source] = aggregate_fold_metrics(source_fold_metrics)
|
||||
|
||||
# Aggregate pairwise metrics
|
||||
all_pairs = sorted({p for f in fold_results for p in f["pair_metrics"]})
|
||||
aggregated_pairwise = {}
|
||||
for pair in all_pairs:
|
||||
pair_fold_metrics = []
|
||||
for f in fold_results:
|
||||
pm = f["pair_metrics"].get(pair)
|
||||
if pm:
|
||||
pair_fold_metrics.append({
|
||||
k: v for k, v in pm.items()
|
||||
if isinstance(v, (int, float)) and k not in ("fold", "n")
|
||||
})
|
||||
if pair_fold_metrics:
|
||||
aggregated_pairwise[pair] = aggregate_fold_metrics(pair_fold_metrics)
|
||||
|
||||
# Load existing log to preserve training history
|
||||
log_path = logs_dir / f"{run_name}.json"
|
||||
if log_path.exists():
|
||||
with open(log_path) as f:
|
||||
existing = json.load(f)
|
||||
# Keep the training history from the original log
|
||||
for fr_new in fold_results:
|
||||
for fr_old in existing.get("fold_results", []):
|
||||
if fr_old["fold"] == fr_new["fold"]:
|
||||
fr_new["history"] = fr_old.get("history")
|
||||
break
|
||||
else:
|
||||
existing = {}
|
||||
|
||||
results = {
|
||||
"run_name": run_name,
|
||||
"n_folds": n_folds,
|
||||
"fold_results": fold_results,
|
||||
"aggregated_metrics": aggregated,
|
||||
"aggregated_per_source": aggregated_per_source,
|
||||
"aggregated_pairwise": aggregated_pairwise,
|
||||
"config": existing.get("config", cfg),
|
||||
}
|
||||
|
||||
with open(log_path, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f" Saved: {log_path}")
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user