Preview of phase 2-5 implementation; needs a full check

This commit is contained in:
Johnny Fernandes
2026-04-30 13:10:33 +01:00
parent 6e32001ebc
commit 7417267117
35 changed files with 3605 additions and 115 deletions
+366
View File
@@ -0,0 +1,366 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a0000001",
"metadata": {},
"source": [
"# Phase 2 — GAN Evolution Analysis\n",
"\n",
"Traces the GAN improvement story, each step motivated by the failure of the previous:\n",
"\n",
"| Step | Model | Key change | Expected failure |\n",
"|------|-------|------------|------------------|\n",
"| 2.1 | DCGAN 64×64 | Baseline on best pipeline | Mode collapse, training instability |\n",
"| 2.2 | WGAN-GP | Wasserstein loss + GP | Texture artifacts, limited coherence |\n",
"| 2.3 | WGAN-GP + SN + GroupNorm + Attn | Principled Lipschitz + long-range deps | Possible underfitting at 64×64 |\n",
"| 2.4 | 2.3 @ 128×128 | Scale resolution | ? |\n",
"\n",
"All runs use the best pipeline from Phase 1: MTCNN-aligned crops, H-flip + rotation + colour jitter, aligned-only dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000002",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\""
]
},
{
"cell_type": "markdown",
"id": "a0000003",
"metadata": {},
"source": [
"## 1. Load experiment logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000004",
"metadata": {},
"outputs": [],
"source": [
"run_names = [\"p2_1_dcgan\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"]\n",
"run_labels = {\n",
" \"p2_1_dcgan\": \"2.1 DCGAN\",\n",
" \"p2_2_wgan\": \"2.2 WGAN-GP\",\n",
" \"p2_3_wgan_sn_attn\": \"2.3 WGAN-GP+SN+Attn\",\n",
" \"p2_4_wgan_sn_attn_128\": \"2.4 +128×128\",\n",
"}\n",
"\n",
"runs = {}\n",
"for name in run_names:\n",
" log_path = LOGS / f\"{name}.json\"\n",
" if log_path.exists():\n",
" with open(log_path) as f:\n",
" runs[name] = json.load(f)\n",
" else:\n",
" print(f\" Missing: {log_path}\")\n",
"\n",
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
"for name in run_names:\n",
" status = \"✓\" if name in runs else \"✗\"\n",
" print(f\" {status} {name}\")"
]
},
{
"cell_type": "markdown",
"id": "a0000005",
"metadata": {},
"source": [
"## 2. FID Comparison Table"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000006",
"metadata": {},
"outputs": [],
"source": [
"def get_fid(run, epoch):\n",
" fid = run[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"rows = []\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" r = runs[name]\n",
" cfg = r[\"config\"]\n",
" h = r[\"history\"]\n",
" rows.append({\n",
" \"Step\": run_labels[name],\n",
" \"Model\": cfg.get(\"model\"),\n",
" \"Size\": f\"{cfg.get('image_size', 64)}×{cfg.get('image_size', 64)}\",\n",
" \"FID@25\": get_fid(r, 25),\n",
" \"FID@50\": get_fid(r, 50),\n",
" \"FID@75\": get_fid(r, 75),\n",
" \"FID@100\": get_fid(r, 100),\n",
" })\n",
"\n",
"df = pd.DataFrame(rows)\n",
"df.style.format({c: \"{:.1f}\" for c in [\"FID@25\", \"FID@50\", \"FID@75\", \"FID@100\"] if c in df})"
]
},
{
"cell_type": "markdown",
"id": "a0000007",
"metadata": {},
"source": [
"## 3. FID Curves — Evolution Story"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000008",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(11, 5))\n",
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n",
"\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
" epochs = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in epochs]\n",
" ax.plot(epochs, fids, \"o-\", label=f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\",\n",
" color=colors[i], linewidth=2, markersize=8)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"FID (lower is better)\")\n",
"ax.set_title(\"Phase 2 — FID Curves: DCGAN → WGAN-GP → +SN+Attn → 128×128\")\n",
"ax.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000009",
"metadata": {},
"source": [
"## 4. Training Dynamics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000010",
"metadata": {},
"outputs": [],
"source": [
"# Separate DCGAN (has g_loss/d_loss/d_real/d_fake) from WGAN (has g_loss/w_dist/gp)\n",
"dcgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") == \"dcgan\"]\n",
"wgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") != \"dcgan\"]\n",
"\n",
"if dcgan_names:\n",
" fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
" for name in dcgan_names:\n",
" h = runs[name][\"history\"]\n",
" epochs = range(1, len(h[\"g_loss\"]) + 1)\n",
" axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], linewidth=1.2)\n",
" axes[1].plot(epochs, h[\"d_loss\"], label=run_labels[name], linewidth=1.2)\n",
" axes[0].set_title(\"DCGAN — Generator Loss (BCE)\")\n",
" axes[1].set_title(\"DCGAN — Discriminator Loss\")\n",
" for ax in axes:\n",
" ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"Loss\"); ax.legend(fontsize=8)\n",
" plt.suptitle(\"Phase 2.1 — DCGAN Training Dynamics\", fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"if wgan_names:\n",
" fig, axes = plt.subplots(1, 3, figsize=(16, 4))\n",
" cmap = plt.cm.Set1\n",
" for i, name in enumerate(wgan_names):\n",
" h = runs[name][\"history\"]\n",
" epochs = range(1, len(h[\"g_loss\"]) + 1)\n",
" c = cmap(i / max(len(wgan_names), 1))\n",
" axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], color=c, linewidth=1.2)\n",
" axes[1].plot(epochs, h[\"w_dist\"], label=run_labels[name], color=c, linewidth=1.2)\n",
" axes[2].plot(epochs, h[\"gp\"], label=run_labels[name], color=c, linewidth=1.2)\n",
" axes[0].set_title(\"Generator Loss (E[D(G(z))])\")\n",
" axes[1].set_title(\"Wasserstein Distance Est. (↑ better)\")\n",
" axes[2].set_title(\"Gradient Penalty\")\n",
" for ax in axes:\n",
" ax.set_xlabel(\"Epoch\"); ax.legend(fontsize=8)\n",
" plt.suptitle(\"Phase 2.22.4 — WGAN-GP Training Dynamics\", fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000011",
"metadata": {},
"source": [
"## 5. Sample Image Grids — Epoch 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000012",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 4, figsize=(18, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=8)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 2 — Epoch 100 Sample Grids (4×4)\", fontsize=13, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000013",
"metadata": {},
"source": [
"## 6. Progression: Epoch 10 → 50 → 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000014",
"metadata": {},
"outputs": [],
"source": [
"check_epochs = [10, 50, 100]\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(runs[name], ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{run_labels[name]} — Training Progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000015",
"metadata": {},
"source": [
"## 7. Step-by-step Pairwise Comparisons"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000016",
"metadata": {},
"outputs": [],
"source": [
"transitions = [\n",
" (\"2.1→2.2: BCE→Wasserstein\", \"p2_1_dcgan\", \"p2_2_wgan\"),\n",
" (\"2.2→2.3: +SN+GroupNorm+Attn\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\"),\n",
" (\"2.3→2.4: 64→128 resolution\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"),\n",
"]\n",
"\n",
"for title, name_a, name_b in transitions:\n",
" fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
" for ax, name in zip(axes, [name_a, name_b]):\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100 = {fid:.1f}\" if fid else run_labels[name], fontsize=10)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=10)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000017",
"metadata": {},
"source": [
"## 8. Conclusions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000018",
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 70)\n",
"print(\"PHASE 2 — GAN EVOLUTION SUMMARY\")\n",
"print(\"=\" * 70)\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
" continue\n",
" fid100 = get_fid(runs[name], 100)\n",
" fid50 = get_fid(runs[name], 50)\n",
" print(f\"\\n {run_labels[name]}:\")\n",
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
"\n",
"print(\"\\n\" + \"=\" * 70)\n",
"print(\"Best model for Phase 3/4 comparison: fill in after runs complete.\")\n",
"print(\"=\" * 70)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+396
View File
@@ -0,0 +1,396 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "b0000001",
"metadata": {},
"source": [
"# Phase 3 — VAE Evolution Analysis\n",
"\n",
"Traces the VAE improvement story — each step motivated by the failure of the previous:\n",
"\n",
"| Step | Model | Key change | Expected failure |\n",
"|------|-------|------------|------------------|\n",
"| 3.1 | Vanilla VAE (MSE+KL) | Baseline | Blurry samples — MSE minimises pixel average |\n",
"| 3.2 | + Perceptual loss (VGG) | Feature-space reconstruction | Residual texture blur |\n",
"| 3.3 | + PatchGAN (VQGAN-lite) | Local texture adversarial | — |\n",
"\n",
"All runs use H-flip-only augmentation and MTCNN-aligned 64×64 crops.\n",
"FID is computed from prior samples (`z ~ N(0, I) → decode`), same metric as GAN.\n",
"Reconstructions are shown separately to diagnose encoder quality."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000002",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\""
]
},
{
"cell_type": "markdown",
"id": "b0000003",
"metadata": {},
"source": [
"## 1. Load experiment logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000004",
"metadata": {},
"outputs": [],
"source": [
"run_names = [\"p3_1_vae\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"]\n",
"run_labels = {\n",
" \"p3_1_vae\": \"3.1 VAE (MSE+KL)\",\n",
" \"p3_2_vae_perceptual\": \"3.2 +Perceptual\",\n",
" \"p3_3_vae_patchgan\": \"3.3 +PatchGAN\",\n",
"}\n",
"\n",
"runs = {}\n",
"for name in run_names:\n",
" log_path = LOGS / f\"{name}.json\"\n",
" if log_path.exists():\n",
" with open(log_path) as f:\n",
" runs[name] = json.load(f)\n",
" else:\n",
" print(f\" Missing: {log_path}\")\n",
"\n",
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
"for name in run_names:\n",
" status = \"✓\" if name in runs else \"✗\"\n",
" print(f\" {status} {name}\")"
]
},
{
"cell_type": "markdown",
"id": "b0000005",
"metadata": {},
"source": [
"## 2. FID Comparison Table (prior samples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000006",
"metadata": {},
"outputs": [],
"source": [
"def get_fid(run, epoch):\n",
" fid = run[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"rows = []\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" r = runs[name]\n",
" cfg = r[\"config\"]\n",
" h = r[\"history\"]\n",
" rows.append({\n",
" \"Step\": run_labels[name],\n",
" \"λ_perc\": cfg.get(\"lambda_perceptual\", 0),\n",
" \"λ_adv\": cfg.get(\"lambda_adversarial\", 0),\n",
" \"β_kl\": cfg.get(\"beta_kl\", 1.0),\n",
" \"FID@25 (prior)\": get_fid(r, 25),\n",
" \"FID@50 (prior)\": get_fid(r, 50),\n",
" \"FID@75 (prior)\": get_fid(r, 75),\n",
" \"FID@100 (prior)\": get_fid(r, 100),\n",
" })\n",
"\n",
"df = pd.DataFrame(rows)\n",
"df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})"
]
},
{
"cell_type": "markdown",
"id": "b0000007",
"metadata": {},
"source": [
"## 3. FID Curves — Evolution Story"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000008",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(10, 5))\n",
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\"]\n",
"\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
" epochs = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in epochs]\n",
" label = f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\"\n",
" ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"FID (lower is better) — prior samples\")\n",
"ax.set_title(\"Phase 3 — FID Curves: VAE → +Perceptual → +PatchGAN\")\n",
"ax.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000009",
"metadata": {},
"source": [
"## 4. Training Loss Curves"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000010",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(2, 3, figsize=(16, 9))\n",
"axes = axes.flatten()\n",
"keys = [\"recon_loss\", \"kl_loss\", \"perc_loss\", \"adv_g_loss\", \"adv_d_loss\"]\n",
"titles = [\"MSE Reconstruction\", \"KL Divergence\", \"Perceptual (VGG)\", \"Adv G Loss\", \"Adv D Loss (PatchGAN)\"]\n",
"\n",
"for ax, key, title in zip(axes, keys, titles):\n",
" for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" h = runs[name][\"history\"].get(key, [])\n",
" if any(v != 0.0 for v in h):\n",
" ax.plot(range(1, len(h)+1), h, label=run_labels[name],\n",
" color=colors[i], linewidth=1.2, alpha=0.9)\n",
" ax.set_title(title)\n",
" ax.set_xlabel(\"Epoch\")\n",
" ax.legend(fontsize=8)\n",
"\n",
"axes[-1].axis(\"off\") # empty sixth panel\n",
"fig.suptitle(\"Phase 3 — Training Dynamics\", fontsize=13, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000011",
"metadata": {},
"source": [
"## 5. Prior Samples — Epoch 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000012",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=9)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 3 — Epoch 100 Prior Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000013",
"metadata": {},
"source": [
"## 6. Reconstructions — Epoch 100\n",
"\n",
"Left half = real images, right half = reconstructions (interleaved pairs)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000014",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100_recon.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nreal | recon\", fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=9)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 3 — Epoch 100 Reconstructions\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000015",
"metadata": {},
"source": [
"## 7. Step-by-step Pairwise Comparisons"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000016",
"metadata": {},
"outputs": [],
"source": [
"transitions = [\n",
" (\"3.1→3.2: MSE→+Perceptual\", \"p3_1_vae\", \"p3_2_vae_perceptual\"),\n",
" (\"3.2→3.3: +PatchGAN adversarial\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"),\n",
"]\n",
"\n",
"for title, name_a, name_b in transitions:\n",
" fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
" for col, name in enumerate([name_a, name_b]):\n",
" for row, suffix in enumerate([\"\", \"_recon\"]):\n",
" ax = axes[row][col]\n",
" ep = 100\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}{suffix}.png\"\n",
" label = run_labels[name]\n",
" kind = \"prior\" if suffix == \"\" else \"recon\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], ep) if (suffix == \"\" and name in runs) else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{label}\\n{kind}\" + (f\" FID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(f\"{label} ({kind})\", fontsize=9)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000017",
"metadata": {},
"source": [
"## 8. Progression: Epoch 10 → 50 → 100 (prior samples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000018",
"metadata": {},
"outputs": [],
"source": [
"check_epochs = [10, 50, 100]\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(runs[name], ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
" transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{run_labels[name]} — Prior Sample Progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000019",
"metadata": {},
"source": [
"## 9. Conclusions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000020",
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 70)\n",
"print(\"PHASE 3 — VAE EVOLUTION SUMMARY\")\n",
"print(\"=\" * 70)\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
" continue\n",
" h = runs[name][\"history\"]\n",
" fid100 = get_fid(runs[name], 100)\n",
" fid50 = get_fid(runs[name], 50)\n",
" mse50 = h[\"recon_loss\"][49] if len(h[\"recon_loss\"]) > 49 else None\n",
" kl50 = h[\"kl_loss\"][49] if len(h[\"kl_loss\"]) > 49 else None\n",
" print(f\"\\n {run_labels[name]}:\")\n",
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
" print(f\" MSE@50 = {mse50:.4f}\" if mse50 else \" MSE@50 = ?\")\n",
" print(f\" KL@50 = {kl50:.2f}\" if kl50 else \" KL@50 = ?\")\n",
"\n",
"print(\"\\n\" + \"=\" * 70)\n",
"print(\"Best VAE model for Phase 5 cross-family comparison: fill in after runs.\")\n",
"print(\"=\" * 70)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+403
View File
@@ -0,0 +1,403 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c0000001",
"metadata": {},
"source": [
"# Phase 4 — DDPM Evolution Analysis\n",
"\n",
"Traces the DDPM improvement story:\n",
"\n",
"| Step | Model | Key change | Expected failure |\n",
"|------|-------|------------|------------------|\n",
"| 4.1 | DDPM linear + ε-pred | Baseline | Noise prediction unstable at very low t (linear schedule over-denoises) |\n",
"| 4.2 | + cosine schedule | Less noise wasted at low timesteps | Residual instability from ε parameterisation |\n",
"| 4.3 | + v-prediction | Numerically stable across full trajectory | Possible underfitting at 64×64 |\n",
"| 4.4 | + wider U-Net (192ch) + 32×32 attention | More capacity and longer-range context | — |\n",
"\n",
"FID is computed via DDIM (100 steps, deterministic) from the EMA model.\n",
"H-flip-only augmentation, MTCNN-aligned 64×64 crops, T=1000."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000002",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\""
]
},
{
"cell_type": "markdown",
"id": "c0000003",
"metadata": {},
"source": [
"## 1. Load experiment logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000004",
"metadata": {},
"outputs": [],
"source": [
"run_names = [\"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"]\n",
"run_labels = {\n",
" \"p4_1_ddpm_linear\": \"4.1 linear + ε\",\n",
" \"p4_2_ddpm_cosine\": \"4.2 cosine + ε\",\n",
" \"p4_3_ddpm_vpred\": \"4.3 cosine + v\",\n",
" \"p4_4_ddpm_wider\": \"4.4 wider + 32×32 attn\",\n",
"}\n",
"\n",
"runs = {}\n",
"for name in run_names:\n",
" log_path = LOGS / f\"{name}.json\"\n",
" if log_path.exists():\n",
" with open(log_path) as f:\n",
" runs[name] = json.load(f)\n",
" else:\n",
" print(f\" Missing: {log_path}\")\n",
"\n",
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
"for name in run_names:\n",
" print(f\" {'✓' if name in runs else '✗'} {name}\")"
]
},
{
"cell_type": "markdown",
"id": "c0000005",
"metadata": {},
"source": [
"## 2. FID Comparison Table"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000006",
"metadata": {},
"outputs": [],
"source": [
"def get_fid(run, epoch):\n",
" fid = run[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"rows = []\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" r = runs[name]\n",
" cfg = r[\"config\"]\n",
" rows.append({\n",
" \"Step\": run_labels[name],\n",
" \"Schedule\": cfg.get(\"noise_schedule\"),\n",
" \"Pred\": cfg.get(\"pred_type\"),\n",
" \"base_ch\": cfg.get(\"base_ch\"),\n",
" \"FID@25\": get_fid(r, 25),\n",
" \"FID@50\": get_fid(r, 50),\n",
" \"FID@75\": get_fid(r, 75),\n",
" \"FID@100\": get_fid(r, 100),\n",
" })\n",
"\n",
"df = pd.DataFrame(rows)\n",
"df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})"
]
},
{
"cell_type": "markdown",
"id": "c0000007",
"metadata": {},
"source": [
"## 3. FID Curves — Evolution Story"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000008",
"metadata": {},
"outputs": [],
"source": [
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n",
"\n",
"fig, ax = plt.subplots(figsize=(11, 5))\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
" epochs = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in epochs]\n",
" fid100 = fid_dict.get(\"100\", \"?\")\n",
" label = f\"{run_labels[name]} (FID@100={fid100:.1f})\" if isinstance(fid100, float) else run_labels[name]\n",
" ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"FID (DDIM 100 steps, lower is better)\")\n",
"ax.set_title(\"Phase 4 — FID Curves: linear·ε → cosine·ε → cosine·v → wider\")\n",
"ax.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000009",
"metadata": {},
"source": [
"## 4. Training Loss Curves (MSE on predicted target)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000010",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(11, 4))\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" losses = runs[name][\"history\"][\"loss\"]\n",
" ax.plot(range(1, len(losses)+1), losses, label=run_labels[name],\n",
" color=colors[i], linewidth=1.2, alpha=0.9)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"MSE (noise / v prediction loss)\")\n",
"ax.set_title(\"Phase 4 — Training Loss\")\n",
"ax.legend(fontsize=9)\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000011",
"metadata": {},
"source": [
"## 5. Sample Grids — Epoch 100 (DDIM 50 steps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000012",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=8)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 4 — Epoch 100 DDIM Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000013",
"metadata": {},
"source": [
"## 6. Step-by-step Pairwise Comparisons"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000014",
"metadata": {},
"outputs": [],
"source": [
"transitions = [\n",
" (\"4.1→4.2: linear→cosine schedule\", \"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\"),\n",
" (\"4.2→4.3: ε-pred→v-pred\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\"),\n",
" (\"4.3→4.4: 128ch→192ch + 32×32 attn\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"),\n",
"]\n",
"\n",
"for title, name_a, name_b in transitions:\n",
" fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
" for ax, name in zip(axes, [name_a, name_b]):\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=9)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000015",
"metadata": {},
"source": [
"## 7. Progression: Epoch 10 → 50 → 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000016",
"metadata": {},
"outputs": [],
"source": [
"check_epochs = [10, 50, 100]\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(runs[name], ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
" transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{run_labels[name]} — Progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000017",
"metadata": {},
"source": [
"## 8. Noise Schedule Visualisation\n",
"\n",
"Illustrates why cosine outperforms linear: the linear schedule allocates many timesteps near t=T where the image is already near-pure noise, wasting model capacity."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000018",
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import numpy as np\n",
"\n",
"T = 1000\n",
"t = np.arange(T)\n",
"\n",
"# Linear betas\n",
"betas_lin = np.linspace(1e-4, 0.02, T)\n",
"ab_lin = np.cumprod(1 - betas_lin)\n",
"\n",
"# Cosine betas\n",
"s = 0.008\n",
"f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2\n",
"f = f / f[0]\n",
"betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)\n",
"ab_cos = np.cumprod(1 - betas_cos)\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
"\n",
"axes[0].plot(t, ab_lin, label=\"linear\", color=\"#5B8DB8\", linewidth=2)\n",
"axes[0].plot(t, ab_cos, label=\"cosine\", color=\"#E8705A\", linewidth=2)\n",
"axes[0].set_xlabel(\"Timestep t\")\n",
"axes[0].set_ylabel(\"ᾱ_t (signal fraction)\")\n",
"axes[0].set_title(\"ᾱ_t vs t — cosine stays informative longer\")\n",
"axes[0].legend()\n",
"\n",
"axes[1].plot(t, np.sqrt(ab_lin), label=\"linear √ᾱ\", color=\"#5B8DB8\", linewidth=2)\n",
"axes[1].plot(t, np.sqrt(ab_cos), label=\"cosine √ᾱ\", color=\"#E8705A\", linewidth=2)\n",
"axes[1].set_xlabel(\"Timestep t\")\n",
"axes[1].set_ylabel(\"√ᾱ_t (signal amplitude)\")\n",
"axes[1].set_title(\"Signal amplitude — cosine is more uniform\")\n",
"axes[1].legend()\n",
"\n",
"plt.suptitle(\"Noise Schedule Comparison\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000019",
"metadata": {},
"source": [
"## 9. Conclusions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000020",
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 70)\n",
"print(\"PHASE 4 — DDPM EVOLUTION SUMMARY\")\n",
"print(\"=\" * 70)\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
" continue\n",
" fid100 = get_fid(runs[name], 100)\n",
" fid50 = get_fid(runs[name], 50)\n",
" h = runs[name][\"history\"]\n",
" loss50 = h[\"loss\"][49] if len(h[\"loss\"]) > 49 else None\n",
" print(f\"\\n {run_labels[name]}:\")\n",
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
" print(f\" Loss@50 = {loss50:.5f}\" if loss50 else \" Loss@50 = ?\")\n",
"\n",
"print(\"\\n\" + \"=\" * 70)\n",
"print(\"Best DDPM model for Phase 5 comparison: fill in after runs complete.\")\n",
"print(\"=\" * 70)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+669
View File
@@ -0,0 +1,669 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Phase 5 — Cross-Family Comparison\n",
"\n",
"Best-of-each finalist retrained for **200 epochs** under identical data conditions.\n",
"\n",
"| Family | Config | Resolution | Key design |\n",
"|--------|--------|-----------|------------|\n",
"| GAN | `p5_gan` | 128×128 | WGAN-GP + SpectralNorm + GroupNorm + Self-Attention |\n",
"| VAE | `p5_vae` | 64×64 | Convolutional VAE + VGG perceptual + PatchGAN |\n",
"| DDPM | `p5_ddpm` | 64×64 | Wider U-Net (192ch) + cosine schedule + v-prediction |\n",
"\n",
"> **Resolution note**: GAN runs at 128×128 (best architecture from Phase 2.4), while VAE and DDPM run at 64×64. FID is measured at each model's native resolution against real images at that resolution, so scores are not directly numerically comparable across families — they are indicators of within-family improvement."
],
"id": "d0000001"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import json\n",
"import sys\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\"\n",
"\n",
"# Phase 5 finalists and their best phase-2/3/4 counterparts (100-epoch comparison)\n",
"FAMILIES = {\n",
" \"GAN\": {\"p5\": \"p5_gan\", \"p4ep\": \"p2_4_wgan_sn_attn_128\", \"label\": \"WGAN-GP+SN+Attn\", \"color\": \"#5B8DB8\"},\n",
" \"VAE\": {\"p5\": \"p5_vae\", \"p4ep\": \"p3_3_vae_patchgan\", \"label\": \"VAE+Perc+PatchGAN\",\"color\": \"#E8705A\"},\n",
" \"DDPM\": {\"p5\": \"p5_ddpm\", \"p4ep\": \"p4_4_ddpm_wider\", \"label\": \"DDPM wider 192ch\", \"color\": \"#6ABF69\"},\n",
"}"
],
"execution_count": null,
"outputs": [],
"id": "d0000002"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Load logs"
],
"id": "d0000003"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"def load_log(run_name):\n",
" p = LOGS / f\"{run_name}.json\"\n",
" if p.exists():\n",
" with open(p) as f:\n",
" return json.load(f)\n",
" return None\n",
"\n",
"def get_fid(log, epoch):\n",
" if log is None:\n",
" return None\n",
" fid = log[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"logs_p5 = {fam: load_log(info[\"p5\"]) for fam, info in FAMILIES.items()}\n",
"logs_p4 = {fam: load_log(info[\"p4ep\"]) for fam, info in FAMILIES.items()}\n",
"\n",
"for fam in FAMILIES:\n",
" p5_ok = \"✓\" if logs_p5[fam] else \"✗\"\n",
" p4_ok = \"✓\" if logs_p4[fam] else \"✗\"\n",
" print(f\" {fam}: 200ep={p5_ok} 100ep={p4_ok}\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000004"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Quantitative Summary Table"
],
"id": "d0000005"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"rows = []\n",
"for fam, info in FAMILIES.items():\n",
" log = logs_p5[fam]\n",
" train_time = log.get(\"history\", {}).get(\"train_time_s\") if log else None\n",
" train_min = f\"{train_time / 60:.1f}\" if train_time else \"?\"\n",
" rows.append({\n",
" \"Family\": fam,\n",
" \"Model\": info[\"label\"],\n",
" \"Res\": log.get(\"config\", log).get(\"image_size\", \"?\") if log else \"?\",\n",
" \"Params\": log.get(\"n_params\") if log else None,\n",
" \"Train (min)\": train_min,\n",
" \"FID@100\": get_fid(log, 100),\n",
" \"FID@150\": get_fid(log, 150),\n",
" \"FID@200\": get_fid(log, 200),\n",
" # IS and LPIPS filled in by Section 6\n",
" \"IS ↑\": None,\n",
" \"LPIPS ↑\": None,\n",
" })\n",
"\n",
"df = pd.DataFrame(rows).set_index(\"Family\")\n",
"\n",
"\n",
"def fmt_params(v):\n",
" if v is None:\n",
" return \"?\"\n",
" if v >= 1_000_000:\n",
" return f\"{v / 1_000_000:.1f}M\"\n",
" if v >= 1_000:\n",
" return f\"{v / 1_000:.0f}K\"\n",
" return str(v)\n",
"\n",
"\n",
"df_display = df.copy()\n",
"df_display[\"Params\"] = df_display[\"Params\"].apply(fmt_params)\n",
"df_display.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df_display})"
],
"execution_count": null,
"outputs": [],
"id": "d0000006"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. FID Curves — All Three Families"
],
"id": "d0000007"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" c = info[\"color\"]\n",
" # Phase 5 (200ep) — solid\n",
" log = logs_p5[fam]\n",
" if log:\n",
" fid_dict = log[\"history\"][\"fid\"]\n",
" eps = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in eps]\n",
" axes[0].plot(eps, fids, \"-o\", color=c, linewidth=2, markersize=6,\n",
" label=f\"{fam} 200ep (FID@200={fid_dict.get('200','?'):.1f})\" if isinstance(fid_dict.get('200'), float) else f\"{fam} 200ep\")\n",
"\n",
" # Phase 4/3/2 best (100ep) — dashed, same colour\n",
" log4 = logs_p4[fam]\n",
" if log4:\n",
" fid4 = log4[\"history\"][\"fid\"]\n",
" eps4 = sorted(int(k) for k in fid4)\n",
" fids4 = [fid4[str(e)] for e in eps4]\n",
" axes[0].plot(eps4, fids4, \"--\", color=c, linewidth=1.2, alpha=0.55,\n",
" label=f\"{fam} 100ep\")\n",
"\n",
"axes[0].set_xlabel(\"Epoch\")\n",
"axes[0].set_ylabel(\"FID (lower is better)\")\n",
"axes[0].set_title(\"FID Curves — Phase 5 (solid) vs best 100-ep (dashed)\")\n",
"axes[0].legend(fontsize=8)\n",
"\n",
"# Bar chart: FID at 100 vs 200 epochs per family\n",
"fams = list(FAMILIES.keys())\n",
"fid100 = [get_fid(logs_p5[f], 100) for f in fams]\n",
"fid200 = [get_fid(logs_p5[f], 200) for f in fams]\n",
"x = np.arange(len(fams))\n",
"w = 0.35\n",
"bars1 = axes[1].bar(x - w/2, [v or 0 for v in fid100], w, label=\"FID@100\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=0.6)\n",
"bars2 = axes[1].bar(x + w/2, [v or 0 for v in fid200], w, label=\"FID@200\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=1.0)\n",
"axes[1].set_xticks(x); axes[1].set_xticklabels(fams)\n",
"axes[1].set_ylabel(\"FID\"); axes[1].set_title(\"FID: 100ep vs 200ep per family\")\n",
"axes[1].legend()\n",
"for bar in list(bars1) + list(bars2):\n",
" h = bar.get_height()\n",
" if h > 0:\n",
" axes[1].text(bar.get_x() + bar.get_width()/2, h + 1, f\"{h:.0f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000008"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Sample Grids — Epoch 200"
],
"id": "d0000009"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(16, 6))\n",
"\n",
"for idx, (fam, info) in enumerate(FAMILIES.items()):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / info[\"p5\"] / \"epoch_0200.png\"\n",
" log = logs_p5[fam]\n",
" fid200 = get_fid(log, 200)\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" title = f\"{fam}: {info['label']}\\n\"\n",
" if fid200:\n",
" res = log[\"config\"].get(\"image_size\", \"?\")\n",
" title += f\"FID@200={fid200:.1f} ({res}×{res})\"\n",
" ax.set_title(title, fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes, fontsize=12)\n",
" ax.set_title(f\"{fam}: {info['label']}\", fontsize=9)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 5 — Epoch 200 Sample Grids (4×4, prior samples)\", fontsize=13, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000010"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Training Progression — Epoch 10 → 50 → 100 → 200"
],
"id": "d0000011"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"check_epochs = [10, 50, 100, 200]\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" run = info[\"p5\"]\n",
" log = logs_p5[fam]\n",
" if log is None:\n",
" print(f\"{fam}: not yet run\")\n",
" continue\n",
"\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(16, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / run / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(log, ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
" transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{fam} ({info['label']}) — 200-epoch progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000012"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Extended Metrics — IS and LPIPS Diversity\n",
"\n",
"Requires loading trained model weights. Run after all phase 5 models have finished.\n",
"Generates 5 000 samples per model and computes IS and LPIPS over 200 random pairs."
],
"id": "d0000013"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import sys\n",
"sys.path.insert(0, \"..\")\n",
"\n",
"import torch\n",
"from src.utils import load_config\n",
"from src.models import get_model\n",
"from src.training.metrics import compute_is, compute_lpips_diversity\n",
"from src.training.diffusion import cosine_betas, make_alpha_bars, ddim_sample\n",
"from src.training.ema import EMA\n",
"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"N_SAMPLE = 5_000\n",
"\n",
"def load_ema_model(run_name, config_path):\n",
" \"\"\"Load the best EMA weights for a given phase-5 run.\"\"\"\n",
" cfg = load_config(str(Path(\"../configs/phase5\") / config_path))\n",
" model_obj, kind = get_model(cfg)\n",
" ema_path = Path(\"../outputs/models\") / f\"{run_name}_best_ema.pt\"\n",
" if not ema_path.exists():\n",
" ema_path = Path(\"../outputs/models\") / f\"{run_name}_final_ema.pt\"\n",
" if isinstance(model_obj, tuple):\n",
" model = model_obj[0] # generator for GAN\n",
" else:\n",
" model = model_obj\n",
" model.load_state_dict(torch.load(ema_path, map_location=DEVICE))\n",
" return model.to(DEVICE).eval(), cfg, kind\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def generate_samples(run_name, config_path, n=N_SAMPLE):\n",
" model, cfg, kind = load_ema_model(run_name, config_path)\n",
" image_size = cfg.get(\"image_size\", 64)\n",
"\n",
" if kind == \"wgan\":\n",
" latent_dim = cfg.get(\"latent_dim\", 128)\n",
" imgs = torch.cat([\n",
" model(torch.randn(min(64, n - i), latent_dim, 1, 1, device=DEVICE))\n",
" for i in range(0, n, 64)\n",
" ])[:n].cpu()\n",
"\n",
" elif kind == \"vae\":\n",
" imgs = torch.cat([\n",
" model.sample(min(64, n - i), DEVICE)\n",
" for i in range(0, n, 64)\n",
" ])[:n].cpu()\n",
"\n",
" elif kind == \"ddpm\":\n",
" schedule = cfg.get(\"noise_schedule\", \"cosine\")\n",
" pred_type = cfg.get(\"pred_type\", \"v\")\n",
" T = cfg.get(\"T\", 1000)\n",
" from src.training.diffusion import cosine_betas, linear_betas, make_alpha_bars, ddim_sample\n",
" betas = (cosine_betas(T) if schedule == \"cosine\" else linear_betas(T)).to(DEVICE)\n",
" ab = make_alpha_bars(betas)\n",
" imgs = ddim_sample(model, n, image_size, ab, n_steps=100,\n",
" pred_type=pred_type, device=DEVICE, batch_size=32)\n",
" return imgs\n",
"\n",
"\n",
"print(\"Run this cell once all phase-5 models have completed.\")\n",
"print(\"Expected to take ~510 minutes per family on an RTX 3090.\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000014"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# ── Compute IS and LPIPS per family ──────────────────────────────────────────\n",
"# Uncomment when models are ready.\n",
"\n",
"extended_metrics = {}\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" run = info[\"p5\"]\n",
" config = f\"{run}.json\"\n",
" try:\n",
" print(f\"\\n{fam}: generating {N_SAMPLE} samples...\")\n",
" imgs = generate_samples(run, config)\n",
"\n",
" print(f\" Computing IS...\")\n",
" is_mean, is_std = compute_is(imgs, device=DEVICE)\n",
"\n",
" print(f\" Computing LPIPS diversity...\")\n",
" lpips = compute_lpips_diversity(imgs, n_pairs=200, device=DEVICE)\n",
"\n",
" extended_metrics[fam] = {\"IS_mean\": is_mean, \"IS_std\": is_std, \"LPIPS\": lpips}\n",
" print(f\" IS = {is_mean:.2f} ± {is_std:.2f} LPIPS = {lpips:.4f}\")\n",
"\n",
" except Exception as e:\n",
" print(f\" Skipped ({e})\")\n",
" extended_metrics[fam] = {}\n",
"\n",
"# Merge with FID table\n",
"for fam in FAMILIES:\n",
" em = extended_metrics.get(fam, {})\n",
" idx = list(FAMILIES.keys()).index(fam)\n",
" df.loc[fam, \"IS ↑\"] = f\"{em['IS_mean']:.2f}±{em['IS_std']:.2f}\" if \"IS_mean\" in em else \"—\"\n",
" df.loc[fam, \"LPIPS ↑\"] = f\"{em['LPIPS']:.4f}\" if \"LPIPS\" in em else \"—\"\n",
"\n",
"df.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df})"
],
"execution_count": null,
"outputs": [],
"id": "d0000015"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Latent Interpolation — GAN and VAE\n",
"\n",
"Smooth interpolation between two latent codes reveals whether the generator has learned a\n",
"continuous manifold. DDPM has no encoder, so interpolation is done by different noise seeds."
],
"id": "d0000016"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Spherical linear interpolation (slerp)\n",
"def slerp(z1, z2, t):\n",
" z1_n = z1 / z1.norm()\n",
" z2_n = z2 / z2.norm()\n",
" omega = torch.acos((z1_n * z2_n).sum().clamp(-1, 1))\n",
" if omega.abs() < 1e-6:\n",
" return (1 - t) * z1 + t * z2\n",
" return (torch.sin((1-t)*omega)/torch.sin(omega)) * z1 + \\\n",
" (torch.sin(t*omega)/torch.sin(omega)) * z2\n",
"\n",
"\n",
"def gan_interpolation(model, latent_dim, n_steps=10, device=DEVICE):\n",
" z1 = torch.randn(1, latent_dim, 1, 1, device=device)\n",
" z2 = torch.randn(1, latent_dim, 1, 1, device=device)\n",
" alphas = torch.linspace(0, 1, n_steps)\n",
" imgs = []\n",
" with torch.no_grad():\n",
" for a in alphas:\n",
" z = slerp(z1.flatten(), z2.flatten(), a.item()).view_as(z1)\n",
" imgs.append(model(z).cpu())\n",
" return torch.cat(imgs)\n",
"\n",
"\n",
"def vae_interpolation(model, real_imgs, n_steps=10, device=DEVICE):\n",
" \"\"\"Encode two real images, interpolate in latent space, decode.\"\"\"\n",
" img1, img2 = real_imgs[:1].to(device), real_imgs[1:2].to(device)\n",
" with torch.no_grad():\n",
" mu1, _ = model.encode(img1)\n",
" mu2, _ = model.encode(img2)\n",
" alphas = torch.linspace(0, 1, n_steps, device=device)\n",
" imgs = [model.decode((1-a)*mu1 + a*mu2).cpu() for a in alphas]\n",
" return torch.cat(imgs)\n",
"\n",
"\n",
"print(\"Interpolation helpers defined. Run cells below after loading models.\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000017"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# ── GAN interpolation ─────────────────────────────────────────────────────────\n",
"try:\n",
" gan_model, gan_cfg, _ = load_ema_model(\"p5_gan\", \"p5_gan.json\")\n",
" latent_dim = gan_cfg.get(\"latent_dim\", 128)\n",
" interp_imgs = gan_interpolation(gan_model, latent_dim, n_steps=10)\n",
" interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n",
"\n",
" fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n",
" for ax, img in zip(axes, interp_imgs):\n",
" ax.imshow(img.permute(1, 2, 0).numpy())\n",
" ax.axis(\"off\")\n",
" fig.suptitle(\"GAN — Slerp latent interpolation (z₁ → z₂)\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
"except Exception as e:\n",
" print(f\"GAN interpolation: {e}\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000018"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# ── VAE interpolation ─────────────────────────────────────────────────────────\n",
"try:\n",
" from src.data import GeneratorDataset, get_transform\n",
"\n",
" vae_model, vae_cfg, _ = load_ema_model(\"p5_vae\", \"p5_vae.json\")\n",
" ds = GeneratorDataset(\"../../\" + vae_cfg[\"data_dir\"],\n",
" sources=vae_cfg.get(\"sources\", [\"wiki\"]),\n",
" transform=get_transform(vae_cfg[\"image_size\"], augment=False))\n",
" sample_real = torch.stack([ds[i] for i in range(2)])\n",
"\n",
" interp_imgs = vae_interpolation(vae_model, sample_real, n_steps=10)\n",
" interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n",
"\n",
" fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n",
" for ax, img in zip(axes, interp_imgs):\n",
" ax.imshow(img.permute(1, 2, 0).numpy())\n",
" ax.axis(\"off\")\n",
" fig.suptitle(\"VAE — μ-space linear interpolation (image₁ → image₂)\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
"except Exception as e:\n",
" print(f\"VAE interpolation: {e}\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000019"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Failure Mode Analysis\n",
"\n",
"Identify the worst-generated images per model: those farthest from their nearest real neighbour in pixel space, or simply those with highest reconstruction error (VAE) or highest DDPM loss."
],
"id": "d0000020"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# For GAN and DDPM: generate 256 images, pick the 8 with lowest mean activation\n",
"# (a proxy for less-coherent images — very rough)\n",
"\n",
"def worst_samples(imgs, n=8):\n",
" \"\"\"Heuristic: pick samples with lowest per-pixel mean (often darker / less structured).\"\"\"\n",
" scores = imgs.mean(dim=[1, 2, 3]) # mean brightness per image\n",
" worst_idx = scores.argsort()[:n]\n",
" return imgs[worst_idx]\n",
"\n",
"print(\"Failure mode analysis requires generated samples — run after model loading (Section 6).\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000021"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Optionally run with already-generated `imgs` from Section 6:\n",
"# worst = worst_samples(imgs.cpu())\n",
"# worst = (worst.clamp(-1,1) + 1) / 2\n",
"# fig, axes = plt.subplots(1, 8, figsize=(18, 2.5))\n",
"# for ax, img in zip(axes, worst):\n",
"# ax.imshow(img.permute(1,2,0)); ax.axis('off')\n",
"# plt.suptitle('Failure modes (lowest-brightness samples)', fontsize=11)\n",
"# plt.tight_layout(); plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000022"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Training Loss Overview (all families)"
],
"id": "d0000023"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(18, 4))\n",
"\n",
"for ax, (fam, info) in zip(axes, FAMILIES.items()):\n",
" log = logs_p5[fam]\n",
" if log is None:\n",
" ax.set_title(f\"{fam} (not yet run)\"); ax.axis(\"off\"); continue\n",
" h = log[\"history\"]\n",
" c = info[\"color\"]\n",
"\n",
" if fam == \"GAN\":\n",
" ax.plot(h[\"g_loss\"], label=\"G loss\", color=c, linewidth=1.2)\n",
" ax.plot(h[\"w_dist\"], label=\"W-dist\", color=c, linewidth=1.2, linestyle=\"--\")\n",
" ax.set_ylabel(\"Loss / W-distance\")\n",
" elif fam == \"VAE\":\n",
" ax.plot(h[\"recon_loss\"], label=\"MSE\", color=c)\n",
" ax2 = ax.twinx()\n",
" ax2.plot(h[\"kl_loss\"], label=\"KL\", color=\"grey\", linestyle=\"--\")\n",
" ax2.set_ylabel(\"KL\", color=\"grey\")\n",
" ax.set_ylabel(\"MSE\")\n",
" elif fam == \"DDPM\":\n",
" ax.plot(h[\"loss\"], label=\"MSE (v-pred)\", color=c)\n",
" ax.set_ylabel(\"MSE loss\")\n",
"\n",
" ax.set_xlabel(\"Epoch\")\n",
" ax.set_title(f\"{fam}: {info['label']}\")\n",
" ax.legend(fontsize=8)\n",
"\n",
"fig.suptitle(\"Phase 5 — Training Dynamics (200 epochs)\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000024"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10. Conclusions"
],
"id": "d0000025"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"print(\"=\" * 72)\n",
"print(\"PHASE 5 — CROSS-FAMILY COMPARISON (200 epochs)\")\n",
"print(\"=\" * 72)\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" log = logs_p5[fam]\n",
" em = extended_metrics.get(fam, {})\n",
" print(f\"\\n ── {fam}: {info['label']} ──\")\n",
" if log:\n",
" res = log.get(\"config\", log).get(\"image_size\", \"?\")\n",
" print(f\" Resolution : {res}×{res}\")\n",
" n_p = log.get(\"n_params\")\n",
" print(f\" Params : {n_p:,}\" if n_p else \" Params : ?\")\n",
" tt = log.get(\"history\", {}).get(\"train_time_s\")\n",
" print(f\" Train time : {tt / 60:.1f} min\" if tt else \" Train time : ?\")\n",
" for ep in (100, 150, 200):\n",
" fid = get_fid(log, ep)\n",
" print(f\" FID@{ep:<3} : {fid:.1f}\" if fid else f\" FID@{ep:<3} : ?\")\n",
" if em:\n",
" print(f\" IS : {em.get('IS_mean','?'):.2f} ± {em.get('IS_std','?'):.2f}\" if 'IS_mean' in em else \" IS : ?\")\n",
" print(f\" LPIPS div : {em.get('LPIPS','?'):.4f}\" if 'LPIPS' in em else \" LPIPS div : ?\")\n",
" else:\n",
" print(\" IS / LPIPS : (run Section 6 to compute)\")\n",
"\n",
"print(\"\\n\" + \"=\" * 72)\n",
"print(\"Narrative to fill in after results:\")\n",
"print(\" - Which family achieves best FID?\")\n",
"print(\" - GAN: fast convergence but mode collapse risk?\")\n",
"print(\" - VAE: blurry priors improved by perceptual+adversarial loss?\")\n",
"print(\" - DDPM: highest quality but slowest inference (100 DDIM steps)?\")\n",
"print(\"=\" * 72)"
],
"execution_count": null,
"outputs": [],
"id": "d0000026"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}