352 lines
12 KiB
Plaintext
352 lines
12 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "eda-00",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 01 — EDA\n",
|
||
"\n",
|
||
"Explore DeepFakeFace (DFF) data quality before training: composition, source distribution, image properties, and split safety.\n",
|
||
"\n",
|
||
"**Sections:**\n",
|
||
"1. Dataset composition and label balance\n",
|
||
"2. Visual sanity-check samples\n",
|
||
"3. Image dimension profile\n",
|
||
"4. Per-source color statistics\n",
|
||
"5. CV split and leakage sanity check\n",
|
||
"6. Observations"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "eda-01",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import sys\n",
|
||
"sys.path.insert(0, '..')\n",
|
||
"\n",
|
||
"import random\n",
|
||
"from collections import Counter\n",
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"from PIL import Image\n",
|
||
"\n",
|
||
"from src.data import DFFDataset, SOURCES, get_splits\n",
|
||
"\n",
|
||
"DATA_DIR = Path('../../data')\n",
|
||
"FIG_DIR = Path('../outputs/figures')\n",
|
||
"FIG_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"SEED = 42\n",
|
||
"random.seed(SEED)\n",
|
||
"np.random.seed(SEED)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "eda-02",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1. Dataset composition and label balance"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "eda-03",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"full_ds = DFFDataset(DATA_DIR)\n",
|
||
"label_counts = full_ds.label_counts()\n",
|
||
"\n",
|
||
"print(f\"Total images : {len(full_ds):,}\")\n",
|
||
"print(f\" Real (label=0) : {label_counts[0]:,}\")\n",
|
||
"print(f\" Fake (label=1) : {label_counts[1]:,}\")\n",
|
||
"print(f\" Fake:real ratio : {label_counts[1] / label_counts[0]:.2f}x\\n\")\n",
|
||
"\n",
|
||
"source_info = []\n",
|
||
"for source, label in SOURCES.items():\n",
|
||
" ds = DFFDataset(DATA_DIR, sources=[source])\n",
|
||
" source_info.append((source, len(ds), label))\n",
|
||
" tag = 'real' if label == 0 else 'fake'\n",
|
||
" print(f\" {source:12s} n={len(ds):6,} label={label} ({tag})\")\n",
|
||
"\n",
|
||
"# Identity-level sanity check: each basename should appear in every source.\n",
|
||
"basename_counts = Counter(path.name for path, _ in full_ds.samples)\n",
|
||
"presence_hist = Counter(basename_counts.values())\n",
|
||
"\n",
|
||
"print(\"\\nIdentity (basename) presence across sources:\")\n",
|
||
"for n_sources, count in sorted(presence_hist.items()):\n",
|
||
" print(f\" present in {n_sources} source(s): {count:,} identities\")\n",
|
||
"\n",
|
||
"incomplete = sum(v for k, v in presence_hist.items() if k < len(SOURCES))\n",
|
||
"print(f\" complete in all {len(SOURCES)} sources: {presence_hist.get(len(SOURCES), 0):,}\")\n",
|
||
"print(f\" incomplete identities : {incomplete:,}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "eda-04",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
|
||
"\n",
|
||
"# Overall class balance\n",
|
||
"class_names = ['Real (wiki)', 'Fake (all 3)']\n",
|
||
"class_counts = [label_counts[0], label_counts[1]]\n",
|
||
"bars = ax1.bar(class_names, class_counts, color=['#2196F3', '#F44336'], width=0.5)\n",
|
||
"ax1.set_title('Overall Class Balance', fontsize=13)\n",
|
||
"ax1.set_ylabel('Images')\n",
|
||
"ax1.set_ylim(0, max(class_counts) * 1.15)\n",
|
||
"for bar, v in zip(bars, class_counts):\n",
|
||
" ax1.text(bar.get_x() + bar.get_width() / 2, v + 300,\n",
|
||
" f'{v:,}', ha='center', fontsize=11)\n",
|
||
"\n",
|
||
"# Per-source breakdown\n",
|
||
"src_names = [s for s, _, _ in source_info]\n",
|
||
"src_counts = [n for _, n, _ in source_info]\n",
|
||
"colors = ['#2196F3', '#FF9800', '#9C27B0', '#4CAF50']\n",
|
||
"bars2 = ax2.bar(src_names, src_counts, color=colors, width=0.5)\n",
|
||
"ax2.set_title('Images per Source', fontsize=13)\n",
|
||
"ax2.set_ylabel('Images')\n",
|
||
"ax2.set_ylim(0, max(src_counts) * 1.15)\n",
|
||
"for bar, v in zip(bars2, src_counts):\n",
|
||
" ax2.text(bar.get_x() + bar.get_width() / 2, v + 300,\n",
|
||
" f'{v:,}', ha='center', fontsize=11)\n",
|
||
"\n",
|
||
"fig.tight_layout()\n",
|
||
"fig.savefig(FIG_DIR / 'class_balance.png', dpi=120, bbox_inches='tight')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "eda-05",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2. Visual sanity-check samples"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "eda-06",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"N_COLS = 6\n",
|
||
"fig, axes = plt.subplots(len(SOURCES), N_COLS, figsize=(18, 12))\n",
|
||
"fig.suptitle('Sample images — 6 per source', fontsize=14)\n",
|
||
"\n",
|
||
"for row, (source, label) in enumerate(SOURCES.items()):\n",
|
||
" ds_src = DFFDataset(DATA_DIR, sources=[source])\n",
|
||
" indices = random.sample(range(len(ds_src)), N_COLS)\n",
|
||
" for col, idx in enumerate(indices):\n",
|
||
" path, _ = ds_src.samples[idx]\n",
|
||
" img = Image.open(path).convert('RGB').resize((128, 128))\n",
|
||
" axes[row, col].imshow(img)\n",
|
||
" axes[row, col].axis('off')\n",
|
||
" tag = 'real' if label == 0 else 'fake'\n",
|
||
" axes[row, 0].set_ylabel(f'{source}\\n({tag})', fontsize=10)\n",
|
||
"\n",
|
||
"fig.tight_layout()\n",
|
||
"fig.savefig(FIG_DIR / 'sample_images.png', dpi=100, bbox_inches='tight')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "eda-07",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3. Image dimension profile"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "eda-08",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sample_paths = [p for p, _ in random.sample(full_ds.samples, min(2000, len(full_ds)))]\n",
|
||
"sizes = Counter(Image.open(p).size for p in sample_paths)\n",
|
||
"\n",
|
||
"print('Most common image dimensions (W x H):')\n",
|
||
"for (w, h), count in sizes.most_common(10):\n",
|
||
" pct = count / len(sample_paths)\n",
|
||
" print(f' {w:4d} x {h:4d} — {count:4d} samples ({pct:.1%})')\n",
|
||
"\n",
|
||
"widths = [w for (w, _) in sizes.elements()]\n",
|
||
"heights = [h for (_, h) in sizes.elements()]\n",
|
||
"square = sum(1 for w, h in zip(widths, heights) if w == h)\n",
|
||
"print(f'\\nWidth range: {min(widths)}–{max(widths)} mean={np.mean(widths):.0f}')\n",
|
||
"print(f'Height range: {min(heights)}–{max(heights)} mean={np.mean(heights):.0f}')\n",
|
||
"print(f'Square images: {square}/{len(widths)} ({square / len(widths):.1%})')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "eda-11",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4. Per-source color statistics"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "eda-12",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print('Sampling per-source colour statistics (sampling 300 images per source)...')\n",
|
||
"N_SAMPLES = 300\n",
|
||
"CH_NAMES = ['R', 'G', 'B']\n",
|
||
"\n",
|
||
"source_means, source_stds = {}, {}\n",
|
||
"for source in SOURCES:\n",
|
||
" ds_src = DFFDataset(DATA_DIR, sources=[source])\n",
|
||
" idxs = random.sample(range(len(ds_src)), min(N_SAMPLES, len(ds_src)))\n",
|
||
" arrays = [\n",
|
||
" np.array(\n",
|
||
" Image.open(ds_src.samples[i][0]).convert('RGB').resize((64, 64)),\n",
|
||
" dtype=np.float32\n",
|
||
" ) / 255.0\n",
|
||
" for i in idxs\n",
|
||
" ]\n",
|
||
" stack = np.stack(arrays) # (N, 64, 64, 3)\n",
|
||
" source_means[source] = stack.mean(axis=(0, 1, 2)) # per channel\n",
|
||
" source_stds[source] = stack.std(axis=(0, 1, 2))\n",
|
||
" print(f' {source}: mean={source_means[source].round(3)} std={source_stds[source].round(3)}')\n",
|
||
"\n",
|
||
"src_keys = list(SOURCES.keys())\n",
|
||
"x = np.arange(len(src_keys))\n",
|
||
"bar_w = 0.22\n",
|
||
"ch_colors = ['#F44336', '#4CAF50', '#2196F3']\n",
|
||
"\n",
|
||
"fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n",
|
||
"for ci, ch in enumerate(CH_NAMES):\n",
|
||
" offset = (ci - 1) * bar_w\n",
|
||
" axes[0].bar(x + offset, [source_means[s][ci] for s in src_keys],\n",
|
||
" bar_w, label=ch, color=ch_colors[ci], alpha=0.85)\n",
|
||
" axes[1].bar(x + offset, [source_stds[s][ci] for s in src_keys],\n",
|
||
" bar_w, label=ch, color=ch_colors[ci], alpha=0.85)\n",
|
||
"\n",
|
||
"for ax, title, ylabel in zip(\n",
|
||
" axes,\n",
|
||
" ['Mean pixel intensity per source', 'Pixel std dev per source'],\n",
|
||
" ['Mean (0–1)', 'Std dev (0–1)'],\n",
|
||
"):\n",
|
||
" ax.set_xticks(x)\n",
|
||
" ax.set_xticklabels(src_keys)\n",
|
||
" ax.set_title(title, fontsize=12)\n",
|
||
" ax.set_ylabel(ylabel)\n",
|
||
" ax.legend(title='Channel')\n",
|
||
" ax.grid(axis='y', alpha=0.3)\n",
|
||
"\n",
|
||
"fig.tight_layout()\n",
|
||
"fig.savefig(FIG_DIR / 'color_stats.png', dpi=120, bbox_inches='tight')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1c7d4660",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 5. CV split and leakage sanity check"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "89513a74",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"cfg = {\n",
|
||
" \"cv_folds\": 5,\n",
|
||
" \"seed\": SEED,\n",
|
||
" \"image_size\": 224,\n",
|
||
" \"train_sources\": None,\n",
|
||
" \"eval_sources\": None,\n",
|
||
"}\n",
|
||
"\n",
|
||
"splits = get_splits(full_ds, cfg)\n",
|
||
"print(f\"Generated {len(splits)} CV folds\")\n",
|
||
"\n",
|
||
"for fold_i, (train_idx, val_idx, test_idx) in enumerate(splits):\n",
|
||
" train_ids = {full_ds.samples[i][0].name for i in train_idx}\n",
|
||
" val_ids = {full_ds.samples[i][0].name for i in val_idx}\n",
|
||
" test_ids = {full_ds.samples[i][0].name for i in test_idx}\n",
|
||
"\n",
|
||
" overlap = (train_ids & val_ids) | (train_ids & test_ids) | (val_ids & test_ids)\n",
|
||
" print(\n",
|
||
" f\"Fold {fold_i}: train={len(train_idx):6d} val={len(val_idx):6d} test={len(test_idx):6d} \"\n",
|
||
" f\"identity_overlap={len(overlap)}\"\n",
|
||
" )\n",
|
||
"\n",
|
||
"print(\"\\nExpected: identity_overlap should be 0 for every fold.\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "eda-13",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 6. Observations template\n",
|
||
"\n",
|
||
"Fill in after running the notebook:\n",
|
||
"\n",
|
||
"**Class balance**\n",
|
||
"- Confirm fake:real ratio and whether sampler/reweighting is needed.\n",
|
||
"\n",
|
||
"**Identity completeness**\n",
|
||
"- Note whether most basenames appear in all sources or if there are missing-source identities.\n",
|
||
"\n",
|
||
"**Dimensions**\n",
|
||
"- Record dominant dimensions and whether extreme outliers appear.\n",
|
||
"\n",
|
||
"**Color stats**\n",
|
||
"- Note clear mean/std shifts by source (if any).\n",
|
||
"\n",
|
||
"**Split sanity**\n",
|
||
"- Confirm every fold reports `identity_overlap=0`.\n",
|
||
"\n",
|
||
"**Action items before training**\n",
|
||
"- List any cleanup/filtering decisions (if required)."
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "drl",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.13"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|