{ "cells": [ { "cell_type": "markdown", "id": "a0000001", "metadata": {}, "source": [ "# Phase 2 — GAN Evolution Analysis\n", "\n", "Traces the GAN improvement story, each step motivated by the failure of the previous:\n", "\n", "| Step | Model | Key change | Expected failure |\n", "|------|-------|------------|------------------|\n", "| 2.1 | DCGAN 64×64 | Baseline on best pipeline | Mode collapse, training instability |\n", "| 2.2 | WGAN-GP | Wasserstein loss + GP | Texture artifacts, limited coherence |\n", "| 2.3 | WGAN-GP + SN + GroupNorm + Attn | Principled Lipschitz + long-range deps | Possible underfitting at 64×64 |\n", "| 2.4 | 2.3 @ 128×128 | Scale resolution | ? |\n", "\n", "All runs use the best pipeline from Phase 1: MTCNN-aligned crops, H-flip + rotation + colour jitter, aligned-only dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "a0000002", "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import numpy as np\n", "import pandas as pd\n", "\n", "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", "\n", "OUTPUTS = Path(\"../outputs\")\n", "LOGS = OUTPUTS / \"logs\"\n", "SAMPLES = OUTPUTS / \"samples\"" ] }, { "cell_type": "markdown", "id": "a0000003", "metadata": {}, "source": [ "## 1. Load experiment logs" ] }, { "cell_type": "code", "execution_count": null, "id": "a0000004", "metadata": {}, "outputs": [], "source": [ "run_names = [\"p2_1_dcgan\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"]\n", "run_labels = {\n", " \"p2_1_dcgan\": \"2.1 DCGAN\",\n", " \"p2_2_wgan\": \"2.2 WGAN-GP\",\n", " \"p2_3_wgan_sn_attn\": \"2.3 WGAN-GP+SN+Attn\",\n", " \"p2_4_wgan_sn_attn_128\": \"2.4 +128×128\",\n", "}\n", "\n", "runs = {}\n", "for name in run_names:\n", " log_path = LOGS / f\"{name}.json\"\n", " if log_path.exists():\n", " with open(log_path) as f:\n", " runs[name] = json.load(f)\n", " else:\n", " print(f\" Missing: {log_path}\")\n", "\n", "print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n", "for name in run_names:\n", " status = \"✓\" if name in runs else \"✗\"\n", " print(f\" {status} {name}\")" ] }, { "cell_type": "markdown", "id": "a0000005", "metadata": {}, "source": [ "## 2. FID Comparison Table" ] }, { "cell_type": "code", "execution_count": null, "id": "a0000006", "metadata": {}, "outputs": [], "source": [ "def get_fid(run, epoch):\n", " fid = run[\"history\"][\"fid\"]\n", " return fid.get(str(epoch), fid.get(epoch, None))\n", "\n", "rows = []\n", "for name in run_names:\n", " if name not in runs:\n", " continue\n", " r = runs[name]\n", " cfg = r[\"config\"]\n", " h = r[\"history\"]\n", " rows.append({\n", " \"Step\": run_labels[name],\n", " \"Model\": cfg.get(\"model\"),\n", " \"Size\": f\"{cfg.get('image_size', 64)}×{cfg.get('image_size', 64)}\",\n", " \"FID@25\": get_fid(r, 25),\n", " \"FID@50\": get_fid(r, 50),\n", " \"FID@75\": get_fid(r, 75),\n", " \"FID@100\": get_fid(r, 100),\n", " })\n", "\n", "df = pd.DataFrame(rows)\n", "df.style.format({c: \"{:.1f}\" for c in [\"FID@25\", \"FID@50\", \"FID@75\", \"FID@100\"] if c in df})" ] }, { "cell_type": "markdown", "id": "a0000007", "metadata": {}, "source": [ "## 3. FID Curves — Evolution Story" ] }, { "cell_type": "code", "execution_count": null, "id": "a0000008", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(11, 5))\n", "colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n", "\n", "for i, name in enumerate(run_names):\n", " if name not in runs:\n", " continue\n", " fid_dict = runs[name][\"history\"][\"fid\"]\n", " epochs = sorted(int(k) for k in fid_dict)\n", " fids = [fid_dict[str(e)] for e in epochs]\n", " ax.plot(epochs, fids, \"o-\", label=f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\",\n", " color=colors[i], linewidth=2, markersize=8)\n", "\n", "ax.set_xlabel(\"Epoch\")\n", "ax.set_ylabel(\"FID (lower is better)\")\n", "ax.set_title(\"Phase 2 — FID Curves: DCGAN → WGAN-GP → +SN+Attn → 128×128\")\n", "ax.legend()\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "a0000009", "metadata": {}, "source": [ "## 4. Training Dynamics" ] }, { "cell_type": "code", "execution_count": null, "id": "a0000010", "metadata": {}, "outputs": [], "source": [ "# Separate DCGAN (has g_loss/d_loss/d_real/d_fake) from WGAN (has g_loss/w_dist/gp)\n", "dcgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") == \"dcgan\"]\n", "wgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") != \"dcgan\"]\n", "\n", "if dcgan_names:\n", " fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n", " for name in dcgan_names:\n", " h = runs[name][\"history\"]\n", " epochs = range(1, len(h[\"g_loss\"]) + 1)\n", " axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], linewidth=1.2)\n", " axes[1].plot(epochs, h[\"d_loss\"], label=run_labels[name], linewidth=1.2)\n", " axes[0].set_title(\"DCGAN — Generator Loss (BCE)\")\n", " axes[1].set_title(\"DCGAN — Discriminator Loss\")\n", " for ax in axes:\n", " ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"Loss\"); ax.legend(fontsize=8)\n", " plt.suptitle(\"Phase 2.1 — DCGAN Training Dynamics\", fontweight=\"bold\")\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "if wgan_names:\n", " fig, axes = plt.subplots(1, 3, figsize=(16, 4))\n", " cmap = plt.cm.Set1\n", " for i, name in enumerate(wgan_names):\n", " h = runs[name][\"history\"]\n", " epochs = range(1, len(h[\"g_loss\"]) + 1)\n", " c = cmap(i / max(len(wgan_names), 1))\n", " axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], color=c, linewidth=1.2)\n", " axes[1].plot(epochs, h[\"w_dist\"], label=run_labels[name], color=c, linewidth=1.2)\n", " axes[2].plot(epochs, h[\"gp\"], label=run_labels[name], color=c, linewidth=1.2)\n", " axes[0].set_title(\"Generator Loss (−E[D(G(z))])\")\n", " axes[1].set_title(\"Wasserstein Distance Est. (↑ better)\")\n", " axes[2].set_title(\"Gradient Penalty\")\n", " for ax in axes:\n", " ax.set_xlabel(\"Epoch\"); ax.legend(fontsize=8)\n", " plt.suptitle(\"Phase 2.2–2.4 — WGAN-GP Training Dynamics\", fontweight=\"bold\")\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "a0000011", "metadata": {}, "source": [ "## 5. Sample Image Grids — Epoch 100" ] }, { "cell_type": "code", "execution_count": null, "id": "a0000012", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 4, figsize=(18, 5))\n", "\n", "for idx, name in enumerate(run_names):\n", " ax = axes[idx]\n", " img_path = SAMPLES / name / \"epoch_0100.png\"\n", " if img_path.exists():\n", " fid = get_fid(runs[name], 100) if name in runs else None\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n", " else:\n", " ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " ax.set_title(run_labels[name], fontsize=8)\n", " ax.axis(\"off\")\n", "\n", "fig.suptitle(\"Phase 2 — Epoch 100 Sample Grids (4×4)\", fontsize=13, fontweight=\"bold\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "a0000013", "metadata": {}, "source": [ "## 6. Progression: Epoch 10 → 50 → 100" ] }, { "cell_type": "code", "execution_count": null, "id": "a0000014", "metadata": {}, "outputs": [], "source": [ "check_epochs = [10, 50, 100]\n", "\n", "for name in run_names:\n", " if name not in runs:\n", " continue\n", " fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n", " for ax, ep in zip(axes, check_epochs):\n", " img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n", " if img_path.exists():\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " fid = get_fid(runs[name], ep)\n", " ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n", " else:\n", " ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " ax.axis(\"off\")\n", " fig.suptitle(f\"{run_labels[name]} — Training Progression\", fontsize=11, fontweight=\"bold\")\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "a0000015", "metadata": {}, "source": [ "## 7. Step-by-step Pairwise Comparisons" ] }, { "cell_type": "code", "execution_count": null, "id": "a0000016", "metadata": {}, "outputs": [], "source": [ "transitions = [\n", " (\"2.1→2.2: BCE→Wasserstein\", \"p2_1_dcgan\", \"p2_2_wgan\"),\n", " (\"2.2→2.3: +SN+GroupNorm+Attn\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\"),\n", " (\"2.3→2.4: 64→128 resolution\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"),\n", "]\n", "\n", "for title, name_a, name_b in transitions:\n", " fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", " for ax, name in zip(axes, [name_a, name_b]):\n", " img_path = SAMPLES / name / \"epoch_0100.png\"\n", " if img_path.exists():\n", " fid = get_fid(runs[name], 100) if name in runs else None\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " ax.set_title(f\"{run_labels[name]}\\nFID@100 = {fid:.1f}\" if fid else run_labels[name], fontsize=10)\n", " else:\n", " ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " ax.set_title(run_labels[name], fontsize=10)\n", " ax.axis(\"off\")\n", " fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "a0000017", "metadata": {}, "source": [ "## 8. Conclusions" ] }, { "cell_type": "code", "execution_count": null, "id": "a0000018", "metadata": {}, "outputs": [], "source": [ "print(\"=\" * 70)\n", "print(\"PHASE 2 — GAN EVOLUTION SUMMARY\")\n", "print(\"=\" * 70)\n", "\n", "for name in run_names:\n", " if name not in runs:\n", " print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n", " continue\n", " fid100 = get_fid(runs[name], 100)\n", " fid50 = get_fid(runs[name], 50)\n", " print(f\"\\n {run_labels[name]}:\")\n", " print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n", " print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"Best model for Phase 3/4 comparison: fill in after runs complete.\")\n", "print(\"=\" * 70)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 5 }