367 lines
12 KiB
Plaintext
367 lines
12 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a0000001",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Phase 2 — GAN Evolution Analysis\n",
|
||
"\n",
|
||
"Traces the GAN improvement story, each step motivated by the failure of the previous:\n",
|
||
"\n",
|
||
"| Step | Model | Key change | Expected failure |\n",
|
||
"|------|-------|------------|------------------|\n",
|
||
"| 2.1 | DCGAN 64×64 | Baseline on best pipeline | Mode collapse, training instability |\n",
|
||
"| 2.2 | WGAN-GP | Wasserstein loss + GP | Texture artifacts, limited coherence |\n",
|
||
"| 2.3 | WGAN-GP + SN + GroupNorm + Attn | Principled Lipschitz + long-range deps | Possible underfitting at 64×64 |\n",
|
||
"| 2.4 | 2.3 @ 128×128 | Scale resolution | ? |\n",
|
||
"\n",
|
||
"All runs use the best pipeline from Phase 1: MTCNN-aligned crops, H-flip + rotation + colour jitter, aligned-only dataset."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a0000002",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import json\n",
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import matplotlib.image as mpimg\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
|
||
"\n",
|
||
"OUTPUTS = Path(\"../outputs\")\n",
|
||
"LOGS = OUTPUTS / \"logs\"\n",
|
||
"SAMPLES = OUTPUTS / \"samples\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a0000003",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1. Load experiment logs"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a0000004",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"run_names = [\"p2_1_dcgan\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"]\n",
|
||
"run_labels = {\n",
|
||
" \"p2_1_dcgan\": \"2.1 DCGAN\",\n",
|
||
" \"p2_2_wgan\": \"2.2 WGAN-GP\",\n",
|
||
" \"p2_3_wgan_sn_attn\": \"2.3 WGAN-GP+SN+Attn\",\n",
|
||
" \"p2_4_wgan_sn_attn_128\": \"2.4 +128×128\",\n",
|
||
"}\n",
|
||
"\n",
|
||
"runs = {}\n",
|
||
"for name in run_names:\n",
|
||
" log_path = LOGS / f\"{name}.json\"\n",
|
||
" if log_path.exists():\n",
|
||
" with open(log_path) as f:\n",
|
||
" runs[name] = json.load(f)\n",
|
||
" else:\n",
|
||
" print(f\" Missing: {log_path}\")\n",
|
||
"\n",
|
||
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
|
||
"for name in run_names:\n",
|
||
" status = \"✓\" if name in runs else \"✗\"\n",
|
||
" print(f\" {status} {name}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a0000005",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2. FID Comparison Table"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a0000006",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get_fid(run, epoch):\n",
|
||
" fid = run[\"history\"][\"fid\"]\n",
|
||
" return fid.get(str(epoch), fid.get(epoch, None))\n",
|
||
"\n",
|
||
"rows = []\n",
|
||
"for name in run_names:\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" r = runs[name]\n",
|
||
" cfg = r[\"config\"]\n",
|
||
" h = r[\"history\"]\n",
|
||
" rows.append({\n",
|
||
" \"Step\": run_labels[name],\n",
|
||
" \"Model\": cfg.get(\"model\"),\n",
|
||
" \"Size\": f\"{cfg.get('image_size', 64)}×{cfg.get('image_size', 64)}\",\n",
|
||
" \"FID@25\": get_fid(r, 25),\n",
|
||
" \"FID@50\": get_fid(r, 50),\n",
|
||
" \"FID@75\": get_fid(r, 75),\n",
|
||
" \"FID@100\": get_fid(r, 100),\n",
|
||
" })\n",
|
||
"\n",
|
||
"df = pd.DataFrame(rows)\n",
|
||
"df.style.format({c: \"{:.1f}\" for c in [\"FID@25\", \"FID@50\", \"FID@75\", \"FID@100\"] if c in df})"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a0000007",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3. FID Curves — Evolution Story"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a0000008",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fig, ax = plt.subplots(figsize=(11, 5))\n",
|
||
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n",
|
||
"\n",
|
||
"for i, name in enumerate(run_names):\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
|
||
" epochs = sorted(int(k) for k in fid_dict)\n",
|
||
" fids = [fid_dict[str(e)] for e in epochs]\n",
|
||
" ax.plot(epochs, fids, \"o-\", label=f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\",\n",
|
||
" color=colors[i], linewidth=2, markersize=8)\n",
|
||
"\n",
|
||
"ax.set_xlabel(\"Epoch\")\n",
|
||
"ax.set_ylabel(\"FID (lower is better)\")\n",
|
||
"ax.set_title(\"Phase 2 — FID Curves: DCGAN → WGAN-GP → +SN+Attn → 128×128\")\n",
|
||
"ax.legend()\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a0000009",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4. Training Dynamics"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a0000010",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Separate DCGAN (has g_loss/d_loss/d_real/d_fake) from WGAN (has g_loss/w_dist/gp)\n",
|
||
"dcgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") == \"dcgan\"]\n",
|
||
"wgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") != \"dcgan\"]\n",
|
||
"\n",
|
||
"if dcgan_names:\n",
|
||
" fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
|
||
" for name in dcgan_names:\n",
|
||
" h = runs[name][\"history\"]\n",
|
||
" epochs = range(1, len(h[\"g_loss\"]) + 1)\n",
|
||
" axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], linewidth=1.2)\n",
|
||
" axes[1].plot(epochs, h[\"d_loss\"], label=run_labels[name], linewidth=1.2)\n",
|
||
" axes[0].set_title(\"DCGAN — Generator Loss (BCE)\")\n",
|
||
" axes[1].set_title(\"DCGAN — Discriminator Loss\")\n",
|
||
" for ax in axes:\n",
|
||
" ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"Loss\"); ax.legend(fontsize=8)\n",
|
||
" plt.suptitle(\"Phase 2.1 — DCGAN Training Dynamics\", fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
"if wgan_names:\n",
|
||
" fig, axes = plt.subplots(1, 3, figsize=(16, 4))\n",
|
||
" cmap = plt.cm.Set1\n",
|
||
" for i, name in enumerate(wgan_names):\n",
|
||
" h = runs[name][\"history\"]\n",
|
||
" epochs = range(1, len(h[\"g_loss\"]) + 1)\n",
|
||
" c = cmap(i / max(len(wgan_names), 1))\n",
|
||
" axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], color=c, linewidth=1.2)\n",
|
||
" axes[1].plot(epochs, h[\"w_dist\"], label=run_labels[name], color=c, linewidth=1.2)\n",
|
||
" axes[2].plot(epochs, h[\"gp\"], label=run_labels[name], color=c, linewidth=1.2)\n",
|
||
" axes[0].set_title(\"Generator Loss (−E[D(G(z))])\")\n",
|
||
" axes[1].set_title(\"Wasserstein Distance Est. (↑ better)\")\n",
|
||
" axes[2].set_title(\"Gradient Penalty\")\n",
|
||
" for ax in axes:\n",
|
||
" ax.set_xlabel(\"Epoch\"); ax.legend(fontsize=8)\n",
|
||
" plt.suptitle(\"Phase 2.2–2.4 — WGAN-GP Training Dynamics\", fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a0000011",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 5. Sample Image Grids — Epoch 100"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a0000012",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fig, axes = plt.subplots(1, 4, figsize=(18, 5))\n",
|
||
"\n",
|
||
"for idx, name in enumerate(run_names):\n",
|
||
" ax = axes[idx]\n",
|
||
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
||
" if img_path.exists():\n",
|
||
" fid = get_fid(runs[name], 100) if name in runs else None\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||
" ax.set_title(run_labels[name], fontsize=8)\n",
|
||
" ax.axis(\"off\")\n",
|
||
"\n",
|
||
"fig.suptitle(\"Phase 2 — Epoch 100 Sample Grids (4×4)\", fontsize=13, fontweight=\"bold\")\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a0000013",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 6. Progression: Epoch 10 → 50 → 100"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a0000014",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"check_epochs = [10, 50, 100]\n",
|
||
"\n",
|
||
"for name in run_names:\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
|
||
" for ax, ep in zip(axes, check_epochs):\n",
|
||
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
|
||
" if img_path.exists():\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" fid = get_fid(runs[name], ep)\n",
|
||
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||
" ax.axis(\"off\")\n",
|
||
" fig.suptitle(f\"{run_labels[name]} — Training Progression\", fontsize=11, fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a0000015",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 7. Step-by-step Pairwise Comparisons"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a0000016",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"transitions = [\n",
|
||
" (\"2.1→2.2: BCE→Wasserstein\", \"p2_1_dcgan\", \"p2_2_wgan\"),\n",
|
||
" (\"2.2→2.3: +SN+GroupNorm+Attn\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\"),\n",
|
||
" (\"2.3→2.4: 64→128 resolution\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"),\n",
|
||
"]\n",
|
||
"\n",
|
||
"for title, name_a, name_b in transitions:\n",
|
||
" fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
|
||
" for ax, name in zip(axes, [name_a, name_b]):\n",
|
||
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
||
" if img_path.exists():\n",
|
||
" fid = get_fid(runs[name], 100) if name in runs else None\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" ax.set_title(f\"{run_labels[name]}\\nFID@100 = {fid:.1f}\" if fid else run_labels[name], fontsize=10)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||
" ax.set_title(run_labels[name], fontsize=10)\n",
|
||
" ax.axis(\"off\")\n",
|
||
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a0000017",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 8. Conclusions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a0000018",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print(\"=\" * 70)\n",
|
||
"print(\"PHASE 2 — GAN EVOLUTION SUMMARY\")\n",
|
||
"print(\"=\" * 70)\n",
|
||
"\n",
|
||
"for name in run_names:\n",
|
||
" if name not in runs:\n",
|
||
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
|
||
" continue\n",
|
||
" fid100 = get_fid(runs[name], 100)\n",
|
||
" fid50 = get_fid(runs[name], 50)\n",
|
||
" print(f\"\\n {run_labels[name]}:\")\n",
|
||
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
|
||
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
|
||
"\n",
|
||
"print(\"\\n\" + \"=\" * 70)\n",
|
||
"print(\"Best model for Phase 3/4 comparison: fill in after runs complete.\")\n",
|
||
"print(\"=\" * 70)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"name": "python",
|
||
"version": "3.10.0"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|