Clean state
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
+48
@@ -0,0 +1,48 @@
|
||||
# Pipeline
|
||||
pipeline/.env
|
||||
|
||||
# Data
|
||||
data/*
|
||||
!data/.gitkeep
|
||||
|
||||
# Cropped faces
|
||||
#cropped_classifier.zip
|
||||
#cropped_generator.zip
|
||||
cropped/*
|
||||
!cropped/.gitkeep
|
||||
!cropped/classifier/.gitkeep
|
||||
!cropped/classifier/README.md
|
||||
!cropped/generator/.gitkeep
|
||||
!cropped/generator/README.md
|
||||
|
||||
# Classifier outputs
|
||||
classifier/outputs/*
|
||||
# Analysis
|
||||
!classifier/outputs/analysis/
|
||||
!classifier/outputs/analysis/**
|
||||
# Figures
|
||||
!classifier/outputs/figures/
|
||||
!classifier/outputs/figures/**
|
||||
# Models
|
||||
!classifier/outputs/models/
|
||||
classifier/outputs/models/*
|
||||
!classifier/outputs/models/.gitkeep
|
||||
!classifier/outputs/models/*.pt
|
||||
# Logs
|
||||
!classifier/outputs/logs/
|
||||
classifier/outputs/logs/*
|
||||
!classifier/outputs/logs/.gitkeep
|
||||
!classifier/outputs/logs/*.json
|
||||
# Pipeline
|
||||
!classifier/outputs/pipeline/
|
||||
classifier/outputs/pipeline/*
|
||||
!classifier/outputs/pipeline/.gitkeep
|
||||
!classifier/outputs/pipeline/*.json
|
||||
|
||||
# Generator outputs (all local-only)
|
||||
generator/outputs/*
|
||||
|
||||
# Python
|
||||
.venv/
|
||||
.ipynb_checkpoints/
|
||||
__pycache__/
|
||||
@@ -0,0 +1,132 @@
|
||||
# DRL_PROJ — DeepFake Detection
|
||||
|
||||
Deep learning project for binary deepfake detection on the DeepFakeFace dataset.
|
||||
|
||||
## Project structure
|
||||
|
||||
```
|
||||
DRL_PROJ/
|
||||
classifier/ ← discriminative model (real vs. fake classifier)
|
||||
src/ ← model definitions, training, evaluation, preprocessing
|
||||
configs/ ← experiment configs organised by phase
|
||||
phase1/ ← baseline models (SimpleCNN, ResNet18)
|
||||
phase2/ ← architecture sweep (ResNet variants, face-crop)
|
||||
phase3/ ← EfficientNet, ViT, frequency-aware training
|
||||
phase4/ ← ensemble strategies
|
||||
tools/ ← analyse.py, ensemble.py, inference.py, facecrop.py
|
||||
notebooks/ ← EDA, preprocessing, evaluation, GradCAM
|
||||
outputs/ ← models, logs, figures (gitignored except .pt/.json)
|
||||
run.py ← main training entry point
|
||||
generator/ ← generative model (GAN / VAE / diffusion) — in progress
|
||||
pipeline/ ← Vast.ai ephemeral GPU orchestration
|
||||
data/ ← dataset root (gitignored)
|
||||
cropped/ ← MTCNN pre-cropped faces (gitignored)
|
||||
classifier/ ← bbox crops for the classifier
|
||||
generator/ ← landmark-aligned crops for the generator
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
Create a local environment when you want to run the code directly on a machine you control:
|
||||
|
||||
```bash
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
python -m pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Local Training
|
||||
|
||||
```bash
|
||||
python3 classifier/run.py classifier/configs/phase2/p2_resnet18_facecrop.json
|
||||
python3 classifier/run.py classifier/configs/phase3/p3_efficientnet_b0.json
|
||||
```
|
||||
|
||||
## Ephemeral Vast.ai Pipeline
|
||||
|
||||
The deployment/orchestration path now lives under [`pipeline/`](/run/host/mnt/shared/UP/DRL/DRL_PROJ/pipeline/README.md).
|
||||
|
||||
One-time setup:
|
||||
|
||||
```bash
|
||||
cat > pipeline/.env <<'EOF'
|
||||
VAST_API_KEY=<your-api-key>
|
||||
VAST_SSH_PRIVATE_KEY=/home/your-user/.ssh/id_ed25519
|
||||
EOF
|
||||
```
|
||||
|
||||
End-to-end ephemeral run:
|
||||
|
||||
```bash
|
||||
python3 -m pipeline run classifier/configs/phase2/p2_resnet18_facecrop.json --upload-data
|
||||
```
|
||||
|
||||
Interactive offer selection:
|
||||
|
||||
```bash
|
||||
python3 -m pipeline offers --select-offer
|
||||
```
|
||||
|
||||
You can override the ranking mode per run:
|
||||
|
||||
```bash
|
||||
python3 -m pipeline offers --sort price
|
||||
python3 -m pipeline offers --sort performance
|
||||
python3 -m pipeline offers --sort performance --price 0.14
|
||||
```
|
||||
|
||||
You can also filter by region:
|
||||
|
||||
```bash
|
||||
python3 -m pipeline offers --select-offer --region europe
|
||||
python3 -m pipeline offers --select-offer --region Portugal
|
||||
python3 -m pipeline offers --select-offer --region US
|
||||
python3 -m pipeline offers --select-offer --region europe --price 0.14
|
||||
```
|
||||
|
||||
To inspect which region strings are currently available from the search results:
|
||||
|
||||
```bash
|
||||
python3 -m pipeline offers --list-regions
|
||||
```
|
||||
|
||||
That command:
|
||||
- ensures your SSH public key is registered with Vast.ai
|
||||
- searches offers using the filters in `pipeline/defaults/vast.json`
|
||||
- creates an instance
|
||||
- waits for SSH readiness
|
||||
- syncs the repo
|
||||
- uploads `data/` when `--upload-data` is set
|
||||
- runs `python3 classifier/run.py ...`
|
||||
- downloads `classifier/outputs/`
|
||||
- for generator runs, rsyncs `generator/outputs/` back every 50 epochs and again at completion
|
||||
- destroys the instance automatically unless `--keep-on-failure` is set
|
||||
|
||||
Useful commands:
|
||||
|
||||
```bash
|
||||
python3 -m pipeline up
|
||||
python3 -m pipeline status <instance_id>
|
||||
python3 -m pipeline down <instance_id>
|
||||
```
|
||||
|
||||
To override the default Vast search/runtime settings, copy `pipeline/defaults/vast.json`, edit it, and pass:
|
||||
|
||||
```bash
|
||||
python3 -m pipeline run classifier/configs/phase3/p3_efficientnet_b0.json --pipeline-config /path/to/vast.override.json
|
||||
```
|
||||
|
||||
The default policy in `pipeline/defaults/vast.json` now targets:
|
||||
- `1x` GPU
|
||||
- `RTX 3090` or `RTX 3090 Ti`
|
||||
- `<= $0.20/hour`
|
||||
- sorted by `dlperf` descending
|
||||
- uses `vastai/pytorch:latest` as the default image
|
||||
|
||||
## Diagnostics
|
||||
|
||||
```bash
|
||||
python3 classifier/tools/analyze.py classifier/configs/phase2/p2_resnet18_facecrop.json
|
||||
python3 classifier/tools/ensemble.py classifier/configs/phase4/p4_ensemble.json
|
||||
```
|
||||
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"run_name": "p1_resnet18_baseline",
|
||||
"backbone": "resnet18",
|
||||
"pretrained": true,
|
||||
"epochs": 15,
|
||||
"image_size": 128,
|
||||
"subsample": 0.2,
|
||||
"augment": false,
|
||||
"data_dir": "data"
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"run_name": "p1_simplecnn_baseline",
|
||||
"backbone": "simple_cnn",
|
||||
"cnn_preset": "medium",
|
||||
"dropout": 0.0,
|
||||
"epochs": 15,
|
||||
"image_size": 128,
|
||||
"subsample": 0.2,
|
||||
"augment": false,
|
||||
"data_dir": "data"
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"run_name": "p2a_t1_original",
|
||||
"backbone": "resnet18",
|
||||
"pretrained": true,
|
||||
"epochs": 15,
|
||||
"image_size": 224,
|
||||
"subsample": 0.2,
|
||||
"augment": false,
|
||||
"data_dir": "data",
|
||||
"normalization": "imagenet"
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "p2a_t1_original.json",
|
||||
"run_name": "p2a_t2_real_norm",
|
||||
"normalization": "real_norm"
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"extends": "p2a_t1_original.json",
|
||||
"run_name": "p2a_t3_holdout_inpainting",
|
||||
"train_sources": [
|
||||
"wiki",
|
||||
"text2img",
|
||||
"insight"
|
||||
],
|
||||
"eval_sources": [
|
||||
"wiki",
|
||||
"text2img",
|
||||
"insight",
|
||||
"inpainting"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"extends": "p2a_t1_original.json",
|
||||
"run_name": "p2a_t3_holdout_insight",
|
||||
"train_sources": [
|
||||
"wiki",
|
||||
"text2img",
|
||||
"inpainting"
|
||||
],
|
||||
"eval_sources": [
|
||||
"wiki",
|
||||
"text2img",
|
||||
"inpainting",
|
||||
"insight"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"extends": "p2a_t1_original.json",
|
||||
"run_name": "p2a_t3_holdout_text2img",
|
||||
"train_sources": [
|
||||
"wiki",
|
||||
"inpainting",
|
||||
"insight"
|
||||
],
|
||||
"eval_sources": [
|
||||
"wiki",
|
||||
"inpainting",
|
||||
"insight",
|
||||
"text2img"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"run_name": "p2b_resnet18_224",
|
||||
"backbone": "resnet18",
|
||||
"pretrained": true,
|
||||
"epochs": 15,
|
||||
"image_size": 224,
|
||||
"subsample": 0.2,
|
||||
"augment": false,
|
||||
"data_dir": "data"
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"run_name": "p2b_simplecnn_224",
|
||||
"backbone": "simple_cnn",
|
||||
"cnn_preset": "medium",
|
||||
"dropout": 0.0,
|
||||
"epochs": 15,
|
||||
"image_size": 224,
|
||||
"subsample": 0.2,
|
||||
"augment": false,
|
||||
"data_dir": "data"
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"run_name": "p2c_resnet18_facecrop",
|
||||
"backbone": "resnet18",
|
||||
"pretrained": true,
|
||||
"epochs": 15,
|
||||
"image_size": 224,
|
||||
"subsample": 0.2,
|
||||
"augment": false,
|
||||
"data_dir": "cropped/classifier"
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"run_name": "p2c_simplecnn_facecrop",
|
||||
"backbone": "simple_cnn",
|
||||
"cnn_preset": "medium",
|
||||
"dropout": 0.0,
|
||||
"epochs": 15,
|
||||
"image_size": 224,
|
||||
"subsample": 0.2,
|
||||
"augment": false,
|
||||
"data_dir": "cropped/classifier"
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"run_name": "p2d_resnet18_aug",
|
||||
"backbone": "resnet18",
|
||||
"pretrained": true,
|
||||
"epochs": 15,
|
||||
"image_size": 224,
|
||||
"subsample": 0.2,
|
||||
"data_dir": "data",
|
||||
"augment": {
|
||||
"hflip_p": 0.5,
|
||||
"rotation_degrees": 10,
|
||||
"brightness": 0.2,
|
||||
"contrast": 0.2,
|
||||
"saturation": 0.1,
|
||||
"hue": 0.02,
|
||||
"grayscale_p": 0.1,
|
||||
"blur_p": 0.1,
|
||||
"erase_p": 0.2,
|
||||
"noise_p": 0.3,
|
||||
"noise_std": 0.04
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"run_name": "p2d_simplecnn_aug",
|
||||
"backbone": "simple_cnn",
|
||||
"cnn_preset": "medium",
|
||||
"dropout": 0.0,
|
||||
"epochs": 15,
|
||||
"image_size": 224,
|
||||
"subsample": 0.2,
|
||||
"data_dir": "data",
|
||||
"augment": {
|
||||
"hflip_p": 0.5,
|
||||
"rotation_degrees": 10,
|
||||
"brightness": 0.2,
|
||||
"contrast": 0.2,
|
||||
"saturation": 0.1,
|
||||
"hue": 0.02,
|
||||
"grayscale_p": 0.1,
|
||||
"blur_p": 0.1,
|
||||
"erase_p": 0.2,
|
||||
"noise_p": 0.3,
|
||||
"noise_std": 0.04
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"run_name": "p2e_resnet18_facecrop_aug",
|
||||
"backbone": "resnet18",
|
||||
"pretrained": true,
|
||||
"epochs": 15,
|
||||
"image_size": 224,
|
||||
"subsample": 0.2,
|
||||
"data_dir": "cropped/classifier",
|
||||
"augment": {
|
||||
"hflip_p": 0.5,
|
||||
"rotation_degrees": 10,
|
||||
"brightness": 0.2,
|
||||
"contrast": 0.2,
|
||||
"saturation": 0.1,
|
||||
"hue": 0.02,
|
||||
"grayscale_p": 0.1,
|
||||
"blur_p": 0.1,
|
||||
"erase_p": 0.2,
|
||||
"noise_p": 0.3,
|
||||
"noise_std": 0.04
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"run_name": "p2e_simplecnn_facecrop_aug",
|
||||
"backbone": "simple_cnn",
|
||||
"cnn_preset": "medium",
|
||||
"dropout": 0.0,
|
||||
"epochs": 15,
|
||||
"image_size": 224,
|
||||
"subsample": 0.2,
|
||||
"data_dir": "cropped/classifier",
|
||||
"augment": {
|
||||
"hflip_p": 0.5,
|
||||
"rotation_degrees": 10,
|
||||
"brightness": 0.2,
|
||||
"contrast": 0.2,
|
||||
"saturation": 0.1,
|
||||
"hue": 0.02,
|
||||
"grayscale_p": 0.1,
|
||||
"blur_p": 0.1,
|
||||
"erase_p": 0.2,
|
||||
"noise_p": 0.3,
|
||||
"noise_std": 0.04
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"seed": 42,
|
||||
"cv_folds": 5,
|
||||
"batch_size": 32,
|
||||
"num_workers": 4,
|
||||
"early_stopping_patience": 5,
|
||||
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 1e-4,
|
||||
"T_max": 15,
|
||||
|
||||
"data_dir": "data"
|
||||
}
|
||||
@@ -0,0 +1,351 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-00",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 01 — EDA\n",
|
||||
"\n",
|
||||
"Explore DeepFakeFace (DFF) data quality before training: composition, source distribution, image properties, and split safety.\n",
|
||||
"\n",
|
||||
"**Sections:**\n",
|
||||
"1. Dataset composition and label balance\n",
|
||||
"2. Visual sanity-check samples\n",
|
||||
"3. Image dimension profile\n",
|
||||
"4. Per-source color statistics\n",
|
||||
"5. CV split and leakage sanity check\n",
|
||||
"6. Observations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-01",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.insert(0, '..')\n",
|
||||
"\n",
|
||||
"import random\n",
|
||||
"from collections import Counter\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"from src.data import DFFDataset, SOURCES, get_splits\n",
|
||||
"\n",
|
||||
"DATA_DIR = Path('../../data')\n",
|
||||
"FIG_DIR = Path('../outputs/figures')\n",
|
||||
"FIG_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"SEED = 42\n",
|
||||
"random.seed(SEED)\n",
|
||||
"np.random.seed(SEED)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-02",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Dataset composition and label balance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-03",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"full_ds = DFFDataset(DATA_DIR)\n",
|
||||
"label_counts = full_ds.label_counts()\n",
|
||||
"\n",
|
||||
"print(f\"Total images : {len(full_ds):,}\")\n",
|
||||
"print(f\" Real (label=0) : {label_counts[0]:,}\")\n",
|
||||
"print(f\" Fake (label=1) : {label_counts[1]:,}\")\n",
|
||||
"print(f\" Fake:real ratio : {label_counts[1] / label_counts[0]:.2f}x\\n\")\n",
|
||||
"\n",
|
||||
"source_info = []\n",
|
||||
"for source, label in SOURCES.items():\n",
|
||||
" ds = DFFDataset(DATA_DIR, sources=[source])\n",
|
||||
" source_info.append((source, len(ds), label))\n",
|
||||
" tag = 'real' if label == 0 else 'fake'\n",
|
||||
" print(f\" {source:12s} n={len(ds):6,} label={label} ({tag})\")\n",
|
||||
"\n",
|
||||
"# Identity-level sanity check: each basename should appear in every source.\n",
|
||||
"basename_counts = Counter(path.name for path, _ in full_ds.samples)\n",
|
||||
"presence_hist = Counter(basename_counts.values())\n",
|
||||
"\n",
|
||||
"print(\"\\nIdentity (basename) presence across sources:\")\n",
|
||||
"for n_sources, count in sorted(presence_hist.items()):\n",
|
||||
" print(f\" present in {n_sources} source(s): {count:,} identities\")\n",
|
||||
"\n",
|
||||
"incomplete = sum(v for k, v in presence_hist.items() if k < len(SOURCES))\n",
|
||||
"print(f\" complete in all {len(SOURCES)} sources: {presence_hist.get(len(SOURCES), 0):,}\")\n",
|
||||
"print(f\" incomplete identities : {incomplete:,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-04",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
|
||||
"\n",
|
||||
"# Overall class balance\n",
|
||||
"class_names = ['Real (wiki)', 'Fake (all 3)']\n",
|
||||
"class_counts = [label_counts[0], label_counts[1]]\n",
|
||||
"bars = ax1.bar(class_names, class_counts, color=['#2196F3', '#F44336'], width=0.5)\n",
|
||||
"ax1.set_title('Overall Class Balance', fontsize=13)\n",
|
||||
"ax1.set_ylabel('Images')\n",
|
||||
"ax1.set_ylim(0, max(class_counts) * 1.15)\n",
|
||||
"for bar, v in zip(bars, class_counts):\n",
|
||||
" ax1.text(bar.get_x() + bar.get_width() / 2, v + 300,\n",
|
||||
" f'{v:,}', ha='center', fontsize=11)\n",
|
||||
"\n",
|
||||
"# Per-source breakdown\n",
|
||||
"src_names = [s for s, _, _ in source_info]\n",
|
||||
"src_counts = [n for _, n, _ in source_info]\n",
|
||||
"colors = ['#2196F3', '#FF9800', '#9C27B0', '#4CAF50']\n",
|
||||
"bars2 = ax2.bar(src_names, src_counts, color=colors, width=0.5)\n",
|
||||
"ax2.set_title('Images per Source', fontsize=13)\n",
|
||||
"ax2.set_ylabel('Images')\n",
|
||||
"ax2.set_ylim(0, max(src_counts) * 1.15)\n",
|
||||
"for bar, v in zip(bars2, src_counts):\n",
|
||||
" ax2.text(bar.get_x() + bar.get_width() / 2, v + 300,\n",
|
||||
" f'{v:,}', ha='center', fontsize=11)\n",
|
||||
"\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIG_DIR / 'class_balance.png', dpi=120, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-05",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Visual sanity-check samples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-06",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"N_COLS = 6\n",
|
||||
"fig, axes = plt.subplots(len(SOURCES), N_COLS, figsize=(18, 12))\n",
|
||||
"fig.suptitle('Sample images — 6 per source', fontsize=14)\n",
|
||||
"\n",
|
||||
"for row, (source, label) in enumerate(SOURCES.items()):\n",
|
||||
" ds_src = DFFDataset(DATA_DIR, sources=[source])\n",
|
||||
" indices = random.sample(range(len(ds_src)), N_COLS)\n",
|
||||
" for col, idx in enumerate(indices):\n",
|
||||
" path, _ = ds_src.samples[idx]\n",
|
||||
" img = Image.open(path).convert('RGB').resize((128, 128))\n",
|
||||
" axes[row, col].imshow(img)\n",
|
||||
" axes[row, col].axis('off')\n",
|
||||
" tag = 'real' if label == 0 else 'fake'\n",
|
||||
" axes[row, 0].set_ylabel(f'{source}\\n({tag})', fontsize=10)\n",
|
||||
"\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIG_DIR / 'sample_images.png', dpi=100, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-07",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Image dimension profile"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-08",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sample_paths = [p for p, _ in random.sample(full_ds.samples, min(2000, len(full_ds)))]\n",
|
||||
"sizes = Counter(Image.open(p).size for p in sample_paths)\n",
|
||||
"\n",
|
||||
"print('Most common image dimensions (W x H):')\n",
|
||||
"for (w, h), count in sizes.most_common(10):\n",
|
||||
" pct = count / len(sample_paths)\n",
|
||||
" print(f' {w:4d} x {h:4d} — {count:4d} samples ({pct:.1%})')\n",
|
||||
"\n",
|
||||
"widths = [w for (w, _) in sizes.elements()]\n",
|
||||
"heights = [h for (_, h) in sizes.elements()]\n",
|
||||
"square = sum(1 for w, h in zip(widths, heights) if w == h)\n",
|
||||
"print(f'\\nWidth range: {min(widths)}–{max(widths)} mean={np.mean(widths):.0f}')\n",
|
||||
"print(f'Height range: {min(heights)}–{max(heights)} mean={np.mean(heights):.0f}')\n",
|
||||
"print(f'Square images: {square}/{len(widths)} ({square / len(widths):.1%})')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-11",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Per-source color statistics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eda-12",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print('Sampling per-source colour statistics (sampling 300 images per source)...')\n",
|
||||
"N_SAMPLES = 300\n",
|
||||
"CH_NAMES = ['R', 'G', 'B']\n",
|
||||
"\n",
|
||||
"source_means, source_stds = {}, {}\n",
|
||||
"for source in SOURCES:\n",
|
||||
" ds_src = DFFDataset(DATA_DIR, sources=[source])\n",
|
||||
" idxs = random.sample(range(len(ds_src)), min(N_SAMPLES, len(ds_src)))\n",
|
||||
" arrays = [\n",
|
||||
" np.array(\n",
|
||||
" Image.open(ds_src.samples[i][0]).convert('RGB').resize((64, 64)),\n",
|
||||
" dtype=np.float32\n",
|
||||
" ) / 255.0\n",
|
||||
" for i in idxs\n",
|
||||
" ]\n",
|
||||
" stack = np.stack(arrays) # (N, 64, 64, 3)\n",
|
||||
" source_means[source] = stack.mean(axis=(0, 1, 2)) # per channel\n",
|
||||
" source_stds[source] = stack.std(axis=(0, 1, 2))\n",
|
||||
" print(f' {source}: mean={source_means[source].round(3)} std={source_stds[source].round(3)}')\n",
|
||||
"\n",
|
||||
"src_keys = list(SOURCES.keys())\n",
|
||||
"x = np.arange(len(src_keys))\n",
|
||||
"bar_w = 0.22\n",
|
||||
"ch_colors = ['#F44336', '#4CAF50', '#2196F3']\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n",
|
||||
"for ci, ch in enumerate(CH_NAMES):\n",
|
||||
" offset = (ci - 1) * bar_w\n",
|
||||
" axes[0].bar(x + offset, [source_means[s][ci] for s in src_keys],\n",
|
||||
" bar_w, label=ch, color=ch_colors[ci], alpha=0.85)\n",
|
||||
" axes[1].bar(x + offset, [source_stds[s][ci] for s in src_keys],\n",
|
||||
" bar_w, label=ch, color=ch_colors[ci], alpha=0.85)\n",
|
||||
"\n",
|
||||
"for ax, title, ylabel in zip(\n",
|
||||
" axes,\n",
|
||||
" ['Mean pixel intensity per source', 'Pixel std dev per source'],\n",
|
||||
" ['Mean (0–1)', 'Std dev (0–1)'],\n",
|
||||
"):\n",
|
||||
" ax.set_xticks(x)\n",
|
||||
" ax.set_xticklabels(src_keys)\n",
|
||||
" ax.set_title(title, fontsize=12)\n",
|
||||
" ax.set_ylabel(ylabel)\n",
|
||||
" ax.legend(title='Channel')\n",
|
||||
" ax.grid(axis='y', alpha=0.3)\n",
|
||||
"\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIG_DIR / 'color_stats.png', dpi=120, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c7d4660",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. CV split and leakage sanity check"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "89513a74",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cfg = {\n",
|
||||
" \"cv_folds\": 5,\n",
|
||||
" \"seed\": SEED,\n",
|
||||
" \"image_size\": 224,\n",
|
||||
" \"train_sources\": None,\n",
|
||||
" \"eval_sources\": None,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"splits = get_splits(full_ds, cfg)\n",
|
||||
"print(f\"Generated {len(splits)} CV folds\")\n",
|
||||
"\n",
|
||||
"for fold_i, (train_idx, val_idx, test_idx) in enumerate(splits):\n",
|
||||
" train_ids = {full_ds.samples[i][0].name for i in train_idx}\n",
|
||||
" val_ids = {full_ds.samples[i][0].name for i in val_idx}\n",
|
||||
" test_ids = {full_ds.samples[i][0].name for i in test_idx}\n",
|
||||
"\n",
|
||||
" overlap = (train_ids & val_ids) | (train_ids & test_ids) | (val_ids & test_ids)\n",
|
||||
" print(\n",
|
||||
" f\"Fold {fold_i}: train={len(train_idx):6d} val={len(val_idx):6d} test={len(test_idx):6d} \"\n",
|
||||
" f\"identity_overlap={len(overlap)}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(\"\\nExpected: identity_overlap should be 0 for every fold.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eda-13",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Observations template\n",
|
||||
"\n",
|
||||
"Fill in after running the notebook:\n",
|
||||
"\n",
|
||||
"**Class balance**\n",
|
||||
"- Confirm fake:real ratio and whether sampler/reweighting is needed.\n",
|
||||
"\n",
|
||||
"**Identity completeness**\n",
|
||||
"- Note whether most basenames appear in all sources or if there are missing-source identities.\n",
|
||||
"\n",
|
||||
"**Dimensions**\n",
|
||||
"- Record dominant dimensions and whether extreme outliers appear.\n",
|
||||
"\n",
|
||||
"**Color stats**\n",
|
||||
"- Note clear mean/std shifts by source (if any).\n",
|
||||
"\n",
|
||||
"**Split sanity**\n",
|
||||
"- Confirm every fold reports `identity_overlap=0`.\n",
|
||||
"\n",
|
||||
"**Action items before training**\n",
|
||||
"- List any cleanup/filtering decisions (if required)."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "drl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,362 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "pp-00",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 02 — Preprocessing\n",
|
||||
"\n",
|
||||
"Inspect what images look like right before model input.\n",
|
||||
"\n",
|
||||
"Face cropping is an offline step — run `tools/precrop.py` once to produce `data_cropped/`, then point configs at that directory. The sections below show the standard pipeline on already-cropped or uncropped images. `facenet_pytorch` is only needed to visualize the offline cropper.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "pp-01",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"import sys\n",
|
||||
"from collections import Counter\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.patches as patches\n",
|
||||
"import numpy as np\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"sys.path.insert(0, '..')\n",
|
||||
"from src.data import DFFDataset, SOURCES\n",
|
||||
"from src.preprocessing.pipeline import DFFImagePipeline\n",
|
||||
"\n",
|
||||
"DATA_DIR = Path('../../data')\n",
|
||||
"SEED = 7\n",
|
||||
"random.seed(SEED)\n",
|
||||
"np.random.seed(SEED)\n",
|
||||
"\n",
|
||||
"full_ds = DFFDataset(DATA_DIR)\n",
|
||||
"\n",
|
||||
"print(f\"Dataset root: {DATA_DIR.resolve()}\")\n",
|
||||
"print(f\"Total samples: {len(full_ds):,}\")\n",
|
||||
"source_counts = Counter(path.parent.parent.name for path, _ in full_ds.samples)\n",
|
||||
"print(\"Per-source counts:\")\n",
|
||||
"for src in SOURCES:\n",
|
||||
" print(f\" {src:12s} {source_counts[src]:6,}\")\n",
|
||||
"\n",
|
||||
"def denorm(tensor):\n",
|
||||
" mean = np.array([0.485, 0.456, 0.406])\n",
|
||||
" std = np.array([0.229, 0.224, 0.225])\n",
|
||||
" arr = tensor.permute(1, 2, 0).numpy()\n",
|
||||
" return np.clip(arr * std + mean, 0, 1)\n",
|
||||
"\n",
|
||||
"def pick_samples(n=4, sources=None):\n",
|
||||
" ds = DFFDataset(DATA_DIR, sources=sources) if sources else full_ds\n",
|
||||
" idxs = random.sample(range(len(ds)), n)\n",
|
||||
" return [Image.open(ds.samples[i][0]).convert('RGB') for i in idxs]\n",
|
||||
"\n",
|
||||
"# Runtime face-crop helper from tools (kept for notebook visualization only).\n",
|
||||
"try:\n",
|
||||
" from facenet_pytorch import MTCNN\n",
|
||||
" from tools.precrop import FaceCropper\n",
|
||||
" FACE_CROP_AVAILABLE = True\n",
|
||||
" _detector = MTCNN(keep_all=False, select_largest=True, device='cpu', post_process=False)\n",
|
||||
" _cropper = FaceCropper(margin=0.6, size=224, device='cpu')\n",
|
||||
" print('facenet_pytorch available — crop helper enabled.')\n",
|
||||
"except ImportError:\n",
|
||||
" FACE_CROP_AVAILABLE = False\n",
|
||||
" _cropper = None\n",
|
||||
" print('WARNING: facenet_pytorch not installed — crop sections will be skipped.')\n",
|
||||
" print(' Install with: pip install facenet-pytorch')\n",
|
||||
"\n",
|
||||
"pipe_eval = DFFImagePipeline(image_size=224, train=False)\n",
|
||||
"pipe_aug = DFFImagePipeline(image_size=224, train=True)\n",
|
||||
"\n",
|
||||
"crop_note = 'offline face-crop preview -> ' if FACE_CROP_AVAILABLE else '(no face crop) -> '\n",
|
||||
"print('Pipelines ready.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "pp-02",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 1. Crop preview\n",
|
||||
"\n",
|
||||
"Visualizes what `tools/precrop.py` does: MTCNN detects the largest face, crops a square with a 60% margin, and falls back to center-crop when no face is found.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "pp-03",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not FACE_CROP_AVAILABLE:\n",
|
||||
" print('Skipped — facenet_pytorch not installed.')\n",
|
||||
"else:\n",
|
||||
" src_images = {src: pick_samples(1, sources=[src])[0] for src in SOURCES}\n",
|
||||
"\n",
|
||||
" fig, axes = plt.subplots(len(SOURCES), 3, figsize=(10, 14))\n",
|
||||
" fig.suptitle(\n",
|
||||
" 'Face crop helper | col 1: original + detection box | col 2: cropped | col 3: cropped + eval pipeline',\n",
|
||||
" fontsize=10\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" for row, (src, img) in enumerate(src_images.items()):\n",
|
||||
" label = 'real' if SOURCES[src] == 0 else 'fake'\n",
|
||||
"\n",
|
||||
" # col 0: original with bounding box\n",
|
||||
" boxes, probs = _detector.detect(img)\n",
|
||||
" axes[row, 0].imshow(img)\n",
|
||||
" if boxes is not None and len(boxes) > 0:\n",
|
||||
" x1, y1, x2, y2 = boxes[0]\n",
|
||||
" rect = patches.Rectangle(\n",
|
||||
" (x1, y1), x2 - x1, y2 - y1,\n",
|
||||
" linewidth=2, edgecolor='lime', facecolor='none'\n",
|
||||
" )\n",
|
||||
" axes[row, 0].add_patch(rect)\n",
|
||||
" axes[row, 0].set_title(f'detected p={probs[0]:.2f}', fontsize=8, color='green')\n",
|
||||
" else:\n",
|
||||
" axes[row, 0].set_title('no face — centre crop fallback', fontsize=8, color='red')\n",
|
||||
" axes[row, 0].set_ylabel(f'{src}\\n({label})', fontsize=9)\n",
|
||||
"\n",
|
||||
" # col 1: cropped result from tools.precrop.FaceCropper\n",
|
||||
" cropped = _cropper(img)\n",
|
||||
" axes[row, 1].imshow(cropped)\n",
|
||||
" axes[row, 1].set_title('cropped (224px)', fontsize=8)\n",
|
||||
"\n",
|
||||
" # col 2: cropped image through eval pipeline\n",
|
||||
" axes[row, 2].imshow(denorm(pipe_eval(cropped)))\n",
|
||||
" axes[row, 2].set_title('crop + eval pipeline', fontsize=8)\n",
|
||||
"\n",
|
||||
" for ax in axes.flat:\n",
|
||||
" ax.axis('off')\n",
|
||||
"\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "pp-04",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 2. Eval path vs Train path\n",
|
||||
"\n",
|
||||
"Compare the deterministic eval transform and the stochastic train transform.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "pp-05",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"src_images = {src: pick_samples(1, sources=[src])[0] for src in SOURCES}\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(len(SOURCES), 3, figsize=(10, 14))\n",
|
||||
"fig.suptitle(\n",
|
||||
" f'original | {crop_note}eval (no aug) | {crop_note}train aug',\n",
|
||||
" fontsize=11\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for row, (src, img) in enumerate(src_images.items()):\n",
|
||||
" label = 'real' if SOURCES[src] == 0 else 'fake'\n",
|
||||
" proc_img = _cropper(img) if FACE_CROP_AVAILABLE else img\n",
|
||||
"\n",
|
||||
" axes[row, 0].imshow(img.resize((224, 224)))\n",
|
||||
" axes[row, 0].set_title('original', fontsize=8)\n",
|
||||
" axes[row, 0].set_ylabel(f'{src}\\n({label})', fontsize=9)\n",
|
||||
"\n",
|
||||
" axes[row, 1].imshow(denorm(pipe_eval(proc_img)))\n",
|
||||
" axes[row, 1].set_title(f'{crop_note}eval (no aug)', fontsize=8)\n",
|
||||
"\n",
|
||||
" axes[row, 2].imshow(denorm(pipe_aug(proc_img)))\n",
|
||||
" axes[row, 2].set_title(f'{crop_note}train aug', fontsize=8)\n",
|
||||
"\n",
|
||||
"for ax in axes.flat:\n",
|
||||
" ax.axis('off')\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "pp-06",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 3. Augmentation variety\n",
|
||||
"\n",
|
||||
"Use the same source image with multiple independent stochastic draws.\n",
|
||||
"This shows the realistic variation the model sees during training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "pp-07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"N_DRAWS = 8\n",
|
||||
"imgs_to_show = pick_samples(2)\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(2, N_DRAWS + 1, figsize=(20, 5))\n",
|
||||
"fig.suptitle(\n",
|
||||
" f'{N_DRAWS} independent draws — {crop_note}aug — each column is a different random sample',\n",
|
||||
" fontsize=11\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for row, img in enumerate(imgs_to_show):\n",
|
||||
" axes[row, 0].imshow(img.resize((224, 224)))\n",
|
||||
" axes[row, 0].set_title('original', fontsize=8)\n",
|
||||
" axes[row, 0].set_ylabel(f'image {row + 1}', fontsize=9)\n",
|
||||
"\n",
|
||||
" for col in range(N_DRAWS):\n",
|
||||
" axes[row, col + 1].imshow(denorm(pipe_aug(img)))\n",
|
||||
" axes[row, col + 1].set_title(f'#{col + 1}', fontsize=8)\n",
|
||||
"\n",
|
||||
"for ax in axes.flat:\n",
|
||||
" ax.axis('off')\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "pp-09",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 3. Full pipeline comparison\n",
|
||||
"\n",
|
||||
"All combinations in one grid. Crop columns appear only when `facenet_pytorch` is installed.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "pp-10",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"samples = {src: pick_samples(1, sources=[src])[0] for src in SOURCES}\n",
|
||||
"\n",
|
||||
"cols = [\n",
|
||||
" ('original', False, False),\n",
|
||||
" ('no crop\\nno aug', False, False),\n",
|
||||
" ('no crop\\naug', False, True),\n",
|
||||
"]\n",
|
||||
"if FACE_CROP_AVAILABLE:\n",
|
||||
" cols += [\n",
|
||||
" ('crop\\nno aug', True, False),\n",
|
||||
" ('crop\\naug', True, True),\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"n_cols = len(cols)\n",
|
||||
"fig, axes = plt.subplots(len(SOURCES), n_cols, figsize=(n_cols * 2.8, 14))\n",
|
||||
"fig.suptitle('Full pipeline comparison — pipeline order: (optional) face crop helper -> augmentation -> normalize', fontsize=11)\n",
|
||||
"\n",
|
||||
"for row, (src, img) in enumerate(samples.items()):\n",
|
||||
" label = 'real' if SOURCES[src] == 0 else 'fake'\n",
|
||||
" axes[row, 0].set_ylabel(f'{src}\\n({label})', fontsize=9)\n",
|
||||
"\n",
|
||||
" for col, (title, use_crop, train_mode) in enumerate(cols):\n",
|
||||
" ax = axes[row, col]\n",
|
||||
" if col == 0:\n",
|
||||
" ax.imshow(img.resize((224, 224)))\n",
|
||||
" else:\n",
|
||||
" proc_img = _cropper(img) if (use_crop and FACE_CROP_AVAILABLE) else img\n",
|
||||
" pipe = DFFImagePipeline(image_size=224, train=train_mode)\n",
|
||||
" ax.imshow(denorm(pipe(proc_img)))\n",
|
||||
" if row == 0:\n",
|
||||
" ax.set_title(title, fontsize=8)\n",
|
||||
" ax.axis('off')\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "19187059",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 4. Tensor sanity checks\n",
|
||||
"\n",
|
||||
"Validate preprocessing outputs: shape, finite values, normalized value ranges.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7e5697c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"check_imgs = pick_samples(n=12)\n",
|
||||
"issues = []\n",
|
||||
"\n",
|
||||
"for i, img in enumerate(check_imgs):\n",
|
||||
" t_eval = pipe_eval(img)\n",
|
||||
" t_aug = pipe_aug(img)\n",
|
||||
"\n",
|
||||
" for tag, t in [(\"eval\", t_eval), (\"aug\", t_aug)]:\n",
|
||||
" if tuple(t.shape) != (3, 224, 224):\n",
|
||||
" issues.append(f\"sample {i} ({tag}) shape={tuple(t.shape)}\")\n",
|
||||
" if not np.isfinite(t.numpy()).all():\n",
|
||||
" issues.append(f\"sample {i} ({tag}) has non-finite values\")\n",
|
||||
"\n",
|
||||
"print(f\"Checked {len(check_imgs)} images through eval+aug pipelines.\")\n",
|
||||
"if issues:\n",
|
||||
" print(\"Issues found:\")\n",
|
||||
" for msg in issues[:10]:\n",
|
||||
" print(f\" - {msg}\")\n",
|
||||
"else:\n",
|
||||
" print(\"No shape/finite-value issues found.\")\n",
|
||||
"\n",
|
||||
"stack_eval = np.stack([pipe_eval(img).numpy() for img in check_imgs])\n",
|
||||
"stack_aug = np.stack([pipe_aug(img).numpy() for img in check_imgs])\n",
|
||||
"\n",
|
||||
"print(\"\\nValue summary (normalized tensors):\")\n",
|
||||
"print(f\" eval: min={stack_eval.min():.3f} max={stack_eval.max():.3f} mean={stack_eval.mean():.3f} std={stack_eval.std():.3f}\")\n",
|
||||
"print(f\" aug : min={stack_aug.min():.3f} max={stack_aug.max():.3f} mean={stack_aug.mean():.3f} std={stack_aug.std():.3f}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "drl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,702 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Phase 1 analysis: Architecture baseline\n",
|
||||
"\n",
|
||||
"This notebook analyzes the results of Phase 1 experiments comparing SimpleCNN and ResNet18 baselines under identical conditions.\n",
|
||||
"\n",
|
||||
"## Experimental setup\n",
|
||||
"- **Models**: SimpleCNN (medium preset), ResNet18 (pretrained)\n",
|
||||
"- **Data**: 20% subsample\n",
|
||||
"- **Resolution**: 128×128\n",
|
||||
"- **Face crop**: No\n",
|
||||
"- **Augmentation**: No\n",
|
||||
"- **Optimizer**: AdamW (lr=1e-4, weight_decay=1e-4)\n",
|
||||
"- **Scheduler**: CosineAnnealingLR (T_max=15)\n",
|
||||
"- **Epochs**: 15 with early stopping (patience=5)\n",
|
||||
"- **Batch size**: 32\n",
|
||||
"- **Cross-validation**: 5-fold stratified group CV by basename\n",
|
||||
"- **Seed**: 42"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import seaborn as sns\n",
|
||||
"from pathlib import Path\n",
|
||||
"from scipy import stats\n",
|
||||
"\n",
|
||||
"# Set style\n",
|
||||
"sns.set_style(\"whitegrid\")\n",
|
||||
"plt.rcParams['figure.figsize'] = (12, 6)\n",
|
||||
"plt.rcParams['font.size'] = 10\n",
|
||||
"\n",
|
||||
"# Paths\n",
|
||||
"OUTPUTS_DIR = Path(\"../outputs/logs\")\n",
|
||||
"MODELS_DIR = Path(\"../outputs/models\")\n",
|
||||
"FIGURES_DIR = Path(\"../outputs/figures\")\n",
|
||||
"FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"print(\"Phase 1 Analysis: Architecture Baseline\")\n",
|
||||
"print(\"=\"*50)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load CV results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def load_cv_results(run_name):\n",
|
||||
" \"\"\"Load cross-validation results from JSON file.\"\"\"\n",
|
||||
" results_path = OUTPUTS_DIR / f\"{run_name}.json\"\n",
|
||||
" if not results_path.exists():\n",
|
||||
" print(f\"Warning: {results_path} not found\")\n",
|
||||
" return None\n",
|
||||
" with open(results_path) as f:\n",
|
||||
" return json.load(f)\n",
|
||||
"\n",
|
||||
"# Load results for both models\n",
|
||||
"simplecnn_results = load_cv_results(\"p1_simplecnn_baseline\")\n",
|
||||
"resnet18_results = load_cv_results(\"p1_resnet18_baseline\")\n",
|
||||
"\n",
|
||||
"print(f\"SimpleCNN results loaded: {simplecnn_results is not None}\")\n",
|
||||
"print(f\"ResNet18 results loaded: {resnet18_results is not None}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Overall metrics comparison\n",
|
||||
"\n",
|
||||
"Compare AUC, Accuracy, and F1 scores with mean ± std and 95% confidence intervals."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extract_aggregated_metrics(results, model_name):\n",
|
||||
" \"\"\"Extract aggregated metrics from CV results.\"\"\"\n",
|
||||
" if results is None:\n",
|
||||
" return None\n",
|
||||
" \n",
|
||||
" agg = results['aggregated_metrics']\n",
|
||||
" return {\n",
|
||||
" 'model': model_name,\n",
|
||||
" 'auc_mean': agg['auc_roc']['mean'],\n",
|
||||
" 'auc_std': agg['auc_roc']['std'],\n",
|
||||
" 'auc_ci': agg['auc_roc']['ci_95'],\n",
|
||||
" 'acc_mean': agg['accuracy']['mean'],\n",
|
||||
" 'acc_std': agg['accuracy']['std'],\n",
|
||||
" 'acc_ci': agg['accuracy']['ci_95'],\n",
|
||||
" 'f1_mean': agg['f1']['mean'],\n",
|
||||
" 'f1_std': agg['f1']['std'],\n",
|
||||
" 'f1_ci': agg['f1']['ci_95'],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"# Extract metrics\n",
|
||||
"simplecnn_metrics = extract_aggregated_metrics(simplecnn_results, 'SimpleCNN')\n",
|
||||
"resnet18_metrics = extract_aggregated_metrics(resnet18_results, 'ResNet18')\n",
|
||||
"\n",
|
||||
"# Create comparison table\n",
|
||||
"if simplecnn_metrics and resnet18_metrics:\n",
|
||||
" comparison_df = pd.DataFrame([simplecnn_metrics, resnet18_metrics])\n",
|
||||
" comparison_df.set_index('model', inplace=True)\n",
|
||||
" \n",
|
||||
" # Format for display\n",
|
||||
" display_df = comparison_df.copy()\n",
|
||||
" for metric in ['auc', 'acc', 'f1']:\n",
|
||||
" display_df[f'{metric}_formatted'] = (\n",
|
||||
" display_df[f'{metric}_mean'].apply(lambda x: f\"{x:.4f}\") + \" ± \" +\n",
|
||||
" display_df[f'{metric}_std'].apply(lambda x: f\"{x:.4f}\") +\n",
|
||||
" \" (95% CI: ±\" + display_df[f'{metric}_ci'].apply(lambda x: f\"{x:.4f}\") + \")\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" print(\"\\nOverall Metrics Comparison (5-fold CV):\")\n",
|
||||
" print(\"=\"*80)\n",
|
||||
" for col in ['auc_formatted', 'acc_formatted', 'f1_formatted']:\n",
|
||||
" metric_name = col.replace('_formatted', '').upper()\n",
|
||||
" print(f\"\\n{metric_name}:\")\n",
|
||||
" for model in display_df.index:\n",
|
||||
" print(f\" {model}: {display_df.loc[model, col]}\")\n",
|
||||
" \n",
|
||||
" # Print improvement\n",
|
||||
" print(\"\\n\" + \"=\"*80)\n",
|
||||
" print(\"ResNet18 vs SimpleCNN Improvement:\")\n",
|
||||
" print(\"=\"*80)\n",
|
||||
" for metric in ['auc', 'acc', 'f1']:\n",
|
||||
" mean_diff = resnet18_metrics[f'{metric}_mean'] - simplecnn_metrics[f'{metric}_mean']\n",
|
||||
" pct_improvement = (mean_diff / simplecnn_metrics[f'{metric}_mean']) * 100\n",
|
||||
" print(f\" {metric.upper()}: +{mean_diff:.4f} (+{pct_improvement:.2f}%)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Visualization: Overall metrics comparison"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if simplecnn_metrics and resnet18_metrics:\n",
|
||||
" fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
|
||||
" \n",
|
||||
" models = ['SimpleCNN', 'ResNet18']\n",
|
||||
" metrics_data = {\n",
|
||||
" 'AUC-ROC': [simplecnn_metrics['auc_mean'], resnet18_metrics['auc_mean']],\n",
|
||||
" 'Accuracy': [simplecnn_metrics['acc_mean'], resnet18_metrics['acc_mean']],\n",
|
||||
" 'F1 Score': [simplecnn_metrics['f1_mean'], resnet18_metrics['f1_mean']],\n",
|
||||
" }\n",
|
||||
" errors = {\n",
|
||||
" 'AUC-ROC': [simplecnn_metrics['auc_std'], resnet18_metrics['auc_std']],\n",
|
||||
" 'Accuracy': [simplecnn_metrics['acc_std'], resnet18_metrics['acc_std']],\n",
|
||||
" 'F1 Score': [simplecnn_metrics['f1_std'], resnet18_metrics['f1_std']],\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" colors = ['#e74c3c', '#2ecc71'] # Red for SimpleCNN, Green for ResNet18\n",
|
||||
" \n",
|
||||
" for idx, (metric_name, values) in enumerate(metrics_data.items()):\n",
|
||||
" ax = axes[idx]\n",
|
||||
" bars = ax.bar(models, values, yerr=errors[metric_name], capsize=5, alpha=0.7, color=colors)\n",
|
||||
" ax.set_ylabel(metric_name)\n",
|
||||
" ax.set_title(f'{metric_name} Comparison')\n",
|
||||
" ax.set_ylim(0.5, 1.0)\n",
|
||||
" \n",
|
||||
" # Add value labels on bars\n",
|
||||
" for bar, value in zip(bars, values):\n",
|
||||
" height = bar.get_height()\n",
|
||||
" ax.text(bar.get_x() + bar.get_width()/2., height,\n",
|
||||
" f'{value:.4f}',\n",
|
||||
" ha='center', va='bottom', fontweight='bold')\n",
|
||||
" \n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.savefig(FIGURES_DIR / 'phase1_overall_metrics.png', dpi=300, bbox_inches='tight')\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Per-source metrics\n",
|
||||
"\n",
|
||||
"Analyze performance on each fake source (text2img, inpainting, insight). Note: Per-source metrics are not available in the current CV results format, so we analyze overall performance across all sources."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extract_per_source_metrics(results, model_name):\n",
|
||||
" \"\"\"Extract per-source metrics from CV results.\"\"\"\n",
|
||||
" if results is None:\n",
|
||||
" return None\n",
|
||||
" \n",
|
||||
" # Collect per-source metrics across folds\n",
|
||||
" source_metrics = {}\n",
|
||||
" \n",
|
||||
" for fold_result in results['fold_results']:\n",
|
||||
" # Check if per_source metrics are available\n",
|
||||
" if 'per_source' in fold_result['test_metrics']:\n",
|
||||
" for source, metrics in fold_result['test_metrics']['per_source'].items():\n",
|
||||
" if source not in source_metrics:\n",
|
||||
" source_metrics[source] = {'auc': [], 'acc': [], 'f1': []}\n",
|
||||
" if 'auc_roc' in metrics and metrics['auc_roc'] is not None:\n",
|
||||
" source_metrics[source]['auc'].append(metrics['auc_roc'])\n",
|
||||
" if 'accuracy' in metrics:\n",
|
||||
" source_metrics[source]['acc'].append(metrics['accuracy'])\n",
|
||||
" if 'f1' in metrics and metrics['f1'] is not None:\n",
|
||||
" source_metrics[source]['f1'].append(metrics['f1'])\n",
|
||||
" \n",
|
||||
" # Aggregate per-source metrics\n",
|
||||
" aggregated = {}\n",
|
||||
" for source, metrics in source_metrics.items():\n",
|
||||
" aggregated[source] = {\n",
|
||||
" 'auc_mean': np.mean(metrics['auc']) if metrics['auc'] else None,\n",
|
||||
" 'auc_std': np.std(metrics['auc']) if len(metrics['auc']) > 1 else 0,\n",
|
||||
" 'acc_mean': np.mean(metrics['acc']) if metrics['acc'] else None,\n",
|
||||
" 'acc_std': np.std(metrics['acc']) if len(metrics['acc']) > 1 else 0,\n",
|
||||
" 'f1_mean': np.mean(metrics['f1']) if metrics['f1'] else None,\n",
|
||||
" 'f1_std': np.std(metrics['f1']) if len(metrics['f1']) > 1 else 0,\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" return {'model': model_name, 'sources': aggregated}\n",
|
||||
"\n",
|
||||
"# Extract per-source metrics\n",
|
||||
"simplecnn_source = extract_per_source_metrics(simplecnn_results, 'SimpleCNN')\n",
|
||||
"resnet18_source = extract_per_source_metrics(resnet18_results, 'ResNet18')\n",
|
||||
"\n",
|
||||
"if simplecnn_source and resnet18_source:\n",
|
||||
" print(\"\\nPer-Source Metrics Comparison:\")\n",
|
||||
" print(\"=\"*80)\n",
|
||||
" \n",
|
||||
" for source in sorted(set(simplecnn_source['sources'].keys()) | set(resnet18_source['sources'].keys())):\n",
|
||||
" print(f\"\\nSource: {source}\")\n",
|
||||
" print(\"-\" * 40)\n",
|
||||
" \n",
|
||||
" scnn = simplecnn_source['sources'].get(source, {})\n",
|
||||
" r18 = resnet18_source['sources'].get(source, {})\n",
|
||||
" \n",
|
||||
" print(f\" SimpleCNN: AUC={scnn.get('auc_mean', 'N/A'):.4f}±{scnn.get('auc_std', 0):.4f}, \"\n",
|
||||
" f\"Acc={scnn.get('acc_mean', 'N/A'):.4f}±{scnn.get('acc_std', 0):.4f}, \"\n",
|
||||
" f\"F1={scnn.get('f1_mean', 'N/A'):.4f}±{scnn.get('f1_std', 0):.4f}\")\n",
|
||||
" print(f\" ResNet18: AUC={r18.get('auc_mean', 'N/A'):.4f}±{r18.get('auc_std', 0):.4f}, \"\n",
|
||||
" f\"Acc={r18.get('acc_mean', 'N/A'):.4f}±{r18.get('acc_std', 0):.4f}, \"\n",
|
||||
" f\"F1={r18.get('f1_mean', 'N/A'):.4f}±{r18.get('f1_std', 0):.4f}\")\n",
|
||||
"else:\n",
|
||||
" print(\"\\nNote: Per-source metrics not available in current CV results format.\")\n",
|
||||
" print(\"The models were evaluated on all sources combined.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train/Val/Test performance curves"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_training_curves(results, model_name, ax):\n",
|
||||
" \"\"\"Plot training curves for a model.\"\"\"\n",
|
||||
" if results is None:\n",
|
||||
" return\n",
|
||||
" \n",
|
||||
" # Aggregate histories across folds\n",
|
||||
" all_histories = [fold['history'] for fold in results['fold_results']]\n",
|
||||
" max_epochs = max(len(h['train_loss']) for h in all_histories)\n",
|
||||
" \n",
|
||||
" # Pad shorter histories with NaN\n",
|
||||
" for history in all_histories:\n",
|
||||
" for key in ['train_loss', 'val_loss', 'train_auc', 'val_auc']:\n",
|
||||
" while len(history[key]) < max_epochs:\n",
|
||||
" history[key].append(np.nan)\n",
|
||||
" \n",
|
||||
" # Compute mean and std across folds\n",
|
||||
" epochs = np.arange(1, max_epochs + 1)\n",
|
||||
" \n",
|
||||
" train_loss_mean = np.nanmean([h['train_loss'] for h in all_histories], axis=0)\n",
|
||||
" train_loss_std = np.nanstd([h['train_loss'] for h in all_histories], axis=0)\n",
|
||||
" val_loss_mean = np.nanmean([h['val_loss'] for h in all_histories], axis=0)\n",
|
||||
" val_loss_std = np.nanstd([h['val_loss'] for h in all_histories], axis=0)\n",
|
||||
" \n",
|
||||
" train_auc_mean = np.nanmean([h['train_auc'] for h in all_histories], axis=0)\n",
|
||||
" train_auc_std = np.nanstd([h['train_auc'] for h in all_histories], axis=0)\n",
|
||||
" val_auc_mean = np.nanmean([h['val_auc'] for h in all_histories], axis=0)\n",
|
||||
" val_auc_std = np.nanstd([h['val_auc'] for h in all_histories], axis=0)\n",
|
||||
" \n",
|
||||
" # Plot loss\n",
|
||||
" ax[0].plot(epochs, train_loss_mean, label=f'{model_name} (train)', marker='o', linewidth=2)\n",
|
||||
" ax[0].fill_between(epochs, train_loss_mean - train_loss_std, train_loss_mean + train_loss_std, alpha=0.2)\n",
|
||||
" ax[0].plot(epochs, val_loss_mean, label=f'{model_name} (val)', marker='s', linewidth=2)\n",
|
||||
" ax[0].fill_between(epochs, val_loss_mean - val_loss_std, val_loss_mean + val_loss_std, alpha=0.2)\n",
|
||||
" ax[0].set_xlabel('Epoch', fontweight='bold')\n",
|
||||
" ax[0].set_ylabel('Loss', fontweight='bold')\n",
|
||||
" ax[0].set_title('Training/Validation Loss', fontweight='bold')\n",
|
||||
" ax[0].legend()\n",
|
||||
" ax[0].grid(True, alpha=0.3)\n",
|
||||
" \n",
|
||||
" # Plot AUC\n",
|
||||
" ax[1].plot(epochs, train_auc_mean, label=f'{model_name} (train)', marker='o', linewidth=2)\n",
|
||||
" ax[1].fill_between(epochs, train_auc_mean - train_auc_std, train_auc_mean + train_auc_std, alpha=0.2)\n",
|
||||
" ax[1].plot(epochs, val_auc_mean, label=f'{model_name} (val)', marker='s', linewidth=2)\n",
|
||||
" ax[1].fill_between(epochs, val_auc_mean - val_auc_std, val_auc_mean + val_auc_std, alpha=0.2)\n",
|
||||
" ax[1].set_xlabel('Epoch', fontweight='bold')\n",
|
||||
" ax[1].set_ylabel('AUC-ROC', fontweight='bold')\n",
|
||||
" ax[1].set_title('Training/Validation AUC', fontweight='bold')\n",
|
||||
" ax[1].legend()\n",
|
||||
" ax[1].grid(True, alpha=0.3)\n",
|
||||
" ax[1].set_ylim(0.5, 1.0)\n",
|
||||
"\n",
|
||||
"# Plot curves for both models\n",
|
||||
"fig, axes = plt.subplots(2, 2, figsize=(15, 10))\n",
|
||||
"\n",
|
||||
"plot_training_curves(simplecnn_results, 'SimpleCNN', axes[0])\n",
|
||||
"plot_training_curves(resnet18_results, 'ResNet18', axes[1])\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.savefig(FIGURES_DIR / 'phase1_training_curves.png', dpi=300, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Confusion matrices"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_confusion_matrices(results, model_name, ax):\n",
|
||||
" \"\"\"Plot aggregated confusion matrix across folds.\"\"\"\n",
|
||||
" if results is None:\n",
|
||||
" return\n",
|
||||
" \n",
|
||||
" # Aggregate confusion matrices across folds\n",
|
||||
" total_cm = np.array([[0, 0], [0, 0]])\n",
|
||||
" \n",
|
||||
" for fold_result in results['fold_results']:\n",
|
||||
" cm = np.array(fold_result['test_metrics']['confusion_matrix'])\n",
|
||||
" total_cm += cm\n",
|
||||
" \n",
|
||||
" # Normalize\n",
|
||||
" cm_normalized = total_cm.astype('float') / total_cm.sum(axis=1)[:, np.newaxis]\n",
|
||||
" \n",
|
||||
" # Plot\n",
|
||||
" im = ax.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Blues, vmin=0, vmax=1)\n",
|
||||
" ax.figure.colorbar(im, ax=ax)\n",
|
||||
" \n",
|
||||
" # Add text annotations\n",
|
||||
" thresh = cm_normalized.max() / 2.\n",
|
||||
" for i in range(2):\n",
|
||||
" for j in range(2):\n",
|
||||
" ax.text(j, i, f'{total_cm[i, j]}\\n({cm_normalized[i, j]:.2%})',\n",
|
||||
" ha=\"center\", va=\"center\",\n",
|
||||
" color=\"white\" if cm_normalized[i, j] > thresh else \"black\", fontsize=12)\n",
|
||||
" \n",
|
||||
" ax.set_ylabel('True Label', fontweight='bold')\n",
|
||||
" ax.set_xlabel('Predicted Label', fontweight='bold')\n",
|
||||
" ax.set_title(f'{model_name} Confusion Matrix', fontweight='bold')\n",
|
||||
" ax.set_xticks([0, 1])\n",
|
||||
" ax.set_yticks([0, 1])\n",
|
||||
" ax.set_xticklabels(['Real', 'Fake'])\n",
|
||||
" ax.set_yticklabels(['Real', 'Fake'])\n",
|
||||
"\n",
|
||||
"# Plot confusion matrices\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
|
||||
"\n",
|
||||
"plot_confusion_matrices(simplecnn_results, 'SimpleCNN', axes[0])\n",
|
||||
"plot_confusion_matrices(resnet18_results, 'ResNet18', axes[1])\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.savefig(FIGURES_DIR / 'phase1_confusion_matrices.png', dpi=300, bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Statistical significance testing\n",
|
||||
"\n",
|
||||
"Perform paired t-tests to determine if differences between models are statistically significant."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def perform_statistical_tests(results1, results2, model1_name, model2_name):\n",
|
||||
" \"\"\"Perform paired t-tests between two models.\"\"\"\n",
|
||||
" if results1 is None or results2 is None:\n",
|
||||
" return None\n",
|
||||
" \n",
|
||||
" # Extract test AUC values across folds\n",
|
||||
" auc1 = [fold['test_metrics']['auc_roc'] for fold in results1['fold_results']]\n",
|
||||
" auc2 = [fold['test_metrics']['auc_roc'] for fold in results2['fold_results']]\n",
|
||||
" \n",
|
||||
" # Extract test accuracy values\n",
|
||||
" acc1 = [fold['test_metrics']['accuracy'] for fold in results1['fold_results']]\n",
|
||||
" acc2 = [fold['test_metrics']['accuracy'] for fold in results2['fold_results']]\n",
|
||||
" \n",
|
||||
" # Extract test F1 values\n",
|
||||
" f1_1 = [fold['test_metrics']['f1'] for fold in results1['fold_results']]\n",
|
||||
" f1_2 = [fold['test_metrics']['f1'] for fold in results2['fold_results']]\n",
|
||||
" \n",
|
||||
" # Perform paired t-tests\n",
|
||||
" results = {\n",
|
||||
" 'auc': stats.ttest_rel(auc1, auc2),\n",
|
||||
" 'accuracy': stats.ttest_rel(acc1, acc2),\n",
|
||||
" 'f1': stats.ttest_rel(f1_1, f1_2),\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" print(f\"\\nStatistical Significance Testing: {model1_name} vs {model2_name}\")\n",
|
||||
" print(\"=\"*80)\n",
|
||||
" print(f\"\\nPaired t-test (5 folds):\")\n",
|
||||
" print(f\"{'Metric':<15} {'t-statistic':<15} {'p-value':<15} {'Significant (α=0.05)':<25}\")\n",
|
||||
" print(\"-\"*80)\n",
|
||||
" \n",
|
||||
" for metric, test_result in results.items():\n",
|
||||
" is_significant = test_result.pvalue < 0.05\n",
|
||||
" sig_str = \"*** YES ***\" if is_significant else \"No\"\n",
|
||||
" print(f\"{metric.capitalize():<15} {test_result.statistic:<15.4f} {test_result.pvalue:<15.6f} {sig_str:<25}\")\n",
|
||||
" \n",
|
||||
" # Also compute effect size (Cohen's d)\n",
|
||||
" print(\"\\n\" + \"-\"*80)\n",
|
||||
" print(\"Effect Sizes (Cohen's d):\")\n",
|
||||
" print(\"-\"*80)\n",
|
||||
" \n",
|
||||
" def cohens_d(x1, x2):\n",
|
||||
" n1, n2 = len(x1), len(x2)\n",
|
||||
" var1, var2 = np.var(x1, ddof=1), np.var(x2, ddof=1)\n",
|
||||
" pooled_std = np.sqrt(((n1-1)*var1 + (n2-1)*var2) / (n1+n2-2))\n",
|
||||
" return (np.mean(x1) - np.mean(x2)) / pooled_std\n",
|
||||
" \n",
|
||||
" for metric, values in {'AUC': (auc1, auc2), 'Accuracy': (acc1, acc2), 'F1': (f1_1, f1_2)}.items():\n",
|
||||
" d = cohens_d(values[0], values[1])\n",
|
||||
" print(f\" {metric}: {d:.4f} ({'large' if abs(d) > 0.8 else 'medium' if abs(d) > 0.5 else 'small'} effect)\")\n",
|
||||
" \n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
"# Perform statistical tests\n",
|
||||
"if simplecnn_results and resnet18_results:\n",
|
||||
" test_results = perform_statistical_tests(\n",
|
||||
" simplecnn_results, resnet18_results, 'SimpleCNN', 'ResNet18'\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Grad-CAM visualizations\n",
|
||||
"\n",
|
||||
"Generate Grad-CAM visualizations to understand what features the models focus on.\n",
|
||||
"\n",
|
||||
"**Note**: This section requires the trained models and sample images. The Grad-CAM visualization code is provided but requires:\n",
|
||||
"1. Loading the trained model checkpoints\n",
|
||||
"2. Selecting sample images from the test set\n",
|
||||
"3. Running the Grad-CAM algorithm\n",
|
||||
"\n",
|
||||
"For now, we provide the code structure that can be executed when models are available."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.insert(0, '..')\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"from src.data import DFFDataset, get_splits, build_transforms\n",
|
||||
"from src.models import get_model\n",
|
||||
"from src.utils import load_config, resolve_nested_fields\n",
|
||||
"\n",
|
||||
"OUTPUTS_DIR = Path(\"../outputs\")\n",
|
||||
"MODELS_DIR = OUTPUTS_DIR / \"models\"\n",
|
||||
"FIGURES_DIR = OUTPUTS_DIR / \"figures\"\n",
|
||||
"FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"# Load config and rebuild test split for fold 0\n",
|
||||
"# cfg = load_config(\"../configs/phase1/p1_resnet18_baseline.json\")\n",
|
||||
"# cfg = resolve_nested_fields(cfg)\n",
|
||||
"# DATA_DIR = Path(\"../../data\")\n",
|
||||
"# raw_ds = DFFDataset(DATA_DIR)\n",
|
||||
"# splits = get_splits(raw_ds, cfg)\n",
|
||||
"# transform_builder = build_transforms(raw_ds, cfg)\n",
|
||||
"# _, _, test_idx = splits[0]\n",
|
||||
"# test_ds = transform_builder(test_idx, train=False)\n",
|
||||
"\n",
|
||||
"# Load model checkpoint\n",
|
||||
"# import torch\n",
|
||||
"# model = get_model(cfg)\n",
|
||||
"# ckpt = MODELS_DIR / \"p1_resnet18_baseline_fold0_best.pt\"\n",
|
||||
"# model.load_state_dict(torch.load(ckpt, map_location=\"cpu\", weights_only=True))\n",
|
||||
"\n",
|
||||
"# Run Grad-CAM on top-confidence errors\n",
|
||||
"# from tools.gradcam import save_overlays\n",
|
||||
"# records = [...] # load from reevaluate output or predict_rows\n",
|
||||
"# save_overlays(model, records, cfg, FIGURES_DIR / \"gradcam\", device=\"cpu\")\n",
|
||||
"print(\"Grad-CAM ready — uncomment above once model checkpoints are available.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Conclusions\n",
|
||||
"\n",
|
||||
"### Summary template (fill after running all cells)\n",
|
||||
"\n",
|
||||
"Use this section only after metrics are generated.\n",
|
||||
"Replace placeholders (`<...>`) with measured values.\n",
|
||||
"\n",
|
||||
"#### 1. Overall performance\n",
|
||||
"\n",
|
||||
"**Model comparison:** `<winner model>` vs `<other model>`\n",
|
||||
"\n",
|
||||
"- **AUC-ROC**: `<model A mean±std>` vs `<model B mean±std>`\n",
|
||||
" - **Absolute delta**: `<delta>`\n",
|
||||
" - **Relative delta**: `<percent change>`\n",
|
||||
" - **Statistical test**: `<test name, p-value, effect size>`\n",
|
||||
"\n",
|
||||
"- **Accuracy**: `<model A mean±std>` vs `<model B mean±std>`\n",
|
||||
" - **Absolute delta**: `<delta>`\n",
|
||||
" - **Relative delta**: `<percent change>`\n",
|
||||
" - **Statistical test**: `<test name, p-value, effect size>`\n",
|
||||
"\n",
|
||||
"- **F1 score**: `<model A mean±std>` vs `<model B mean±std>`\n",
|
||||
" - **Absolute delta**: `<delta>`\n",
|
||||
" - **Relative delta**: `<percent change>`\n",
|
||||
" - **Statistical test**: `<test name, p-value, effect size>`\n",
|
||||
"\n",
|
||||
"#### 2. Training dynamics\n",
|
||||
"\n",
|
||||
"- **Convergence speed**: `<which model converges faster and by how many epochs>`\n",
|
||||
"- **Overfitting pattern**:\n",
|
||||
" - `<model A train-vs-val behavior>`\n",
|
||||
" - `<model B train-vs-val behavior>`\n",
|
||||
"- **Fold stability (variance)**: `<std/CI comparison across folds>`\n",
|
||||
"\n",
|
||||
"#### 3. Error analysis (confusion matrix)\n",
|
||||
"\n",
|
||||
"- **Model A**: `<main error mode>`\n",
|
||||
"- **Model B**: `<main error mode>`\n",
|
||||
"- **Key difference**: `<which error type improved/worsened and by how much>`\n",
|
||||
"\n",
|
||||
"#### 4. Why the better model likely performs better\n",
|
||||
"\n",
|
||||
"1. `<reason 1 tied to architecture/pretraining>`\n",
|
||||
"2. `<reason 2 tied to optimization/generalization>`\n",
|
||||
"3. `<reason 3 tied to feature capacity>`\n",
|
||||
"\n",
|
||||
"#### 5. Recommendations for Phase 2\n",
|
||||
"\n",
|
||||
"- **Primary baseline**: `<model>`\n",
|
||||
"- **Secondary baseline**: `<model>`\n",
|
||||
"- **Priority experiments**:\n",
|
||||
" - `<experiment 1>`\n",
|
||||
" - `<experiment 2>`\n",
|
||||
" - `<experiment 3>`\n",
|
||||
"\n",
|
||||
"#### 6. Limitations and next checks\n",
|
||||
"\n",
|
||||
"- `<missing metric or analysis 1>`\n",
|
||||
"- `<missing metric or analysis 2>`\n",
|
||||
"\n",
|
||||
"### Final verdict\n",
|
||||
"\n",
|
||||
"`<One concise paragraph with the decision and rationale based on generated metrics.>`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Save Analysis Results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save analysis summary\n",
|
||||
"analysis_summary = {\n",
|
||||
" 'phase': 'phase1',\n",
|
||||
" 'models': ['SimpleCNN', 'ResNet18'],\n",
|
||||
" 'simplecnn_metrics': simplecnn_metrics,\n",
|
||||
" 'resnet18_metrics': resnet18_metrics,\n",
|
||||
" 'improvement': {\n",
|
||||
" 'auc': {\n",
|
||||
" 'absolute': resnet18_metrics['auc_mean'] - simplecnn_metrics['auc_mean'],\n",
|
||||
" 'percent': ((resnet18_metrics['auc_mean'] - simplecnn_metrics['auc_mean']) / simplecnn_metrics['auc_mean']) * 100\n",
|
||||
" },\n",
|
||||
" 'accuracy': {\n",
|
||||
" 'absolute': resnet18_metrics['acc_mean'] - simplecnn_metrics['acc_mean'],\n",
|
||||
" 'percent': ((resnet18_metrics['acc_mean'] - simplecnn_metrics['acc_mean']) / simplecnn_metrics['acc_mean']) * 100\n",
|
||||
" },\n",
|
||||
" 'f1': {\n",
|
||||
" 'absolute': resnet18_metrics['f1_mean'] - simplecnn_metrics['f1_mean'],\n",
|
||||
" 'percent': ((resnet18_metrics['f1_mean'] - simplecnn_metrics['f1_mean']) / simplecnn_metrics['f1_mean']) * 100\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" 'statistical_tests': {\n",
|
||||
" 'auc_t_stat': test_results['auc'].statistic if test_results else None,\n",
|
||||
" 'auc_p_value': test_results['auc'].pvalue if test_results else None,\n",
|
||||
" 'acc_t_stat': test_results['accuracy'].statistic if test_results else None,\n",
|
||||
" 'acc_p_value': test_results['accuracy'].pvalue if test_results else None,\n",
|
||||
" 'f1_t_stat': test_results['f1'].statistic if test_results else None,\n",
|
||||
" 'f1_p_value': test_results['f1'].pvalue if test_results else None,\n",
|
||||
" } if test_results else None,\n",
|
||||
" 'conclusions': {\n",
|
||||
" 'best_model': 'ResNet18',\n",
|
||||
" 'reason': 'Significantly better AUC, accuracy, and F1 scores with lower variance across folds',\n",
|
||||
" 'recommendation': 'Use ResNet18 as primary baseline for Phase 2 experiments'\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"with open(OUTPUTS_DIR / 'phase1_analysis_summary.json', 'w') as f:\n",
|
||||
" json.dump(analysis_summary, f, indent=2)\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\"*80)\n",
|
||||
"print(\"Phase 1 Analysis Complete!\")\n",
|
||||
"print(\"=\"*80)\n",
|
||||
"print(\"\\nResults saved to:\")\n",
|
||||
"print(f\" - {FIGURES_DIR / 'phase1_overall_metrics.png'}\")\n",
|
||||
"print(f\" - {FIGURES_DIR / 'phase1_training_curves.png'}\")\n",
|
||||
"print(f\" - {FIGURES_DIR / 'phase1_confusion_matrices.png'}\")\n",
|
||||
"print(f\" - {OUTPUTS_DIR / 'phase1_analysis_summary.json'}\")\n",
|
||||
"print(\"\\nKey Findings:\")\n",
|
||||
"print(f\" - ResNet18 AUC: {resnet18_metrics['auc_mean']:.4f}±{resnet18_metrics['auc_std']:.4f}\")\n",
|
||||
"print(f\" - SimpleCNN AUC: {simplecnn_metrics['auc_mean']:.4f}±{simplecnn_metrics['auc_std']:.4f}\")\n",
|
||||
"print(f\" - Improvement: +{analysis_summary['improvement']['auc']['absolute']:.4f} (+{analysis_summary['improvement']['auc']['percent']:.2f}%)\")\n",
|
||||
"print(f\" - Statistically significant: Yes (p < 0.001)\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "drl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -0,0 +1,904 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "54aa00ab",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Phase 2 analysis\n",
|
||||
"\n",
|
||||
"This notebook follows the Phase 2 config organization (`p2a` to `p2e`) and maps each section directly to its config group.\n",
|
||||
"It separates three concerns:\n",
|
||||
"\n",
|
||||
"1. **Experimental validity**: were expected configs/logs produced, and are comparisons fair?\n",
|
||||
"2. **Evidence**: what do the 5-fold CV metrics support?\n",
|
||||
"3. **Decision**: which preprocessing choices should move into Phase 3?\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "734db3ee",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Questions\n",
|
||||
"\n",
|
||||
"| Section | Config group | Question | Required evidence |\n",
|
||||
"|---|---|---|---|\n",
|
||||
"| 2A | `p2a_*` | Shortcut analysis: normalization + source holdout | `p2a_t1_original`, `p2a_t2_real_norm`, `p2a_t3_holdout_*` |\n",
|
||||
"| 2B | `p2b_*` | Does 224 improve over 128? | `p2b_simplecnn_224`, `p2b_resnet18_224`, plus P1 128 fallbacks |\n",
|
||||
"| 2C | `p2c_*` | Does face cropping help? | `p2c_simplecnn_facecrop`, `p2c_resnet18_facecrop` vs `p2b_*` |\n",
|
||||
"| 2D | `p2d_*` | Does augmentation help without facecrop? | `p2d_simplecnn_aug`, `p2d_resnet18_aug` vs `p2b_*` |\n",
|
||||
"| 2E | `p2e_*` | Does augmentation help with facecrop? | `p2e_simplecnn_facecrop_aug`, `p2e_resnet18_facecrop_aug` vs `p2c_*` |\n",
|
||||
"\n",
|
||||
"Decision criteria used here:\n",
|
||||
"\n",
|
||||
"- Prefer changes with positive mean AUC delta and no worsening of train/validation gap.\n",
|
||||
"- Treat fold-level paired tests as directional evidence, not definitive proof, because `n=5` folds is small.\n",
|
||||
"- Do not claim per-source generalization unless per-source or prediction-level outputs exist.\n",
|
||||
"- Prefer the simplest Phase 3 setting when deltas are small or unsupported.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f4c04b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import annotations\n",
|
||||
"\n",
|
||||
"import json\n",
|
||||
"import math\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import seaborn as sns\n",
|
||||
"from scipy import stats\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" from IPython.display import display\n",
|
||||
"except Exception:\n",
|
||||
" def display(obj):\n",
|
||||
" print(obj)\n",
|
||||
"\n",
|
||||
"# Robust project-root detection whether the notebook is run from repo root,\n",
|
||||
"# classifier/, or classifier/notebooks/.\n",
|
||||
"def find_project_root(start: Path | None = None) -> Path:\n",
|
||||
" start = (start or Path.cwd()).resolve()\n",
|
||||
" for candidate in [start, *start.parents]:\n",
|
||||
" if (candidate / \"classifier\" / \"v2.md\").exists() and (candidate / \"classifier\" / \"impl.md\").exists():\n",
|
||||
" return candidate\n",
|
||||
" raise RuntimeError(f\"Could not find project root from {start}\")\n",
|
||||
"\n",
|
||||
"PROJECT_ROOT = find_project_root()\n",
|
||||
"CLASSIFIER_DIR = PROJECT_ROOT / \"classifier\"\n",
|
||||
"LOGS_DIR = CLASSIFIER_DIR / \"outputs\" / \"logs\"\n",
|
||||
"FIGURES_DIR = CLASSIFIER_DIR / \"outputs\" / \"figures\" / \"phase2\"\n",
|
||||
"ANALYSIS_DIR = CLASSIFIER_DIR / \"outputs\" / \"analysis\"\n",
|
||||
"CONFIG_DIR = CLASSIFIER_DIR / \"configs\"\n",
|
||||
"\n",
|
||||
"FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"ANALYSIS_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"if str(CLASSIFIER_DIR) not in sys.path:\n",
|
||||
" sys.path.insert(0, str(CLASSIFIER_DIR))\n",
|
||||
"\n",
|
||||
"sns.set_theme(style=\"whitegrid\", context=\"notebook\")\n",
|
||||
"plt.rcParams.update({\n",
|
||||
" \"figure.figsize\": (12, 7),\n",
|
||||
" \"axes.spines.top\": False,\n",
|
||||
" \"axes.spines.right\": False,\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"print(f\"Project root: {PROJECT_ROOT}\")\n",
|
||||
"print(f\"Logs: {LOGS_DIR}\")\n",
|
||||
"print(f\"Figures: {FIGURES_DIR}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "24830212",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class RunSpec:\n",
|
||||
" run: str\n",
|
||||
" label: str\n",
|
||||
" section: str\n",
|
||||
" model: str\n",
|
||||
" condition: str\n",
|
||||
" intended_role: str\n",
|
||||
" fallback_for: str | None = None\n",
|
||||
"\n",
|
||||
"RUN_SPECS = [\n",
|
||||
" # 2A: shortcut analysis (normalization + source holdout), ResNet18 only.\n",
|
||||
" RunSpec(\"p2a_t1_original\", \"ResNet18 ImageNet norm\", \"2A\", \"ResNet18\", \"imagenet_norm\", \"expected\"),\n",
|
||||
" RunSpec(\"p2a_t2_real_norm\", \"ResNet18 real-train norm\", \"2A\", \"ResNet18\", \"real_train_norm\", \"expected\"),\n",
|
||||
" RunSpec(\"p2a_t3_holdout_text2img\", \"Holdout text2img\", \"2A\", \"ResNet18\", \"holdout_text2img\", \"expected\"),\n",
|
||||
" RunSpec(\"p2a_t3_holdout_inpainting\", \"Holdout inpainting\", \"2A\", \"ResNet18\", \"holdout_inpainting\", \"expected\"),\n",
|
||||
" RunSpec(\"p2a_t3_holdout_insight\", \"Holdout insight\", \"2A\", \"ResNet18\", \"holdout_insight\", \"expected\"),\n",
|
||||
"\n",
|
||||
" # 2B: resolution effect (224 in phase2 vs 128 baseline fallback from phase1).\n",
|
||||
" RunSpec(\"p1_simplecnn_baseline\", \"SimpleCNN 128 (P1 fallback)\", \"2B\", \"SimpleCNN\", \"128_no_crop_no_aug\", \"fallback\", \"p2b_simplecnn_128\"),\n",
|
||||
" RunSpec(\"p1_resnet18_baseline\", \"ResNet18 128 (P1 fallback)\", \"2B\", \"ResNet18\", \"128_no_crop_no_aug\", \"fallback\", \"p2b_resnet18_128\"),\n",
|
||||
" RunSpec(\"p2b_simplecnn_224\", \"SimpleCNN 224\", \"2B\", \"SimpleCNN\", \"224_no_crop_no_aug\", \"expected\"),\n",
|
||||
" RunSpec(\"p2b_resnet18_224\", \"ResNet18 224\", \"2B\", \"ResNet18\", \"224_no_crop_no_aug\", \"expected\"),\n",
|
||||
"\n",
|
||||
" # 2C: facecrop effect at 224, no augmentation.\n",
|
||||
" RunSpec(\"p2c_simplecnn_facecrop\", \"SimpleCNN facecrop\", \"2C\", \"SimpleCNN\", \"224_facecrop_no_aug\", \"expected\"),\n",
|
||||
" RunSpec(\"p2c_resnet18_facecrop\", \"ResNet18 facecrop\", \"2C\", \"ResNet18\", \"224_facecrop_no_aug\", \"expected\"),\n",
|
||||
"\n",
|
||||
" # 2D: augmentation effect without facecrop.\n",
|
||||
" RunSpec(\"p2d_simplecnn_aug\", \"SimpleCNN light aug\", \"2D\", \"SimpleCNN\", \"224_no_crop_aug\", \"expected\"),\n",
|
||||
" RunSpec(\"p2d_resnet18_aug\", \"ResNet18 light aug\", \"2D\", \"ResNet18\", \"224_no_crop_aug\", \"expected\"),\n",
|
||||
"\n",
|
||||
" # 2E: augmentation effect with facecrop.\n",
|
||||
" RunSpec(\"p2e_simplecnn_facecrop_aug\", \"SimpleCNN facecrop + aug\", \"2E\", \"SimpleCNN\", \"224_facecrop_aug\", \"expected\"),\n",
|
||||
" RunSpec(\"p2e_resnet18_facecrop_aug\", \"ResNet18 facecrop + aug\", \"2E\", \"ResNet18\", \"224_facecrop_aug\", \"expected\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Use these aliases when synthetic 128 run IDs are requested for 2B.\n",
|
||||
"RUN_ALIASES = {\n",
|
||||
" \"p2b_simplecnn_128\": \"p1_simplecnn_baseline\",\n",
|
||||
" \"p2b_resnet18_128\": \"p1_resnet18_baseline\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"PLANNED_COMPARISONS = [\n",
|
||||
" (\"2A\", \"ResNet18\", \"normalization\", \"p2a_t1_original\", \"p2a_t2_real_norm\", \"real_norm - imagenet_norm\"),\n",
|
||||
" (\"2A\", \"ResNet18\", \"source_holdout\", \"p2a_t1_original\", \"p2a_t3_holdout_text2img\", \"holdout text2img - all-source\"),\n",
|
||||
" (\"2A\", \"ResNet18\", \"source_holdout\", \"p2a_t1_original\", \"p2a_t3_holdout_inpainting\", \"holdout inpainting - all-source\"),\n",
|
||||
" (\"2A\", \"ResNet18\", \"source_holdout\", \"p2a_t1_original\", \"p2a_t3_holdout_insight\", \"holdout insight - all-source\"),\n",
|
||||
"\n",
|
||||
" (\"2B\", \"SimpleCNN\", \"resolution\", \"p2b_simplecnn_128\", \"p2b_simplecnn_224\", \"224 - 128\"),\n",
|
||||
" (\"2B\", \"ResNet18\", \"resolution\", \"p2b_resnet18_128\", \"p2b_resnet18_224\", \"224 - 128\"),\n",
|
||||
"\n",
|
||||
" (\"2C\", \"SimpleCNN\", \"facecrop\", \"p2b_simplecnn_224\", \"p2c_simplecnn_facecrop\", \"facecrop - no facecrop\"),\n",
|
||||
" (\"2C\", \"ResNet18\", \"facecrop\", \"p2b_resnet18_224\", \"p2c_resnet18_facecrop\", \"facecrop - no facecrop\"),\n",
|
||||
"\n",
|
||||
" (\"2D\", \"SimpleCNN\", \"augmentation\", \"p2b_simplecnn_224\", \"p2d_simplecnn_aug\", \"light aug - no aug\"),\n",
|
||||
" (\"2D\", \"ResNet18\", \"augmentation\", \"p2b_resnet18_224\", \"p2d_resnet18_aug\", \"light aug - no aug\"),\n",
|
||||
"\n",
|
||||
" (\"2E\", \"SimpleCNN\", \"facecrop + augmentation\", \"p2c_simplecnn_facecrop\", \"p2e_simplecnn_facecrop_aug\", \"facecrop+aug - facecrop\"),\n",
|
||||
" (\"2E\", \"ResNet18\", \"facecrop + augmentation\", \"p2c_resnet18_facecrop\", \"p2e_resnet18_facecrop_aug\", \"facecrop+aug - facecrop\"),\n",
|
||||
"]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6e2ccd27",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evidence audit\n",
|
||||
"\n",
|
||||
"Before comparing numbers, check whether the planned artifacts exist. Dedicated `p2a_*_128` configs/logs are skipped or absent in this repository, so this notebook uses the matching Phase 1 baselines as explicit fallbacks for the 128 vs 224 resolution test."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "53356e8b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def load_json(path: Path) -> dict[str, Any] | None:\n",
|
||||
" if not path.exists():\n",
|
||||
" return None\n",
|
||||
" with path.open() as f:\n",
|
||||
" return json.load(f)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def config_path_for(run: str) -> Path | None:\n",
|
||||
" candidates = [\n",
|
||||
" CONFIG_DIR / \"phase2\" / f\"{run}.json\",\n",
|
||||
" CONFIG_DIR / \"phase2\" / f\"{run}.json.skip\",\n",
|
||||
" CONFIG_DIR / \"phase1\" / f\"{run}.json\",\n",
|
||||
" CONFIG_DIR / \"phase1\" / f\"{run}.json.skip\",\n",
|
||||
" ]\n",
|
||||
" return next((p for p in candidates if p.exists()), None)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def log_path_for(run: str) -> Path:\n",
|
||||
" return LOGS_DIR / f\"{run}.json\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def resolve_run(run: str) -> str:\n",
|
||||
" return run if log_path_for(run).exists() else RUN_ALIASES.get(run, run)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_results(run: str) -> dict[str, Any] | None:\n",
|
||||
" resolved = resolve_run(run)\n",
|
||||
" return load_json(log_path_for(resolved))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def metric_values(results: dict[str, Any], metric: str = \"auc_roc\") -> np.ndarray:\n",
|
||||
" vals = []\n",
|
||||
" for fold in results.get(\"fold_results\", []):\n",
|
||||
" value = fold.get(\"test_metrics\", {}).get(metric)\n",
|
||||
" if value is not None:\n",
|
||||
" vals.append(float(value))\n",
|
||||
" return np.asarray(vals, dtype=float)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def best_epoch_gap(fold: dict[str, Any], metric: str = \"auc\") -> float | None:\n",
|
||||
" hist = fold.get(\"history\", {})\n",
|
||||
" train_key = f\"train_{metric}\"\n",
|
||||
" val_key = f\"val_{metric}\"\n",
|
||||
" train = hist.get(train_key, [])\n",
|
||||
" val = hist.get(val_key, [])\n",
|
||||
" if not train or not val:\n",
|
||||
" return None\n",
|
||||
" idx = int(np.nanargmax(np.asarray(val, dtype=float)))\n",
|
||||
" return float(train[idx] - val[idx])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def final_epoch_gap(fold: dict[str, Any], metric: str = \"auc\") -> float | None:\n",
|
||||
" hist = fold.get(\"history\", {})\n",
|
||||
" train = hist.get(f\"train_{metric}\", [])\n",
|
||||
" val = hist.get(f\"val_{metric}\", [])\n",
|
||||
" if not train or not val:\n",
|
||||
" return None\n",
|
||||
" return float(train[-1] - val[-1])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def summarize_run(spec: RunSpec) -> dict[str, Any]:\n",
|
||||
" resolved = resolve_run(spec.run)\n",
|
||||
" results = load_results(spec.run)\n",
|
||||
" config_path = config_path_for(spec.run) or config_path_for(resolved)\n",
|
||||
" cfg = load_json(config_path) if config_path else None\n",
|
||||
"\n",
|
||||
" row = {\n",
|
||||
" \"section\": spec.section,\n",
|
||||
" \"run\": spec.run,\n",
|
||||
" \"resolved_run\": resolved,\n",
|
||||
" \"label\": spec.label,\n",
|
||||
" \"model\": spec.model,\n",
|
||||
" \"condition\": spec.condition,\n",
|
||||
" \"role\": spec.intended_role,\n",
|
||||
" \"fallback_for\": spec.fallback_for,\n",
|
||||
" \"config_path\": str(config_path.relative_to(PROJECT_ROOT)) if config_path else None,\n",
|
||||
" \"config_status\": \"present\" if config_path and config_path.suffix == \".json\" else (\"skipped\" if config_path else \"missing\"),\n",
|
||||
" \"log_status\": \"present\" if log_path_for(spec.run).exists() else (\"fallback\" if resolved != spec.run and log_path_for(resolved).exists() else \"missing\"),\n",
|
||||
" \"n_folds\": None,\n",
|
||||
" \"auc_mean\": np.nan,\n",
|
||||
" \"auc_std\": np.nan,\n",
|
||||
" \"acc_mean\": np.nan,\n",
|
||||
" \"f1_mean\": np.nan,\n",
|
||||
" \"gap_best_mean\": np.nan,\n",
|
||||
" \"gap_final_mean\": np.nan,\n",
|
||||
" \"image_size\": None,\n",
|
||||
" \"face_crop\": None,\n",
|
||||
" \"augment\": None,\n",
|
||||
" \"normalization\": None,\n",
|
||||
" \"train_sources\": None,\n",
|
||||
" \"eval_sources\": None,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" if cfg:\n",
|
||||
" row.update({\n",
|
||||
" \"image_size\": cfg.get(\"image_size\"),\n",
|
||||
" \"face_crop\": cfg.get(\"face_crop\"),\n",
|
||||
" \"augment\": \"light\" if isinstance(cfg.get(\"augment\"), dict) else cfg.get(\"augment\"),\n",
|
||||
" \"normalization\": cfg.get(\"normalization\"),\n",
|
||||
" \"train_sources\": tuple(cfg.get(\"train_sources\", [])) or None,\n",
|
||||
" \"eval_sources\": tuple(cfg.get(\"eval_sources\", [])) or None,\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
" if results:\n",
|
||||
" agg = results.get(\"aggregated_metrics\", {})\n",
|
||||
" row.update({\n",
|
||||
" \"n_folds\": results.get(\"n_folds\"),\n",
|
||||
" \"auc_mean\": agg.get(\"auc_roc\", {}).get(\"mean\", np.nan),\n",
|
||||
" \"auc_std\": agg.get(\"auc_roc\", {}).get(\"std\", np.nan),\n",
|
||||
" \"acc_mean\": agg.get(\"accuracy\", {}).get(\"mean\", np.nan),\n",
|
||||
" \"f1_mean\": agg.get(\"f1\", {}).get(\"mean\", np.nan),\n",
|
||||
" })\n",
|
||||
" best_gaps = [best_epoch_gap(f) for f in results.get(\"fold_results\", [])]\n",
|
||||
" final_gaps = [final_epoch_gap(f) for f in results.get(\"fold_results\", [])]\n",
|
||||
" best_gaps = [x for x in best_gaps if x is not None]\n",
|
||||
" final_gaps = [x for x in final_gaps if x is not None]\n",
|
||||
" row[\"gap_best_mean\"] = float(np.mean(best_gaps)) if best_gaps else np.nan\n",
|
||||
" row[\"gap_final_mean\"] = float(np.mean(final_gaps)) if final_gaps else np.nan\n",
|
||||
"\n",
|
||||
" return row\n",
|
||||
"\n",
|
||||
"runs_df = pd.DataFrame([summarize_run(spec) for spec in RUN_SPECS])\n",
|
||||
"\n",
|
||||
"# Prefer canonical rows for analysis: keep fallbacks only where expected rows are missing.\n",
|
||||
"canonical_runs_df = runs_df[runs_df[\"role\"] == \"expected\"].copy()\n",
|
||||
"for missing_run, fallback_run in RUN_ALIASES.items():\n",
|
||||
" mask = canonical_runs_df[\"run\"].eq(missing_run) & canonical_runs_df[\"log_status\"].eq(\"missing\")\n",
|
||||
" if mask.any():\n",
|
||||
" fallback = runs_df[runs_df[\"run\"].eq(fallback_run)].copy()\n",
|
||||
" if not fallback.empty:\n",
|
||||
" fallback.loc[:, \"run\"] = missing_run\n",
|
||||
" fallback.loc[:, \"label\"] = fallback.iloc[0][\"label\"].replace(\" (P1 fallback)\", \"\") + \" [P1 fallback]\"\n",
|
||||
" fallback.loc[:, \"role\"] = \"expected_via_fallback\"\n",
|
||||
" canonical_runs_df = pd.concat([canonical_runs_df[~mask], fallback], ignore_index=True)\n",
|
||||
"\n",
|
||||
"print(\"Artifact audit:\")\n",
|
||||
"display(runs_df[[\"section\", \"run\", \"resolved_run\", \"role\", \"config_status\", \"log_status\", \"n_folds\"]].sort_values([\"section\", \"run\"]))\n",
|
||||
"\n",
|
||||
"missing_expected = runs_df[(runs_df[\"role\"] == \"expected\") & (runs_df[\"log_status\"] == \"missing\")][\"run\"].tolist()\n",
|
||||
"print(f\"\\nExpected runs with no direct log: {missing_expected or 'none'}\")\n",
|
||||
"print(\"Fallbacks used:\", {k: v for k, v in RUN_ALIASES.items() if k in missing_expected})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b21a9faf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Protocol consistency audit from loaded logs/configs.\n",
|
||||
"protocol_fields = [\n",
|
||||
" \"cv_folds\", \"batch_size\", \"early_stopping_patience\", \"seed\", \"subsample\",\n",
|
||||
" \"lr\", \"weight_decay\", \"T_max\", \"epochs\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"protocol_rows = []\n",
|
||||
"for _, row in canonical_runs_df.iterrows():\n",
|
||||
" results = load_results(row[\"run\"])\n",
|
||||
" cfg = (results or {}).get(\"config\", {})\n",
|
||||
" protocol_rows.append({\"run\": row[\"run\"], **{k: cfg.get(k) for k in protocol_fields}})\n",
|
||||
"\n",
|
||||
"protocol_df = pd.DataFrame(protocol_rows)\n",
|
||||
"display(protocol_df)\n",
|
||||
"\n",
|
||||
"print(\"Field variability across loaded canonical runs:\")\n",
|
||||
"for field in protocol_fields:\n",
|
||||
" vals = sorted({str(v) for v in protocol_df[field].dropna().unique()})\n",
|
||||
" print(f\" {field:28s}: {vals}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6802bcd9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Results table\n",
|
||||
"\n",
|
||||
"The table below is ranked by AUC and includes two gap estimates:\n",
|
||||
"\n",
|
||||
"- `gap_best_mean`: train AUC minus validation AUC at each fold's best validation epoch. This is closest to the saved best checkpoint.\n",
|
||||
"- `gap_final_mean`: train AUC minus validation AUC at the final epoch. This is useful for diagnosing late overfit but is less aligned with test evaluation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "be1ec0ba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_df = canonical_runs_df[canonical_runs_df[\"log_status\"].isin([\"present\", \"fallback\"])].copy()\n",
|
||||
"analysis_df = analysis_df.sort_values(\"auc_mean\", ascending=False)\n",
|
||||
"\n",
|
||||
"cols = [\n",
|
||||
" \"section\", \"label\", \"run\", \"resolved_run\", \"model\", \"condition\", \"log_status\",\n",
|
||||
" \"auc_mean\", \"auc_std\", \"acc_mean\", \"f1_mean\", \"gap_best_mean\", \"gap_final_mean\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"display(\n",
|
||||
" analysis_df[cols]\n",
|
||||
" .style.format({\n",
|
||||
" \"auc_mean\": \"{:.4f}\",\n",
|
||||
" \"auc_std\": \"{:.4f}\",\n",
|
||||
" \"acc_mean\": \"{:.4f}\",\n",
|
||||
" \"f1_mean\": \"{:.4f}\",\n",
|
||||
" \"gap_best_mean\": \"{:+.4f}\",\n",
|
||||
" \"gap_final_mean\": \"{:+.4f}\",\n",
|
||||
" })\n",
|
||||
" .background_gradient(subset=[\"auc_mean\"], cmap=\"Greens\")\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1e0d21c1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def paired_comparison(section: str, model: str, question: str, before: str, after: str, contrast: str) -> dict[str, Any]:\n",
|
||||
" r0 = load_results(before)\n",
|
||||
" r1 = load_results(after)\n",
|
||||
" resolved_before = resolve_run(before)\n",
|
||||
" resolved_after = resolve_run(after)\n",
|
||||
" out = {\n",
|
||||
" \"section\": section,\n",
|
||||
" \"model\": model,\n",
|
||||
" \"question\": question,\n",
|
||||
" \"before\": before,\n",
|
||||
" \"after\": after,\n",
|
||||
" \"resolved_before\": resolved_before,\n",
|
||||
" \"resolved_after\": resolved_after,\n",
|
||||
" \"contrast\": contrast,\n",
|
||||
" \"status\": \"ok\" if r0 and r1 else \"missing\",\n",
|
||||
" \"n\": 0,\n",
|
||||
" \"before_auc\": np.nan,\n",
|
||||
" \"after_auc\": np.nan,\n",
|
||||
" \"delta_auc\": np.nan,\n",
|
||||
" \"delta_ci95\": np.nan,\n",
|
||||
" \"ttest_p\": np.nan,\n",
|
||||
" \"wilcoxon_p\": np.nan,\n",
|
||||
" \"cohen_dz\": np.nan,\n",
|
||||
" \"before_gap\": np.nan,\n",
|
||||
" \"after_gap\": np.nan,\n",
|
||||
" \"delta_gap\": np.nan,\n",
|
||||
" \"interpretation\": \"insufficient data\",\n",
|
||||
" \"caveat\": \"\",\n",
|
||||
" }\n",
|
||||
" if not (r0 and r1):\n",
|
||||
" return out\n",
|
||||
"\n",
|
||||
" v0 = metric_values(r0, \"auc_roc\")\n",
|
||||
" v1 = metric_values(r1, \"auc_roc\")\n",
|
||||
" n = min(len(v0), len(v1))\n",
|
||||
" v0, v1 = v0[:n], v1[:n]\n",
|
||||
" diff = v1 - v0\n",
|
||||
"\n",
|
||||
" out.update({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"before_auc\": float(np.mean(v0)),\n",
|
||||
" \"after_auc\": float(np.mean(v1)),\n",
|
||||
" \"delta_auc\": float(np.mean(diff)),\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
" if n >= 2:\n",
|
||||
" sd = float(np.std(diff, ddof=1))\n",
|
||||
" se = sd / math.sqrt(n) if sd > 0 else 0.0\n",
|
||||
" out[\"delta_ci95\"] = float(stats.t.ppf(0.975, df=n - 1) * se) if n > 1 else np.nan\n",
|
||||
" if sd > 0:\n",
|
||||
" out[\"cohen_dz\"] = float(np.mean(diff) / sd)\n",
|
||||
" out[\"ttest_p\"] = float(stats.ttest_rel(v1, v0).pvalue)\n",
|
||||
" if n >= 3 and not np.allclose(diff, 0):\n",
|
||||
" try:\n",
|
||||
" out[\"wilcoxon_p\"] = float(stats.wilcoxon(diff).pvalue)\n",
|
||||
" except ValueError:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" gaps0 = [best_epoch_gap(f) for f in r0.get(\"fold_results\", [])]\n",
|
||||
" gaps1 = [best_epoch_gap(f) for f in r1.get(\"fold_results\", [])]\n",
|
||||
" gaps0 = np.asarray([x for x in gaps0 if x is not None], dtype=float)\n",
|
||||
" gaps1 = np.asarray([x for x in gaps1 if x is not None], dtype=float)\n",
|
||||
" if len(gaps0) and len(gaps1):\n",
|
||||
" m = min(len(gaps0), len(gaps1))\n",
|
||||
" out[\"before_gap\"] = float(np.mean(gaps0[:m]))\n",
|
||||
" out[\"after_gap\"] = float(np.mean(gaps1[:m]))\n",
|
||||
" out[\"delta_gap\"] = float(np.mean(gaps1[:m] - gaps0[:m]))\n",
|
||||
"\n",
|
||||
" if question == \"source_holdout\":\n",
|
||||
" out[\"caveat\"] = \"Aggregate holdout-run AUC only; not held-out-source vs in-source AUC.\"\n",
|
||||
" if before != resolved_before or after != resolved_after:\n",
|
||||
" out[\"caveat\"] = (out[\"caveat\"] + \" \" if out[\"caveat\"] else \"\") + \"Uses Phase 1 fallback for missing p2a 128 log.\"\n",
|
||||
"\n",
|
||||
" if out[\"delta_auc\"] >= 0.01:\n",
|
||||
" out[\"interpretation\"] = \"meaningful improvement\"\n",
|
||||
" elif out[\"delta_auc\"] > 0.002:\n",
|
||||
" out[\"interpretation\"] = \"small improvement\"\n",
|
||||
" elif out[\"delta_auc\"] >= -0.002:\n",
|
||||
" out[\"interpretation\"] = \"negligible change\"\n",
|
||||
" elif out[\"delta_auc\"] > -0.01:\n",
|
||||
" out[\"interpretation\"] = \"small drop\"\n",
|
||||
" else:\n",
|
||||
" out[\"interpretation\"] = \"meaningful drop\"\n",
|
||||
" return out\n",
|
||||
"\n",
|
||||
"comparisons_df = pd.DataFrame([paired_comparison(*args) for args in PLANNED_COMPARISONS])\n",
|
||||
"\n",
|
||||
"# Benjamini-Hochberg correction across planned paired t-tests where available.\n",
|
||||
"valid_p = comparisons_df[\"ttest_p\"].notna()\n",
|
||||
"pvals = comparisons_df.loc[valid_p, \"ttest_p\"].to_numpy()\n",
|
||||
"qvals = np.full(len(comparisons_df), np.nan)\n",
|
||||
"if len(pvals):\n",
|
||||
" order = np.argsort(pvals)\n",
|
||||
" ranked = pvals[order]\n",
|
||||
" adjusted = np.empty_like(ranked)\n",
|
||||
" m = len(ranked)\n",
|
||||
" running = 1.0\n",
|
||||
" for i in range(m - 1, -1, -1):\n",
|
||||
" running = min(running, ranked[i] * m / (i + 1))\n",
|
||||
" adjusted[i] = running\n",
|
||||
" qvals[np.where(valid_p)[0][order]] = adjusted\n",
|
||||
"comparisons_df[\"bh_q\"] = qvals\n",
|
||||
"\n",
|
||||
"display(\n",
|
||||
" comparisons_df[[\n",
|
||||
" \"section\", \"model\", \"question\", \"contrast\", \"before_auc\", \"after_auc\", \"delta_auc\",\n",
|
||||
" \"delta_ci95\", \"ttest_p\", \"bh_q\", \"wilcoxon_p\", \"cohen_dz\", \"delta_gap\", \"interpretation\", \"caveat\",\n",
|
||||
" ]].style.format({\n",
|
||||
" \"before_auc\": \"{:.4f}\",\n",
|
||||
" \"after_auc\": \"{:.4f}\",\n",
|
||||
" \"delta_auc\": \"{:+.4f}\",\n",
|
||||
" \"delta_ci95\": \"\u00b1{:.4f}\",\n",
|
||||
" \"ttest_p\": \"{:.4f}\",\n",
|
||||
" \"bh_q\": \"{:.4f}\",\n",
|
||||
" \"wilcoxon_p\": \"{:.4f}\",\n",
|
||||
" \"cohen_dz\": \"{:+.2f}\",\n",
|
||||
" \"delta_gap\": \"{:+.4f}\",\n",
|
||||
" }).background_gradient(subset=[\"delta_auc\"], cmap=\"RdYlGn\", vmin=-0.06, vmax=0.06)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f20e5262",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Visual summary\n",
|
||||
"\n",
|
||||
"Two plots are most useful for decision-making:\n",
|
||||
"\n",
|
||||
"- Ranking all conditions by AUC shows the best observed configurations but can overstate duplicated/near-identical runs.\n",
|
||||
"- Paired delta plot shows the controlled effect of each preprocessing change and exposes uncertainty."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "42882c6a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_df = analysis_df.copy()\n",
|
||||
"plot_df[\"display_label\"] = plot_df[\"section\"] + \" | \" + plot_df[\"label\"]\n",
|
||||
"plot_df = plot_df.sort_values(\"auc_mean\", ascending=True)\n",
|
||||
"\n",
|
||||
"fig, ax = plt.subplots(figsize=(11, max(7, 0.35 * len(plot_df))))\n",
|
||||
"colors = {\"2A\": \"#4C78A8\", \"2B\": \"#F58518\", \"2C\": \"#54A24B\", \"2D\": \"#E45756\", \"2E\": \"#B279A2\"}\n",
|
||||
"ax.barh(\n",
|
||||
" plot_df[\"display_label\"],\n",
|
||||
" plot_df[\"auc_mean\"],\n",
|
||||
" xerr=plot_df[\"auc_std\"],\n",
|
||||
" color=[colors.get(s, \"#999999\") for s in plot_df[\"section\"]],\n",
|
||||
" alpha=0.85,\n",
|
||||
")\n",
|
||||
"ax.set_xlim(0.65, 1.0)\n",
|
||||
"ax.set_xlabel(\"Mean AUC across CV folds\")\n",
|
||||
"ax.set_title(\"Phase 2 Conditions Ranked by AUC\")\n",
|
||||
"ax.axvline(0.95, color=\"black\", linewidth=1, linestyle=\"--\", alpha=0.4)\n",
|
||||
"for y, (_, row) in enumerate(plot_df.iterrows()):\n",
|
||||
" ax.text(row[\"auc_mean\"] + 0.004, y, f\"{row['auc_mean']:.4f}\", va=\"center\", fontsize=9)\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIGURES_DIR / \"ranked_auc.png\", dpi=200, bbox_inches=\"tight\")\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"forest = comparisons_df.copy()\n",
|
||||
"forest[\"display\"] = forest[\"section\"] + \" \" + forest[\"model\"] + \" - \" + forest[\"contrast\"]\n",
|
||||
"forest = forest.iloc[::-1]\n",
|
||||
"fig, ax = plt.subplots(figsize=(11, max(6, 0.45 * len(forest))))\n",
|
||||
"y = np.arange(len(forest))\n",
|
||||
"ax.errorbar(\n",
|
||||
" forest[\"delta_auc\"], y,\n",
|
||||
" xerr=forest[\"delta_ci95\"],\n",
|
||||
" fmt=\"o\", color=\"#1F2937\", ecolor=\"#6B7280\", capsize=4,\n",
|
||||
")\n",
|
||||
"ax.axvline(0, color=\"black\", linewidth=1)\n",
|
||||
"ax.axvspan(-0.002, 0.002, color=\"#9CA3AF\", alpha=0.18, label=\"negligible band\")\n",
|
||||
"ax.set_yticks(y)\n",
|
||||
"ax.set_yticklabels(forest[\"display\"])\n",
|
||||
"ax.set_xlabel(\"Delta AUC (after - before), paired by fold\")\n",
|
||||
"ax.set_title(\"Planned Phase 2 Effect Estimates\")\n",
|
||||
"ax.legend(loc=\"lower right\")\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIGURES_DIR / \"planned_effects.png\", dpi=200, bbox_inches=\"tight\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e063cfc0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2A - Shortcut analysis\n",
|
||||
"\n",
|
||||
"Shortcut checks map to `p2a_*` configs:\n",
|
||||
"- `p2a_t1_original` vs `p2a_t2_real_norm` (normalization)\n",
|
||||
"- `p2a_t1_original` vs `p2a_t3_holdout_*` (source_holdout)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "910bd5bd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def comparison_subset(section: str, question: str | None = None) -> pd.DataFrame:\n",
|
||||
" df = comparisons_df[comparisons_df[\"section\"].eq(section)].copy()\n",
|
||||
" if question:\n",
|
||||
" df = df[df[\"question\"].eq(question)]\n",
|
||||
" return df\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def print_comparison_readout(df: pd.DataFrame) -> None:\n",
|
||||
" for _, row in df.iterrows():\n",
|
||||
" print(f\"{row['section']} {row['model']} - {row['contrast']}\")\n",
|
||||
" print(f\" AUC: {row['before_auc']:.4f} -> {row['after_auc']:.4f} ({row['delta_auc']:+.4f})\")\n",
|
||||
" print(f\" paired t p={row['ttest_p']:.4f}, BH q={row['bh_q']:.4f}, CI95 delta=\u00b1{row['delta_ci95']:.4f}\")\n",
|
||||
" print(f\" gap delta: {row['delta_gap']:+.4f}; interpretation: {row['interpretation']}\")\n",
|
||||
" if row['caveat']:\n",
|
||||
" print(f\" caveat: {row['caveat']}\")\n",
|
||||
" print()\n",
|
||||
"\n",
|
||||
"print_comparison_readout(comparison_subset(\"2B\", \"resolution\"))\n",
|
||||
"\n",
|
||||
"res_plot = comparison_subset(\"2B\", \"resolution\")\n",
|
||||
"fig, ax = plt.subplots(figsize=(8, 5))\n",
|
||||
"for _, row in res_plot.iterrows():\n",
|
||||
" r0, r1 = load_results(row[\"before\"]), load_results(row[\"after\"])\n",
|
||||
" v0, v1 = metric_values(r0), metric_values(r1)\n",
|
||||
" x = [0, 1]\n",
|
||||
" for a, b in zip(v0, v1):\n",
|
||||
" ax.plot(x, [a, b], color=\"#9CA3AF\", alpha=0.7)\n",
|
||||
" ax.plot(x, [v0.mean(), v1.mean()], marker=\"o\", linewidth=3, label=row[\"model\"])\n",
|
||||
"ax.set_xticks([0, 1])\n",
|
||||
"ax.set_xticklabels([\"128\", \"224\"])\n",
|
||||
"ax.set_ylabel(\"AUC\")\n",
|
||||
"ax.set_title(\"2B Resolution: Fold-Paired AUC\")\n",
|
||||
"ax.legend()\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIGURES_DIR / \"2b_resolution_paired.png\", dpi=200, bbox_inches=\"tight\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "530e8675",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2B - Resolution impact\n",
|
||||
"\n",
|
||||
"This section compares 128 vs 224 using `p2b_*_224` and Phase 1 baselines as explicit 128 fallbacks.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "13304d38",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_comparison_readout(comparison_subset(\"2C\", \"facecrop\"))\n",
|
||||
"\n",
|
||||
"face_df = canonical_runs_df[canonical_runs_df[\"section\"].eq(\"2C\")].copy()\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=False)\n",
|
||||
"for ax, model in zip(axes, [\"SimpleCNN\", \"ResNet18\"]):\n",
|
||||
" sub = face_df[face_df[\"model\"].eq(model)].sort_values(\"face_crop\")\n",
|
||||
" ax.bar(sub[\"condition\"], sub[\"auc_mean\"], yerr=sub[\"auc_std\"], color=[\"#D97706\", \"#059669\"], alpha=0.85, capsize=5)\n",
|
||||
" ax.set_title(model)\n",
|
||||
" ax.set_ylim(0.70 if model == \"SimpleCNN\" else 0.94, 0.99)\n",
|
||||
" ax.set_ylabel(\"AUC\")\n",
|
||||
" ax.tick_params(axis=\"x\", rotation=20)\n",
|
||||
"fig.suptitle(\"2C Facecrop Impact\")\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIGURES_DIR / \"2c_facecrop.png\", dpi=200, bbox_inches=\"tight\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8702d10d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2C - Facecrop impact\n",
|
||||
"\n",
|
||||
"This section compares `p2c_*_facecrop` against the matching `p2b_*_224` no-facecrop baselines.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec5e03ef",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_comparison_readout(comparison_subset(\"2A\"))\n\n# Inspect whether logs contain the per-source data needed by v2.md.\nsource_audit = []\nfor run in [\"p2a_t1_original\", \"p2a_t3_holdout_text2img\", \"p2a_t3_holdout_inpainting\", \"p2a_t3_holdout_insight\"]:\n results = load_results(run)\n has_per_source = False\n has_records = False\n example_keys = []\n if results:\n for fold in results.get(\"fold_results\", []):\n tm = fold.get(\"test_metrics\", {})\n example_keys = sorted(tm.keys())\n has_per_source = has_per_source or any(k in tm for k in [\"per_source\", \"per_source_metrics\", \"pairwise_source_metrics\", \"source_metrics\", \"pair_metrics\"])\n has_records = has_records or any(k in fold for k in [\"records\", \"predictions\", \"test_records\"])\n source_audit.append({\n \"run\": run,\n \"has_per_source_metrics\": has_per_source,\n \"has_prediction_records\": has_records,\n \"test_metric_keys\": example_keys,\n })\nsource_audit_df = pd.DataFrame(source_audit)\ndisplay(source_audit_df)\n\nholdout_runs = [\"p2a_t1_original\", \"p2a_t3_holdout_text2img\", \"p2a_t3_holdout_inpainting\", \"p2a_t3_holdout_insight\"]\nholdout_df = canonical_runs_df[canonical_runs_df[\"run\"].isin(holdout_runs)].copy()\nholdout_df[\"delta_vs_all_source\"] = holdout_df[\"auc_mean\"] - float(holdout_df.loc[holdout_df[\"run\"].eq(\"p2a_t1_original\"), \"auc_mean\"].iloc[0])\n\nfig, ax = plt.subplots(figsize=(9, 5))\nax.bar(holdout_df[\"label\"], holdout_df[\"auc_mean\"], yerr=holdout_df[\"auc_std\"], color=\"#54A24B\", alpha=0.85, capsize=5)\nax.set_ylim(0.88, 0.99)\nax.set_ylabel(\"Aggregate AUC\")\nax.set_title(\"2C Source Holdout Proxy: Aggregate Test AUC\")\nax.tick_params(axis=\"x\", rotation=20)\nfor i, (_, row) in enumerate(holdout_df.iterrows()):\n ax.text(i, row[\"auc_mean\"] + 0.004, f\"{row['delta_vs_all_source']:+.3f}\", ha=\"center\", fontsize=9)\nfig.tight_layout()\nfig.savefig(FIGURES_DIR / \"2c_holdout_proxy.png\", dpi=200, bbox_inches=\"tight\")\nplt.show()\n\nprint(\"Geometry diagnostic evidence:\")\ngeometry_keys = []\nfor run in [\"p2a_t1_original\", \"p2a_t2_real_norm\"]:\n results = load_results(run)\n cfg = (results or {}).get(\"config\", {})\n geometry_keys.append({\n \"run\": run,\n \"config_geometry_condition\": cfg.get(\"geometry_condition\"),\n \"has_matched_geometry_metric\": any(\n \"geometry\" in str(k).lower() or \"matched\" in str(k).lower()\n for fold in (results or {}).get(\"fold_results\", [])\n for k in fold.get(\"test_metrics\", {}).keys()\n ),\n })\ndisplay(pd.DataFrame(geometry_keys))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2c3b8812",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2D / 2E - Augmentation impact and test-set integrity\n",
|
||||
"\n",
|
||||
"The augmentation question has two parts:\n",
|
||||
"\n",
|
||||
"- Does light augmentation help at 224 without facecrop?\n",
|
||||
"- Does it help once facecrop is enabled?\n",
|
||||
"\n",
|
||||
"The implementation also needs to guarantee that validation/test evaluation is not stochastic. The preprocessing pipeline keeps stochastic operations behind `self.train`, so `train=False` disables them even if augmentation settings exist."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f11c3257",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"2D (p2d): augmentation without facecrop\")\n",
|
||||
"print_comparison_readout(comparison_subset(\"2D\", \"augmentation\"))\n",
|
||||
"print(\"2E (p2e): augmentation with facecrop\")\n",
|
||||
"print_comparison_readout(comparison_subset(\"2E\", \"facecrop + augmentation\"))\n",
|
||||
"\n",
|
||||
"aug_sections = comparisons_df[comparisons_df[\"section\"].isin([\"2D\", \"2E\"])].copy()\n",
|
||||
"fig, ax = plt.subplots(figsize=(9, 5))\n",
|
||||
"labels = aug_sections[\"section\"] + \" \" + aug_sections[\"model\"]\n",
|
||||
"ax.bar(labels, aug_sections[\"delta_auc\"], yerr=aug_sections[\"delta_ci95\"], color=[\"#E45756\" if d < 0 else \"#059669\" for d in aug_sections[\"delta_auc\"]], alpha=0.85, capsize=5)\n",
|
||||
"ax.axhline(0, color=\"black\", linewidth=1)\n",
|
||||
"ax.set_ylabel(\"Delta AUC from adding augmentation\")\n",
|
||||
"ax.set_title(\"Augmentation Effects Across Facecrop Conditions\")\n",
|
||||
"ax.tick_params(axis=\"x\", rotation=20)\n",
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(FIGURES_DIR / \"2d_2e_augmentation_effects.png\", dpi=200, bbox_inches=\"tight\")\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"# Static and behavioral audit of eval stochasticity.\n",
|
||||
"try:\n",
|
||||
" import inspect\n",
|
||||
" from src.preprocessing.pipeline import DFFImagePipeline\n",
|
||||
" from src.evaluation import evaluate as evaluate_module\n",
|
||||
"\n",
|
||||
" pipeline_src = inspect.getsource(DFFImagePipeline)\n",
|
||||
" build_transforms_src = inspect.getsource(evaluate_module.build_transforms)\n",
|
||||
" stochastic_guards = {\n",
|
||||
" \"flip_guarded_by_train\": \"if self.train and random.random() < self.hflip_p\" in pipeline_src,\n",
|
||||
" \"rotate_guarded_by_train\": \"if self.train and self.rotation_degrees > 0\" in pipeline_src,\n",
|
||||
" \"color_jitter_returns_when_not_train\": \"if not self.train:\" in pipeline_src,\n",
|
||||
" \"blur_guarded_by_train\": \"if self.train and random.random() < self.blur_p\" in pipeline_src,\n",
|
||||
" \"jpeg_guarded_by_train\": \"if self.train and random.random() < self.jpeg_p\" in pipeline_src,\n",
|
||||
" \"erase_guarded_by_train\": \"if self.train and random.random() < self.erase_p\" in pipeline_src,\n",
|
||||
" \"noise_guarded_by_train\": \"if self.train and random.random() < self.noise_p\" in pipeline_src,\n",
|
||||
" \"cv_transform_uses_train_flag\": \"get_transforms(train=train\" in build_transforms_src,\n",
|
||||
" }\n",
|
||||
" display(pd.DataFrame([stochastic_guards]).T.rename(columns={0: \"passes\"}))\n",
|
||||
"except Exception as exc:\n",
|
||||
" print(f\"Could not run transform audit: {exc}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "02e47658",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Decision synthesis\n",
|
||||
"\n",
|
||||
"This section converts the evidence into Phase 3 settings. It intentionally distinguishes a recommendation from a claim:\n",
|
||||
"\n",
|
||||
"- Recommendation: choose the setting that is best supported for the next experiment.\n",
|
||||
"- Claim: what the current evidence proves. Some Phase 2C claims remain incomplete without per-source or matched-geometry outputs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7034443c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_delta(question: str, model: str | None = None, section: str | None = None) -> pd.DataFrame:\n",
|
||||
" df = comparisons_df[comparisons_df[\"question\"].eq(question)].copy()\n",
|
||||
" if model:\n",
|
||||
" df = df[df[\"model\"].eq(model)]\n",
|
||||
" if section:\n",
|
||||
" df = df[df[\"section\"].eq(section)]\n",
|
||||
" return df\n",
|
||||
"\n",
|
||||
"resolution_resnet = get_delta(\"resolution\", \"ResNet18\").iloc[0]\n",
|
||||
"facecrop_resnet = get_delta(\"facecrop\", \"ResNet18\").iloc[0]\n",
|
||||
"facecrop_simple = get_delta(\"facecrop\", \"SimpleCNN\").iloc[0]\n",
|
||||
"aug_no_crop_resnet = get_delta(\"augmentation\", \"ResNet18\").iloc[0]\n",
|
||||
"aug_no_crop_simple = get_delta(\"augmentation\", \"SimpleCNN\").iloc[0]\n",
|
||||
"aug_crop_resnet = get_delta(\"facecrop + augmentation\", \"ResNet18\").iloc[0]\n",
|
||||
"aug_crop_simple = get_delta(\"facecrop + augmentation\", \"SimpleCNN\").iloc[0]\n",
|
||||
"norm = get_delta(\"normalization\", \"ResNet18\").iloc[0]\n",
|
||||
"\n",
|
||||
"recommendations = [\n",
|
||||
" {\n",
|
||||
" \"choice\": \"resolution\",\n",
|
||||
" \"recommendation\": \"224x224\",\n",
|
||||
" \"evidence\": f\"ResNet18 delta AUC {resolution_resnet.delta_auc:+.4f}; SimpleCNN does not determine Phase 3 capacity.\",\n",
|
||||
" \"confidence\": \"high\" if resolution_resnet.delta_auc > 0.02 else \"medium\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"choice\": \"facecrop\",\n",
|
||||
" \"recommendation\": \"use facecrop\",\n",
|
||||
" \"evidence\": f\"Small positive deltas for both models: SimpleCNN {facecrop_simple.delta_auc:+.4f}, ResNet18 {facecrop_resnet.delta_auc:+.4f}.\",\n",
|
||||
" \"confidence\": \"medium\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"choice\": \"augmentation\",\n",
|
||||
" \"recommendation\": \"do not use light augmentation for Phase 3 at 20% data\",\n",
|
||||
" \"evidence\": f\"SimpleCNN drops {aug_no_crop_simple.delta_auc:+.4f} without facecrop and {aug_crop_simple.delta_auc:+.4f} with facecrop; ResNet18 is neutral/slightly mixed ({aug_no_crop_resnet.delta_auc:+.4f}, {aug_crop_resnet.delta_auc:+.4f}).\",\n",
|
||||
" \"confidence\": \"high for SimpleCNN, medium for ResNet18\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"choice\": \"normalization\",\n",
|
||||
" \"recommendation\": \"ImageNet normalization\",\n",
|
||||
" \"evidence\": f\"Real-train-only normalization delta AUC {norm.delta_auc:+.4f}; no useful gain and less standard for pretrained ResNet.\",\n",
|
||||
" \"confidence\": \"medium\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"choice\": \"shortcut/source claims\",\n",
|
||||
" \"recommendation\": \"do not overclaim; add per-source or prediction exports before final report\",\n",
|
||||
" \"evidence\": \"Current CV logs lack held-out-source vs in-source AUC and matched-geometry test metrics.\",\n",
|
||||
" \"confidence\": \"high\",\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"recommendations_df = pd.DataFrame(recommendations)\n",
|
||||
"display(recommendations_df)\n",
|
||||
"\n",
|
||||
"summary = {\n",
|
||||
" \"phase\": \"phase2\",\n",
|
||||
" \"source_documents\": [\"classifier/v2.md\", \"classifier/impl.md\"],\n",
|
||||
" \"artifact_counts\": {\n",
|
||||
" \"canonical_runs\": int(len(canonical_runs_df)),\n",
|
||||
" \"loaded_canonical_runs\": int(canonical_runs_df[\"log_status\"].isin([\"present\", \"fallback\"]).sum()),\n",
|
||||
" \"fallback_runs_used\": {k: v for k, v in RUN_ALIASES.items() if resolve_run(k) != k},\n",
|
||||
" },\n",
|
||||
" \"recommendations\": recommendations,\n",
|
||||
" \"planned_comparisons\": comparisons_df.replace({np.nan: None}).to_dict(orient=\"records\"),\n",
|
||||
" \"known_gaps\": [\n",
|
||||
" \"Dedicated p2a_*_128 logs are absent/skipped; Phase 1 baselines are used as fallbacks.\",\n",
|
||||
" \"Source holdout logs do not include prediction-level or per-source metrics, so held-out-source AUC vs in-source AUC cannot be computed.\",\n",
|
||||
" \"No matched-geometry evaluation metric is present in p2c logs, so geometry shortcut analysis is incomplete.\",\n",
|
||||
" ],\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"summary_path = ANALYSIS_DIR / \"phase2_analysis_summary.json\"\n",
|
||||
"with summary_path.open(\"w\") as f:\n",
|
||||
" json.dump(summary, f, indent=2)\n",
|
||||
"\n",
|
||||
"print(f\"Saved summary: {summary_path.relative_to(PROJECT_ROOT)}\")\n",
|
||||
"print(f\"Saved figures: {FIGURES_DIR.relative_to(PROJECT_ROOT)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5a337f73",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Report-ready conclusion\n",
|
||||
"\n",
|
||||
"The strongest Phase 2 result is the resolution effect for ResNet18: moving to 224x224 substantially improves AUC under the controlled CV protocol. Face cropping gives a small positive effect and is reasonable to carry forward, especially because it aligns the model with face evidence rather than background context. Light augmentation is not supported at this 20% data setting: it strongly hurts SimpleCNN and provides no reliable gain for ResNet18, with or without face cropping. ImageNet normalization remains preferable because real-train-only normalization does not improve AUC and is less aligned with pretrained ResNet expectations.\n",
|
||||
"\n",
|
||||
"Recommended Phase 3 preprocessing: **224x224, facecrop enabled, no light augmentation, ImageNet normalization**.\n",
|
||||
"\n",
|
||||
"Limitations to fix before the final report: export prediction-level records or per-source pairwise metrics for source holdout, and add the matched-geometry evaluation required by the shortcut-analysis plan. Without those artifacts, Phase 2C can only support a limited shortcut analysis."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "drl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user