{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Phase 2 - GAN Progression\n", "\n", "Phase 2 keeps the Phase 1 pipeline fixed and changes the GAN recipe. This makes\n", "the question narrow: once the data is aligned, what model changes are needed to\n", "escape collapse?\n", "\n", "The progression moves from the DCGAN proxy to Wasserstein training, then to the\n", "stability package that finally makes the samples recognizable.\n", "\n", "## What this phase changes\n", "\n", "| Run | Recipe change |\n", "|---|---|\n", "| `p2_1_dcgan` | DCGAN baseline under the Phase 2 protocol |\n", "| `p2_2_wgan` | BCE objective replaced by Wasserstein-GP |\n", "| `p2_3_wgan_sn_attn` | Spectral norm + GroupNorm + self-attention |\n", "| `p2_4_wgan_sn_attn_128` | Same stabilized recipe at 128x128 |\n", "\n", "**Headline result:** `p2_3_wgan_sn_attn` reaches **best FID = 110.1**. The\n", "critical step is not the objective change alone; it is the stabilized 64x64\n", "recipe with spectral normalization, GroupNorm, and self-attention.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> ### FID is not comparable across phases\n", ">\n", "> Phase 1's \"best\" was FID 33 (`p1c_dcgan_full_aug`). Phase 2's \"best\" is FID 110.\n", "> **This is not a regression.** The two numbers were computed under different\n", "> protocols:\n", ">\n", "> - Phase 1 used a quick proxy FID for fast pipeline ablation, with a smaller\n", "> real-image reference set, on the un-augmented validation split.\n", "> - Phase 2 uses the project's standard FID protocol: 5000 aligned 64x64 real\n", "> images from the matched augmentation pipeline (`fid_n_real: 5000`).\n", ">\n", "> Within Phase 2 the deltas are meaningful: 2.2 -> 2.3 is about **-311 FID**,\n", "> which is a real architecture jump. Don't compare phase 1 vs phase 2 numbers absolutely;\n", "> only compare within a phase, or against phase 5 which uses the same protocol.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reference: Phase 0 baseline from the same family\n", "\n", "`p0_wgan` was the un-aligned, no-augmentation, basic-architecture WGAN-GP: face blobs\n", "with no recognisable features (no FID logged). Phase 2 below shows what happens once\n", "the pipeline is fixed and the model is allowed to evolve.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "bf821370", "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import numpy as np\n", "import pandas as pd\n", "\n", "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", "\n", "def find_generator_root():\n", " for base in [Path.cwd(), *Path.cwd().parents]:\n", " for candidate in [base, base / \"generator\"]:\n", " if (candidate / \"outputs\" / \"logs\").exists() and (candidate / \"outputs\" / \"samples\").exists():\n", " return candidate.resolve()\n", " raise FileNotFoundError(\"Could not locate generator/outputs from the current working directory\")\n", "\n", "GENERATOR_ROOT = find_generator_root()\n", "PROJECT_ROOT = GENERATOR_ROOT.parent\n", "OUTPUTS = GENERATOR_ROOT / \"outputs\"\n", "LOGS = OUTPUTS / \"logs\"\n", "SAMPLES = OUTPUTS / \"samples\"\n", "\n", "\n", "def load_log(name):\n", " p = LOGS / f\"{name}.json\"\n", " return json.load(open(p)) if p.exists() else None\n", "\n", "def get_fid(log, epoch):\n", " fid = log.get(\"history\", {}).get(\"fid\", {})\n", " return fid.get(str(epoch))\n", "\n", "def fid_series(log):\n", " fid = log.get(\"history\", {}).get(\"fid\", {})\n", " items = sorted((int(k), v) for k, v in fid.items())\n", " return [e for e, _ in items], [v for _, v in items]\n", "\n", "def show_image_or_missing(ax, path, title=None):\n", " if path.exists():\n", " ax.imshow(mpimg.imread(str(path)))\n", " else:\n", " ax.text(0.5, 0.5, f\"missing artifact\\n{path.name}\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " if title:\n", " ax.set_title(title, fontsize=9)\n", " ax.axis(\"off\")\n", "\n" ] }, { "cell_type": "markdown", "id": "f627af73", "metadata": {}, "source": [ "## 1. Load experiment logs\n", "\n", "Only existing Phase 2 logs are loaded here. No training or re-evaluation is launched." ] }, { "cell_type": "code", "execution_count": 10, "id": "59f61b4e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " p2_1_dcgan: 100 epochs\n", " p2_2_wgan: 100 epochs\n", " p2_3_wgan_sn_attn: 100 epochs\n", " p2_4_wgan_sn_attn_128: 100 epochs\n" ] } ], "source": [ "run_names = [\"p2_1_dcgan\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"]\n", "run_labels = {\n", " \"p2_1_dcgan\": \"2.1 DCGAN (BCE)\",\n", " \"p2_2_wgan\": \"2.2 WGAN-GP\",\n", " \"p2_3_wgan_sn_attn\": \"2.3 + SN + Attn\",\n", " \"p2_4_wgan_sn_attn_128\": \"2.4 + 128x128\",\n", "}\n", "visual_notes = {\n", " \"p2_1_dcgan\": \"collapsed gray output\",\n", " \"p2_2_wgan\": \"collapsed gray output\",\n", " \"p2_3_wgan_sn_attn\": \"recognizable faces\",\n", " \"p2_4_wgan_sn_attn_128\": \"under-trained 128x128\",\n", "}\n", "runs = {name: load_log(name) for name in run_names}\n", "runs = {k: v for k, v in runs.items() if v}\n", "for n in run_names:\n", " if n in runs: print(f\" {n}: {len(runs[n]['history']['g_loss'])} epochs\")\n", " else: print(f\" {n}: MISSING\")\n" ] }, { "cell_type": "markdown", "id": "c1bad44a", "metadata": {}, "source": [ "## 2. FID comparison table\n", "\n", "This table is the quantitative spine of the GAN progression: lower FID means generated samples are closer to the saved real reference distribution." ] }, { "cell_type": "code", "execution_count": 11, "id": "528d3bb2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
| \n", " | Run | \n", "FID@25 | \n", "FID@50 | \n", "FID@100 | \n", "Best FID | \n", "Train (min) | \n", "
|---|---|---|---|---|---|---|
| 2 | \n", "2.3 + SN + Attn | \n", "274.4 | \n", "223.2 | \n", "110.1 | \n", "110.1 | \n", "39.0 | \n", "
| 3 | \n", "2.4 + 128x128 | \n", "428.6 | \n", "264.3 | \n", "186.0 | \n", "186.0 | \n", "97.7 | \n", "
| 1 | \n", "2.2 WGAN-GP | \n", "489.6 | \n", "474.6 | \n", "421.3 | \n", "421.3 | \n", "27.1 | \n", "
| 0 | \n", "2.1 DCGAN (BCE) | \n", "444.3 | \n", "438.9 | \n", "429.3 | \n", "429.3 | \n", "17.8 | \n", "