""" Grad-CAM visualization for trained classifiers. Generates heatmaps showing which image regions the model focused on, overlaid on the original image. Targets the last Conv2d layer automatically. Usage (from notebook or script): from tools.gradcam import save_overlays from src.evaluation.evaluate import predict_rows records = predict_rows(model, test_dataset, raw_ds, test_idx, batch_size=32, device="cpu") save_overlays(model, records, cfg, output_dir=Path("outputs/gradcam"), device="cpu") Output: one PNG per selected sample, named 01_.png, 02_.png, ... top_k//2 false positives (real predicted fake) and top_k//2 false negatives. """ from pathlib import Path import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from PIL import Image from src.preprocessing import get_transforms # Pick the last convolution layer as the Grad-CAM target def find_conv(model): for module in reversed(list(model.modules())): if isinstance(module, nn.Conv2d): return module raise ValueError("Could not find a Conv2d layer for Grad-CAM.") # Build a normalized Grad-CAM heatmap for one image tensor (shape: 1xCxHxW) def gradcam_map(model, image_tensor, device): activations = [] gradients = [] target = find_conv(model) def on_fwd(_, __, output): activations.append(output.detach()) def on_bwd(_, __, grad_output): gradients.append(grad_output[0].detach()) h_fwd = target.register_forward_hook(on_fwd) h_bwd = target.register_full_backward_hook(on_bwd) model.zero_grad(set_to_none=True) logits = model(image_tensor.to(device)).squeeze() logits.backward() h_fwd.remove() h_bwd.remove() grads = gradients[0][0] acts = activations[0][0] weights = grads.mean(dim=(1, 2), keepdim=True) cam = torch.relu((weights * acts).sum(dim=0)) cam = cam - cam.min() cam = cam / cam.max().clamp(min=1e-8) return cam.cpu().numpy() # Save side-by-side input and Grad-CAM overlays for top-confidence errors def save_overlays(model, records, cfg, output_dir, device, *, top_k=8): output_dir.mkdir(parents=True, exist_ok=True) transform = get_transforms(train=False, image_size=cfg["image_size"]) false_pos = sorted( (r for r in records if r["label"] == 0 and r["pred"] == 1), key=lambda r: r["prob_fake"], reverse=True, )[: top_k // 2] false_neg = sorted( (r for r in records if r["label"] == 1 and r["pred"] == 0), key=lambda r: r["prob_fake"], )[: top_k // 2] selected = [*false_pos, *false_neg] total = len(selected) for idx, record in enumerate(selected, start=1): print(f"Grad-CAM: rendering {idx}/{total} for {Path(record['path']).name}") image = Image.open(record["path"]).convert("RGB") image_tensor = transform(image).unsqueeze(0) heatmap = gradcam_map(model, image_tensor, device) fig, axes = plt.subplots(1, 2, figsize=(8, 4)) axes[0].imshow(np.asarray(image)) axes[0].set_title("Input") axes[0].axis("off") axes[1].imshow(np.asarray(image)) axes[1].imshow(heatmap, cmap="jet", alpha=0.4, extent=(0, image.width, image.height, 0)) axes[1].set_title( f"{Path(record['path']).name}\ntrue={record['label']} pred={record['pred']} p={record['prob_fake']:.3f}" ) axes[1].axis("off") fig.tight_layout() fig.savefig(output_dir / f"{idx:02d}_{Path(record['path']).stem}.png", dpi=160) plt.close(fig)