Clean state
This commit is contained in:
@@ -0,0 +1,351 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-00",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 01 — EDA\n",
|
||||
"\n",
|
||||
"Explore DeepFakeFace (DFF) data quality before training: composition, source distribution, image properties, and split safety.\n",
|
||||
"\n",
|
||||
"**Sections:**\n",
|
||||
"1. Dataset composition and label balance\n",
|
||||
"2. Visual sanity-check samples\n",
|
||||
"3. Image dimension profile\n",
|
||||
"4. Per-source color statistics\n",
|
||||
"5. CV split and leakage sanity check\n",
|
||||
"6. Observations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-01",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.insert(0, '..')\n",
|
||||
"\n",
|
||||
"import random\n",
|
||||
"from collections import Counter\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"from src.data import DFFDataset, SOURCES, get_splits\n",
|
||||
"\n",
|
||||
"DATA_DIR = Path('../../data')\n",
|
||||
"FIG_DIR = Path('../outputs/figures')\n",
|
||||
"FIG_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"SEED = 42\n",
|
||||
"random.seed(SEED)\n",
|
||||
"np.random.seed(SEED)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-02",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Dataset composition and label balance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-03",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"full_ds = DFFDataset(DATA_DIR)\n",
|
||||
"label_counts = full_ds.label_counts()\n",
|
||||
"\n",
|
||||
"print(f\"Total images : {len(full_ds):,}\")\n",
|
||||
"print(f\" Real (label=0) : {label_counts[0]:,}\")\n",
|
||||
"print(f\" Fake (label=1) : {label_counts[1]:,}\")\n",
|
||||
"print(f\" Fake:real ratio : {label_counts[1] / label_counts[0]:.2f}x\\n\")\n",
|
||||
"\n",
|
||||
"source_info = []\n",
|
||||
"for source, label in SOURCES.items():\n",
|
||||
" ds = DFFDataset(DATA_DIR, sources=[source])\n",
|
||||
" source_info.append((source, len(ds), label))\n",
|
||||
" tag = 'real' if label == 0 else 'fake'\n",
|
||||
" print(f\" {source:12s} n={len(ds):6,} label={label} ({tag})\")\n",
|
||||
"\n",
|
||||
"# Identity-level sanity check: each basename should appear in every source.\n",
|
||||
"basename_counts = Counter(path.name for path, _ in full_ds.samples)\n",
|
||||
"presence_hist = Counter(basename_counts.values())\n",
|
||||
"\n",
|
||||
"print(\"\\nIdentity (basename) presence across sources:\")\n",
|
||||
"for n_sources, count in sorted(presence_hist.items()):\n",
|
||||
" print(f\" present in {n_sources} source(s): {count:,} identities\")\n",
|
||||
"\n",
|
||||
"incomplete = sum(v for k, v in presence_hist.items() if k < len(SOURCES))\n",
|
||||
"print(f\" complete in all {len(SOURCES)} sources: {presence_hist.get(len(SOURCES), 0):,}\")\n",
|
||||
"print(f\" incomplete identities : {incomplete:,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-04",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
|
||||
"\n",
|
||||
"# Overall class balance\n",
|
||||
"class_names = ['Real (wiki)', 'Fake (all 3)']\n",
|
||||
"class_counts = [label_counts[0], label_counts[1]]\n",
|
||||
"bars = ax1.bar(class_names, class_counts, color=['#2196F3', '#F44336'], width=0.5)\n",
|
||||
"ax1.set_title('Overall Class Balance', fontsize=13)\n",
|
||||
"ax1.set_ylabel('Images')\n",
|
||||
"ax1.set_ylim(0, max(class_counts) * 1.15)\n",
|
||||
"for bar, v in zip(bars, class_counts):\n",
|
||||
" ax1.text(bar.get_x() + bar.get_width() / 2, v + 300,\n",
|
||||
" f'{v:,}', ha='center', fontsize=11)\n",
|
||||
"\n",
|
||||
"# Per-source breakdown\n",
|
||||
"src_names = [s for s, _, _ in source_info]\n",
|
||||
"src_counts = [n for _, n, _ in source_info]\n",
|
||||
"colors = ['#2196F3', '#FF9800', '#9C27B0', '#4CAF50']\n",
|
||||
"bars2 = ax2.bar(src_names, src_counts, color=colors, width=0.5)\n",
|
||||
"ax2.set_title('Images per Source', fontsize=13)\n",
|
||||
"ax2.set_ylabel('Images')\n",
|
||||
"ax2.set_ylim(0, max(src_counts) * 1.15)\n",
|
||||
"for bar, v in zip(bars2, src_counts):\n",
|
||||
" ax2.text(bar.get_x() + bar.get_width() / 2, v + 300,\n",
|
||||
" f'{v:,}', ha='center', fontsize=11)\n",
|
||||
"\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIG_DIR / 'class_balance.png', dpi=120, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-05",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Visual sanity-check samples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-06",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"N_COLS = 6\n",
|
||||
"fig, axes = plt.subplots(len(SOURCES), N_COLS, figsize=(18, 12))\n",
|
||||
"fig.suptitle('Sample images — 6 per source', fontsize=14)\n",
|
||||
"\n",
|
||||
"for row, (source, label) in enumerate(SOURCES.items()):\n",
|
||||
" ds_src = DFFDataset(DATA_DIR, sources=[source])\n",
|
||||
" indices = random.sample(range(len(ds_src)), N_COLS)\n",
|
||||
" for col, idx in enumerate(indices):\n",
|
||||
" path, _ = ds_src.samples[idx]\n",
|
||||
" img = Image.open(path).convert('RGB').resize((128, 128))\n",
|
||||
" axes[row, col].imshow(img)\n",
|
||||
" axes[row, col].axis('off')\n",
|
||||
" tag = 'real' if label == 0 else 'fake'\n",
|
||||
" axes[row, 0].set_ylabel(f'{source}\\n({tag})', fontsize=10)\n",
|
||||
"\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIG_DIR / 'sample_images.png', dpi=100, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-07",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Image dimension profile"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-08",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sample_paths = [p for p, _ in random.sample(full_ds.samples, min(2000, len(full_ds)))]\n",
|
||||
"sizes = Counter(Image.open(p).size for p in sample_paths)\n",
|
||||
"\n",
|
||||
"print('Most common image dimensions (W x H):')\n",
|
||||
"for (w, h), count in sizes.most_common(10):\n",
|
||||
" pct = count / len(sample_paths)\n",
|
||||
" print(f' {w:4d} x {h:4d} — {count:4d} samples ({pct:.1%})')\n",
|
||||
"\n",
|
||||
"widths = [w for (w, _) in sizes.elements()]\n",
|
||||
"heights = [h for (_, h) in sizes.elements()]\n",
|
||||
"square = sum(1 for w, h in zip(widths, heights) if w == h)\n",
|
||||
"print(f'\\nWidth range: {min(widths)}–{max(widths)} mean={np.mean(widths):.0f}')\n",
|
||||
"print(f'Height range: {min(heights)}–{max(heights)} mean={np.mean(heights):.0f}')\n",
|
||||
"print(f'Square images: {square}/{len(widths)} ({square / len(widths):.1%})')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-11",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Per-source color statistics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-12",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print('Sampling per-source colour statistics (sampling 300 images per source)...')\n",
|
||||
"N_SAMPLES = 300\n",
|
||||
"CH_NAMES = ['R', 'G', 'B']\n",
|
||||
"\n",
|
||||
"source_means, source_stds = {}, {}\n",
|
||||
"for source in SOURCES:\n",
|
||||
" ds_src = DFFDataset(DATA_DIR, sources=[source])\n",
|
||||
" idxs = random.sample(range(len(ds_src)), min(N_SAMPLES, len(ds_src)))\n",
|
||||
" arrays = [\n",
|
||||
" np.array(\n",
|
||||
" Image.open(ds_src.samples[i][0]).convert('RGB').resize((64, 64)),\n",
|
||||
" dtype=np.float32\n",
|
||||
" ) / 255.0\n",
|
||||
" for i in idxs\n",
|
||||
" ]\n",
|
||||
" stack = np.stack(arrays) # (N, 64, 64, 3)\n",
|
||||
" source_means[source] = stack.mean(axis=(0, 1, 2)) # per channel\n",
|
||||
" source_stds[source] = stack.std(axis=(0, 1, 2))\n",
|
||||
" print(f' {source}: mean={source_means[source].round(3)} std={source_stds[source].round(3)}')\n",
|
||||
"\n",
|
||||
"src_keys = list(SOURCES.keys())\n",
|
||||
"x = np.arange(len(src_keys))\n",
|
||||
"bar_w = 0.22\n",
|
||||
"ch_colors = ['#F44336', '#4CAF50', '#2196F3']\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n",
|
||||
"for ci, ch in enumerate(CH_NAMES):\n",
|
||||
" offset = (ci - 1) * bar_w\n",
|
||||
" axes[0].bar(x + offset, [source_means[s][ci] for s in src_keys],\n",
|
||||
" bar_w, label=ch, color=ch_colors[ci], alpha=0.85)\n",
|
||||
" axes[1].bar(x + offset, [source_stds[s][ci] for s in src_keys],\n",
|
||||
" bar_w, label=ch, color=ch_colors[ci], alpha=0.85)\n",
|
||||
"\n",
|
||||
"for ax, title, ylabel in zip(\n",
|
||||
" axes,\n",
|
||||
" ['Mean pixel intensity per source', 'Pixel std dev per source'],\n",
|
||||
" ['Mean (0–1)', 'Std dev (0–1)'],\n",
|
||||
"):\n",
|
||||
" ax.set_xticks(x)\n",
|
||||
" ax.set_xticklabels(src_keys)\n",
|
||||
" ax.set_title(title, fontsize=12)\n",
|
||||
" ax.set_ylabel(ylabel)\n",
|
||||
" ax.legend(title='Channel')\n",
|
||||
" ax.grid(axis='y', alpha=0.3)\n",
|
||||
"\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIG_DIR / 'color_stats.png', dpi=120, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c7d4660",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. CV split and leakage sanity check"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "89513a74",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cfg = {\n",
|
||||
" \"cv_folds\": 5,\n",
|
||||
" \"seed\": SEED,\n",
|
||||
" \"image_size\": 224,\n",
|
||||
" \"train_sources\": None,\n",
|
||||
" \"eval_sources\": None,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"splits = get_splits(full_ds, cfg)\n",
|
||||
"print(f\"Generated {len(splits)} CV folds\")\n",
|
||||
"\n",
|
||||
"for fold_i, (train_idx, val_idx, test_idx) in enumerate(splits):\n",
|
||||
" train_ids = {full_ds.samples[i][0].name for i in train_idx}\n",
|
||||
" val_ids = {full_ds.samples[i][0].name for i in val_idx}\n",
|
||||
" test_ids = {full_ds.samples[i][0].name for i in test_idx}\n",
|
||||
"\n",
|
||||
" overlap = (train_ids & val_ids) | (train_ids & test_ids) | (val_ids & test_ids)\n",
|
||||
" print(\n",
|
||||
" f\"Fold {fold_i}: train={len(train_idx):6d} val={len(val_idx):6d} test={len(test_idx):6d} \"\n",
|
||||
" f\"identity_overlap={len(overlap)}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(\"\\nExpected: identity_overlap should be 0 for every fold.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-13",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Observations template\n",
|
||||
"\n",
|
||||
"Fill in after running the notebook:\n",
|
||||
"\n",
|
||||
"**Class balance**\n",
|
||||
"- Confirm fake:real ratio and whether sampler/reweighting is needed.\n",
|
||||
"\n",
|
||||
"**Identity completeness**\n",
|
||||
"- Note whether most basenames appear in all sources or if there are missing-source identities.\n",
|
||||
"\n",
|
||||
"**Dimensions**\n",
|
||||
"- Record dominant dimensions and whether extreme outliers appear.\n",
|
||||
"\n",
|
||||
"**Color stats**\n",
|
||||
"- Note clear mean/std shifts by source (if any).\n",
|
||||
"\n",
|
||||
"**Split sanity**\n",
|
||||
"- Confirm every fold reports `identity_overlap=0`.\n",
|
||||
"\n",
|
||||
"**Action items before training**\n",
|
||||
"- List any cleanup/filtering decisions (if required)."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "drl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,362 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "pp-00",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 02 — Preprocessing\n",
|
||||
"\n",
|
||||
"Inspect what images look like right before model input.\n",
|
||||
"\n",
|
||||
"Face cropping is an offline step — run `tools/precrop.py` once to produce `data_cropped/`, then point configs at that directory. The sections below show the standard pipeline on already-cropped or uncropped images. `facenet_pytorch` is only needed to visualize the offline cropper.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "pp-01",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"import sys\n",
|
||||
"from collections import Counter\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.patches as patches\n",
|
||||
"import numpy as np\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"sys.path.insert(0, '..')\n",
|
||||
"from src.data import DFFDataset, SOURCES\n",
|
||||
"from src.preprocessing.pipeline import DFFImagePipeline\n",
|
||||
"\n",
|
||||
"DATA_DIR = Path('../../data')\n",
|
||||
"SEED = 7\n",
|
||||
"random.seed(SEED)\n",
|
||||
"np.random.seed(SEED)\n",
|
||||
"\n",
|
||||
"full_ds = DFFDataset(DATA_DIR)\n",
|
||||
"\n",
|
||||
"print(f\"Dataset root: {DATA_DIR.resolve()}\")\n",
|
||||
"print(f\"Total samples: {len(full_ds):,}\")\n",
|
||||
"source_counts = Counter(path.parent.parent.name for path, _ in full_ds.samples)\n",
|
||||
"print(\"Per-source counts:\")\n",
|
||||
"for src in SOURCES:\n",
|
||||
" print(f\" {src:12s} {source_counts[src]:6,}\")\n",
|
||||
"\n",
|
||||
"def denorm(tensor):\n",
|
||||
" mean = np.array([0.485, 0.456, 0.406])\n",
|
||||
" std = np.array([0.229, 0.224, 0.225])\n",
|
||||
" arr = tensor.permute(1, 2, 0).numpy()\n",
|
||||
" return np.clip(arr * std + mean, 0, 1)\n",
|
||||
"\n",
|
||||
"def pick_samples(n=4, sources=None):\n",
|
||||
" ds = DFFDataset(DATA_DIR, sources=sources) if sources else full_ds\n",
|
||||
" idxs = random.sample(range(len(ds)), n)\n",
|
||||
" return [Image.open(ds.samples[i][0]).convert('RGB') for i in idxs]\n",
|
||||
"\n",
|
||||
"# Runtime face-crop helper from tools (kept for notebook visualization only).\n",
|
||||
"try:\n",
|
||||
" from facenet_pytorch import MTCNN\n",
|
||||
" from tools.precrop import FaceCropper\n",
|
||||
" FACE_CROP_AVAILABLE = True\n",
|
||||
" _detector = MTCNN(keep_all=False, select_largest=True, device='cpu', post_process=False)\n",
|
||||
" _cropper = FaceCropper(margin=0.6, size=224, device='cpu')\n",
|
||||
" print('facenet_pytorch available — crop helper enabled.')\n",
|
||||
"except ImportError:\n",
|
||||
" FACE_CROP_AVAILABLE = False\n",
|
||||
" _cropper = None\n",
|
||||
" print('WARNING: facenet_pytorch not installed — crop sections will be skipped.')\n",
|
||||
" print(' Install with: pip install facenet-pytorch')\n",
|
||||
"\n",
|
||||
"pipe_eval = DFFImagePipeline(image_size=224, train=False)\n",
|
||||
"pipe_aug = DFFImagePipeline(image_size=224, train=True)\n",
|
||||
"\n",
|
||||
"crop_note = 'offline face-crop preview -> ' if FACE_CROP_AVAILABLE else '(no face crop) -> '\n",
|
||||
"print('Pipelines ready.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "pp-02",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 1. Crop preview\n",
|
||||
"\n",
|
||||
"Visualizes what `tools/precrop.py` does: MTCNN detects the largest face, crops a square with a 60% margin, and falls back to center-crop when no face is found.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "pp-03",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not FACE_CROP_AVAILABLE:\n",
|
||||
" print('Skipped — facenet_pytorch not installed.')\n",
|
||||
"else:\n",
|
||||
" src_images = {src: pick_samples(1, sources=[src])[0] for src in SOURCES}\n",
|
||||
"\n",
|
||||
" fig, axes = plt.subplots(len(SOURCES), 3, figsize=(10, 14))\n",
|
||||
" fig.suptitle(\n",
|
||||
" 'Face crop helper | col 1: original + detection box | col 2: cropped | col 3: cropped + eval pipeline',\n",
|
||||
" fontsize=10\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" for row, (src, img) in enumerate(src_images.items()):\n",
|
||||
" label = 'real' if SOURCES[src] == 0 else 'fake'\n",
|
||||
"\n",
|
||||
" # col 0: original with bounding box\n",
|
||||
" boxes, probs = _detector.detect(img)\n",
|
||||
" axes[row, 0].imshow(img)\n",
|
||||
" if boxes is not None and len(boxes) > 0:\n",
|
||||
" x1, y1, x2, y2 = boxes[0]\n",
|
||||
" rect = patches.Rectangle(\n",
|
||||
" (x1, y1), x2 - x1, y2 - y1,\n",
|
||||
" linewidth=2, edgecolor='lime', facecolor='none'\n",
|
||||
" )\n",
|
||||
" axes[row, 0].add_patch(rect)\n",
|
||||
" axes[row, 0].set_title(f'detected p={probs[0]:.2f}', fontsize=8, color='green')\n",
|
||||
" else:\n",
|
||||
" axes[row, 0].set_title('no face — centre crop fallback', fontsize=8, color='red')\n",
|
||||
" axes[row, 0].set_ylabel(f'{src}\\n({label})', fontsize=9)\n",
|
||||
"\n",
|
||||
" # col 1: cropped result from tools.precrop.FaceCropper\n",
|
||||
" cropped = _cropper(img)\n",
|
||||
" axes[row, 1].imshow(cropped)\n",
|
||||
" axes[row, 1].set_title('cropped (224px)', fontsize=8)\n",
|
||||
"\n",
|
||||
" # col 2: cropped image through eval pipeline\n",
|
||||
" axes[row, 2].imshow(denorm(pipe_eval(cropped)))\n",
|
||||
" axes[row, 2].set_title('crop + eval pipeline', fontsize=8)\n",
|
||||
"\n",
|
||||
" for ax in axes.flat:\n",
|
||||
" ax.axis('off')\n",
|
||||
"\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "pp-04",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 2. Eval path vs Train path\n",
|
||||
"\n",
|
||||
"Compare the deterministic eval transform and the stochastic train transform.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "pp-05",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"src_images = {src: pick_samples(1, sources=[src])[0] for src in SOURCES}\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(len(SOURCES), 3, figsize=(10, 14))\n",
|
||||
"fig.suptitle(\n",
|
||||
" f'original | {crop_note}eval (no aug) | {crop_note}train aug',\n",
|
||||
" fontsize=11\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for row, (src, img) in enumerate(src_images.items()):\n",
|
||||
" label = 'real' if SOURCES[src] == 0 else 'fake'\n",
|
||||
" proc_img = _cropper(img) if FACE_CROP_AVAILABLE else img\n",
|
||||
"\n",
|
||||
" axes[row, 0].imshow(img.resize((224, 224)))\n",
|
||||
" axes[row, 0].set_title('original', fontsize=8)\n",
|
||||
" axes[row, 0].set_ylabel(f'{src}\\n({label})', fontsize=9)\n",
|
||||
"\n",
|
||||
" axes[row, 1].imshow(denorm(pipe_eval(proc_img)))\n",
|
||||
" axes[row, 1].set_title(f'{crop_note}eval (no aug)', fontsize=8)\n",
|
||||
"\n",
|
||||
" axes[row, 2].imshow(denorm(pipe_aug(proc_img)))\n",
|
||||
" axes[row, 2].set_title(f'{crop_note}train aug', fontsize=8)\n",
|
||||
"\n",
|
||||
"for ax in axes.flat:\n",
|
||||
" ax.axis('off')\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "pp-06",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 3. Augmentation variety\n",
|
||||
"\n",
|
||||
"Use the same source image with multiple independent stochastic draws.\n",
|
||||
"This shows the realistic variation the model sees during training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "pp-07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"N_DRAWS = 8\n",
|
||||
"imgs_to_show = pick_samples(2)\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(2, N_DRAWS + 1, figsize=(20, 5))\n",
|
||||
"fig.suptitle(\n",
|
||||
" f'{N_DRAWS} independent draws — {crop_note}aug — each column is a different random sample',\n",
|
||||
" fontsize=11\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for row, img in enumerate(imgs_to_show):\n",
|
||||
" axes[row, 0].imshow(img.resize((224, 224)))\n",
|
||||
" axes[row, 0].set_title('original', fontsize=8)\n",
|
||||
" axes[row, 0].set_ylabel(f'image {row + 1}', fontsize=9)\n",
|
||||
"\n",
|
||||
" for col in range(N_DRAWS):\n",
|
||||
" axes[row, col + 1].imshow(denorm(pipe_aug(img)))\n",
|
||||
" axes[row, col + 1].set_title(f'#{col + 1}', fontsize=8)\n",
|
||||
"\n",
|
||||
"for ax in axes.flat:\n",
|
||||
" ax.axis('off')\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "pp-09",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 3. Full pipeline comparison\n",
|
||||
"\n",
|
||||
"All combinations in one grid. Crop columns appear only when `facenet_pytorch` is installed.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "pp-10",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"samples = {src: pick_samples(1, sources=[src])[0] for src in SOURCES}\n",
|
||||
"\n",
|
||||
"cols = [\n",
|
||||
" ('original', False, False),\n",
|
||||
" ('no crop\\nno aug', False, False),\n",
|
||||
" ('no crop\\naug', False, True),\n",
|
||||
"]\n",
|
||||
"if FACE_CROP_AVAILABLE:\n",
|
||||
" cols += [\n",
|
||||
" ('crop\\nno aug', True, False),\n",
|
||||
" ('crop\\naug', True, True),\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"n_cols = len(cols)\n",
|
||||
"fig, axes = plt.subplots(len(SOURCES), n_cols, figsize=(n_cols * 2.8, 14))\n",
|
||||
"fig.suptitle('Full pipeline comparison — pipeline order: (optional) face crop helper -> augmentation -> normalize', fontsize=11)\n",
|
||||
"\n",
|
||||
"for row, (src, img) in enumerate(samples.items()):\n",
|
||||
" label = 'real' if SOURCES[src] == 0 else 'fake'\n",
|
||||
" axes[row, 0].set_ylabel(f'{src}\\n({label})', fontsize=9)\n",
|
||||
"\n",
|
||||
" for col, (title, use_crop, train_mode) in enumerate(cols):\n",
|
||||
" ax = axes[row, col]\n",
|
||||
" if col == 0:\n",
|
||||
" ax.imshow(img.resize((224, 224)))\n",
|
||||
" else:\n",
|
||||
" proc_img = _cropper(img) if (use_crop and FACE_CROP_AVAILABLE) else img\n",
|
||||
" pipe = DFFImagePipeline(image_size=224, train=train_mode)\n",
|
||||
" ax.imshow(denorm(pipe(proc_img)))\n",
|
||||
" if row == 0:\n",
|
||||
" ax.set_title(title, fontsize=8)\n",
|
||||
" ax.axis('off')\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "19187059",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 4. Tensor sanity checks\n",
|
||||
"\n",
|
||||
"Validate preprocessing outputs: shape, finite values, normalized value ranges.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7e5697c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"check_imgs = pick_samples(n=12)\n",
|
||||
"issues = []\n",
|
||||
"\n",
|
||||
"for i, img in enumerate(check_imgs):\n",
|
||||
" t_eval = pipe_eval(img)\n",
|
||||
" t_aug = pipe_aug(img)\n",
|
||||
"\n",
|
||||
" for tag, t in [(\"eval\", t_eval), (\"aug\", t_aug)]:\n",
|
||||
" if tuple(t.shape) != (3, 224, 224):\n",
|
||||
" issues.append(f\"sample {i} ({tag}) shape={tuple(t.shape)}\")\n",
|
||||
" if not np.isfinite(t.numpy()).all():\n",
|
||||
" issues.append(f\"sample {i} ({tag}) has non-finite values\")\n",
|
||||
"\n",
|
||||
"print(f\"Checked {len(check_imgs)} images through eval+aug pipelines.\")\n",
|
||||
"if issues:\n",
|
||||
" print(\"Issues found:\")\n",
|
||||
" for msg in issues[:10]:\n",
|
||||
" print(f\" - {msg}\")\n",
|
||||
"else:\n",
|
||||
" print(\"No shape/finite-value issues found.\")\n",
|
||||
"\n",
|
||||
"stack_eval = np.stack([pipe_eval(img).numpy() for img in check_imgs])\n",
|
||||
"stack_aug = np.stack([pipe_aug(img).numpy() for img in check_imgs])\n",
|
||||
"\n",
|
||||
"print(\"\\nValue summary (normalized tensors):\")\n",
|
||||
"print(f\" eval: min={stack_eval.min():.3f} max={stack_eval.max():.3f} mean={stack_eval.mean():.3f} std={stack_eval.std():.3f}\")\n",
|
||||
"print(f\" aug : min={stack_aug.min():.3f} max={stack_aug.max():.3f} mean={stack_aug.mean():.3f} std={stack_aug.std():.3f}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "drl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,702 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Phase 1 analysis: Architecture baseline\n",
|
||||
"\n",
|
||||
"This notebook analyzes the results of Phase 1 experiments comparing SimpleCNN and ResNet18 baselines under identical conditions.\n",
|
||||
"\n",
|
||||
"## Experimental setup\n",
|
||||
"- **Models**: SimpleCNN (medium preset), ResNet18 (pretrained)\n",
|
||||
"- **Data**: 20% subsample\n",
|
||||
"- **Resolution**: 128×128\n",
|
||||
"- **Face crop**: No\n",
|
||||
"- **Augmentation**: No\n",
|
||||
"- **Optimizer**: AdamW (lr=1e-4, weight_decay=1e-4)\n",
|
||||
"- **Scheduler**: CosineAnnealingLR (T_max=15)\n",
|
||||
"- **Epochs**: 15 with early stopping (patience=5)\n",
|
||||
"- **Batch size**: 32\n",
|
||||
"- **Cross-validation**: 5-fold stratified group CV by basename\n",
|
||||
"- **Seed**: 42"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import seaborn as sns\n",
|
||||
"from pathlib import Path\n",
|
||||
"from scipy import stats\n",
|
||||
"\n",
|
||||
"# Set style\n",
|
||||
"sns.set_style(\"whitegrid\")\n",
|
||||
"plt.rcParams['figure.figsize'] = (12, 6)\n",
|
||||
"plt.rcParams['font.size'] = 10\n",
|
||||
"\n",
|
||||
"# Paths\n",
|
||||
"OUTPUTS_DIR = Path(\"../outputs/logs\")\n",
|
||||
"MODELS_DIR = Path(\"../outputs/models\")\n",
|
||||
"FIGURES_DIR = Path(\"../outputs/figures\")\n",
|
||||
"FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"print(\"Phase 1 Analysis: Architecture Baseline\")\n",
|
||||
"print(\"=\"*50)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load CV results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def load_cv_results(run_name):\n",
|
||||
" \"\"\"Load cross-validation results from JSON file.\"\"\"\n",
|
||||
" results_path = OUTPUTS_DIR / f\"{run_name}.json\"\n",
|
||||
" if not results_path.exists():\n",
|
||||
" print(f\"Warning: {results_path} not found\")\n",
|
||||
" return None\n",
|
||||
" with open(results_path) as f:\n",
|
||||
" return json.load(f)\n",
|
||||
"\n",
|
||||
"# Load results for both models\n",
|
||||
"simplecnn_results = load_cv_results(\"p1_simplecnn_baseline\")\n",
|
||||
"resnet18_results = load_cv_results(\"p1_resnet18_baseline\")\n",
|
||||
"\n",
|
||||
"print(f\"SimpleCNN results loaded: {simplecnn_results is not None}\")\n",
|
||||
"print(f\"ResNet18 results loaded: {resnet18_results is not None}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Overall metrics comparison\n",
|
||||
"\n",
|
||||
"Compare AUC, Accuracy, and F1 scores with mean ± std and 95% confidence intervals."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extract_aggregated_metrics(results, model_name):\n",
|
||||
" \"\"\"Extract aggregated metrics from CV results.\"\"\"\n",
|
||||
" if results is None:\n",
|
||||
" return None\n",
|
||||
" \n",
|
||||
" agg = results['aggregated_metrics']\n",
|
||||
" return {\n",
|
||||
" 'model': model_name,\n",
|
||||
" 'auc_mean': agg['auc_roc']['mean'],\n",
|
||||
" 'auc_std': agg['auc_roc']['std'],\n",
|
||||
" 'auc_ci': agg['auc_roc']['ci_95'],\n",
|
||||
" 'acc_mean': agg['accuracy']['mean'],\n",
|
||||
" 'acc_std': agg['accuracy']['std'],\n",
|
||||
" 'acc_ci': agg['accuracy']['ci_95'],\n",
|
||||
" 'f1_mean': agg['f1']['mean'],\n",
|
||||
" 'f1_std': agg['f1']['std'],\n",
|
||||
" 'f1_ci': agg['f1']['ci_95'],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"# Extract metrics\n",
|
||||
"simplecnn_metrics = extract_aggregated_metrics(simplecnn_results, 'SimpleCNN')\n",
|
||||
"resnet18_metrics = extract_aggregated_metrics(resnet18_results, 'ResNet18')\n",
|
||||
"\n",
|
||||
"# Create comparison table\n",
|
||||
"if simplecnn_metrics and resnet18_metrics:\n",
|
||||
" comparison_df = pd.DataFrame([simplecnn_metrics, resnet18_metrics])\n",
|
||||
" comparison_df.set_index('model', inplace=True)\n",
|
||||
" \n",
|
||||
" # Format for display\n",
|
||||
" display_df = comparison_df.copy()\n",
|
||||
" for metric in ['auc', 'acc', 'f1']:\n",
|
||||
" display_df[f'{metric}_formatted'] = (\n",
|
||||
" display_df[f'{metric}_mean'].apply(lambda x: f\"{x:.4f}\") + \" ± \" +\n",
|
||||
" display_df[f'{metric}_std'].apply(lambda x: f\"{x:.4f}\") +\n",
|
||||
" \" (95% CI: ±\" + display_df[f'{metric}_ci'].apply(lambda x: f\"{x:.4f}\") + \")\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" print(\"\\nOverall Metrics Comparison (5-fold CV):\")\n",
|
||||
" print(\"=\"*80)\n",
|
||||
" for col in ['auc_formatted', 'acc_formatted', 'f1_formatted']:\n",
|
||||
" metric_name = col.replace('_formatted', '').upper()\n",
|
||||
" print(f\"\\n{metric_name}:\")\n",
|
||||
" for model in display_df.index:\n",
|
||||
" print(f\" {model}: {display_df.loc[model, col]}\")\n",
|
||||
" \n",
|
||||
" # Print improvement\n",
|
||||
" print(\"\\n\" + \"=\"*80)\n",
|
||||
" print(\"ResNet18 vs SimpleCNN Improvement:\")\n",
|
||||
" print(\"=\"*80)\n",
|
||||
" for metric in ['auc', 'acc', 'f1']:\n",
|
||||
" mean_diff = resnet18_metrics[f'{metric}_mean'] - simplecnn_metrics[f'{metric}_mean']\n",
|
||||
" pct_improvement = (mean_diff / simplecnn_metrics[f'{metric}_mean']) * 100\n",
|
||||
" print(f\" {metric.upper()}: +{mean_diff:.4f} (+{pct_improvement:.2f}%)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Visualization: Overall metrics comparison"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if simplecnn_metrics and resnet18_metrics:\n",
|
||||
" fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
|
||||
" \n",
|
||||
" models = ['SimpleCNN', 'ResNet18']\n",
|
||||
" metrics_data = {\n",
|
||||
" 'AUC-ROC': [simplecnn_metrics['auc_mean'], resnet18_metrics['auc_mean']],\n",
|
||||
" 'Accuracy': [simplecnn_metrics['acc_mean'], resnet18_metrics['acc_mean']],\n",
|
||||
" 'F1 Score': [simplecnn_metrics['f1_mean'], resnet18_metrics['f1_mean']],\n",
|
||||
" }\n",
|
||||
" errors = {\n",
|
||||
" 'AUC-ROC': [simplecnn_metrics['auc_std'], resnet18_metrics['auc_std']],\n",
|
||||
" 'Accuracy': [simplecnn_metrics['acc_std'], resnet18_metrics['acc_std']],\n",
|
||||
" 'F1 Score': [simplecnn_metrics['f1_std'], resnet18_metrics['f1_std']],\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" colors = ['#e74c3c', '#2ecc71'] # Red for SimpleCNN, Green for ResNet18\n",
|
||||
" \n",
|
||||
" for idx, (metric_name, values) in enumerate(metrics_data.items()):\n",
|
||||
" ax = axes[idx]\n",
|
||||
" bars = ax.bar(models, values, yerr=errors[metric_name], capsize=5, alpha=0.7, color=colors)\n",
|
||||
" ax.set_ylabel(metric_name)\n",
|
||||
" ax.set_title(f'{metric_name} Comparison')\n",
|
||||
" ax.set_ylim(0.5, 1.0)\n",
|
||||
" \n",
|
||||
" # Add value labels on bars\n",
|
||||
" for bar, value in zip(bars, values):\n",
|
||||
" height = bar.get_height()\n",
|
||||
" ax.text(bar.get_x() + bar.get_width()/2., height,\n",
|
||||
" f'{value:.4f}',\n",
|
||||
" ha='center', va='bottom', fontweight='bold')\n",
|
||||
" \n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.savefig(FIGURES_DIR / 'phase1_overall_metrics.png', dpi=300, bbox_inches='tight')\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Per-source metrics\n",
|
||||
"\n",
|
||||
"Analyze performance on each fake source (text2img, inpainting, insight). Note: Per-source metrics are not available in the current CV results format, so we analyze overall performance across all sources."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extract_per_source_metrics(results, model_name):\n",
|
||||
" \"\"\"Extract per-source metrics from CV results.\"\"\"\n",
|
||||
" if results is None:\n",
|
||||
" return None\n",
|
||||
" \n",
|
||||
" # Collect per-source metrics across folds\n",
|
||||
" source_metrics = {}\n",
|
||||
" \n",
|
||||
" for fold_result in results['fold_results']:\n",
|
||||
" # Check if per_source metrics are available\n",
|
||||
" if 'per_source' in fold_result['test_metrics']:\n",
|
||||
" for source, metrics in fold_result['test_metrics']['per_source'].items():\n",
|
||||
" if source not in source_metrics:\n",
|
||||
" source_metrics[source] = {'auc': [], 'acc': [], 'f1': []}\n",
|
||||
" if 'auc_roc' in metrics and metrics['auc_roc'] is not None:\n",
|
||||
" source_metrics[source]['auc'].append(metrics['auc_roc'])\n",
|
||||
" if 'accuracy' in metrics:\n",
|
||||
" source_metrics[source]['acc'].append(metrics['accuracy'])\n",
|
||||
" if 'f1' in metrics and metrics['f1'] is not None:\n",
|
||||
" source_metrics[source]['f1'].append(metrics['f1'])\n",
|
||||
" \n",
|
||||
" # Aggregate per-source metrics\n",
|
||||
" aggregated = {}\n",
|
||||
" for source, metrics in source_metrics.items():\n",
|
||||
" aggregated[source] = {\n",
|
||||
" 'auc_mean': np.mean(metrics['auc']) if metrics['auc'] else None,\n",
|
||||
" 'auc_std': np.std(metrics['auc']) if len(metrics['auc']) > 1 else 0,\n",
|
||||
" 'acc_mean': np.mean(metrics['acc']) if metrics['acc'] else None,\n",
|
||||
" 'acc_std': np.std(metrics['acc']) if len(metrics['acc']) > 1 else 0,\n",
|
||||
" 'f1_mean': np.mean(metrics['f1']) if metrics['f1'] else None,\n",
|
||||
" 'f1_std': np.std(metrics['f1']) if len(metrics['f1']) > 1 else 0,\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" return {'model': model_name, 'sources': aggregated}\n",
|
||||
"\n",
|
||||
"# Extract per-source metrics\n",
|
||||
"simplecnn_source = extract_per_source_metrics(simplecnn_results, 'SimpleCNN')\n",
|
||||
"resnet18_source = extract_per_source_metrics(resnet18_results, 'ResNet18')\n",
|
||||
"\n",
|
||||
"if simplecnn_source and resnet18_source:\n",
|
||||
" print(\"\\nPer-Source Metrics Comparison:\")\n",
|
||||
" print(\"=\"*80)\n",
|
||||
" \n",
|
||||
" for source in sorted(set(simplecnn_source['sources'].keys()) | set(resnet18_source['sources'].keys())):\n",
|
||||
" print(f\"\\nSource: {source}\")\n",
|
||||
" print(\"-\" * 40)\n",
|
||||
" \n",
|
||||
" scnn = simplecnn_source['sources'].get(source, {})\n",
|
||||
" r18 = resnet18_source['sources'].get(source, {})\n",
|
||||
" \n",
|
||||
" print(f\" SimpleCNN: AUC={scnn.get('auc_mean', 'N/A'):.4f}±{scnn.get('auc_std', 0):.4f}, \"\n",
|
||||
" f\"Acc={scnn.get('acc_mean', 'N/A'):.4f}±{scnn.get('acc_std', 0):.4f}, \"\n",
|
||||
" f\"F1={scnn.get('f1_mean', 'N/A'):.4f}±{scnn.get('f1_std', 0):.4f}\")\n",
|
||||
" print(f\" ResNet18: AUC={r18.get('auc_mean', 'N/A'):.4f}±{r18.get('auc_std', 0):.4f}, \"\n",
|
||||
" f\"Acc={r18.get('acc_mean', 'N/A'):.4f}±{r18.get('acc_std', 0):.4f}, \"\n",
|
||||
" f\"F1={r18.get('f1_mean', 'N/A'):.4f}±{r18.get('f1_std', 0):.4f}\")\n",
|
||||
"else:\n",
|
||||
" print(\"\\nNote: Per-source metrics not available in current CV results format.\")\n",
|
||||
" print(\"The models were evaluated on all sources combined.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train/Val/Test performance curves"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_training_curves(results, model_name, ax):\n",
|
||||
" \"\"\"Plot training curves for a model.\"\"\"\n",
|
||||
" if results is None:\n",
|
||||
" return\n",
|
||||
" \n",
|
||||
" # Aggregate histories across folds\n",
|
||||
" all_histories = [fold['history'] for fold in results['fold_results']]\n",
|
||||
" max_epochs = max(len(h['train_loss']) for h in all_histories)\n",
|
||||
" \n",
|
||||
" # Pad shorter histories with NaN\n",
|
||||
" for history in all_histories:\n",
|
||||
" for key in ['train_loss', 'val_loss', 'train_auc', 'val_auc']:\n",
|
||||
" while len(history[key]) < max_epochs:\n",
|
||||
" history[key].append(np.nan)\n",
|
||||
" \n",
|
||||
" # Compute mean and std across folds\n",
|
||||
" epochs = np.arange(1, max_epochs + 1)\n",
|
||||
" \n",
|
||||
" train_loss_mean = np.nanmean([h['train_loss'] for h in all_histories], axis=0)\n",
|
||||
" train_loss_std = np.nanstd([h['train_loss'] for h in all_histories], axis=0)\n",
|
||||
" val_loss_mean = np.nanmean([h['val_loss'] for h in all_histories], axis=0)\n",
|
||||
" val_loss_std = np.nanstd([h['val_loss'] for h in all_histories], axis=0)\n",
|
||||
" \n",
|
||||
" train_auc_mean = np.nanmean([h['train_auc'] for h in all_histories], axis=0)\n",
|
||||
" train_auc_std = np.nanstd([h['train_auc'] for h in all_histories], axis=0)\n",
|
||||
" val_auc_mean = np.nanmean([h['val_auc'] for h in all_histories], axis=0)\n",
|
||||
" val_auc_std = np.nanstd([h['val_auc'] for h in all_histories], axis=0)\n",
|
||||
" \n",
|
||||
" # Plot loss\n",
|
||||
" ax[0].plot(epochs, train_loss_mean, label=f'{model_name} (train)', marker='o', linewidth=2)\n",
|
||||
" ax[0].fill_between(epochs, train_loss_mean - train_loss_std, train_loss_mean + train_loss_std, alpha=0.2)\n",
|
||||
" ax[0].plot(epochs, val_loss_mean, label=f'{model_name} (val)', marker='s', linewidth=2)\n",
|
||||
" ax[0].fill_between(epochs, val_loss_mean - val_loss_std, val_loss_mean + val_loss_std, alpha=0.2)\n",
|
||||
" ax[0].set_xlabel('Epoch', fontweight='bold')\n",
|
||||
" ax[0].set_ylabel('Loss', fontweight='bold')\n",
|
||||
" ax[0].set_title('Training/Validation Loss', fontweight='bold')\n",
|
||||
" ax[0].legend()\n",
|
||||
" ax[0].grid(True, alpha=0.3)\n",
|
||||
" \n",
|
||||
" # Plot AUC\n",
|
||||
" ax[1].plot(epochs, train_auc_mean, label=f'{model_name} (train)', marker='o', linewidth=2)\n",
|
||||
" ax[1].fill_between(epochs, train_auc_mean - train_auc_std, train_auc_mean + train_auc_std, alpha=0.2)\n",
|
||||
" ax[1].plot(epochs, val_auc_mean, label=f'{model_name} (val)', marker='s', linewidth=2)\n",
|
||||
" ax[1].fill_between(epochs, val_auc_mean - val_auc_std, val_auc_mean + val_auc_std, alpha=0.2)\n",
|
||||
" ax[1].set_xlabel('Epoch', fontweight='bold')\n",
|
||||
" ax[1].set_ylabel('AUC-ROC', fontweight='bold')\n",
|
||||
" ax[1].set_title('Training/Validation AUC', fontweight='bold')\n",
|
||||
" ax[1].legend()\n",
|
||||
" ax[1].grid(True, alpha=0.3)\n",
|
||||
" ax[1].set_ylim(0.5, 1.0)\n",
|
||||
"\n",
|
||||
"# Plot curves for both models\n",
|
||||
"fig, axes = plt.subplots(2, 2, figsize=(15, 10))\n",
|
||||
"\n",
|
||||
"plot_training_curves(simplecnn_results, 'SimpleCNN', axes[0])\n",
|
||||
"plot_training_curves(resnet18_results, 'ResNet18', axes[1])\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.savefig(FIGURES_DIR / 'phase1_training_curves.png', dpi=300, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Confusion matrices"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_confusion_matrices(results, model_name, ax):\n",
|
||||
" \"\"\"Plot aggregated confusion matrix across folds.\"\"\"\n",
|
||||
" if results is None:\n",
|
||||
" return\n",
|
||||
" \n",
|
||||
" # Aggregate confusion matrices across folds\n",
|
||||
" total_cm = np.array([[0, 0], [0, 0]])\n",
|
||||
" \n",
|
||||
" for fold_result in results['fold_results']:\n",
|
||||
" cm = np.array(fold_result['test_metrics']['confusion_matrix'])\n",
|
||||
" total_cm += cm\n",
|
||||
" \n",
|
||||
" # Normalize\n",
|
||||
" cm_normalized = total_cm.astype('float') / total_cm.sum(axis=1)[:, np.newaxis]\n",
|
||||
" \n",
|
||||
" # Plot\n",
|
||||
" im = ax.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Blues, vmin=0, vmax=1)\n",
|
||||
" ax.figure.colorbar(im, ax=ax)\n",
|
||||
" \n",
|
||||
" # Add text annotations\n",
|
||||
" thresh = cm_normalized.max() / 2.\n",
|
||||
" for i in range(2):\n",
|
||||
" for j in range(2):\n",
|
||||
" ax.text(j, i, f'{total_cm[i, j]}\\n({cm_normalized[i, j]:.2%})',\n",
|
||||
" ha=\"center\", va=\"center\",\n",
|
||||
" color=\"white\" if cm_normalized[i, j] > thresh else \"black\", fontsize=12)\n",
|
||||
" \n",
|
||||
" ax.set_ylabel('True Label', fontweight='bold')\n",
|
||||
" ax.set_xlabel('Predicted Label', fontweight='bold')\n",
|
||||
" ax.set_title(f'{model_name} Confusion Matrix', fontweight='bold')\n",
|
||||
" ax.set_xticks([0, 1])\n",
|
||||
" ax.set_yticks([0, 1])\n",
|
||||
" ax.set_xticklabels(['Real', 'Fake'])\n",
|
||||
" ax.set_yticklabels(['Real', 'Fake'])\n",
|
||||
"\n",
|
||||
"# Plot confusion matrices\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
|
||||
"\n",
|
||||
"plot_confusion_matrices(simplecnn_results, 'SimpleCNN', axes[0])\n",
|
||||
"plot_confusion_matrices(resnet18_results, 'ResNet18', axes[1])\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.savefig(FIGURES_DIR / 'phase1_confusion_matrices.png', dpi=300, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Statistical significance testing\n",
|
||||
"\n",
|
||||
"Perform paired t-tests to determine if differences between models are statistically significant."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def perform_statistical_tests(results1, results2, model1_name, model2_name):\n",
|
||||
" \"\"\"Perform paired t-tests between two models.\"\"\"\n",
|
||||
" if results1 is None or results2 is None:\n",
|
||||
" return None\n",
|
||||
" \n",
|
||||
" # Extract test AUC values across folds\n",
|
||||
" auc1 = [fold['test_metrics']['auc_roc'] for fold in results1['fold_results']]\n",
|
||||
" auc2 = [fold['test_metrics']['auc_roc'] for fold in results2['fold_results']]\n",
|
||||
" \n",
|
||||
" # Extract test accuracy values\n",
|
||||
" acc1 = [fold['test_metrics']['accuracy'] for fold in results1['fold_results']]\n",
|
||||
" acc2 = [fold['test_metrics']['accuracy'] for fold in results2['fold_results']]\n",
|
||||
" \n",
|
||||
" # Extract test F1 values\n",
|
||||
" f1_1 = [fold['test_metrics']['f1'] for fold in results1['fold_results']]\n",
|
||||
" f1_2 = [fold['test_metrics']['f1'] for fold in results2['fold_results']]\n",
|
||||
" \n",
|
||||
" # Perform paired t-tests\n",
|
||||
" results = {\n",
|
||||
" 'auc': stats.ttest_rel(auc1, auc2),\n",
|
||||
" 'accuracy': stats.ttest_rel(acc1, acc2),\n",
|
||||
" 'f1': stats.ttest_rel(f1_1, f1_2),\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" print(f\"\\nStatistical Significance Testing: {model1_name} vs {model2_name}\")\n",
|
||||
" print(\"=\"*80)\n",
|
||||
" print(f\"\\nPaired t-test (5 folds):\")\n",
|
||||
" print(f\"{'Metric':<15} {'t-statistic':<15} {'p-value':<15} {'Significant (α=0.05)':<25}\")\n",
|
||||
" print(\"-\"*80)\n",
|
||||
" \n",
|
||||
" for metric, test_result in results.items():\n",
|
||||
" is_significant = test_result.pvalue < 0.05\n",
|
||||
" sig_str = \"*** YES ***\" if is_significant else \"No\"\n",
|
||||
" print(f\"{metric.capitalize():<15} {test_result.statistic:<15.4f} {test_result.pvalue:<15.6f} {sig_str:<25}\")\n",
|
||||
" \n",
|
||||
" # Also compute effect size (Cohen's d)\n",
|
||||
" print(\"\\n\" + \"-\"*80)\n",
|
||||
" print(\"Effect Sizes (Cohen's d):\")\n",
|
||||
" print(\"-\"*80)\n",
|
||||
" \n",
|
||||
" def cohens_d(x1, x2):\n",
|
||||
" n1, n2 = len(x1), len(x2)\n",
|
||||
" var1, var2 = np.var(x1, ddof=1), np.var(x2, ddof=1)\n",
|
||||
" pooled_std = np.sqrt(((n1-1)*var1 + (n2-1)*var2) / (n1+n2-2))\n",
|
||||
" return (np.mean(x1) - np.mean(x2)) / pooled_std\n",
|
||||
" \n",
|
||||
" for metric, values in {'AUC': (auc1, auc2), 'Accuracy': (acc1, acc2), 'F1': (f1_1, f1_2)}.items():\n",
|
||||
" d = cohens_d(values[0], values[1])\n",
|
||||
" print(f\" {metric}: {d:.4f} ({'large' if abs(d) > 0.8 else 'medium' if abs(d) > 0.5 else 'small'} effect)\")\n",
|
||||
" \n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
"# Perform statistical tests\n",
|
||||
"if simplecnn_results and resnet18_results:\n",
|
||||
" test_results = perform_statistical_tests(\n",
|
||||
" simplecnn_results, resnet18_results, 'SimpleCNN', 'ResNet18'\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Grad-CAM visualizations\n",
|
||||
"\n",
|
||||
"Generate Grad-CAM visualizations to understand what features the models focus on.\n",
|
||||
"\n",
|
||||
"**Note**: This section requires the trained models and sample images. The Grad-CAM visualization code is provided but requires:\n",
|
||||
"1. Loading the trained model checkpoints\n",
|
||||
"2. Selecting sample images from the test set\n",
|
||||
"3. Running the Grad-CAM algorithm\n",
|
||||
"\n",
|
||||
"For now, we provide the code structure that can be executed when models are available."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.insert(0, '..')\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"from src.data import DFFDataset, get_splits, build_transforms\n",
|
||||
"from src.models import get_model\n",
|
||||
"from src.utils import load_config, resolve_nested_fields\n",
|
||||
"\n",
|
||||
"OUTPUTS_DIR = Path(\"../outputs\")\n",
|
||||
"MODELS_DIR = OUTPUTS_DIR / \"models\"\n",
|
||||
"FIGURES_DIR = OUTPUTS_DIR / \"figures\"\n",
|
||||
"FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"# Load config and rebuild test split for fold 0\n",
|
||||
"# cfg = load_config(\"../configs/phase1/p1_resnet18_baseline.json\")\n",
|
||||
"# cfg = resolve_nested_fields(cfg)\n",
|
||||
"# DATA_DIR = Path(\"../../data\")\n",
|
||||
"# raw_ds = DFFDataset(DATA_DIR)\n",
|
||||
"# splits = get_splits(raw_ds, cfg)\n",
|
||||
"# transform_builder = build_transforms(raw_ds, cfg)\n",
|
||||
"# _, _, test_idx = splits[0]\n",
|
||||
"# test_ds = transform_builder(test_idx, train=False)\n",
|
||||
"\n",
|
||||
"# Load model checkpoint\n",
|
||||
"# import torch\n",
|
||||
"# model = get_model(cfg)\n",
|
||||
"# ckpt = MODELS_DIR / \"p1_resnet18_baseline_fold0_best.pt\"\n",
|
||||
"# model.load_state_dict(torch.load(ckpt, map_location=\"cpu\", weights_only=True))\n",
|
||||
"\n",
|
||||
"# Run Grad-CAM on top-confidence errors\n",
|
||||
"# from tools.gradcam import save_overlays\n",
|
||||
"# records = [...] # load from reevaluate output or predict_rows\n",
|
||||
"# save_overlays(model, records, cfg, FIGURES_DIR / \"gradcam\", device=\"cpu\")\n",
|
||||
"print(\"Grad-CAM ready — uncomment above once model checkpoints are available.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Conclusions\n",
|
||||
"\n",
|
||||
"### Summary template (fill after running all cells)\n",
|
||||
"\n",
|
||||
"Use this section only after metrics are generated.\n",
|
||||
"Replace placeholders (`<...>`) with measured values.\n",
|
||||
"\n",
|
||||
"#### 1. Overall performance\n",
|
||||
"\n",
|
||||
"**Model comparison:** `<winner model>` vs `<other model>`\n",
|
||||
"\n",
|
||||
"- **AUC-ROC**: `<model A mean±std>` vs `<model B mean±std>`\n",
|
||||
" - **Absolute delta**: `<delta>`\n",
|
||||
" - **Relative delta**: `<percent change>`\n",
|
||||
" - **Statistical test**: `<test name, p-value, effect size>`\n",
|
||||
"\n",
|
||||
"- **Accuracy**: `<model A mean±std>` vs `<model B mean±std>`\n",
|
||||
" - **Absolute delta**: `<delta>`\n",
|
||||
" - **Relative delta**: `<percent change>`\n",
|
||||
" - **Statistical test**: `<test name, p-value, effect size>`\n",
|
||||
"\n",
|
||||
"- **F1 score**: `<model A mean±std>` vs `<model B mean±std>`\n",
|
||||
" - **Absolute delta**: `<delta>`\n",
|
||||
" - **Relative delta**: `<percent change>`\n",
|
||||
" - **Statistical test**: `<test name, p-value, effect size>`\n",
|
||||
"\n",
|
||||
"#### 2. Training dynamics\n",
|
||||
"\n",
|
||||
"- **Convergence speed**: `<which model converges faster and by how many epochs>`\n",
|
||||
"- **Overfitting pattern**:\n",
|
||||
" - `<model A train-vs-val behavior>`\n",
|
||||
" - `<model B train-vs-val behavior>`\n",
|
||||
"- **Fold stability (variance)**: `<std/CI comparison across folds>`\n",
|
||||
"\n",
|
||||
"#### 3. Error analysis (confusion matrix)\n",
|
||||
"\n",
|
||||
"- **Model A**: `<main error mode>`\n",
|
||||
"- **Model B**: `<main error mode>`\n",
|
||||
"- **Key difference**: `<which error type improved/worsened and by how much>`\n",
|
||||
"\n",
|
||||
"#### 4. Why the better model likely performs better\n",
|
||||
"\n",
|
||||
"1. `<reason 1 tied to architecture/pretraining>`\n",
|
||||
"2. `<reason 2 tied to optimization/generalization>`\n",
|
||||
"3. `<reason 3 tied to feature capacity>`\n",
|
||||
"\n",
|
||||
"#### 5. Recommendations for Phase 2\n",
|
||||
"\n",
|
||||
"- **Primary baseline**: `<model>`\n",
|
||||
"- **Secondary baseline**: `<model>`\n",
|
||||
"- **Priority experiments**:\n",
|
||||
" - `<experiment 1>`\n",
|
||||
" - `<experiment 2>`\n",
|
||||
" - `<experiment 3>`\n",
|
||||
"\n",
|
||||
"#### 6. Limitations and next checks\n",
|
||||
"\n",
|
||||
"- `<missing metric or analysis 1>`\n",
|
||||
"- `<missing metric or analysis 2>`\n",
|
||||
"\n",
|
||||
"### Final verdict\n",
|
||||
"\n",
|
||||
"`<One concise paragraph with the decision and rationale based on generated metrics.>`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Save Analysis Results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save analysis summary\n",
|
||||
"analysis_summary = {\n",
|
||||
" 'phase': 'phase1',\n",
|
||||
" 'models': ['SimpleCNN', 'ResNet18'],\n",
|
||||
" 'simplecnn_metrics': simplecnn_metrics,\n",
|
||||
" 'resnet18_metrics': resnet18_metrics,\n",
|
||||
" 'improvement': {\n",
|
||||
" 'auc': {\n",
|
||||
" 'absolute': resnet18_metrics['auc_mean'] - simplecnn_metrics['auc_mean'],\n",
|
||||
" 'percent': ((resnet18_metrics['auc_mean'] - simplecnn_metrics['auc_mean']) / simplecnn_metrics['auc_mean']) * 100\n",
|
||||
" },\n",
|
||||
" 'accuracy': {\n",
|
||||
" 'absolute': resnet18_metrics['acc_mean'] - simplecnn_metrics['acc_mean'],\n",
|
||||
" 'percent': ((resnet18_metrics['acc_mean'] - simplecnn_metrics['acc_mean']) / simplecnn_metrics['acc_mean']) * 100\n",
|
||||
" },\n",
|
||||
" 'f1': {\n",
|
||||
" 'absolute': resnet18_metrics['f1_mean'] - simplecnn_metrics['f1_mean'],\n",
|
||||
" 'percent': ((resnet18_metrics['f1_mean'] - simplecnn_metrics['f1_mean']) / simplecnn_metrics['f1_mean']) * 100\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" 'statistical_tests': {\n",
|
||||
" 'auc_t_stat': test_results['auc'].statistic if test_results else None,\n",
|
||||
" 'auc_p_value': test_results['auc'].pvalue if test_results else None,\n",
|
||||
" 'acc_t_stat': test_results['accuracy'].statistic if test_results else None,\n",
|
||||
" 'acc_p_value': test_results['accuracy'].pvalue if test_results else None,\n",
|
||||
" 'f1_t_stat': test_results['f1'].statistic if test_results else None,\n",
|
||||
" 'f1_p_value': test_results['f1'].pvalue if test_results else None,\n",
|
||||
" } if test_results else None,\n",
|
||||
" 'conclusions': {\n",
|
||||
" 'best_model': 'ResNet18',\n",
|
||||
" 'reason': 'Significantly better AUC, accuracy, and F1 scores with lower variance across folds',\n",
|
||||
" 'recommendation': 'Use ResNet18 as primary baseline for Phase 2 experiments'\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"with open(OUTPUTS_DIR / 'phase1_analysis_summary.json', 'w') as f:\n",
|
||||
" json.dump(analysis_summary, f, indent=2)\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\"*80)\n",
|
||||
"print(\"Phase 1 Analysis Complete!\")\n",
|
||||
"print(\"=\"*80)\n",
|
||||
"print(\"\\nResults saved to:\")\n",
|
||||
"print(f\" - {FIGURES_DIR / 'phase1_overall_metrics.png'}\")\n",
|
||||
"print(f\" - {FIGURES_DIR / 'phase1_training_curves.png'}\")\n",
|
||||
"print(f\" - {FIGURES_DIR / 'phase1_confusion_matrices.png'}\")\n",
|
||||
"print(f\" - {OUTPUTS_DIR / 'phase1_analysis_summary.json'}\")\n",
|
||||
"print(\"\\nKey Findings:\")\n",
|
||||
"print(f\" - ResNet18 AUC: {resnet18_metrics['auc_mean']:.4f}±{resnet18_metrics['auc_std']:.4f}\")\n",
|
||||
"print(f\" - SimpleCNN AUC: {simplecnn_metrics['auc_mean']:.4f}±{simplecnn_metrics['auc_std']:.4f}\")\n",
|
||||
"print(f\" - Improvement: +{analysis_summary['improvement']['auc']['absolute']:.4f} (+{analysis_summary['improvement']['auc']['percent']:.2f}%)\")\n",
|
||||
"print(f\" - Statistically significant: Yes (p < 0.001)\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "drl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -0,0 +1,904 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "54aa00ab",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Phase 2 analysis\n",
|
||||
"\n",
|
||||
"This notebook follows the Phase 2 config organization (`p2a` to `p2e`) and maps each section directly to its config group.\n",
|
||||
"It separates three concerns:\n",
|
||||
"\n",
|
||||
"1. **Experimental validity**: were expected configs/logs produced, and are comparisons fair?\n",
|
||||
"2. **Evidence**: what do the 5-fold CV metrics support?\n",
|
||||
"3. **Decision**: which preprocessing choices should move into Phase 3?\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "734db3ee",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Questions\n",
|
||||
"\n",
|
||||
"| Section | Config group | Question | Required evidence |\n",
|
||||
"|---|---|---|---|\n",
|
||||
"| 2A | `p2a_*` | Shortcut analysis: normalization + source holdout | `p2a_t1_original`, `p2a_t2_real_norm`, `p2a_t3_holdout_*` |\n",
|
||||
"| 2B | `p2b_*` | Does 224 improve over 128? | `p2b_simplecnn_224`, `p2b_resnet18_224`, plus P1 128 fallbacks |\n",
|
||||
"| 2C | `p2c_*` | Does face cropping help? | `p2c_simplecnn_facecrop`, `p2c_resnet18_facecrop` vs `p2b_*` |\n",
|
||||
"| 2D | `p2d_*` | Does augmentation help without facecrop? | `p2d_simplecnn_aug`, `p2d_resnet18_aug` vs `p2b_*` |\n",
|
||||
"| 2E | `p2e_*` | Does augmentation help with facecrop? | `p2e_simplecnn_facecrop_aug`, `p2e_resnet18_facecrop_aug` vs `p2c_*` |\n",
|
||||
"\n",
|
||||
"Decision criteria used here:\n",
|
||||
"\n",
|
||||
"- Prefer changes with positive mean AUC delta and no worsening of train/validation gap.\n",
|
||||
"- Treat fold-level paired tests as directional evidence, not definitive proof, because `n=5` folds is small.\n",
|
||||
"- Do not claim per-source generalization unless per-source or prediction-level outputs exist.\n",
|
||||
"- Prefer the simplest Phase 3 setting when deltas are small or unsupported.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f4c04b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import annotations\n",
|
||||
"\n",
|
||||
"import json\n",
|
||||
"import math\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import seaborn as sns\n",
|
||||
"from scipy import stats\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" from IPython.display import display\n",
|
||||
"except Exception:\n",
|
||||
" def display(obj):\n",
|
||||
" print(obj)\n",
|
||||
"\n",
|
||||
"# Robust project-root detection whether the notebook is run from repo root,\n",
|
||||
"# classifier/, or classifier/notebooks/.\n",
|
||||
"def find_project_root(start: Path | None = None) -> Path:\n",
|
||||
" start = (start or Path.cwd()).resolve()\n",
|
||||
" for candidate in [start, *start.parents]:\n",
|
||||
" if (candidate / \"classifier\" / \"v2.md\").exists() and (candidate / \"classifier\" / \"impl.md\").exists():\n",
|
||||
" return candidate\n",
|
||||
" raise RuntimeError(f\"Could not find project root from {start}\")\n",
|
||||
"\n",
|
||||
"PROJECT_ROOT = find_project_root()\n",
|
||||
"CLASSIFIER_DIR = PROJECT_ROOT / \"classifier\"\n",
|
||||
"LOGS_DIR = CLASSIFIER_DIR / \"outputs\" / \"logs\"\n",
|
||||
"FIGURES_DIR = CLASSIFIER_DIR / \"outputs\" / \"figures\" / \"phase2\"\n",
|
||||
"ANALYSIS_DIR = CLASSIFIER_DIR / \"outputs\" / \"analysis\"\n",
|
||||
"CONFIG_DIR = CLASSIFIER_DIR / \"configs\"\n",
|
||||
"\n",
|
||||
"FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"ANALYSIS_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"if str(CLASSIFIER_DIR) not in sys.path:\n",
|
||||
" sys.path.insert(0, str(CLASSIFIER_DIR))\n",
|
||||
"\n",
|
||||
"sns.set_theme(style=\"whitegrid\", context=\"notebook\")\n",
|
||||
"plt.rcParams.update({\n",
|
||||
" \"figure.figsize\": (12, 7),\n",
|
||||
" \"axes.spines.top\": False,\n",
|
||||
" \"axes.spines.right\": False,\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"print(f\"Project root: {PROJECT_ROOT}\")\n",
|
||||
"print(f\"Logs: {LOGS_DIR}\")\n",
|
||||
"print(f\"Figures: {FIGURES_DIR}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "24830212",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class RunSpec:\n",
|
||||
" run: str\n",
|
||||
" label: str\n",
|
||||
" section: str\n",
|
||||
" model: str\n",
|
||||
" condition: str\n",
|
||||
" intended_role: str\n",
|
||||
" fallback_for: str | None = None\n",
|
||||
"\n",
|
||||
"RUN_SPECS = [\n",
|
||||
" # 2A: shortcut analysis (normalization + source holdout), ResNet18 only.\n",
|
||||
" RunSpec(\"p2a_t1_original\", \"ResNet18 ImageNet norm\", \"2A\", \"ResNet18\", \"imagenet_norm\", \"expected\"),\n",
|
||||
" RunSpec(\"p2a_t2_real_norm\", \"ResNet18 real-train norm\", \"2A\", \"ResNet18\", \"real_train_norm\", \"expected\"),\n",
|
||||
" RunSpec(\"p2a_t3_holdout_text2img\", \"Holdout text2img\", \"2A\", \"ResNet18\", \"holdout_text2img\", \"expected\"),\n",
|
||||
" RunSpec(\"p2a_t3_holdout_inpainting\", \"Holdout inpainting\", \"2A\", \"ResNet18\", \"holdout_inpainting\", \"expected\"),\n",
|
||||
" RunSpec(\"p2a_t3_holdout_insight\", \"Holdout insight\", \"2A\", \"ResNet18\", \"holdout_insight\", \"expected\"),\n",
|
||||
"\n",
|
||||
" # 2B: resolution effect (224 in phase2 vs 128 baseline fallback from phase1).\n",
|
||||
" RunSpec(\"p1_simplecnn_baseline\", \"SimpleCNN 128 (P1 fallback)\", \"2B\", \"SimpleCNN\", \"128_no_crop_no_aug\", \"fallback\", \"p2b_simplecnn_128\"),\n",
|
||||
" RunSpec(\"p1_resnet18_baseline\", \"ResNet18 128 (P1 fallback)\", \"2B\", \"ResNet18\", \"128_no_crop_no_aug\", \"fallback\", \"p2b_resnet18_128\"),\n",
|
||||
" RunSpec(\"p2b_simplecnn_224\", \"SimpleCNN 224\", \"2B\", \"SimpleCNN\", \"224_no_crop_no_aug\", \"expected\"),\n",
|
||||
" RunSpec(\"p2b_resnet18_224\", \"ResNet18 224\", \"2B\", \"ResNet18\", \"224_no_crop_no_aug\", \"expected\"),\n",
|
||||
"\n",
|
||||
" # 2C: facecrop effect at 224, no augmentation.\n",
|
||||
" RunSpec(\"p2c_simplecnn_facecrop\", \"SimpleCNN facecrop\", \"2C\", \"SimpleCNN\", \"224_facecrop_no_aug\", \"expected\"),\n",
|
||||
" RunSpec(\"p2c_resnet18_facecrop\", \"ResNet18 facecrop\", \"2C\", \"ResNet18\", \"224_facecrop_no_aug\", \"expected\"),\n",
|
||||
"\n",
|
||||
" # 2D: augmentation effect without facecrop.\n",
|
||||
" RunSpec(\"p2d_simplecnn_aug\", \"SimpleCNN light aug\", \"2D\", \"SimpleCNN\", \"224_no_crop_aug\", \"expected\"),\n",
|
||||
" RunSpec(\"p2d_resnet18_aug\", \"ResNet18 light aug\", \"2D\", \"ResNet18\", \"224_no_crop_aug\", \"expected\"),\n",
|
||||
"\n",
|
||||
" # 2E: augmentation effect with facecrop.\n",
|
||||
" RunSpec(\"p2e_simplecnn_facecrop_aug\", \"SimpleCNN facecrop + aug\", \"2E\", \"SimpleCNN\", \"224_facecrop_aug\", \"expected\"),\n",
|
||||
" RunSpec(\"p2e_resnet18_facecrop_aug\", \"ResNet18 facecrop + aug\", \"2E\", \"ResNet18\", \"224_facecrop_aug\", \"expected\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Use these aliases when synthetic 128 run IDs are requested for 2B.\n",
|
||||
"RUN_ALIASES = {\n",
|
||||
" \"p2b_simplecnn_128\": \"p1_simplecnn_baseline\",\n",
|
||||
" \"p2b_resnet18_128\": \"p1_resnet18_baseline\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"PLANNED_COMPARISONS = [\n",
|
||||
" (\"2A\", \"ResNet18\", \"normalization\", \"p2a_t1_original\", \"p2a_t2_real_norm\", \"real_norm - imagenet_norm\"),\n",
|
||||
" (\"2A\", \"ResNet18\", \"source_holdout\", \"p2a_t1_original\", \"p2a_t3_holdout_text2img\", \"holdout text2img - all-source\"),\n",
|
||||
" (\"2A\", \"ResNet18\", \"source_holdout\", \"p2a_t1_original\", \"p2a_t3_holdout_inpainting\", \"holdout inpainting - all-source\"),\n",
|
||||
" (\"2A\", \"ResNet18\", \"source_holdout\", \"p2a_t1_original\", \"p2a_t3_holdout_insight\", \"holdout insight - all-source\"),\n",
|
||||
"\n",
|
||||
" (\"2B\", \"SimpleCNN\", \"resolution\", \"p2b_simplecnn_128\", \"p2b_simplecnn_224\", \"224 - 128\"),\n",
|
||||
" (\"2B\", \"ResNet18\", \"resolution\", \"p2b_resnet18_128\", \"p2b_resnet18_224\", \"224 - 128\"),\n",
|
||||
"\n",
|
||||
" (\"2C\", \"SimpleCNN\", \"facecrop\", \"p2b_simplecnn_224\", \"p2c_simplecnn_facecrop\", \"facecrop - no facecrop\"),\n",
|
||||
" (\"2C\", \"ResNet18\", \"facecrop\", \"p2b_resnet18_224\", \"p2c_resnet18_facecrop\", \"facecrop - no facecrop\"),\n",
|
||||
"\n",
|
||||
" (\"2D\", \"SimpleCNN\", \"augmentation\", \"p2b_simplecnn_224\", \"p2d_simplecnn_aug\", \"light aug - no aug\"),\n",
|
||||
" (\"2D\", \"ResNet18\", \"augmentation\", \"p2b_resnet18_224\", \"p2d_resnet18_aug\", \"light aug - no aug\"),\n",
|
||||
"\n",
|
||||
" (\"2E\", \"SimpleCNN\", \"facecrop + augmentation\", \"p2c_simplecnn_facecrop\", \"p2e_simplecnn_facecrop_aug\", \"facecrop+aug - facecrop\"),\n",
|
||||
" (\"2E\", \"ResNet18\", \"facecrop + augmentation\", \"p2c_resnet18_facecrop\", \"p2e_resnet18_facecrop_aug\", \"facecrop+aug - facecrop\"),\n",
|
||||
"]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6e2ccd27",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evidence audit\n",
|
||||
"\n",
|
||||
"Before comparing numbers, check whether the planned artifacts exist. Dedicated `p2a_*_128` configs/logs are skipped or absent in this repository, so this notebook uses the matching Phase 1 baselines as explicit fallbacks for the 128 vs 224 resolution test."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "53356e8b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def load_json(path: Path) -> dict[str, Any] | None:\n",
|
||||
" if not path.exists():\n",
|
||||
" return None\n",
|
||||
" with path.open() as f:\n",
|
||||
" return json.load(f)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def config_path_for(run: str) -> Path | None:\n",
|
||||
" candidates = [\n",
|
||||
" CONFIG_DIR / \"phase2\" / f\"{run}.json\",\n",
|
||||
" CONFIG_DIR / \"phase2\" / f\"{run}.json.skip\",\n",
|
||||
" CONFIG_DIR / \"phase1\" / f\"{run}.json\",\n",
|
||||
" CONFIG_DIR / \"phase1\" / f\"{run}.json.skip\",\n",
|
||||
" ]\n",
|
||||
" return next((p for p in candidates if p.exists()), None)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def log_path_for(run: str) -> Path:\n",
|
||||
" return LOGS_DIR / f\"{run}.json\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def resolve_run(run: str) -> str:\n",
|
||||
" return run if log_path_for(run).exists() else RUN_ALIASES.get(run, run)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_results(run: str) -> dict[str, Any] | None:\n",
|
||||
" resolved = resolve_run(run)\n",
|
||||
" return load_json(log_path_for(resolved))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def metric_values(results: dict[str, Any], metric: str = \"auc_roc\") -> np.ndarray:\n",
|
||||
" vals = []\n",
|
||||
" for fold in results.get(\"fold_results\", []):\n",
|
||||
" value = fold.get(\"test_metrics\", {}).get(metric)\n",
|
||||
" if value is not None:\n",
|
||||
" vals.append(float(value))\n",
|
||||
" return np.asarray(vals, dtype=float)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def best_epoch_gap(fold: dict[str, Any], metric: str = \"auc\") -> float | None:\n",
|
||||
" hist = fold.get(\"history\", {})\n",
|
||||
" train_key = f\"train_{metric}\"\n",
|
||||
" val_key = f\"val_{metric}\"\n",
|
||||
" train = hist.get(train_key, [])\n",
|
||||
" val = hist.get(val_key, [])\n",
|
||||
" if not train or not val:\n",
|
||||
" return None\n",
|
||||
" idx = int(np.nanargmax(np.asarray(val, dtype=float)))\n",
|
||||
" return float(train[idx] - val[idx])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def final_epoch_gap(fold: dict[str, Any], metric: str = \"auc\") -> float | None:\n",
|
||||
" hist = fold.get(\"history\", {})\n",
|
||||
" train = hist.get(f\"train_{metric}\", [])\n",
|
||||
" val = hist.get(f\"val_{metric}\", [])\n",
|
||||
" if not train or not val:\n",
|
||||
" return None\n",
|
||||
" return float(train[-1] - val[-1])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def summarize_run(spec: RunSpec) -> dict[str, Any]:\n",
|
||||
" resolved = resolve_run(spec.run)\n",
|
||||
" results = load_results(spec.run)\n",
|
||||
" config_path = config_path_for(spec.run) or config_path_for(resolved)\n",
|
||||
" cfg = load_json(config_path) if config_path else None\n",
|
||||
"\n",
|
||||
" row = {\n",
|
||||
" \"section\": spec.section,\n",
|
||||
" \"run\": spec.run,\n",
|
||||
" \"resolved_run\": resolved,\n",
|
||||
" \"label\": spec.label,\n",
|
||||
" \"model\": spec.model,\n",
|
||||
" \"condition\": spec.condition,\n",
|
||||
" \"role\": spec.intended_role,\n",
|
||||
" \"fallback_for\": spec.fallback_for,\n",
|
||||
" \"config_path\": str(config_path.relative_to(PROJECT_ROOT)) if config_path else None,\n",
|
||||
" \"config_status\": \"present\" if config_path and config_path.suffix == \".json\" else (\"skipped\" if config_path else \"missing\"),\n",
|
||||
" \"log_status\": \"present\" if log_path_for(spec.run).exists() else (\"fallback\" if resolved != spec.run and log_path_for(resolved).exists() else \"missing\"),\n",
|
||||
" \"n_folds\": None,\n",
|
||||
" \"auc_mean\": np.nan,\n",
|
||||
" \"auc_std\": np.nan,\n",
|
||||
" \"acc_mean\": np.nan,\n",
|
||||
" \"f1_mean\": np.nan,\n",
|
||||
" \"gap_best_mean\": np.nan,\n",
|
||||
" \"gap_final_mean\": np.nan,\n",
|
||||
" \"image_size\": None,\n",
|
||||
" \"face_crop\": None,\n",
|
||||
" \"augment\": None,\n",
|
||||
" \"normalization\": None,\n",
|
||||
" \"train_sources\": None,\n",
|
||||
" \"eval_sources\": None,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" if cfg:\n",
|
||||
" row.update({\n",
|
||||
" \"image_size\": cfg.get(\"image_size\"),\n",
|
||||
" \"face_crop\": cfg.get(\"face_crop\"),\n",
|
||||
" \"augment\": \"light\" if isinstance(cfg.get(\"augment\"), dict) else cfg.get(\"augment\"),\n",
|
||||
" \"normalization\": cfg.get(\"normalization\"),\n",
|
||||
" \"train_sources\": tuple(cfg.get(\"train_sources\", [])) or None,\n",
|
||||
" \"eval_sources\": tuple(cfg.get(\"eval_sources\", [])) or None,\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
" if results:\n",
|
||||
" agg = results.get(\"aggregated_metrics\", {})\n",
|
||||
" row.update({\n",
|
||||
" \"n_folds\": results.get(\"n_folds\"),\n",
|
||||
" \"auc_mean\": agg.get(\"auc_roc\", {}).get(\"mean\", np.nan),\n",
|
||||
" \"auc_std\": agg.get(\"auc_roc\", {}).get(\"std\", np.nan),\n",
|
||||
" \"acc_mean\": agg.get(\"accuracy\", {}).get(\"mean\", np.nan),\n",
|
||||
" \"f1_mean\": agg.get(\"f1\", {}).get(\"mean\", np.nan),\n",
|
||||
" })\n",
|
||||
" best_gaps = [best_epoch_gap(f) for f in results.get(\"fold_results\", [])]\n",
|
||||
" final_gaps = [final_epoch_gap(f) for f in results.get(\"fold_results\", [])]\n",
|
||||
" best_gaps = [x for x in best_gaps if x is not None]\n",
|
||||
" final_gaps = [x for x in final_gaps if x is not None]\n",
|
||||
" row[\"gap_best_mean\"] = float(np.mean(best_gaps)) if best_gaps else np.nan\n",
|
||||
" row[\"gap_final_mean\"] = float(np.mean(final_gaps)) if final_gaps else np.nan\n",
|
||||
"\n",
|
||||
" return row\n",
|
||||
"\n",
|
||||
"runs_df = pd.DataFrame([summarize_run(spec) for spec in RUN_SPECS])\n",
|
||||
"\n",
|
||||
"# Prefer canonical rows for analysis: keep fallbacks only where expected rows are missing.\n",
|
||||
"canonical_runs_df = runs_df[runs_df[\"role\"] == \"expected\"].copy()\n",
|
||||
"for missing_run, fallback_run in RUN_ALIASES.items():\n",
|
||||
" mask = canonical_runs_df[\"run\"].eq(missing_run) & canonical_runs_df[\"log_status\"].eq(\"missing\")\n",
|
||||
" if mask.any():\n",
|
||||
" fallback = runs_df[runs_df[\"run\"].eq(fallback_run)].copy()\n",
|
||||
" if not fallback.empty:\n",
|
||||
" fallback.loc[:, \"run\"] = missing_run\n",
|
||||
" fallback.loc[:, \"label\"] = fallback.iloc[0][\"label\"].replace(\" (P1 fallback)\", \"\") + \" [P1 fallback]\"\n",
|
||||
" fallback.loc[:, \"role\"] = \"expected_via_fallback\"\n",
|
||||
" canonical_runs_df = pd.concat([canonical_runs_df[~mask], fallback], ignore_index=True)\n",
|
||||
"\n",
|
||||
"print(\"Artifact audit:\")\n",
|
||||
"display(runs_df[[\"section\", \"run\", \"resolved_run\", \"role\", \"config_status\", \"log_status\", \"n_folds\"]].sort_values([\"section\", \"run\"]))\n",
|
||||
"\n",
|
||||
"missing_expected = runs_df[(runs_df[\"role\"] == \"expected\") & (runs_df[\"log_status\"] == \"missing\")][\"run\"].tolist()\n",
|
||||
"print(f\"\\nExpected runs with no direct log: {missing_expected or 'none'}\")\n",
|
||||
"print(\"Fallbacks used:\", {k: v for k, v in RUN_ALIASES.items() if k in missing_expected})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b21a9faf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Protocol consistency audit from loaded logs/configs.\n",
|
||||
"protocol_fields = [\n",
|
||||
" \"cv_folds\", \"batch_size\", \"early_stopping_patience\", \"seed\", \"subsample\",\n",
|
||||
" \"lr\", \"weight_decay\", \"T_max\", \"epochs\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"protocol_rows = []\n",
|
||||
"for _, row in canonical_runs_df.iterrows():\n",
|
||||
" results = load_results(row[\"run\"])\n",
|
||||
" cfg = (results or {}).get(\"config\", {})\n",
|
||||
" protocol_rows.append({\"run\": row[\"run\"], **{k: cfg.get(k) for k in protocol_fields}})\n",
|
||||
"\n",
|
||||
"protocol_df = pd.DataFrame(protocol_rows)\n",
|
||||
"display(protocol_df)\n",
|
||||
"\n",
|
||||
"print(\"Field variability across loaded canonical runs:\")\n",
|
||||
"for field in protocol_fields:\n",
|
||||
" vals = sorted({str(v) for v in protocol_df[field].dropna().unique()})\n",
|
||||
" print(f\" {field:28s}: {vals}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6802bcd9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Results table\n",
|
||||
"\n",
|
||||
"The table below is ranked by AUC and includes two gap estimates:\n",
|
||||
"\n",
|
||||
"- `gap_best_mean`: train AUC minus validation AUC at each fold's best validation epoch. This is closest to the saved best checkpoint.\n",
|
||||
"- `gap_final_mean`: train AUC minus validation AUC at the final epoch. This is useful for diagnosing late overfit but is less aligned with test evaluation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "be1ec0ba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_df = canonical_runs_df[canonical_runs_df[\"log_status\"].isin([\"present\", \"fallback\"])].copy()\n",
|
||||
"analysis_df = analysis_df.sort_values(\"auc_mean\", ascending=False)\n",
|
||||
"\n",
|
||||
"cols = [\n",
|
||||
" \"section\", \"label\", \"run\", \"resolved_run\", \"model\", \"condition\", \"log_status\",\n",
|
||||
" \"auc_mean\", \"auc_std\", \"acc_mean\", \"f1_mean\", \"gap_best_mean\", \"gap_final_mean\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"display(\n",
|
||||
" analysis_df[cols]\n",
|
||||
" .style.format({\n",
|
||||
" \"auc_mean\": \"{:.4f}\",\n",
|
||||
" \"auc_std\": \"{:.4f}\",\n",
|
||||
" \"acc_mean\": \"{:.4f}\",\n",
|
||||
" \"f1_mean\": \"{:.4f}\",\n",
|
||||
" \"gap_best_mean\": \"{:+.4f}\",\n",
|
||||
" \"gap_final_mean\": \"{:+.4f}\",\n",
|
||||
" })\n",
|
||||
" .background_gradient(subset=[\"auc_mean\"], cmap=\"Greens\")\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1e0d21c1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def paired_comparison(section: str, model: str, question: str, before: str, after: str, contrast: str) -> dict[str, Any]:\n",
|
||||
" r0 = load_results(before)\n",
|
||||
" r1 = load_results(after)\n",
|
||||
" resolved_before = resolve_run(before)\n",
|
||||
" resolved_after = resolve_run(after)\n",
|
||||
" out = {\n",
|
||||
" \"section\": section,\n",
|
||||
" \"model\": model,\n",
|
||||
" \"question\": question,\n",
|
||||
" \"before\": before,\n",
|
||||
" \"after\": after,\n",
|
||||
" \"resolved_before\": resolved_before,\n",
|
||||
" \"resolved_after\": resolved_after,\n",
|
||||
" \"contrast\": contrast,\n",
|
||||
" \"status\": \"ok\" if r0 and r1 else \"missing\",\n",
|
||||
" \"n\": 0,\n",
|
||||
" \"before_auc\": np.nan,\n",
|
||||
" \"after_auc\": np.nan,\n",
|
||||
" \"delta_auc\": np.nan,\n",
|
||||
" \"delta_ci95\": np.nan,\n",
|
||||
" \"ttest_p\": np.nan,\n",
|
||||
" \"wilcoxon_p\": np.nan,\n",
|
||||
" \"cohen_dz\": np.nan,\n",
|
||||
" \"before_gap\": np.nan,\n",
|
||||
" \"after_gap\": np.nan,\n",
|
||||
" \"delta_gap\": np.nan,\n",
|
||||
" \"interpretation\": \"insufficient data\",\n",
|
||||
" \"caveat\": \"\",\n",
|
||||
" }\n",
|
||||
" if not (r0 and r1):\n",
|
||||
" return out\n",
|
||||
"\n",
|
||||
" v0 = metric_values(r0, \"auc_roc\")\n",
|
||||
" v1 = metric_values(r1, \"auc_roc\")\n",
|
||||
" n = min(len(v0), len(v1))\n",
|
||||
" v0, v1 = v0[:n], v1[:n]\n",
|
||||
" diff = v1 - v0\n",
|
||||
"\n",
|
||||
" out.update({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"before_auc\": float(np.mean(v0)),\n",
|
||||
" \"after_auc\": float(np.mean(v1)),\n",
|
||||
" \"delta_auc\": float(np.mean(diff)),\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
" if n >= 2:\n",
|
||||
" sd = float(np.std(diff, ddof=1))\n",
|
||||
" se = sd / math.sqrt(n) if sd > 0 else 0.0\n",
|
||||
" out[\"delta_ci95\"] = float(stats.t.ppf(0.975, df=n - 1) * se) if n > 1 else np.nan\n",
|
||||
" if sd > 0:\n",
|
||||
" out[\"cohen_dz\"] = float(np.mean(diff) / sd)\n",
|
||||
" out[\"ttest_p\"] = float(stats.ttest_rel(v1, v0).pvalue)\n",
|
||||
" if n >= 3 and not np.allclose(diff, 0):\n",
|
||||
" try:\n",
|
||||
" out[\"wilcoxon_p\"] = float(stats.wilcoxon(diff).pvalue)\n",
|
||||
" except ValueError:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" gaps0 = [best_epoch_gap(f) for f in r0.get(\"fold_results\", [])]\n",
|
||||
" gaps1 = [best_epoch_gap(f) for f in r1.get(\"fold_results\", [])]\n",
|
||||
" gaps0 = np.asarray([x for x in gaps0 if x is not None], dtype=float)\n",
|
||||
" gaps1 = np.asarray([x for x in gaps1 if x is not None], dtype=float)\n",
|
||||
" if len(gaps0) and len(gaps1):\n",
|
||||
" m = min(len(gaps0), len(gaps1))\n",
|
||||
" out[\"before_gap\"] = float(np.mean(gaps0[:m]))\n",
|
||||
" out[\"after_gap\"] = float(np.mean(gaps1[:m]))\n",
|
||||
" out[\"delta_gap\"] = float(np.mean(gaps1[:m] - gaps0[:m]))\n",
|
||||
"\n",
|
||||
" if question == \"source_holdout\":\n",
|
||||
" out[\"caveat\"] = \"Aggregate holdout-run AUC only; not held-out-source vs in-source AUC.\"\n",
|
||||
" if before != resolved_before or after != resolved_after:\n",
|
||||
" out[\"caveat\"] = (out[\"caveat\"] + \" \" if out[\"caveat\"] else \"\") + \"Uses Phase 1 fallback for missing p2a 128 log.\"\n",
|
||||
"\n",
|
||||
" if out[\"delta_auc\"] >= 0.01:\n",
|
||||
" out[\"interpretation\"] = \"meaningful improvement\"\n",
|
||||
" elif out[\"delta_auc\"] > 0.002:\n",
|
||||
" out[\"interpretation\"] = \"small improvement\"\n",
|
||||
" elif out[\"delta_auc\"] >= -0.002:\n",
|
||||
" out[\"interpretation\"] = \"negligible change\"\n",
|
||||
" elif out[\"delta_auc\"] > -0.01:\n",
|
||||
" out[\"interpretation\"] = \"small drop\"\n",
|
||||
" else:\n",
|
||||
" out[\"interpretation\"] = \"meaningful drop\"\n",
|
||||
" return out\n",
|
||||
"\n",
|
||||
"comparisons_df = pd.DataFrame([paired_comparison(*args) for args in PLANNED_COMPARISONS])\n",
|
||||
"\n",
|
||||
"# Benjamini-Hochberg correction across planned paired t-tests where available.\n",
|
||||
"valid_p = comparisons_df[\"ttest_p\"].notna()\n",
|
||||
"pvals = comparisons_df.loc[valid_p, \"ttest_p\"].to_numpy()\n",
|
||||
"qvals = np.full(len(comparisons_df), np.nan)\n",
|
||||
"if len(pvals):\n",
|
||||
" order = np.argsort(pvals)\n",
|
||||
" ranked = pvals[order]\n",
|
||||
" adjusted = np.empty_like(ranked)\n",
|
||||
" m = len(ranked)\n",
|
||||
" running = 1.0\n",
|
||||
" for i in range(m - 1, -1, -1):\n",
|
||||
" running = min(running, ranked[i] * m / (i + 1))\n",
|
||||
" adjusted[i] = running\n",
|
||||
" qvals[np.where(valid_p)[0][order]] = adjusted\n",
|
||||
"comparisons_df[\"bh_q\"] = qvals\n",
|
||||
"\n",
|
||||
"display(\n",
|
||||
" comparisons_df[[\n",
|
||||
" \"section\", \"model\", \"question\", \"contrast\", \"before_auc\", \"after_auc\", \"delta_auc\",\n",
|
||||
" \"delta_ci95\", \"ttest_p\", \"bh_q\", \"wilcoxon_p\", \"cohen_dz\", \"delta_gap\", \"interpretation\", \"caveat\",\n",
|
||||
" ]].style.format({\n",
|
||||
" \"before_auc\": \"{:.4f}\",\n",
|
||||
" \"after_auc\": \"{:.4f}\",\n",
|
||||
" \"delta_auc\": \"{:+.4f}\",\n",
|
||||
" \"delta_ci95\": \"\u00b1{:.4f}\",\n",
|
||||
" \"ttest_p\": \"{:.4f}\",\n",
|
||||
" \"bh_q\": \"{:.4f}\",\n",
|
||||
" \"wilcoxon_p\": \"{:.4f}\",\n",
|
||||
" \"cohen_dz\": \"{:+.2f}\",\n",
|
||||
" \"delta_gap\": \"{:+.4f}\",\n",
|
||||
" }).background_gradient(subset=[\"delta_auc\"], cmap=\"RdYlGn\", vmin=-0.06, vmax=0.06)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f20e5262",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Visual summary\n",
|
||||
"\n",
|
||||
"Two plots are most useful for decision-making:\n",
|
||||
"\n",
|
||||
"- Ranking all conditions by AUC shows the best observed configurations but can overstate duplicated/near-identical runs.\n",
|
||||
"- Paired delta plot shows the controlled effect of each preprocessing change and exposes uncertainty."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "42882c6a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_df = analysis_df.copy()\n",
|
||||
"plot_df[\"display_label\"] = plot_df[\"section\"] + \" | \" + plot_df[\"label\"]\n",
|
||||
"plot_df = plot_df.sort_values(\"auc_mean\", ascending=True)\n",
|
||||
"\n",
|
||||
"fig, ax = plt.subplots(figsize=(11, max(7, 0.35 * len(plot_df))))\n",
|
||||
"colors = {\"2A\": \"#4C78A8\", \"2B\": \"#F58518\", \"2C\": \"#54A24B\", \"2D\": \"#E45756\", \"2E\": \"#B279A2\"}\n",
|
||||
"ax.barh(\n",
|
||||
" plot_df[\"display_label\"],\n",
|
||||
" plot_df[\"auc_mean\"],\n",
|
||||
" xerr=plot_df[\"auc_std\"],\n",
|
||||
" color=[colors.get(s, \"#999999\") for s in plot_df[\"section\"]],\n",
|
||||
" alpha=0.85,\n",
|
||||
")\n",
|
||||
"ax.set_xlim(0.65, 1.0)\n",
|
||||
"ax.set_xlabel(\"Mean AUC across CV folds\")\n",
|
||||
"ax.set_title(\"Phase 2 Conditions Ranked by AUC\")\n",
|
||||
"ax.axvline(0.95, color=\"black\", linewidth=1, linestyle=\"--\", alpha=0.4)\n",
|
||||
"for y, (_, row) in enumerate(plot_df.iterrows()):\n",
|
||||
" ax.text(row[\"auc_mean\"] + 0.004, y, f\"{row['auc_mean']:.4f}\", va=\"center\", fontsize=9)\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIGURES_DIR / \"ranked_auc.png\", dpi=200, bbox_inches=\"tight\")\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"forest = comparisons_df.copy()\n",
|
||||
"forest[\"display\"] = forest[\"section\"] + \" \" + forest[\"model\"] + \" - \" + forest[\"contrast\"]\n",
|
||||
"forest = forest.iloc[::-1]\n",
|
||||
"fig, ax = plt.subplots(figsize=(11, max(6, 0.45 * len(forest))))\n",
|
||||
"y = np.arange(len(forest))\n",
|
||||
"ax.errorbar(\n",
|
||||
" forest[\"delta_auc\"], y,\n",
|
||||
" xerr=forest[\"delta_ci95\"],\n",
|
||||
" fmt=\"o\", color=\"#1F2937\", ecolor=\"#6B7280\", capsize=4,\n",
|
||||
")\n",
|
||||
"ax.axvline(0, color=\"black\", linewidth=1)\n",
|
||||
"ax.axvspan(-0.002, 0.002, color=\"#9CA3AF\", alpha=0.18, label=\"negligible band\")\n",
|
||||
"ax.set_yticks(y)\n",
|
||||
"ax.set_yticklabels(forest[\"display\"])\n",
|
||||
"ax.set_xlabel(\"Delta AUC (after - before), paired by fold\")\n",
|
||||
"ax.set_title(\"Planned Phase 2 Effect Estimates\")\n",
|
||||
"ax.legend(loc=\"lower right\")\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIGURES_DIR / \"planned_effects.png\", dpi=200, bbox_inches=\"tight\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e063cfc0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2A - Shortcut analysis\n",
|
||||
"\n",
|
||||
"Shortcut checks map to `p2a_*` configs:\n",
|
||||
"- `p2a_t1_original` vs `p2a_t2_real_norm` (normalization)\n",
|
||||
"- `p2a_t1_original` vs `p2a_t3_holdout_*` (source_holdout)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "910bd5bd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def comparison_subset(section: str, question: str | None = None) -> pd.DataFrame:\n",
|
||||
" df = comparisons_df[comparisons_df[\"section\"].eq(section)].copy()\n",
|
||||
" if question:\n",
|
||||
" df = df[df[\"question\"].eq(question)]\n",
|
||||
" return df\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def print_comparison_readout(df: pd.DataFrame) -> None:\n",
|
||||
" for _, row in df.iterrows():\n",
|
||||
" print(f\"{row['section']} {row['model']} - {row['contrast']}\")\n",
|
||||
" print(f\" AUC: {row['before_auc']:.4f} -> {row['after_auc']:.4f} ({row['delta_auc']:+.4f})\")\n",
|
||||
" print(f\" paired t p={row['ttest_p']:.4f}, BH q={row['bh_q']:.4f}, CI95 delta=\u00b1{row['delta_ci95']:.4f}\")\n",
|
||||
" print(f\" gap delta: {row['delta_gap']:+.4f}; interpretation: {row['interpretation']}\")\n",
|
||||
" if row['caveat']:\n",
|
||||
" print(f\" caveat: {row['caveat']}\")\n",
|
||||
" print()\n",
|
||||
"\n",
|
||||
"print_comparison_readout(comparison_subset(\"2B\", \"resolution\"))\n",
|
||||
"\n",
|
||||
"res_plot = comparison_subset(\"2B\", \"resolution\")\n",
|
||||
"fig, ax = plt.subplots(figsize=(8, 5))\n",
|
||||
"for _, row in res_plot.iterrows():\n",
|
||||
" r0, r1 = load_results(row[\"before\"]), load_results(row[\"after\"])\n",
|
||||
" v0, v1 = metric_values(r0), metric_values(r1)\n",
|
||||
" x = [0, 1]\n",
|
||||
" for a, b in zip(v0, v1):\n",
|
||||
" ax.plot(x, [a, b], color=\"#9CA3AF\", alpha=0.7)\n",
|
||||
" ax.plot(x, [v0.mean(), v1.mean()], marker=\"o\", linewidth=3, label=row[\"model\"])\n",
|
||||
"ax.set_xticks([0, 1])\n",
|
||||
"ax.set_xticklabels([\"128\", \"224\"])\n",
|
||||
"ax.set_ylabel(\"AUC\")\n",
|
||||
"ax.set_title(\"2B Resolution: Fold-Paired AUC\")\n",
|
||||
"ax.legend()\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIGURES_DIR / \"2b_resolution_paired.png\", dpi=200, bbox_inches=\"tight\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "530e8675",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2B - Resolution impact\n",
|
||||
"\n",
|
||||
"This section compares 128 vs 224 using `p2b_*_224` and Phase 1 baselines as explicit 128 fallbacks.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "13304d38",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_comparison_readout(comparison_subset(\"2C\", \"facecrop\"))\n",
|
||||
"\n",
|
||||
"face_df = canonical_runs_df[canonical_runs_df[\"section\"].eq(\"2C\")].copy()\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=False)\n",
|
||||
"for ax, model in zip(axes, [\"SimpleCNN\", \"ResNet18\"]):\n",
|
||||
" sub = face_df[face_df[\"model\"].eq(model)].sort_values(\"face_crop\")\n",
|
||||
" ax.bar(sub[\"condition\"], sub[\"auc_mean\"], yerr=sub[\"auc_std\"], color=[\"#D97706\", \"#059669\"], alpha=0.85, capsize=5)\n",
|
||||
" ax.set_title(model)\n",
|
||||
" ax.set_ylim(0.70 if model == \"SimpleCNN\" else 0.94, 0.99)\n",
|
||||
" ax.set_ylabel(\"AUC\")\n",
|
||||
" ax.tick_params(axis=\"x\", rotation=20)\n",
|
||||
"fig.suptitle(\"2C Facecrop Impact\")\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIGURES_DIR / \"2c_facecrop.png\", dpi=200, bbox_inches=\"tight\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8702d10d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2C - Facecrop impact\n",
|
||||
"\n",
|
||||
"This section compares `p2c_*_facecrop` against the matching `p2b_*_224` no-facecrop baselines.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec5e03ef",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_comparison_readout(comparison_subset(\"2A\"))\n\n# Inspect whether logs contain the per-source data needed by v2.md.\nsource_audit = []\nfor run in [\"p2a_t1_original\", \"p2a_t3_holdout_text2img\", \"p2a_t3_holdout_inpainting\", \"p2a_t3_holdout_insight\"]:\n results = load_results(run)\n has_per_source = False\n has_records = False\n example_keys = []\n if results:\n for fold in results.get(\"fold_results\", []):\n tm = fold.get(\"test_metrics\", {})\n example_keys = sorted(tm.keys())\n has_per_source = has_per_source or any(k in tm for k in [\"per_source\", \"per_source_metrics\", \"pairwise_source_metrics\", \"source_metrics\", \"pair_metrics\"])\n has_records = has_records or any(k in fold for k in [\"records\", \"predictions\", \"test_records\"])\n source_audit.append({\n \"run\": run,\n \"has_per_source_metrics\": has_per_source,\n \"has_prediction_records\": has_records,\n \"test_metric_keys\": example_keys,\n })\nsource_audit_df = pd.DataFrame(source_audit)\ndisplay(source_audit_df)\n\nholdout_runs = [\"p2a_t1_original\", \"p2a_t3_holdout_text2img\", \"p2a_t3_holdout_inpainting\", \"p2a_t3_holdout_insight\"]\nholdout_df = canonical_runs_df[canonical_runs_df[\"run\"].isin(holdout_runs)].copy()\nholdout_df[\"delta_vs_all_source\"] = holdout_df[\"auc_mean\"] - float(holdout_df.loc[holdout_df[\"run\"].eq(\"p2a_t1_original\"), \"auc_mean\"].iloc[0])\n\nfig, ax = plt.subplots(figsize=(9, 5))\nax.bar(holdout_df[\"label\"], holdout_df[\"auc_mean\"], yerr=holdout_df[\"auc_std\"], color=\"#54A24B\", alpha=0.85, capsize=5)\nax.set_ylim(0.88, 0.99)\nax.set_ylabel(\"Aggregate AUC\")\nax.set_title(\"2C Source Holdout Proxy: Aggregate Test AUC\")\nax.tick_params(axis=\"x\", rotation=20)\nfor i, (_, row) in enumerate(holdout_df.iterrows()):\n ax.text(i, row[\"auc_mean\"] + 0.004, f\"{row['delta_vs_all_source']:+.3f}\", ha=\"center\", fontsize=9)\nfig.tight_layout()\nfig.savefig(FIGURES_DIR / \"2c_holdout_proxy.png\", dpi=200, bbox_inches=\"tight\")\nplt.show()\n\nprint(\"Geometry diagnostic evidence:\")\ngeometry_keys = []\nfor run in [\"p2a_t1_original\", \"p2a_t2_real_norm\"]:\n results = load_results(run)\n cfg = (results or {}).get(\"config\", {})\n geometry_keys.append({\n \"run\": run,\n \"config_geometry_condition\": cfg.get(\"geometry_condition\"),\n \"has_matched_geometry_metric\": any(\n \"geometry\" in str(k).lower() or \"matched\" in str(k).lower()\n for fold in (results or {}).get(\"fold_results\", [])\n for k in fold.get(\"test_metrics\", {}).keys()\n ),\n })\ndisplay(pd.DataFrame(geometry_keys))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2c3b8812",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2D / 2E - Augmentation impact and test-set integrity\n",
|
||||
"\n",
|
||||
"The augmentation question has two parts:\n",
|
||||
"\n",
|
||||
"- Does light augmentation help at 224 without facecrop?\n",
|
||||
"- Does it help once facecrop is enabled?\n",
|
||||
"\n",
|
||||
"The implementation also needs to guarantee that validation/test evaluation is not stochastic. The preprocessing pipeline keeps stochastic operations behind `self.train`, so `train=False` disables them even if augmentation settings exist."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f11c3257",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"2D (p2d): augmentation without facecrop\")\n",
|
||||
"print_comparison_readout(comparison_subset(\"2D\", \"augmentation\"))\n",
|
||||
"print(\"2E (p2e): augmentation with facecrop\")\n",
|
||||
"print_comparison_readout(comparison_subset(\"2E\", \"facecrop + augmentation\"))\n",
|
||||
"\n",
|
||||
"aug_sections = comparisons_df[comparisons_df[\"section\"].isin([\"2D\", \"2E\"])].copy()\n",
|
||||
"fig, ax = plt.subplots(figsize=(9, 5))\n",
|
||||
"labels = aug_sections[\"section\"] + \" \" + aug_sections[\"model\"]\n",
|
||||
"ax.bar(labels, aug_sections[\"delta_auc\"], yerr=aug_sections[\"delta_ci95\"], color=[\"#E45756\" if d < 0 else \"#059669\" for d in aug_sections[\"delta_auc\"]], alpha=0.85, capsize=5)\n",
|
||||
"ax.axhline(0, color=\"black\", linewidth=1)\n",
|
||||
"ax.set_ylabel(\"Delta AUC from adding augmentation\")\n",
|
||||
"ax.set_title(\"Augmentation Effects Across Facecrop Conditions\")\n",
|
||||
"ax.tick_params(axis=\"x\", rotation=20)\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIGURES_DIR / \"2d_2e_augmentation_effects.png\", dpi=200, bbox_inches=\"tight\")\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"# Static and behavioral audit of eval stochasticity.\n",
|
||||
"try:\n",
|
||||
" import inspect\n",
|
||||
" from src.preprocessing.pipeline import DFFImagePipeline\n",
|
||||
" from src.evaluation import evaluate as evaluate_module\n",
|
||||
"\n",
|
||||
" pipeline_src = inspect.getsource(DFFImagePipeline)\n",
|
||||
" build_transforms_src = inspect.getsource(evaluate_module.build_transforms)\n",
|
||||
" stochastic_guards = {\n",
|
||||
" \"flip_guarded_by_train\": \"if self.train and random.random() < self.hflip_p\" in pipeline_src,\n",
|
||||
" \"rotate_guarded_by_train\": \"if self.train and self.rotation_degrees > 0\" in pipeline_src,\n",
|
||||
" \"color_jitter_returns_when_not_train\": \"if not self.train:\" in pipeline_src,\n",
|
||||
" \"blur_guarded_by_train\": \"if self.train and random.random() < self.blur_p\" in pipeline_src,\n",
|
||||
" \"jpeg_guarded_by_train\": \"if self.train and random.random() < self.jpeg_p\" in pipeline_src,\n",
|
||||
" \"erase_guarded_by_train\": \"if self.train and random.random() < self.erase_p\" in pipeline_src,\n",
|
||||
" \"noise_guarded_by_train\": \"if self.train and random.random() < self.noise_p\" in pipeline_src,\n",
|
||||
" \"cv_transform_uses_train_flag\": \"get_transforms(train=train\" in build_transforms_src,\n",
|
||||
" }\n",
|
||||
" display(pd.DataFrame([stochastic_guards]).T.rename(columns={0: \"passes\"}))\n",
|
||||
"except Exception as exc:\n",
|
||||
" print(f\"Could not run transform audit: {exc}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "02e47658",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Decision synthesis\n",
|
||||
"\n",
|
||||
"This section converts the evidence into Phase 3 settings. It intentionally distinguishes a recommendation from a claim:\n",
|
||||
"\n",
|
||||
"- Recommendation: choose the setting that is best supported for the next experiment.\n",
|
||||
"- Claim: what the current evidence proves. Some Phase 2C claims remain incomplete without per-source or matched-geometry outputs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7034443c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_delta(question: str, model: str | None = None, section: str | None = None) -> pd.DataFrame:\n",
|
||||
" df = comparisons_df[comparisons_df[\"question\"].eq(question)].copy()\n",
|
||||
" if model:\n",
|
||||
" df = df[df[\"model\"].eq(model)]\n",
|
||||
" if section:\n",
|
||||
" df = df[df[\"section\"].eq(section)]\n",
|
||||
" return df\n",
|
||||
"\n",
|
||||
"resolution_resnet = get_delta(\"resolution\", \"ResNet18\").iloc[0]\n",
|
||||
"facecrop_resnet = get_delta(\"facecrop\", \"ResNet18\").iloc[0]\n",
|
||||
"facecrop_simple = get_delta(\"facecrop\", \"SimpleCNN\").iloc[0]\n",
|
||||
"aug_no_crop_resnet = get_delta(\"augmentation\", \"ResNet18\").iloc[0]\n",
|
||||
"aug_no_crop_simple = get_delta(\"augmentation\", \"SimpleCNN\").iloc[0]\n",
|
||||
"aug_crop_resnet = get_delta(\"facecrop + augmentation\", \"ResNet18\").iloc[0]\n",
|
||||
"aug_crop_simple = get_delta(\"facecrop + augmentation\", \"SimpleCNN\").iloc[0]\n",
|
||||
"norm = get_delta(\"normalization\", \"ResNet18\").iloc[0]\n",
|
||||
"\n",
|
||||
"recommendations = [\n",
|
||||
" {\n",
|
||||
" \"choice\": \"resolution\",\n",
|
||||
" \"recommendation\": \"224x224\",\n",
|
||||
" \"evidence\": f\"ResNet18 delta AUC {resolution_resnet.delta_auc:+.4f}; SimpleCNN does not determine Phase 3 capacity.\",\n",
|
||||
" \"confidence\": \"high\" if resolution_resnet.delta_auc > 0.02 else \"medium\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"choice\": \"facecrop\",\n",
|
||||
" \"recommendation\": \"use facecrop\",\n",
|
||||
" \"evidence\": f\"Small positive deltas for both models: SimpleCNN {facecrop_simple.delta_auc:+.4f}, ResNet18 {facecrop_resnet.delta_auc:+.4f}.\",\n",
|
||||
" \"confidence\": \"medium\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"choice\": \"augmentation\",\n",
|
||||
" \"recommendation\": \"do not use light augmentation for Phase 3 at 20% data\",\n",
|
||||
" \"evidence\": f\"SimpleCNN drops {aug_no_crop_simple.delta_auc:+.4f} without facecrop and {aug_crop_simple.delta_auc:+.4f} with facecrop; ResNet18 is neutral/slightly mixed ({aug_no_crop_resnet.delta_auc:+.4f}, {aug_crop_resnet.delta_auc:+.4f}).\",\n",
|
||||
" \"confidence\": \"high for SimpleCNN, medium for ResNet18\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"choice\": \"normalization\",\n",
|
||||
" \"recommendation\": \"ImageNet normalization\",\n",
|
||||
" \"evidence\": f\"Real-train-only normalization delta AUC {norm.delta_auc:+.4f}; no useful gain and less standard for pretrained ResNet.\",\n",
|
||||
" \"confidence\": \"medium\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"choice\": \"shortcut/source claims\",\n",
|
||||
" \"recommendation\": \"do not overclaim; add per-source or prediction exports before final report\",\n",
|
||||
" \"evidence\": \"Current CV logs lack held-out-source vs in-source AUC and matched-geometry test metrics.\",\n",
|
||||
" \"confidence\": \"high\",\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"recommendations_df = pd.DataFrame(recommendations)\n",
|
||||
"display(recommendations_df)\n",
|
||||
"\n",
|
||||
"summary = {\n",
|
||||
" \"phase\": \"phase2\",\n",
|
||||
" \"source_documents\": [\"classifier/v2.md\", \"classifier/impl.md\"],\n",
|
||||
" \"artifact_counts\": {\n",
|
||||
" \"canonical_runs\": int(len(canonical_runs_df)),\n",
|
||||
" \"loaded_canonical_runs\": int(canonical_runs_df[\"log_status\"].isin([\"present\", \"fallback\"]).sum()),\n",
|
||||
" \"fallback_runs_used\": {k: v for k, v in RUN_ALIASES.items() if resolve_run(k) != k},\n",
|
||||
" },\n",
|
||||
" \"recommendations\": recommendations,\n",
|
||||
" \"planned_comparisons\": comparisons_df.replace({np.nan: None}).to_dict(orient=\"records\"),\n",
|
||||
" \"known_gaps\": [\n",
|
||||
" \"Dedicated p2a_*_128 logs are absent/skipped; Phase 1 baselines are used as fallbacks.\",\n",
|
||||
" \"Source holdout logs do not include prediction-level or per-source metrics, so held-out-source AUC vs in-source AUC cannot be computed.\",\n",
|
||||
" \"No matched-geometry evaluation metric is present in p2c logs, so geometry shortcut analysis is incomplete.\",\n",
|
||||
" ],\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"summary_path = ANALYSIS_DIR / \"phase2_analysis_summary.json\"\n",
|
||||
"with summary_path.open(\"w\") as f:\n",
|
||||
" json.dump(summary, f, indent=2)\n",
|
||||
"\n",
|
||||
"print(f\"Saved summary: {summary_path.relative_to(PROJECT_ROOT)}\")\n",
|
||||
"print(f\"Saved figures: {FIGURES_DIR.relative_to(PROJECT_ROOT)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5a337f73",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Report-ready conclusion\n",
|
||||
"\n",
|
||||
"The strongest Phase 2 result is the resolution effect for ResNet18: moving to 224x224 substantially improves AUC under the controlled CV protocol. Face cropping gives a small positive effect and is reasonable to carry forward, especially because it aligns the model with face evidence rather than background context. Light augmentation is not supported at this 20% data setting: it strongly hurts SimpleCNN and provides no reliable gain for ResNet18, with or without face cropping. ImageNet normalization remains preferable because real-train-only normalization does not improve AUC and is less aligned with pretrained ResNet expectations.\n",
|
||||
"\n",
|
||||
"Recommended Phase 3 preprocessing: **224x224, facecrop enabled, no light augmentation, ImageNet normalization**.\n",
|
||||
"\n",
|
||||
"Limitations to fix before the final report: export prediction-level records or per-source pairwise metrics for source holdout, and add the matched-geometry evaluation required by the shortcut-analysis plan. Without those artifacts, Phase 2C can only support a limited shortcut analysis."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "drl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user