105 lines
3.5 KiB
Python
105 lines
3.5 KiB
Python
"""
|
|
Grad-CAM visualization for trained classifiers.
|
|
|
|
Generates heatmaps showing which image regions the model focused on,
|
|
overlaid on the original image. Targets the last Conv2d layer automatically.
|
|
|
|
Usage (from notebook or script):
|
|
from tools.gradcam import save_overlays
|
|
from src.evaluation.evaluate import predict_rows
|
|
|
|
records = predict_rows(model, test_dataset, raw_ds, test_idx, batch_size=32, device="cpu")
|
|
save_overlays(model, records, cfg, output_dir=Path("outputs/gradcam"), device="cpu")
|
|
|
|
Output: one PNG per selected sample, named 01_<stem>.png, 02_<stem>.png, ...
|
|
top_k//2 false positives (real predicted fake) and top_k//2 false negatives.
|
|
"""
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from PIL import Image
|
|
|
|
from src.preprocessing import get_transforms
|
|
|
|
|
|
# Pick the last convolution layer as the Grad-CAM target
|
|
def find_conv(model):
|
|
for module in reversed(list(model.modules())):
|
|
if isinstance(module, nn.Conv2d):
|
|
return module
|
|
raise ValueError("Could not find a Conv2d layer for Grad-CAM.")
|
|
|
|
|
|
# Build a normalized Grad-CAM heatmap for one image tensor (shape: 1xCxHxW)
|
|
def gradcam_map(model, image_tensor, device):
|
|
activations = []
|
|
gradients = []
|
|
target = find_conv(model)
|
|
|
|
def on_fwd(_, __, output):
|
|
activations.append(output.detach())
|
|
|
|
def on_bwd(_, __, grad_output):
|
|
gradients.append(grad_output[0].detach())
|
|
|
|
h_fwd = target.register_forward_hook(on_fwd)
|
|
h_bwd = target.register_full_backward_hook(on_bwd)
|
|
|
|
model.zero_grad(set_to_none=True)
|
|
logits = model(image_tensor.to(device)).squeeze()
|
|
logits.backward()
|
|
|
|
h_fwd.remove()
|
|
h_bwd.remove()
|
|
|
|
grads = gradients[0][0]
|
|
acts = activations[0][0]
|
|
weights = grads.mean(dim=(1, 2), keepdim=True)
|
|
cam = torch.relu((weights * acts).sum(dim=0))
|
|
cam = cam - cam.min()
|
|
cam = cam / cam.max().clamp(min=1e-8)
|
|
return cam.cpu().numpy()
|
|
|
|
|
|
# Save side-by-side input and Grad-CAM overlays for top-confidence errors
|
|
def save_overlays(model, records, cfg, output_dir, device, *, top_k=8):
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
transform = get_transforms(train=False, image_size=cfg["image_size"])
|
|
|
|
false_pos = sorted(
|
|
(r for r in records if r["label"] == 0 and r["pred"] == 1),
|
|
key=lambda r: r["prob_fake"],
|
|
reverse=True,
|
|
)[: top_k // 2]
|
|
false_neg = sorted(
|
|
(r for r in records if r["label"] == 1 and r["pred"] == 0),
|
|
key=lambda r: r["prob_fake"],
|
|
)[: top_k // 2]
|
|
selected = [*false_pos, *false_neg]
|
|
|
|
total = len(selected)
|
|
for idx, record in enumerate(selected, start=1):
|
|
print(f"Grad-CAM: rendering {idx}/{total} for {Path(record['path']).name}")
|
|
image = Image.open(record["path"]).convert("RGB")
|
|
image_tensor = transform(image).unsqueeze(0)
|
|
heatmap = gradcam_map(model, image_tensor, device)
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
|
|
axes[0].imshow(np.asarray(image))
|
|
axes[0].set_title("Input")
|
|
axes[0].axis("off")
|
|
|
|
axes[1].imshow(np.asarray(image))
|
|
axes[1].imshow(heatmap, cmap="jet", alpha=0.4, extent=(0, image.width, image.height, 0))
|
|
axes[1].set_title(
|
|
f"{Path(record['path']).name}\ntrue={record['label']} pred={record['pred']} p={record['prob_fake']:.3f}"
|
|
)
|
|
axes[1].axis("off")
|
|
|
|
fig.tight_layout()
|
|
fig.savefig(output_dir / f"{idx:02d}_{Path(record['path']).stem}.png", dpi=160)
|
|
plt.close(fig)
|