Preview of phase 2-5 implementation; needs a full check

This commit is contained in:
Johnny Fernandes
2026-04-30 13:10:33 +01:00
parent 6e32001ebc
commit 7417267117
35 changed files with 3605 additions and 115 deletions
+366
View File
@@ -0,0 +1,366 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a0000001",
"metadata": {},
"source": [
"# Phase 2 — GAN Evolution Analysis\n",
"\n",
"Traces the GAN improvement story, each step motivated by the failure of the previous:\n",
"\n",
"| Step | Model | Key change | Expected failure |\n",
"|------|-------|------------|------------------|\n",
"| 2.1 | DCGAN 64×64 | Baseline on best pipeline | Mode collapse, training instability |\n",
"| 2.2 | WGAN-GP | Wasserstein loss + GP | Texture artifacts, limited coherence |\n",
"| 2.3 | WGAN-GP + SN + GroupNorm + Attn | Principled Lipschitz + long-range deps | Possible underfitting at 64×64 |\n",
"| 2.4 | 2.3 @ 128×128 | Scale resolution | ? |\n",
"\n",
"All runs use the best pipeline from Phase 1: MTCNN-aligned crops, H-flip + rotation + colour jitter, aligned-only dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000002",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\""
]
},
{
"cell_type": "markdown",
"id": "a0000003",
"metadata": {},
"source": [
"## 1. Load experiment logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000004",
"metadata": {},
"outputs": [],
"source": [
"run_names = [\"p2_1_dcgan\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"]\n",
"run_labels = {\n",
" \"p2_1_dcgan\": \"2.1 DCGAN\",\n",
" \"p2_2_wgan\": \"2.2 WGAN-GP\",\n",
" \"p2_3_wgan_sn_attn\": \"2.3 WGAN-GP+SN+Attn\",\n",
" \"p2_4_wgan_sn_attn_128\": \"2.4 +128×128\",\n",
"}\n",
"\n",
"runs = {}\n",
"for name in run_names:\n",
" log_path = LOGS / f\"{name}.json\"\n",
" if log_path.exists():\n",
" with open(log_path) as f:\n",
" runs[name] = json.load(f)\n",
" else:\n",
" print(f\" Missing: {log_path}\")\n",
"\n",
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
"for name in run_names:\n",
" status = \"✓\" if name in runs else \"✗\"\n",
" print(f\" {status} {name}\")"
]
},
{
"cell_type": "markdown",
"id": "a0000005",
"metadata": {},
"source": [
"## 2. FID Comparison Table"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000006",
"metadata": {},
"outputs": [],
"source": [
"def get_fid(run, epoch):\n",
" fid = run[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"rows = []\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" r = runs[name]\n",
" cfg = r[\"config\"]\n",
" h = r[\"history\"]\n",
" rows.append({\n",
" \"Step\": run_labels[name],\n",
" \"Model\": cfg.get(\"model\"),\n",
" \"Size\": f\"{cfg.get('image_size', 64)}×{cfg.get('image_size', 64)}\",\n",
" \"FID@25\": get_fid(r, 25),\n",
" \"FID@50\": get_fid(r, 50),\n",
" \"FID@75\": get_fid(r, 75),\n",
" \"FID@100\": get_fid(r, 100),\n",
" })\n",
"\n",
"df = pd.DataFrame(rows)\n",
"df.style.format({c: \"{:.1f}\" for c in [\"FID@25\", \"FID@50\", \"FID@75\", \"FID@100\"] if c in df})"
]
},
{
"cell_type": "markdown",
"id": "a0000007",
"metadata": {},
"source": [
"## 3. FID Curves — Evolution Story"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000008",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(11, 5))\n",
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n",
"\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
" epochs = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in epochs]\n",
" ax.plot(epochs, fids, \"o-\", label=f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\",\n",
" color=colors[i], linewidth=2, markersize=8)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"FID (lower is better)\")\n",
"ax.set_title(\"Phase 2 — FID Curves: DCGAN → WGAN-GP → +SN+Attn → 128×128\")\n",
"ax.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000009",
"metadata": {},
"source": [
"## 4. Training Dynamics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000010",
"metadata": {},
"outputs": [],
"source": [
"# Separate DCGAN (has g_loss/d_loss/d_real/d_fake) from WGAN (has g_loss/w_dist/gp)\n",
"dcgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") == \"dcgan\"]\n",
"wgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") != \"dcgan\"]\n",
"\n",
"if dcgan_names:\n",
" fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
" for name in dcgan_names:\n",
" h = runs[name][\"history\"]\n",
" epochs = range(1, len(h[\"g_loss\"]) + 1)\n",
" axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], linewidth=1.2)\n",
" axes[1].plot(epochs, h[\"d_loss\"], label=run_labels[name], linewidth=1.2)\n",
" axes[0].set_title(\"DCGAN — Generator Loss (BCE)\")\n",
" axes[1].set_title(\"DCGAN — Discriminator Loss\")\n",
" for ax in axes:\n",
" ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"Loss\"); ax.legend(fontsize=8)\n",
" plt.suptitle(\"Phase 2.1 — DCGAN Training Dynamics\", fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"if wgan_names:\n",
" fig, axes = plt.subplots(1, 3, figsize=(16, 4))\n",
" cmap = plt.cm.Set1\n",
" for i, name in enumerate(wgan_names):\n",
" h = runs[name][\"history\"]\n",
" epochs = range(1, len(h[\"g_loss\"]) + 1)\n",
" c = cmap(i / max(len(wgan_names), 1))\n",
" axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], color=c, linewidth=1.2)\n",
" axes[1].plot(epochs, h[\"w_dist\"], label=run_labels[name], color=c, linewidth=1.2)\n",
" axes[2].plot(epochs, h[\"gp\"], label=run_labels[name], color=c, linewidth=1.2)\n",
" axes[0].set_title(\"Generator Loss (E[D(G(z))])\")\n",
" axes[1].set_title(\"Wasserstein Distance Est. (↑ better)\")\n",
" axes[2].set_title(\"Gradient Penalty\")\n",
" for ax in axes:\n",
" ax.set_xlabel(\"Epoch\"); ax.legend(fontsize=8)\n",
" plt.suptitle(\"Phase 2.22.4 — WGAN-GP Training Dynamics\", fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000011",
"metadata": {},
"source": [
"## 5. Sample Image Grids — Epoch 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000012",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 4, figsize=(18, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=8)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 2 — Epoch 100 Sample Grids (4×4)\", fontsize=13, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000013",
"metadata": {},
"source": [
"## 6. Progression: Epoch 10 → 50 → 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000014",
"metadata": {},
"outputs": [],
"source": [
"check_epochs = [10, 50, 100]\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(runs[name], ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{run_labels[name]} — Training Progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000015",
"metadata": {},
"source": [
"## 7. Step-by-step Pairwise Comparisons"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000016",
"metadata": {},
"outputs": [],
"source": [
"transitions = [\n",
" (\"2.1→2.2: BCE→Wasserstein\", \"p2_1_dcgan\", \"p2_2_wgan\"),\n",
" (\"2.2→2.3: +SN+GroupNorm+Attn\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\"),\n",
" (\"2.3→2.4: 64→128 resolution\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"),\n",
"]\n",
"\n",
"for title, name_a, name_b in transitions:\n",
" fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
" for ax, name in zip(axes, [name_a, name_b]):\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100 = {fid:.1f}\" if fid else run_labels[name], fontsize=10)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=10)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000017",
"metadata": {},
"source": [
"## 8. Conclusions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000018",
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 70)\n",
"print(\"PHASE 2 — GAN EVOLUTION SUMMARY\")\n",
"print(\"=\" * 70)\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
" continue\n",
" fid100 = get_fid(runs[name], 100)\n",
" fid50 = get_fid(runs[name], 50)\n",
" print(f\"\\n {run_labels[name]}:\")\n",
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
"\n",
"print(\"\\n\" + \"=\" * 70)\n",
"print(\"Best model for Phase 3/4 comparison: fill in after runs complete.\")\n",
"print(\"=\" * 70)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}