Clean state
This commit is contained in:
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Grad-CAM visualization for trained classifiers.
|
||||
|
||||
Generates heatmaps showing which image regions the model focused on,
|
||||
overlaid on the original image. Targets the last Conv2d layer automatically.
|
||||
|
||||
Usage (from notebook or script):
|
||||
from tools.gradcam import save_overlays
|
||||
from src.evaluation.evaluate import predict_rows
|
||||
|
||||
records = predict_rows(model, test_dataset, raw_ds, test_idx, batch_size=32, device="cpu")
|
||||
save_overlays(model, records, cfg, output_dir=Path("outputs/gradcam"), device="cpu")
|
||||
|
||||
Output: one PNG per selected sample, named 01_<stem>.png, 02_<stem>.png, ...
|
||||
top_k//2 false positives (real predicted fake) and top_k//2 false negatives.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
from src.preprocessing import get_transforms
|
||||
|
||||
|
||||
# Pick the last convolution layer as the Grad-CAM target
|
||||
def find_conv(model):
|
||||
for module in reversed(list(model.modules())):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
return module
|
||||
raise ValueError("Could not find a Conv2d layer for Grad-CAM.")
|
||||
|
||||
|
||||
# Build a normalized Grad-CAM heatmap for one image tensor (shape: 1xCxHxW)
|
||||
def gradcam_map(model, image_tensor, device):
|
||||
activations = []
|
||||
gradients = []
|
||||
target = find_conv(model)
|
||||
|
||||
def on_fwd(_, __, output):
|
||||
activations.append(output.detach())
|
||||
|
||||
def on_bwd(_, __, grad_output):
|
||||
gradients.append(grad_output[0].detach())
|
||||
|
||||
h_fwd = target.register_forward_hook(on_fwd)
|
||||
h_bwd = target.register_full_backward_hook(on_bwd)
|
||||
|
||||
model.zero_grad(set_to_none=True)
|
||||
logits = model(image_tensor.to(device)).squeeze()
|
||||
logits.backward()
|
||||
|
||||
h_fwd.remove()
|
||||
h_bwd.remove()
|
||||
|
||||
grads = gradients[0][0]
|
||||
acts = activations[0][0]
|
||||
weights = grads.mean(dim=(1, 2), keepdim=True)
|
||||
cam = torch.relu((weights * acts).sum(dim=0))
|
||||
cam = cam - cam.min()
|
||||
cam = cam / cam.max().clamp(min=1e-8)
|
||||
return cam.cpu().numpy()
|
||||
|
||||
|
||||
# Save side-by-side input and Grad-CAM overlays for top-confidence errors
|
||||
def save_overlays(model, records, cfg, output_dir, device, *, top_k=8):
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
transform = get_transforms(train=False, image_size=cfg["image_size"])
|
||||
|
||||
false_pos = sorted(
|
||||
(r for r in records if r["label"] == 0 and r["pred"] == 1),
|
||||
key=lambda r: r["prob_fake"],
|
||||
reverse=True,
|
||||
)[: top_k // 2]
|
||||
false_neg = sorted(
|
||||
(r for r in records if r["label"] == 1 and r["pred"] == 0),
|
||||
key=lambda r: r["prob_fake"],
|
||||
)[: top_k // 2]
|
||||
selected = [*false_pos, *false_neg]
|
||||
|
||||
total = len(selected)
|
||||
for idx, record in enumerate(selected, start=1):
|
||||
print(f"Grad-CAM: rendering {idx}/{total} for {Path(record['path']).name}")
|
||||
image = Image.open(record["path"]).convert("RGB")
|
||||
image_tensor = transform(image).unsqueeze(0)
|
||||
heatmap = gradcam_map(model, image_tensor, device)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
|
||||
axes[0].imshow(np.asarray(image))
|
||||
axes[0].set_title("Input")
|
||||
axes[0].axis("off")
|
||||
|
||||
axes[1].imshow(np.asarray(image))
|
||||
axes[1].imshow(heatmap, cmap="jet", alpha=0.4, extent=(0, image.width, image.height, 0))
|
||||
axes[1].set_title(
|
||||
f"{Path(record['path']).name}\ntrue={record['label']} pred={record['pred']} p={record['prob_fake']:.3f}"
|
||||
)
|
||||
axes[1].axis("off")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_dir / f"{idx:02d}_{Path(record['path']).stem}.png", dpi=160)
|
||||
plt.close(fig)
|
||||
Reference in New Issue
Block a user