Files
DRL_PROJ/classifier/notebooks/01_eda.ipynb
T
Johnny Fernandes bb3dfb92d5 Clean state
2026-04-30 01:25:39 +01:00

352 lines
12 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "eda-00",
"metadata": {},
"source": [
"# 01 — EDA\n",
"\n",
"Explore DeepFakeFace (DFF) data quality before training: composition, source distribution, image properties, and split safety.\n",
"\n",
"**Sections:**\n",
"1. Dataset composition and label balance\n",
"2. Visual sanity-check samples\n",
"3. Image dimension profile\n",
"4. Per-source color statistics\n",
"5. CV split and leakage sanity check\n",
"6. Observations"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eda-01",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(0, '..')\n",
"\n",
"import random\n",
"from collections import Counter\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"from src.data import DFFDataset, SOURCES, get_splits\n",
"\n",
"DATA_DIR = Path('../../data')\n",
"FIG_DIR = Path('../outputs/figures')\n",
"FIG_DIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"SEED = 42\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)"
]
},
{
"cell_type": "markdown",
"id": "eda-02",
"metadata": {},
"source": [
"## 1. Dataset composition and label balance"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eda-03",
"metadata": {},
"outputs": [],
"source": [
"full_ds = DFFDataset(DATA_DIR)\n",
"label_counts = full_ds.label_counts()\n",
"\n",
"print(f\"Total images : {len(full_ds):,}\")\n",
"print(f\" Real (label=0) : {label_counts[0]:,}\")\n",
"print(f\" Fake (label=1) : {label_counts[1]:,}\")\n",
"print(f\" Fake:real ratio : {label_counts[1] / label_counts[0]:.2f}x\\n\")\n",
"\n",
"source_info = []\n",
"for source, label in SOURCES.items():\n",
" ds = DFFDataset(DATA_DIR, sources=[source])\n",
" source_info.append((source, len(ds), label))\n",
" tag = 'real' if label == 0 else 'fake'\n",
" print(f\" {source:12s} n={len(ds):6,} label={label} ({tag})\")\n",
"\n",
"# Identity-level sanity check: each basename should appear in every source.\n",
"basename_counts = Counter(path.name for path, _ in full_ds.samples)\n",
"presence_hist = Counter(basename_counts.values())\n",
"\n",
"print(\"\\nIdentity (basename) presence across sources:\")\n",
"for n_sources, count in sorted(presence_hist.items()):\n",
" print(f\" present in {n_sources} source(s): {count:,} identities\")\n",
"\n",
"incomplete = sum(v for k, v in presence_hist.items() if k < len(SOURCES))\n",
"print(f\" complete in all {len(SOURCES)} sources: {presence_hist.get(len(SOURCES), 0):,}\")\n",
"print(f\" incomplete identities : {incomplete:,}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eda-04",
"metadata": {},
"outputs": [],
"source": [
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
"\n",
"# Overall class balance\n",
"class_names = ['Real (wiki)', 'Fake (all 3)']\n",
"class_counts = [label_counts[0], label_counts[1]]\n",
"bars = ax1.bar(class_names, class_counts, color=['#2196F3', '#F44336'], width=0.5)\n",
"ax1.set_title('Overall Class Balance', fontsize=13)\n",
"ax1.set_ylabel('Images')\n",
"ax1.set_ylim(0, max(class_counts) * 1.15)\n",
"for bar, v in zip(bars, class_counts):\n",
" ax1.text(bar.get_x() + bar.get_width() / 2, v + 300,\n",
" f'{v:,}', ha='center', fontsize=11)\n",
"\n",
"# Per-source breakdown\n",
"src_names = [s for s, _, _ in source_info]\n",
"src_counts = [n for _, n, _ in source_info]\n",
"colors = ['#2196F3', '#FF9800', '#9C27B0', '#4CAF50']\n",
"bars2 = ax2.bar(src_names, src_counts, color=colors, width=0.5)\n",
"ax2.set_title('Images per Source', fontsize=13)\n",
"ax2.set_ylabel('Images')\n",
"ax2.set_ylim(0, max(src_counts) * 1.15)\n",
"for bar, v in zip(bars2, src_counts):\n",
" ax2.text(bar.get_x() + bar.get_width() / 2, v + 300,\n",
" f'{v:,}', ha='center', fontsize=11)\n",
"\n",
"fig.tight_layout()\n",
"fig.savefig(FIG_DIR / 'class_balance.png', dpi=120, bbox_inches='tight')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "eda-05",
"metadata": {},
"source": [
"## 2. Visual sanity-check samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eda-06",
"metadata": {},
"outputs": [],
"source": [
"N_COLS = 6\n",
"fig, axes = plt.subplots(len(SOURCES), N_COLS, figsize=(18, 12))\n",
"fig.suptitle('Sample images — 6 per source', fontsize=14)\n",
"\n",
"for row, (source, label) in enumerate(SOURCES.items()):\n",
" ds_src = DFFDataset(DATA_DIR, sources=[source])\n",
" indices = random.sample(range(len(ds_src)), N_COLS)\n",
" for col, idx in enumerate(indices):\n",
" path, _ = ds_src.samples[idx]\n",
" img = Image.open(path).convert('RGB').resize((128, 128))\n",
" axes[row, col].imshow(img)\n",
" axes[row, col].axis('off')\n",
" tag = 'real' if label == 0 else 'fake'\n",
" axes[row, 0].set_ylabel(f'{source}\\n({tag})', fontsize=10)\n",
"\n",
"fig.tight_layout()\n",
"fig.savefig(FIG_DIR / 'sample_images.png', dpi=100, bbox_inches='tight')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "eda-07",
"metadata": {},
"source": [
"## 3. Image dimension profile"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eda-08",
"metadata": {},
"outputs": [],
"source": [
"sample_paths = [p for p, _ in random.sample(full_ds.samples, min(2000, len(full_ds)))]\n",
"sizes = Counter(Image.open(p).size for p in sample_paths)\n",
"\n",
"print('Most common image dimensions (W x H):')\n",
"for (w, h), count in sizes.most_common(10):\n",
" pct = count / len(sample_paths)\n",
" print(f' {w:4d} x {h:4d} — {count:4d} samples ({pct:.1%})')\n",
"\n",
"widths = [w for (w, _) in sizes.elements()]\n",
"heights = [h for (_, h) in sizes.elements()]\n",
"square = sum(1 for w, h in zip(widths, heights) if w == h)\n",
"print(f'\\nWidth range: {min(widths)}{max(widths)} mean={np.mean(widths):.0f}')\n",
"print(f'Height range: {min(heights)}{max(heights)} mean={np.mean(heights):.0f}')\n",
"print(f'Square images: {square}/{len(widths)} ({square / len(widths):.1%})')"
]
},
{
"cell_type": "markdown",
"id": "eda-11",
"metadata": {},
"source": [
"## 4. Per-source color statistics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eda-12",
"metadata": {},
"outputs": [],
"source": [
"print('Sampling per-source colour statistics (sampling 300 images per source)...')\n",
"N_SAMPLES = 300\n",
"CH_NAMES = ['R', 'G', 'B']\n",
"\n",
"source_means, source_stds = {}, {}\n",
"for source in SOURCES:\n",
" ds_src = DFFDataset(DATA_DIR, sources=[source])\n",
" idxs = random.sample(range(len(ds_src)), min(N_SAMPLES, len(ds_src)))\n",
" arrays = [\n",
" np.array(\n",
" Image.open(ds_src.samples[i][0]).convert('RGB').resize((64, 64)),\n",
" dtype=np.float32\n",
" ) / 255.0\n",
" for i in idxs\n",
" ]\n",
" stack = np.stack(arrays) # (N, 64, 64, 3)\n",
" source_means[source] = stack.mean(axis=(0, 1, 2)) # per channel\n",
" source_stds[source] = stack.std(axis=(0, 1, 2))\n",
" print(f' {source}: mean={source_means[source].round(3)} std={source_stds[source].round(3)}')\n",
"\n",
"src_keys = list(SOURCES.keys())\n",
"x = np.arange(len(src_keys))\n",
"bar_w = 0.22\n",
"ch_colors = ['#F44336', '#4CAF50', '#2196F3']\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n",
"for ci, ch in enumerate(CH_NAMES):\n",
" offset = (ci - 1) * bar_w\n",
" axes[0].bar(x + offset, [source_means[s][ci] for s in src_keys],\n",
" bar_w, label=ch, color=ch_colors[ci], alpha=0.85)\n",
" axes[1].bar(x + offset, [source_stds[s][ci] for s in src_keys],\n",
" bar_w, label=ch, color=ch_colors[ci], alpha=0.85)\n",
"\n",
"for ax, title, ylabel in zip(\n",
" axes,\n",
" ['Mean pixel intensity per source', 'Pixel std dev per source'],\n",
" ['Mean (01)', 'Std dev (01)'],\n",
"):\n",
" ax.set_xticks(x)\n",
" ax.set_xticklabels(src_keys)\n",
" ax.set_title(title, fontsize=12)\n",
" ax.set_ylabel(ylabel)\n",
" ax.legend(title='Channel')\n",
" ax.grid(axis='y', alpha=0.3)\n",
"\n",
"fig.tight_layout()\n",
"fig.savefig(FIG_DIR / 'color_stats.png', dpi=120, bbox_inches='tight')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "1c7d4660",
"metadata": {},
"source": [
"## 5. CV split and leakage sanity check"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "89513a74",
"metadata": {},
"outputs": [],
"source": [
"cfg = {\n",
" \"cv_folds\": 5,\n",
" \"seed\": SEED,\n",
" \"image_size\": 224,\n",
" \"train_sources\": None,\n",
" \"eval_sources\": None,\n",
"}\n",
"\n",
"splits = get_splits(full_ds, cfg)\n",
"print(f\"Generated {len(splits)} CV folds\")\n",
"\n",
"for fold_i, (train_idx, val_idx, test_idx) in enumerate(splits):\n",
" train_ids = {full_ds.samples[i][0].name for i in train_idx}\n",
" val_ids = {full_ds.samples[i][0].name for i in val_idx}\n",
" test_ids = {full_ds.samples[i][0].name for i in test_idx}\n",
"\n",
" overlap = (train_ids & val_ids) | (train_ids & test_ids) | (val_ids & test_ids)\n",
" print(\n",
" f\"Fold {fold_i}: train={len(train_idx):6d} val={len(val_idx):6d} test={len(test_idx):6d} \"\n",
" f\"identity_overlap={len(overlap)}\"\n",
" )\n",
"\n",
"print(\"\\nExpected: identity_overlap should be 0 for every fold.\")"
]
},
{
"cell_type": "markdown",
"id": "eda-13",
"metadata": {},
"source": [
"## 6. Observations template\n",
"\n",
"Fill in after running the notebook:\n",
"\n",
"**Class balance**\n",
"- Confirm fake:real ratio and whether sampler/reweighting is needed.\n",
"\n",
"**Identity completeness**\n",
"- Note whether most basenames appear in all sources or if there are missing-source identities.\n",
"\n",
"**Dimensions**\n",
"- Record dominant dimensions and whether extreme outliers appear.\n",
"\n",
"**Color stats**\n",
"- Note clear mean/std shifts by source (if any).\n",
"\n",
"**Split sanity**\n",
"- Confirm every fold reports `identity_overlap=0`.\n",
"\n",
"**Action items before training**\n",
"- List any cleanup/filtering decisions (if required)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "drl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}