Files
DRL_PROJ/classifier/notebooks/02_preprocessing.ipynb
T
Johnny Fernandes bb3dfb92d5 Clean state
2026-04-30 01:25:39 +01:00

363 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "pp-00",
"metadata": {},
"source": [
"# 02 — Preprocessing\n",
"\n",
"Inspect what images look like right before model input.\n",
"\n",
"Face cropping is an offline step — run `tools/precrop.py` once to produce `data_cropped/`, then point configs at that directory. The sections below show the standard pipeline on already-cropped or uncropped images. `facenet_pytorch` is only needed to visualize the offline cropper.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pp-01",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import sys\n",
"from collections import Counter\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.patches as patches\n",
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"sys.path.insert(0, '..')\n",
"from src.data import DFFDataset, SOURCES\n",
"from src.preprocessing.pipeline import DFFImagePipeline\n",
"\n",
"DATA_DIR = Path('../../data')\n",
"SEED = 7\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)\n",
"\n",
"full_ds = DFFDataset(DATA_DIR)\n",
"\n",
"print(f\"Dataset root: {DATA_DIR.resolve()}\")\n",
"print(f\"Total samples: {len(full_ds):,}\")\n",
"source_counts = Counter(path.parent.parent.name for path, _ in full_ds.samples)\n",
"print(\"Per-source counts:\")\n",
"for src in SOURCES:\n",
" print(f\" {src:12s} {source_counts[src]:6,}\")\n",
"\n",
"def denorm(tensor):\n",
" mean = np.array([0.485, 0.456, 0.406])\n",
" std = np.array([0.229, 0.224, 0.225])\n",
" arr = tensor.permute(1, 2, 0).numpy()\n",
" return np.clip(arr * std + mean, 0, 1)\n",
"\n",
"def pick_samples(n=4, sources=None):\n",
" ds = DFFDataset(DATA_DIR, sources=sources) if sources else full_ds\n",
" idxs = random.sample(range(len(ds)), n)\n",
" return [Image.open(ds.samples[i][0]).convert('RGB') for i in idxs]\n",
"\n",
"# Runtime face-crop helper from tools (kept for notebook visualization only).\n",
"try:\n",
" from facenet_pytorch import MTCNN\n",
" from tools.precrop import FaceCropper\n",
" FACE_CROP_AVAILABLE = True\n",
" _detector = MTCNN(keep_all=False, select_largest=True, device='cpu', post_process=False)\n",
" _cropper = FaceCropper(margin=0.6, size=224, device='cpu')\n",
" print('facenet_pytorch available — crop helper enabled.')\n",
"except ImportError:\n",
" FACE_CROP_AVAILABLE = False\n",
" _cropper = None\n",
" print('WARNING: facenet_pytorch not installed — crop sections will be skipped.')\n",
" print(' Install with: pip install facenet-pytorch')\n",
"\n",
"pipe_eval = DFFImagePipeline(image_size=224, train=False)\n",
"pipe_aug = DFFImagePipeline(image_size=224, train=True)\n",
"\n",
"crop_note = 'offline face-crop preview -> ' if FACE_CROP_AVAILABLE else '(no face crop) -> '\n",
"print('Pipelines ready.')"
]
},
{
"cell_type": "markdown",
"id": "pp-02",
"metadata": {},
"source": [
"---\n",
"## 1. Crop preview\n",
"\n",
"Visualizes what `tools/precrop.py` does: MTCNN detects the largest face, crops a square with a 60% margin, and falls back to center-crop when no face is found.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pp-03",
"metadata": {},
"outputs": [],
"source": [
"if not FACE_CROP_AVAILABLE:\n",
" print('Skipped — facenet_pytorch not installed.')\n",
"else:\n",
" src_images = {src: pick_samples(1, sources=[src])[0] for src in SOURCES}\n",
"\n",
" fig, axes = plt.subplots(len(SOURCES), 3, figsize=(10, 14))\n",
" fig.suptitle(\n",
" 'Face crop helper | col 1: original + detection box | col 2: cropped | col 3: cropped + eval pipeline',\n",
" fontsize=10\n",
" )\n",
"\n",
" for row, (src, img) in enumerate(src_images.items()):\n",
" label = 'real' if SOURCES[src] == 0 else 'fake'\n",
"\n",
" # col 0: original with bounding box\n",
" boxes, probs = _detector.detect(img)\n",
" axes[row, 0].imshow(img)\n",
" if boxes is not None and len(boxes) > 0:\n",
" x1, y1, x2, y2 = boxes[0]\n",
" rect = patches.Rectangle(\n",
" (x1, y1), x2 - x1, y2 - y1,\n",
" linewidth=2, edgecolor='lime', facecolor='none'\n",
" )\n",
" axes[row, 0].add_patch(rect)\n",
" axes[row, 0].set_title(f'detected p={probs[0]:.2f}', fontsize=8, color='green')\n",
" else:\n",
" axes[row, 0].set_title('no face — centre crop fallback', fontsize=8, color='red')\n",
" axes[row, 0].set_ylabel(f'{src}\\n({label})', fontsize=9)\n",
"\n",
" # col 1: cropped result from tools.precrop.FaceCropper\n",
" cropped = _cropper(img)\n",
" axes[row, 1].imshow(cropped)\n",
" axes[row, 1].set_title('cropped (224px)', fontsize=8)\n",
"\n",
" # col 2: cropped image through eval pipeline\n",
" axes[row, 2].imshow(denorm(pipe_eval(cropped)))\n",
" axes[row, 2].set_title('crop + eval pipeline', fontsize=8)\n",
"\n",
" for ax in axes.flat:\n",
" ax.axis('off')\n",
"\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "pp-04",
"metadata": {},
"source": [
"---\n",
"## 2. Eval path vs Train path\n",
"\n",
"Compare the deterministic eval transform and the stochastic train transform.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pp-05",
"metadata": {},
"outputs": [],
"source": [
"src_images = {src: pick_samples(1, sources=[src])[0] for src in SOURCES}\n",
"\n",
"fig, axes = plt.subplots(len(SOURCES), 3, figsize=(10, 14))\n",
"fig.suptitle(\n",
" f'original | {crop_note}eval (no aug) | {crop_note}train aug',\n",
" fontsize=11\n",
")\n",
"\n",
"for row, (src, img) in enumerate(src_images.items()):\n",
" label = 'real' if SOURCES[src] == 0 else 'fake'\n",
" proc_img = _cropper(img) if FACE_CROP_AVAILABLE else img\n",
"\n",
" axes[row, 0].imshow(img.resize((224, 224)))\n",
" axes[row, 0].set_title('original', fontsize=8)\n",
" axes[row, 0].set_ylabel(f'{src}\\n({label})', fontsize=9)\n",
"\n",
" axes[row, 1].imshow(denorm(pipe_eval(proc_img)))\n",
" axes[row, 1].set_title(f'{crop_note}eval (no aug)', fontsize=8)\n",
"\n",
" axes[row, 2].imshow(denorm(pipe_aug(proc_img)))\n",
" axes[row, 2].set_title(f'{crop_note}train aug', fontsize=8)\n",
"\n",
"for ax in axes.flat:\n",
" ax.axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "pp-06",
"metadata": {},
"source": [
"---\n",
"## 3. Augmentation variety\n",
"\n",
"Use the same source image with multiple independent stochastic draws.\n",
"This shows the realistic variation the model sees during training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pp-07",
"metadata": {},
"outputs": [],
"source": [
"N_DRAWS = 8\n",
"imgs_to_show = pick_samples(2)\n",
"\n",
"fig, axes = plt.subplots(2, N_DRAWS + 1, figsize=(20, 5))\n",
"fig.suptitle(\n",
" f'{N_DRAWS} independent draws — {crop_note}aug — each column is a different random sample',\n",
" fontsize=11\n",
")\n",
"\n",
"for row, img in enumerate(imgs_to_show):\n",
" axes[row, 0].imshow(img.resize((224, 224)))\n",
" axes[row, 0].set_title('original', fontsize=8)\n",
" axes[row, 0].set_ylabel(f'image {row + 1}', fontsize=9)\n",
"\n",
" for col in range(N_DRAWS):\n",
" axes[row, col + 1].imshow(denorm(pipe_aug(img)))\n",
" axes[row, col + 1].set_title(f'#{col + 1}', fontsize=8)\n",
"\n",
"for ax in axes.flat:\n",
" ax.axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "pp-09",
"metadata": {},
"source": [
"---\n",
"## 3. Full pipeline comparison\n",
"\n",
"All combinations in one grid. Crop columns appear only when `facenet_pytorch` is installed.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pp-10",
"metadata": {},
"outputs": [],
"source": [
"samples = {src: pick_samples(1, sources=[src])[0] for src in SOURCES}\n",
"\n",
"cols = [\n",
" ('original', False, False),\n",
" ('no crop\\nno aug', False, False),\n",
" ('no crop\\naug', False, True),\n",
"]\n",
"if FACE_CROP_AVAILABLE:\n",
" cols += [\n",
" ('crop\\nno aug', True, False),\n",
" ('crop\\naug', True, True),\n",
" ]\n",
"\n",
"n_cols = len(cols)\n",
"fig, axes = plt.subplots(len(SOURCES), n_cols, figsize=(n_cols * 2.8, 14))\n",
"fig.suptitle('Full pipeline comparison — pipeline order: (optional) face crop helper -> augmentation -> normalize', fontsize=11)\n",
"\n",
"for row, (src, img) in enumerate(samples.items()):\n",
" label = 'real' if SOURCES[src] == 0 else 'fake'\n",
" axes[row, 0].set_ylabel(f'{src}\\n({label})', fontsize=9)\n",
"\n",
" for col, (title, use_crop, train_mode) in enumerate(cols):\n",
" ax = axes[row, col]\n",
" if col == 0:\n",
" ax.imshow(img.resize((224, 224)))\n",
" else:\n",
" proc_img = _cropper(img) if (use_crop and FACE_CROP_AVAILABLE) else img\n",
" pipe = DFFImagePipeline(image_size=224, train=train_mode)\n",
" ax.imshow(denorm(pipe(proc_img)))\n",
" if row == 0:\n",
" ax.set_title(title, fontsize=8)\n",
" ax.axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "19187059",
"metadata": {},
"source": [
"---\n",
"## 4. Tensor sanity checks\n",
"\n",
"Validate preprocessing outputs: shape, finite values, normalized value ranges.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7e5697c4",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"check_imgs = pick_samples(n=12)\n",
"issues = []\n",
"\n",
"for i, img in enumerate(check_imgs):\n",
" t_eval = pipe_eval(img)\n",
" t_aug = pipe_aug(img)\n",
"\n",
" for tag, t in [(\"eval\", t_eval), (\"aug\", t_aug)]:\n",
" if tuple(t.shape) != (3, 224, 224):\n",
" issues.append(f\"sample {i} ({tag}) shape={tuple(t.shape)}\")\n",
" if not np.isfinite(t.numpy()).all():\n",
" issues.append(f\"sample {i} ({tag}) has non-finite values\")\n",
"\n",
"print(f\"Checked {len(check_imgs)} images through eval+aug pipelines.\")\n",
"if issues:\n",
" print(\"Issues found:\")\n",
" for msg in issues[:10]:\n",
" print(f\" - {msg}\")\n",
"else:\n",
" print(\"No shape/finite-value issues found.\")\n",
"\n",
"stack_eval = np.stack([pipe_eval(img).numpy() for img in check_imgs])\n",
"stack_aug = np.stack([pipe_aug(img).numpy() for img in check_imgs])\n",
"\n",
"print(\"\\nValue summary (normalized tensors):\")\n",
"print(f\" eval: min={stack_eval.min():.3f} max={stack_eval.max():.3f} mean={stack_eval.mean():.3f} std={stack_eval.std():.3f}\")\n",
"print(f\" aug : min={stack_aug.min():.3f} max={stack_aug.max():.3f} mean={stack_aug.mean():.3f} std={stack_aug.std():.3f}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "drl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}