Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+104
View File
@@ -0,0 +1,104 @@
"""
Grad-CAM visualization for trained classifiers.
Generates heatmaps showing which image regions the model focused on,
overlaid on the original image. Targets the last Conv2d layer automatically.
Usage (from notebook or script):
from tools.gradcam import save_overlays
from src.evaluation.evaluate import predict_rows
records = predict_rows(model, test_dataset, raw_ds, test_idx, batch_size=32, device="cpu")
save_overlays(model, records, cfg, output_dir=Path("outputs/gradcam"), device="cpu")
Output: one PNG per selected sample, named 01_<stem>.png, 02_<stem>.png, ...
top_k//2 false positives (real predicted fake) and top_k//2 false negatives.
"""
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from src.preprocessing import get_transforms
# Pick the last convolution layer as the Grad-CAM target
def find_conv(model):
for module in reversed(list(model.modules())):
if isinstance(module, nn.Conv2d):
return module
raise ValueError("Could not find a Conv2d layer for Grad-CAM.")
# Build a normalized Grad-CAM heatmap for one image tensor (shape: 1xCxHxW)
def gradcam_map(model, image_tensor, device):
activations = []
gradients = []
target = find_conv(model)
def on_fwd(_, __, output):
activations.append(output.detach())
def on_bwd(_, __, grad_output):
gradients.append(grad_output[0].detach())
h_fwd = target.register_forward_hook(on_fwd)
h_bwd = target.register_full_backward_hook(on_bwd)
model.zero_grad(set_to_none=True)
logits = model(image_tensor.to(device)).squeeze()
logits.backward()
h_fwd.remove()
h_bwd.remove()
grads = gradients[0][0]
acts = activations[0][0]
weights = grads.mean(dim=(1, 2), keepdim=True)
cam = torch.relu((weights * acts).sum(dim=0))
cam = cam - cam.min()
cam = cam / cam.max().clamp(min=1e-8)
return cam.cpu().numpy()
# Save side-by-side input and Grad-CAM overlays for top-confidence errors
def save_overlays(model, records, cfg, output_dir, device, *, top_k=8):
output_dir.mkdir(parents=True, exist_ok=True)
transform = get_transforms(train=False, image_size=cfg["image_size"])
false_pos = sorted(
(r for r in records if r["label"] == 0 and r["pred"] == 1),
key=lambda r: r["prob_fake"],
reverse=True,
)[: top_k // 2]
false_neg = sorted(
(r for r in records if r["label"] == 1 and r["pred"] == 0),
key=lambda r: r["prob_fake"],
)[: top_k // 2]
selected = [*false_pos, *false_neg]
total = len(selected)
for idx, record in enumerate(selected, start=1):
print(f"Grad-CAM: rendering {idx}/{total} for {Path(record['path']).name}")
image = Image.open(record["path"]).convert("RGB")
image_tensor = transform(image).unsqueeze(0)
heatmap = gradcam_map(model, image_tensor, device)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(np.asarray(image))
axes[0].set_title("Input")
axes[0].axis("off")
axes[1].imshow(np.asarray(image))
axes[1].imshow(heatmap, cmap="jet", alpha=0.4, extent=(0, image.width, image.height, 0))
axes[1].set_title(
f"{Path(record['path']).name}\ntrue={record['label']} pred={record['pred']} p={record['prob_fake']:.3f}"
)
axes[1].axis("off")
fig.tight_layout()
fig.savefig(output_dir / f"{idx:02d}_{Path(record['path']).stem}.png", dpi=160)
plt.close(fig)