{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Phase 1 - Pipeline Selection\n", "\n", "Phase 1 answers the data-handling question left open by the baseline. Instead\n", "of changing the model family, it uses a cheap DCGAN proxy and varies one\n", "pipeline choice at a time.\n", "\n", "This phase is deliberately controlled. The output quality is still limited, but\n", "the relative differences tell us which input pipeline gives later recipes the\n", "best chance.\n", "\n", "## What this phase changes\n", "\n", "Four pipeline choices are tested as ablations:\n", "\n", "| Ablation | Question | Choices |\n", "|---|---|---|\n", "| 1A | How much resolution can the proxy handle? | 64x64 vs 128x128 |\n", "| 1B | Does alignment matter? | Full raw image vs MTCNN-aligned crop |\n", "| 1C | Does augmentation help the proxy? | H-flip only vs H-flip + rotation + color jitter |\n", "| 1D | Should raw and aligned images be mixed? | Aligned only vs aligned + raw mixed |\n", "\n", "**Headline result:** `p1c_dcgan_full_aug` reaches **FID@50 = 33.4**. The\n", "locked pipeline for the following phases is aligned face crops at 64x64, no\n", "raw/aligned mixing, with augmentation choices following the saved family configs.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reference: Phase 0 baseline from the same family\n", "\n", "The phase 0 WGAN-GP (`p0_wgan`) trained on raw un-aligned images for 200 epochs\n", "without any pipeline tuning, and it also collapsed. Phase 1 below uses the same model class\n", "with the data pipeline systematically varied; the architecture limitation is constant.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import numpy as np\n", "import pandas as pd\n", "\n", "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", "\n", "try:\n", " display\n", "except NameError:\n", " def display(obj):\n", " print(obj)\n", "\n", "def find_generator_root():\n", " for base in [Path.cwd(), *Path.cwd().parents]:\n", " for candidate in [base, base / \"generator\"]:\n", " if (candidate / \"outputs\" / \"logs\").exists() and (candidate / \"outputs\" / \"samples\").exists():\n", " return candidate.resolve()\n", " raise FileNotFoundError(\"Could not locate generator/outputs from the current working directory\")\n", "\n", "GENERATOR_ROOT = find_generator_root()\n", "PROJECT_ROOT = GENERATOR_ROOT.parent\n", "OUTPUTS = GENERATOR_ROOT / \"outputs\"\n", "LOGS = OUTPUTS / \"logs\"\n", "SAMPLES = OUTPUTS / \"samples\"\n", "\n", "\n", "def load_log(name):\n", " p = LOGS / f\"{name}.json\"\n", " return json.load(open(p)) if p.exists() else None\n", "\n", "def get_fid(log, epoch):\n", " fid = log.get(\"history\", {}).get(\"fid\", {})\n", " return fid.get(str(epoch))\n", "\n", "def fid_series(log):\n", " fid = log.get(\"history\", {}).get(\"fid\", {})\n", " items = sorted((int(k), v) for k, v in fid.items())\n", " return [e for e, _ in items], [v for _, v in items]\n", "\n", "def show_image_or_missing(ax, path, title=None):\n", " if path.exists():\n", " ax.imshow(mpimg.imread(str(path)))\n", " else:\n", " ax.text(0.5, 0.5, f\"missing artifact\\n{path.name}\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " if title:\n", " ax.set_title(title, fontsize=9)\n", " ax.axis(\"off\")\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load all experiment logs\n", "\n", "All evidence in this notebook comes from the existing Phase 1 logs and sample folders." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded 7 experiments:\n", " p1a_dcgan_128\n", " p1a_dcgan_64\n", " p1b_dcgan_aligned\n", " p1b_dcgan_full\n", " p1c_dcgan_full_aug\n", " p1c_dcgan_hflip\n", " p1d_dcgan_combined\n" ] } ], "source": [ "run_names = sorted(p.stem for p in LOGS.glob(\"p1*.json\"))\n", "runs = {name: load_log(name) for name in run_names}\n", "runs = {k: v for k, v in runs.items() if v}\n", "\n", "print(f\"Loaded {len(runs)} experiments:\")\n", "for name in run_names: print(f\" {name}\")\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "experiment_groups = {\n", " \"1A - Resolution\": {\"p1a_dcgan_64\": \"64x64 (raw)\",\n", " \"p1a_dcgan_128\": \"128x128 (raw)\"},\n", " \"1B - Alignment\": {\"p1b_dcgan_full\": \"Full image (raw)\",\n", " \"p1b_dcgan_aligned\": \"MTCNN-aligned\"},\n", " \"1C - Augmentation\": {\"p1c_dcgan_hflip\": \"H-flip only\",\n", " \"p1c_dcgan_full_aug\": \"H-flip + rot + colour\"},\n", " \"1D - Dataset mixing\": {\"p1b_dcgan_aligned\": \"Aligned only\",\n", " \"p1d_dcgan_combined\": \"Aligned + raw mixed\"},\n", "}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. FID comparison table\n", "\n", "The table ranks the proxy runs. The values are useful within Phase 1, but they should not be compared directly with later FID protocols." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
| \n", " | Experiment | \n", "Size | \n", "Augment | \n", "FID@25 | \n", "FID@50 | \n", "G loss (ep50) | \n", "D loss (ep50) | \n", "
|---|---|---|---|---|---|---|---|
| 4 | \n", "p1c_dcgan_full_aug | \n", "64x64 | \n", "True | \n", "48.0 | \n", "33.4 | \n", "3.480 | \n", "0.412 | \n", "
| 5 | \n", "p1c_dcgan_hflip | \n", "64x64 | \n", "False | \n", "48.9 | \n", "37.9 | \n", "3.739 | \n", "0.392 | \n", "
| 2 | \n", "p1b_dcgan_aligned | \n", "64x64 | \n", "False | \n", "47.5 | \n", "42.0 | \n", "3.965 | \n", "0.312 | \n", "
| 1 | \n", "p1a_dcgan_64 | \n", "64x64 | \n", "False | \n", "120.9 | \n", "86.7 | \n", "4.019 | \n", "0.283 | \n", "
| 6 | \n", "p1d_dcgan_combined | \n", "64x64 | \n", "False | \n", "95.5 | \n", "87.4 | \n", "5.265 | \n", "0.198 | \n", "
| 3 | \n", "p1b_dcgan_full | \n", "64x64 | \n", "False | \n", "109.6 | \n", "89.0 | \n", "3.960 | \n", "0.370 | \n", "
| 0 | \n", "p1a_dcgan_128 | \n", "128x128 | \n", "False | \n", "143.1 | \n", "115.0 | \n", "5.013 | \n", "0.185 | \n", "