Files
DRL_PROJ/classifier/tools/gradcam.py
T
Johnny Fernandes bb3dfb92d5 Clean state
2026-04-30 01:25:39 +01:00

105 lines
3.5 KiB
Python

"""
Grad-CAM visualization for trained classifiers.
Generates heatmaps showing which image regions the model focused on,
overlaid on the original image. Targets the last Conv2d layer automatically.
Usage (from notebook or script):
from tools.gradcam import save_overlays
from src.evaluation.evaluate import predict_rows
records = predict_rows(model, test_dataset, raw_ds, test_idx, batch_size=32, device="cpu")
save_overlays(model, records, cfg, output_dir=Path("outputs/gradcam"), device="cpu")
Output: one PNG per selected sample, named 01_<stem>.png, 02_<stem>.png, ...
top_k//2 false positives (real predicted fake) and top_k//2 false negatives.
"""
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from src.preprocessing import get_transforms
# Pick the last convolution layer as the Grad-CAM target
def find_conv(model):
for module in reversed(list(model.modules())):
if isinstance(module, nn.Conv2d):
return module
raise ValueError("Could not find a Conv2d layer for Grad-CAM.")
# Build a normalized Grad-CAM heatmap for one image tensor (shape: 1xCxHxW)
def gradcam_map(model, image_tensor, device):
activations = []
gradients = []
target = find_conv(model)
def on_fwd(_, __, output):
activations.append(output.detach())
def on_bwd(_, __, grad_output):
gradients.append(grad_output[0].detach())
h_fwd = target.register_forward_hook(on_fwd)
h_bwd = target.register_full_backward_hook(on_bwd)
model.zero_grad(set_to_none=True)
logits = model(image_tensor.to(device)).squeeze()
logits.backward()
h_fwd.remove()
h_bwd.remove()
grads = gradients[0][0]
acts = activations[0][0]
weights = grads.mean(dim=(1, 2), keepdim=True)
cam = torch.relu((weights * acts).sum(dim=0))
cam = cam - cam.min()
cam = cam / cam.max().clamp(min=1e-8)
return cam.cpu().numpy()
# Save side-by-side input and Grad-CAM overlays for top-confidence errors
def save_overlays(model, records, cfg, output_dir, device, *, top_k=8):
output_dir.mkdir(parents=True, exist_ok=True)
transform = get_transforms(train=False, image_size=cfg["image_size"])
false_pos = sorted(
(r for r in records if r["label"] == 0 and r["pred"] == 1),
key=lambda r: r["prob_fake"],
reverse=True,
)[: top_k // 2]
false_neg = sorted(
(r for r in records if r["label"] == 1 and r["pred"] == 0),
key=lambda r: r["prob_fake"],
)[: top_k // 2]
selected = [*false_pos, *false_neg]
total = len(selected)
for idx, record in enumerate(selected, start=1):
print(f"Grad-CAM: rendering {idx}/{total} for {Path(record['path']).name}")
image = Image.open(record["path"]).convert("RGB")
image_tensor = transform(image).unsqueeze(0)
heatmap = gradcam_map(model, image_tensor, device)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(np.asarray(image))
axes[0].set_title("Input")
axes[0].axis("off")
axes[1].imshow(np.asarray(image))
axes[1].imshow(heatmap, cmap="jet", alpha=0.4, extent=(0, image.width, image.height, 0))
axes[1].set_title(
f"{Path(record['path']).name}\ntrue={record['label']} pred={record['pred']} p={record['prob_fake']:.3f}"
)
axes[1].axis("off")
fig.tight_layout()
fig.savefig(output_dir / f"{idx:02d}_{Path(record['path']).stem}.png", dpi=160)
plt.close(fig)