Preview of phase 2-5 implementation; needs a full check
This commit is contained in:
@@ -0,0 +1,366 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0000001",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Phase 2 — GAN Evolution Analysis\n",
|
||||
"\n",
|
||||
"Traces the GAN improvement story, each step motivated by the failure of the previous:\n",
|
||||
"\n",
|
||||
"| Step | Model | Key change | Expected failure |\n",
|
||||
"|------|-------|------------|------------------|\n",
|
||||
"| 2.1 | DCGAN 64×64 | Baseline on best pipeline | Mode collapse, training instability |\n",
|
||||
"| 2.2 | WGAN-GP | Wasserstein loss + GP | Texture artifacts, limited coherence |\n",
|
||||
"| 2.3 | WGAN-GP + SN + GroupNorm + Attn | Principled Lipschitz + long-range deps | Possible underfitting at 64×64 |\n",
|
||||
"| 2.4 | 2.3 @ 128×128 | Scale resolution | ? |\n",
|
||||
"\n",
|
||||
"All runs use the best pipeline from Phase 1: MTCNN-aligned crops, H-flip + rotation + colour jitter, aligned-only dataset."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0000002",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.image as mpimg\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
|
||||
"\n",
|
||||
"OUTPUTS = Path(\"../outputs\")\n",
|
||||
"LOGS = OUTPUTS / \"logs\"\n",
|
||||
"SAMPLES = OUTPUTS / \"samples\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0000003",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Load experiment logs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0000004",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"run_names = [\"p2_1_dcgan\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"]\n",
|
||||
"run_labels = {\n",
|
||||
" \"p2_1_dcgan\": \"2.1 DCGAN\",\n",
|
||||
" \"p2_2_wgan\": \"2.2 WGAN-GP\",\n",
|
||||
" \"p2_3_wgan_sn_attn\": \"2.3 WGAN-GP+SN+Attn\",\n",
|
||||
" \"p2_4_wgan_sn_attn_128\": \"2.4 +128×128\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"runs = {}\n",
|
||||
"for name in run_names:\n",
|
||||
" log_path = LOGS / f\"{name}.json\"\n",
|
||||
" if log_path.exists():\n",
|
||||
" with open(log_path) as f:\n",
|
||||
" runs[name] = json.load(f)\n",
|
||||
" else:\n",
|
||||
" print(f\" Missing: {log_path}\")\n",
|
||||
"\n",
|
||||
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
|
||||
"for name in run_names:\n",
|
||||
" status = \"✓\" if name in runs else \"✗\"\n",
|
||||
" print(f\" {status} {name}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0000005",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. FID Comparison Table"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0000006",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_fid(run, epoch):\n",
|
||||
" fid = run[\"history\"][\"fid\"]\n",
|
||||
" return fid.get(str(epoch), fid.get(epoch, None))\n",
|
||||
"\n",
|
||||
"rows = []\n",
|
||||
"for name in run_names:\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" r = runs[name]\n",
|
||||
" cfg = r[\"config\"]\n",
|
||||
" h = r[\"history\"]\n",
|
||||
" rows.append({\n",
|
||||
" \"Step\": run_labels[name],\n",
|
||||
" \"Model\": cfg.get(\"model\"),\n",
|
||||
" \"Size\": f\"{cfg.get('image_size', 64)}×{cfg.get('image_size', 64)}\",\n",
|
||||
" \"FID@25\": get_fid(r, 25),\n",
|
||||
" \"FID@50\": get_fid(r, 50),\n",
|
||||
" \"FID@75\": get_fid(r, 75),\n",
|
||||
" \"FID@100\": get_fid(r, 100),\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"df = pd.DataFrame(rows)\n",
|
||||
"df.style.format({c: \"{:.1f}\" for c in [\"FID@25\", \"FID@50\", \"FID@75\", \"FID@100\"] if c in df})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0000007",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. FID Curves — Evolution Story"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0000008",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, ax = plt.subplots(figsize=(11, 5))\n",
|
||||
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n",
|
||||
"\n",
|
||||
"for i, name in enumerate(run_names):\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
|
||||
" epochs = sorted(int(k) for k in fid_dict)\n",
|
||||
" fids = [fid_dict[str(e)] for e in epochs]\n",
|
||||
" ax.plot(epochs, fids, \"o-\", label=f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\",\n",
|
||||
" color=colors[i], linewidth=2, markersize=8)\n",
|
||||
"\n",
|
||||
"ax.set_xlabel(\"Epoch\")\n",
|
||||
"ax.set_ylabel(\"FID (lower is better)\")\n",
|
||||
"ax.set_title(\"Phase 2 — FID Curves: DCGAN → WGAN-GP → +SN+Attn → 128×128\")\n",
|
||||
"ax.legend()\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0000009",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Training Dynamics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0000010",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Separate DCGAN (has g_loss/d_loss/d_real/d_fake) from WGAN (has g_loss/w_dist/gp)\n",
|
||||
"dcgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") == \"dcgan\"]\n",
|
||||
"wgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") != \"dcgan\"]\n",
|
||||
"\n",
|
||||
"if dcgan_names:\n",
|
||||
" fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
|
||||
" for name in dcgan_names:\n",
|
||||
" h = runs[name][\"history\"]\n",
|
||||
" epochs = range(1, len(h[\"g_loss\"]) + 1)\n",
|
||||
" axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], linewidth=1.2)\n",
|
||||
" axes[1].plot(epochs, h[\"d_loss\"], label=run_labels[name], linewidth=1.2)\n",
|
||||
" axes[0].set_title(\"DCGAN — Generator Loss (BCE)\")\n",
|
||||
" axes[1].set_title(\"DCGAN — Discriminator Loss\")\n",
|
||||
" for ax in axes:\n",
|
||||
" ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"Loss\"); ax.legend(fontsize=8)\n",
|
||||
" plt.suptitle(\"Phase 2.1 — DCGAN Training Dynamics\", fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"if wgan_names:\n",
|
||||
" fig, axes = plt.subplots(1, 3, figsize=(16, 4))\n",
|
||||
" cmap = plt.cm.Set1\n",
|
||||
" for i, name in enumerate(wgan_names):\n",
|
||||
" h = runs[name][\"history\"]\n",
|
||||
" epochs = range(1, len(h[\"g_loss\"]) + 1)\n",
|
||||
" c = cmap(i / max(len(wgan_names), 1))\n",
|
||||
" axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], color=c, linewidth=1.2)\n",
|
||||
" axes[1].plot(epochs, h[\"w_dist\"], label=run_labels[name], color=c, linewidth=1.2)\n",
|
||||
" axes[2].plot(epochs, h[\"gp\"], label=run_labels[name], color=c, linewidth=1.2)\n",
|
||||
" axes[0].set_title(\"Generator Loss (−E[D(G(z))])\")\n",
|
||||
" axes[1].set_title(\"Wasserstein Distance Est. (↑ better)\")\n",
|
||||
" axes[2].set_title(\"Gradient Penalty\")\n",
|
||||
" for ax in axes:\n",
|
||||
" ax.set_xlabel(\"Epoch\"); ax.legend(fontsize=8)\n",
|
||||
" plt.suptitle(\"Phase 2.2–2.4 — WGAN-GP Training Dynamics\", fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0000011",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Sample Image Grids — Epoch 100"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0000012",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 4, figsize=(18, 5))\n",
|
||||
"\n",
|
||||
"for idx, name in enumerate(run_names):\n",
|
||||
" ax = axes[idx]\n",
|
||||
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" fid = get_fid(runs[name], 100) if name in runs else None\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||||
" ax.set_title(run_labels[name], fontsize=8)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
"\n",
|
||||
"fig.suptitle(\"Phase 2 — Epoch 100 Sample Grids (4×4)\", fontsize=13, fontweight=\"bold\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0000013",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Progression: Epoch 10 → 50 → 100"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0000014",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"check_epochs = [10, 50, 100]\n",
|
||||
"\n",
|
||||
"for name in run_names:\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
|
||||
" for ax, ep in zip(axes, check_epochs):\n",
|
||||
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" fid = get_fid(runs[name], ep)\n",
|
||||
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
" fig.suptitle(f\"{run_labels[name]} — Training Progression\", fontsize=11, fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0000015",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. Step-by-step Pairwise Comparisons"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0000016",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"transitions = [\n",
|
||||
" (\"2.1→2.2: BCE→Wasserstein\", \"p2_1_dcgan\", \"p2_2_wgan\"),\n",
|
||||
" (\"2.2→2.3: +SN+GroupNorm+Attn\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\"),\n",
|
||||
" (\"2.3→2.4: 64→128 resolution\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"for title, name_a, name_b in transitions:\n",
|
||||
" fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
|
||||
" for ax, name in zip(axes, [name_a, name_b]):\n",
|
||||
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" fid = get_fid(runs[name], 100) if name in runs else None\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" ax.set_title(f\"{run_labels[name]}\\nFID@100 = {fid:.1f}\" if fid else run_labels[name], fontsize=10)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||||
" ax.set_title(run_labels[name], fontsize=10)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0000017",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 8. Conclusions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0000018",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"=\" * 70)\n",
|
||||
"print(\"PHASE 2 — GAN EVOLUTION SUMMARY\")\n",
|
||||
"print(\"=\" * 70)\n",
|
||||
"\n",
|
||||
"for name in run_names:\n",
|
||||
" if name not in runs:\n",
|
||||
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
|
||||
" continue\n",
|
||||
" fid100 = get_fid(runs[name], 100)\n",
|
||||
" fid50 = get_fid(runs[name], 50)\n",
|
||||
" print(f\"\\n {run_labels[name]}:\")\n",
|
||||
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
|
||||
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\" * 70)\n",
|
||||
"print(\"Best model for Phase 3/4 comparison: fill in after runs complete.\")\n",
|
||||
"print(\"=\" * 70)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,396 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0000001",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Phase 3 — VAE Evolution Analysis\n",
|
||||
"\n",
|
||||
"Traces the VAE improvement story — each step motivated by the failure of the previous:\n",
|
||||
"\n",
|
||||
"| Step | Model | Key change | Expected failure |\n",
|
||||
"|------|-------|------------|------------------|\n",
|
||||
"| 3.1 | Vanilla VAE (MSE+KL) | Baseline | Blurry samples — MSE minimises pixel average |\n",
|
||||
"| 3.2 | + Perceptual loss (VGG) | Feature-space reconstruction | Residual texture blur |\n",
|
||||
"| 3.3 | + PatchGAN (VQGAN-lite) | Local texture adversarial | — |\n",
|
||||
"\n",
|
||||
"All runs use H-flip-only augmentation and MTCNN-aligned 64×64 crops.\n",
|
||||
"FID is computed from prior samples (`z ~ N(0, I) → decode`), same metric as GAN.\n",
|
||||
"Reconstructions are shown separately to diagnose encoder quality."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0000002",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.image as mpimg\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
|
||||
"\n",
|
||||
"OUTPUTS = Path(\"../outputs\")\n",
|
||||
"LOGS = OUTPUTS / \"logs\"\n",
|
||||
"SAMPLES = OUTPUTS / \"samples\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0000003",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Load experiment logs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0000004",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"run_names = [\"p3_1_vae\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"]\n",
|
||||
"run_labels = {\n",
|
||||
" \"p3_1_vae\": \"3.1 VAE (MSE+KL)\",\n",
|
||||
" \"p3_2_vae_perceptual\": \"3.2 +Perceptual\",\n",
|
||||
" \"p3_3_vae_patchgan\": \"3.3 +PatchGAN\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"runs = {}\n",
|
||||
"for name in run_names:\n",
|
||||
" log_path = LOGS / f\"{name}.json\"\n",
|
||||
" if log_path.exists():\n",
|
||||
" with open(log_path) as f:\n",
|
||||
" runs[name] = json.load(f)\n",
|
||||
" else:\n",
|
||||
" print(f\" Missing: {log_path}\")\n",
|
||||
"\n",
|
||||
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
|
||||
"for name in run_names:\n",
|
||||
" status = \"✓\" if name in runs else \"✗\"\n",
|
||||
" print(f\" {status} {name}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0000005",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. FID Comparison Table (prior samples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0000006",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_fid(run, epoch):\n",
|
||||
" fid = run[\"history\"][\"fid\"]\n",
|
||||
" return fid.get(str(epoch), fid.get(epoch, None))\n",
|
||||
"\n",
|
||||
"rows = []\n",
|
||||
"for name in run_names:\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" r = runs[name]\n",
|
||||
" cfg = r[\"config\"]\n",
|
||||
" h = r[\"history\"]\n",
|
||||
" rows.append({\n",
|
||||
" \"Step\": run_labels[name],\n",
|
||||
" \"λ_perc\": cfg.get(\"lambda_perceptual\", 0),\n",
|
||||
" \"λ_adv\": cfg.get(\"lambda_adversarial\", 0),\n",
|
||||
" \"β_kl\": cfg.get(\"beta_kl\", 1.0),\n",
|
||||
" \"FID@25 (prior)\": get_fid(r, 25),\n",
|
||||
" \"FID@50 (prior)\": get_fid(r, 50),\n",
|
||||
" \"FID@75 (prior)\": get_fid(r, 75),\n",
|
||||
" \"FID@100 (prior)\": get_fid(r, 100),\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"df = pd.DataFrame(rows)\n",
|
||||
"df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0000007",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. FID Curves — Evolution Story"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0000008",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, ax = plt.subplots(figsize=(10, 5))\n",
|
||||
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\"]\n",
|
||||
"\n",
|
||||
"for i, name in enumerate(run_names):\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
|
||||
" epochs = sorted(int(k) for k in fid_dict)\n",
|
||||
" fids = [fid_dict[str(e)] for e in epochs]\n",
|
||||
" label = f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\"\n",
|
||||
" ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n",
|
||||
"\n",
|
||||
"ax.set_xlabel(\"Epoch\")\n",
|
||||
"ax.set_ylabel(\"FID (lower is better) — prior samples\")\n",
|
||||
"ax.set_title(\"Phase 3 — FID Curves: VAE → +Perceptual → +PatchGAN\")\n",
|
||||
"ax.legend()\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0000009",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Training Loss Curves"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0000010",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(2, 3, figsize=(16, 9))\n",
|
||||
"axes = axes.flatten()\n",
|
||||
"keys = [\"recon_loss\", \"kl_loss\", \"perc_loss\", \"adv_g_loss\", \"adv_d_loss\"]\n",
|
||||
"titles = [\"MSE Reconstruction\", \"KL Divergence\", \"Perceptual (VGG)\", \"Adv G Loss\", \"Adv D Loss (PatchGAN)\"]\n",
|
||||
"\n",
|
||||
"for ax, key, title in zip(axes, keys, titles):\n",
|
||||
" for i, name in enumerate(run_names):\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" h = runs[name][\"history\"].get(key, [])\n",
|
||||
" if any(v != 0.0 for v in h):\n",
|
||||
" ax.plot(range(1, len(h)+1), h, label=run_labels[name],\n",
|
||||
" color=colors[i], linewidth=1.2, alpha=0.9)\n",
|
||||
" ax.set_title(title)\n",
|
||||
" ax.set_xlabel(\"Epoch\")\n",
|
||||
" ax.legend(fontsize=8)\n",
|
||||
"\n",
|
||||
"axes[-1].axis(\"off\") # empty sixth panel\n",
|
||||
"fig.suptitle(\"Phase 3 — Training Dynamics\", fontsize=13, fontweight=\"bold\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0000011",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Prior Samples — Epoch 100"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0000012",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
|
||||
"\n",
|
||||
"for idx, name in enumerate(run_names):\n",
|
||||
" ax = axes[idx]\n",
|
||||
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" fid = get_fid(runs[name], 100) if name in runs else None\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||||
" ax.set_title(run_labels[name], fontsize=9)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
"\n",
|
||||
"fig.suptitle(\"Phase 3 — Epoch 100 Prior Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0000013",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Reconstructions — Epoch 100\n",
|
||||
"\n",
|
||||
"Left half = real images, right half = reconstructions (interleaved pairs)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0000014",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
|
||||
"\n",
|
||||
"for idx, name in enumerate(run_names):\n",
|
||||
" ax = axes[idx]\n",
|
||||
" img_path = SAMPLES / name / \"epoch_0100_recon.png\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" ax.set_title(f\"{run_labels[name]}\\nreal | recon\", fontsize=9)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||||
" ax.set_title(run_labels[name], fontsize=9)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
"\n",
|
||||
"fig.suptitle(\"Phase 3 — Epoch 100 Reconstructions\", fontsize=12, fontweight=\"bold\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0000015",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. Step-by-step Pairwise Comparisons"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0000016",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"transitions = [\n",
|
||||
" (\"3.1→3.2: MSE→+Perceptual\", \"p3_1_vae\", \"p3_2_vae_perceptual\"),\n",
|
||||
" (\"3.2→3.3: +PatchGAN adversarial\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"for title, name_a, name_b in transitions:\n",
|
||||
" fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
|
||||
" for col, name in enumerate([name_a, name_b]):\n",
|
||||
" for row, suffix in enumerate([\"\", \"_recon\"]):\n",
|
||||
" ax = axes[row][col]\n",
|
||||
" ep = 100\n",
|
||||
" img_path = SAMPLES / name / f\"epoch_{ep:04d}{suffix}.png\"\n",
|
||||
" label = run_labels[name]\n",
|
||||
" kind = \"prior\" if suffix == \"\" else \"recon\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" fid = get_fid(runs[name], ep) if (suffix == \"\" and name in runs) else None\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" ax.set_title(f\"{label}\\n{kind}\" + (f\" FID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||||
" ax.set_title(f\"{label} ({kind})\", fontsize=9)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0000017",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 8. Progression: Epoch 10 → 50 → 100 (prior samples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0000018",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"check_epochs = [10, 50, 100]\n",
|
||||
"\n",
|
||||
"for name in run_names:\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
|
||||
" for ax, ep in zip(axes, check_epochs):\n",
|
||||
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" fid = get_fid(runs[name], ep)\n",
|
||||
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
|
||||
" transform=ax.transAxes)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
" fig.suptitle(f\"{run_labels[name]} — Prior Sample Progression\", fontsize=11, fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0000019",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 9. Conclusions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0000020",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"=\" * 70)\n",
|
||||
"print(\"PHASE 3 — VAE EVOLUTION SUMMARY\")\n",
|
||||
"print(\"=\" * 70)\n",
|
||||
"\n",
|
||||
"for name in run_names:\n",
|
||||
" if name not in runs:\n",
|
||||
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
|
||||
" continue\n",
|
||||
" h = runs[name][\"history\"]\n",
|
||||
" fid100 = get_fid(runs[name], 100)\n",
|
||||
" fid50 = get_fid(runs[name], 50)\n",
|
||||
" mse50 = h[\"recon_loss\"][49] if len(h[\"recon_loss\"]) > 49 else None\n",
|
||||
" kl50 = h[\"kl_loss\"][49] if len(h[\"kl_loss\"]) > 49 else None\n",
|
||||
" print(f\"\\n {run_labels[name]}:\")\n",
|
||||
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
|
||||
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
|
||||
" print(f\" MSE@50 = {mse50:.4f}\" if mse50 else \" MSE@50 = ?\")\n",
|
||||
" print(f\" KL@50 = {kl50:.2f}\" if kl50 else \" KL@50 = ?\")\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\" * 70)\n",
|
||||
"print(\"Best VAE model for Phase 5 cross-family comparison: fill in after runs.\")\n",
|
||||
"print(\"=\" * 70)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,403 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c0000001",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Phase 4 — DDPM Evolution Analysis\n",
|
||||
"\n",
|
||||
"Traces the DDPM improvement story:\n",
|
||||
"\n",
|
||||
"| Step | Model | Key change | Expected failure |\n",
|
||||
"|------|-------|------------|------------------|\n",
|
||||
"| 4.1 | DDPM linear + ε-pred | Baseline | Noise prediction unstable at very low t (linear schedule over-denoises) |\n",
|
||||
"| 4.2 | + cosine schedule | Less noise wasted at low timesteps | Residual instability from ε parameterisation |\n",
|
||||
"| 4.3 | + v-prediction | Numerically stable across full trajectory | Possible underfitting at 64×64 |\n",
|
||||
"| 4.4 | + wider U-Net (192ch) + 32×32 attention | More capacity and longer-range context | — |\n",
|
||||
"\n",
|
||||
"FID is computed via DDIM (100 steps, deterministic) from the EMA model.\n",
|
||||
"H-flip-only augmentation, MTCNN-aligned 64×64 crops, T=1000."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0000002",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.image as mpimg\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
|
||||
"\n",
|
||||
"OUTPUTS = Path(\"../outputs\")\n",
|
||||
"LOGS = OUTPUTS / \"logs\"\n",
|
||||
"SAMPLES = OUTPUTS / \"samples\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c0000003",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Load experiment logs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0000004",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"run_names = [\"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"]\n",
|
||||
"run_labels = {\n",
|
||||
" \"p4_1_ddpm_linear\": \"4.1 linear + ε\",\n",
|
||||
" \"p4_2_ddpm_cosine\": \"4.2 cosine + ε\",\n",
|
||||
" \"p4_3_ddpm_vpred\": \"4.3 cosine + v\",\n",
|
||||
" \"p4_4_ddpm_wider\": \"4.4 wider + 32×32 attn\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"runs = {}\n",
|
||||
"for name in run_names:\n",
|
||||
" log_path = LOGS / f\"{name}.json\"\n",
|
||||
" if log_path.exists():\n",
|
||||
" with open(log_path) as f:\n",
|
||||
" runs[name] = json.load(f)\n",
|
||||
" else:\n",
|
||||
" print(f\" Missing: {log_path}\")\n",
|
||||
"\n",
|
||||
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
|
||||
"for name in run_names:\n",
|
||||
" print(f\" {'✓' if name in runs else '✗'} {name}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c0000005",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. FID Comparison Table"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0000006",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_fid(run, epoch):\n",
|
||||
" fid = run[\"history\"][\"fid\"]\n",
|
||||
" return fid.get(str(epoch), fid.get(epoch, None))\n",
|
||||
"\n",
|
||||
"rows = []\n",
|
||||
"for name in run_names:\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" r = runs[name]\n",
|
||||
" cfg = r[\"config\"]\n",
|
||||
" rows.append({\n",
|
||||
" \"Step\": run_labels[name],\n",
|
||||
" \"Schedule\": cfg.get(\"noise_schedule\"),\n",
|
||||
" \"Pred\": cfg.get(\"pred_type\"),\n",
|
||||
" \"base_ch\": cfg.get(\"base_ch\"),\n",
|
||||
" \"FID@25\": get_fid(r, 25),\n",
|
||||
" \"FID@50\": get_fid(r, 50),\n",
|
||||
" \"FID@75\": get_fid(r, 75),\n",
|
||||
" \"FID@100\": get_fid(r, 100),\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"df = pd.DataFrame(rows)\n",
|
||||
"df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c0000007",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. FID Curves — Evolution Story"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0000008",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n",
|
||||
"\n",
|
||||
"fig, ax = plt.subplots(figsize=(11, 5))\n",
|
||||
"for i, name in enumerate(run_names):\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
|
||||
" epochs = sorted(int(k) for k in fid_dict)\n",
|
||||
" fids = [fid_dict[str(e)] for e in epochs]\n",
|
||||
" fid100 = fid_dict.get(\"100\", \"?\")\n",
|
||||
" label = f\"{run_labels[name]} (FID@100={fid100:.1f})\" if isinstance(fid100, float) else run_labels[name]\n",
|
||||
" ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n",
|
||||
"\n",
|
||||
"ax.set_xlabel(\"Epoch\")\n",
|
||||
"ax.set_ylabel(\"FID (DDIM 100 steps, lower is better)\")\n",
|
||||
"ax.set_title(\"Phase 4 — FID Curves: linear·ε → cosine·ε → cosine·v → wider\")\n",
|
||||
"ax.legend()\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c0000009",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Training Loss Curves (MSE on predicted target)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0000010",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, ax = plt.subplots(figsize=(11, 4))\n",
|
||||
"for i, name in enumerate(run_names):\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" losses = runs[name][\"history\"][\"loss\"]\n",
|
||||
" ax.plot(range(1, len(losses)+1), losses, label=run_labels[name],\n",
|
||||
" color=colors[i], linewidth=1.2, alpha=0.9)\n",
|
||||
"\n",
|
||||
"ax.set_xlabel(\"Epoch\")\n",
|
||||
"ax.set_ylabel(\"MSE (noise / v prediction loss)\")\n",
|
||||
"ax.set_title(\"Phase 4 — Training Loss\")\n",
|
||||
"ax.legend(fontsize=9)\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c0000011",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Sample Grids — Epoch 100 (DDIM 50 steps)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0000012",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n",
|
||||
"\n",
|
||||
"for idx, name in enumerate(run_names):\n",
|
||||
" ax = axes[idx]\n",
|
||||
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" fid = get_fid(runs[name], 100) if name in runs else None\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||||
" ax.set_title(run_labels[name], fontsize=8)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
"\n",
|
||||
"fig.suptitle(\"Phase 4 — Epoch 100 DDIM Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c0000013",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Step-by-step Pairwise Comparisons"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0000014",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"transitions = [\n",
|
||||
" (\"4.1→4.2: linear→cosine schedule\", \"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\"),\n",
|
||||
" (\"4.2→4.3: ε-pred→v-pred\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\"),\n",
|
||||
" (\"4.3→4.4: 128ch→192ch + 32×32 attn\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"for title, name_a, name_b in transitions:\n",
|
||||
" fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
|
||||
" for ax, name in zip(axes, [name_a, name_b]):\n",
|
||||
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" fid = get_fid(runs[name], 100) if name in runs else None\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||||
" ax.set_title(run_labels[name], fontsize=9)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c0000015",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. Progression: Epoch 10 → 50 → 100"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0000016",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"check_epochs = [10, 50, 100]\n",
|
||||
"\n",
|
||||
"for name in run_names:\n",
|
||||
" if name not in runs:\n",
|
||||
" continue\n",
|
||||
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
|
||||
" for ax, ep in zip(axes, check_epochs):\n",
|
||||
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" fid = get_fid(runs[name], ep)\n",
|
||||
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
|
||||
" transform=ax.transAxes)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
" fig.suptitle(f\"{run_labels[name]} — Progression\", fontsize=11, fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c0000017",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 8. Noise Schedule Visualisation\n",
|
||||
"\n",
|
||||
"Illustrates why cosine outperforms linear: the linear schedule allocates many timesteps near t=T where the image is already near-pure noise, wasting model capacity."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0000018",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import math\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"T = 1000\n",
|
||||
"t = np.arange(T)\n",
|
||||
"\n",
|
||||
"# Linear betas\n",
|
||||
"betas_lin = np.linspace(1e-4, 0.02, T)\n",
|
||||
"ab_lin = np.cumprod(1 - betas_lin)\n",
|
||||
"\n",
|
||||
"# Cosine betas\n",
|
||||
"s = 0.008\n",
|
||||
"f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2\n",
|
||||
"f = f / f[0]\n",
|
||||
"betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)\n",
|
||||
"ab_cos = np.cumprod(1 - betas_cos)\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
|
||||
"\n",
|
||||
"axes[0].plot(t, ab_lin, label=\"linear\", color=\"#5B8DB8\", linewidth=2)\n",
|
||||
"axes[0].plot(t, ab_cos, label=\"cosine\", color=\"#E8705A\", linewidth=2)\n",
|
||||
"axes[0].set_xlabel(\"Timestep t\")\n",
|
||||
"axes[0].set_ylabel(\"ᾱ_t (signal fraction)\")\n",
|
||||
"axes[0].set_title(\"ᾱ_t vs t — cosine stays informative longer\")\n",
|
||||
"axes[0].legend()\n",
|
||||
"\n",
|
||||
"axes[1].plot(t, np.sqrt(ab_lin), label=\"linear √ᾱ\", color=\"#5B8DB8\", linewidth=2)\n",
|
||||
"axes[1].plot(t, np.sqrt(ab_cos), label=\"cosine √ᾱ\", color=\"#E8705A\", linewidth=2)\n",
|
||||
"axes[1].set_xlabel(\"Timestep t\")\n",
|
||||
"axes[1].set_ylabel(\"√ᾱ_t (signal amplitude)\")\n",
|
||||
"axes[1].set_title(\"Signal amplitude — cosine is more uniform\")\n",
|
||||
"axes[1].legend()\n",
|
||||
"\n",
|
||||
"plt.suptitle(\"Noise Schedule Comparison\", fontsize=12, fontweight=\"bold\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c0000019",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 9. Conclusions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c0000020",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"=\" * 70)\n",
|
||||
"print(\"PHASE 4 — DDPM EVOLUTION SUMMARY\")\n",
|
||||
"print(\"=\" * 70)\n",
|
||||
"\n",
|
||||
"for name in run_names:\n",
|
||||
" if name not in runs:\n",
|
||||
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
|
||||
" continue\n",
|
||||
" fid100 = get_fid(runs[name], 100)\n",
|
||||
" fid50 = get_fid(runs[name], 50)\n",
|
||||
" h = runs[name][\"history\"]\n",
|
||||
" loss50 = h[\"loss\"][49] if len(h[\"loss\"]) > 49 else None\n",
|
||||
" print(f\"\\n {run_labels[name]}:\")\n",
|
||||
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
|
||||
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
|
||||
" print(f\" Loss@50 = {loss50:.5f}\" if loss50 else \" Loss@50 = ?\")\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\" * 70)\n",
|
||||
"print(\"Best DDPM model for Phase 5 comparison: fill in after runs complete.\")\n",
|
||||
"print(\"=\" * 70)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,669 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Phase 5 — Cross-Family Comparison\n",
|
||||
"\n",
|
||||
"Best-of-each finalist retrained for **200 epochs** under identical data conditions.\n",
|
||||
"\n",
|
||||
"| Family | Config | Resolution | Key design |\n",
|
||||
"|--------|--------|-----------|------------|\n",
|
||||
"| GAN | `p5_gan` | 128×128 | WGAN-GP + SpectralNorm + GroupNorm + Self-Attention |\n",
|
||||
"| VAE | `p5_vae` | 64×64 | Convolutional VAE + VGG perceptual + PatchGAN |\n",
|
||||
"| DDPM | `p5_ddpm` | 64×64 | Wider U-Net (192ch) + cosine schedule + v-prediction |\n",
|
||||
"\n",
|
||||
"> **Resolution note**: GAN runs at 128×128 (best architecture from Phase 2.4), while VAE and DDPM run at 64×64. FID is measured at each model's native resolution against real images at that resolution, so scores are not directly numerically comparable across families — they are indicators of within-family improvement."
|
||||
],
|
||||
"id": "d0000001"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import sys\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.image as mpimg\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
|
||||
"\n",
|
||||
"OUTPUTS = Path(\"../outputs\")\n",
|
||||
"LOGS = OUTPUTS / \"logs\"\n",
|
||||
"SAMPLES = OUTPUTS / \"samples\"\n",
|
||||
"\n",
|
||||
"# Phase 5 finalists and their best phase-2/3/4 counterparts (100-epoch comparison)\n",
|
||||
"FAMILIES = {\n",
|
||||
" \"GAN\": {\"p5\": \"p5_gan\", \"p4ep\": \"p2_4_wgan_sn_attn_128\", \"label\": \"WGAN-GP+SN+Attn\", \"color\": \"#5B8DB8\"},\n",
|
||||
" \"VAE\": {\"p5\": \"p5_vae\", \"p4ep\": \"p3_3_vae_patchgan\", \"label\": \"VAE+Perc+PatchGAN\",\"color\": \"#E8705A\"},\n",
|
||||
" \"DDPM\": {\"p5\": \"p5_ddpm\", \"p4ep\": \"p4_4_ddpm_wider\", \"label\": \"DDPM wider 192ch\", \"color\": \"#6ABF69\"},\n",
|
||||
"}"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000002"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Load logs"
|
||||
],
|
||||
"id": "d0000003"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"def load_log(run_name):\n",
|
||||
" p = LOGS / f\"{run_name}.json\"\n",
|
||||
" if p.exists():\n",
|
||||
" with open(p) as f:\n",
|
||||
" return json.load(f)\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
"def get_fid(log, epoch):\n",
|
||||
" if log is None:\n",
|
||||
" return None\n",
|
||||
" fid = log[\"history\"][\"fid\"]\n",
|
||||
" return fid.get(str(epoch), fid.get(epoch, None))\n",
|
||||
"\n",
|
||||
"logs_p5 = {fam: load_log(info[\"p5\"]) for fam, info in FAMILIES.items()}\n",
|
||||
"logs_p4 = {fam: load_log(info[\"p4ep\"]) for fam, info in FAMILIES.items()}\n",
|
||||
"\n",
|
||||
"for fam in FAMILIES:\n",
|
||||
" p5_ok = \"✓\" if logs_p5[fam] else \"✗\"\n",
|
||||
" p4_ok = \"✓\" if logs_p4[fam] else \"✗\"\n",
|
||||
" print(f\" {fam}: 200ep={p5_ok} 100ep={p4_ok}\")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000004"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Quantitative Summary Table"
|
||||
],
|
||||
"id": "d0000005"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"rows = []\n",
|
||||
"for fam, info in FAMILIES.items():\n",
|
||||
" log = logs_p5[fam]\n",
|
||||
" train_time = log.get(\"history\", {}).get(\"train_time_s\") if log else None\n",
|
||||
" train_min = f\"{train_time / 60:.1f}\" if train_time else \"?\"\n",
|
||||
" rows.append({\n",
|
||||
" \"Family\": fam,\n",
|
||||
" \"Model\": info[\"label\"],\n",
|
||||
" \"Res\": log.get(\"config\", log).get(\"image_size\", \"?\") if log else \"?\",\n",
|
||||
" \"Params\": log.get(\"n_params\") if log else None,\n",
|
||||
" \"Train (min)\": train_min,\n",
|
||||
" \"FID@100\": get_fid(log, 100),\n",
|
||||
" \"FID@150\": get_fid(log, 150),\n",
|
||||
" \"FID@200\": get_fid(log, 200),\n",
|
||||
" # IS and LPIPS filled in by Section 6\n",
|
||||
" \"IS ↑\": None,\n",
|
||||
" \"LPIPS ↑\": None,\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"df = pd.DataFrame(rows).set_index(\"Family\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def fmt_params(v):\n",
|
||||
" if v is None:\n",
|
||||
" return \"?\"\n",
|
||||
" if v >= 1_000_000:\n",
|
||||
" return f\"{v / 1_000_000:.1f}M\"\n",
|
||||
" if v >= 1_000:\n",
|
||||
" return f\"{v / 1_000:.0f}K\"\n",
|
||||
" return str(v)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"df_display = df.copy()\n",
|
||||
"df_display[\"Params\"] = df_display[\"Params\"].apply(fmt_params)\n",
|
||||
"df_display.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df_display})"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000006"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. FID Curves — All Three Families"
|
||||
],
|
||||
"id": "d0000007"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
|
||||
"\n",
|
||||
"for fam, info in FAMILIES.items():\n",
|
||||
" c = info[\"color\"]\n",
|
||||
" # Phase 5 (200ep) — solid\n",
|
||||
" log = logs_p5[fam]\n",
|
||||
" if log:\n",
|
||||
" fid_dict = log[\"history\"][\"fid\"]\n",
|
||||
" eps = sorted(int(k) for k in fid_dict)\n",
|
||||
" fids = [fid_dict[str(e)] for e in eps]\n",
|
||||
" axes[0].plot(eps, fids, \"-o\", color=c, linewidth=2, markersize=6,\n",
|
||||
" label=f\"{fam} 200ep (FID@200={fid_dict.get('200','?'):.1f})\" if isinstance(fid_dict.get('200'), float) else f\"{fam} 200ep\")\n",
|
||||
"\n",
|
||||
" # Phase 4/3/2 best (100ep) — dashed, same colour\n",
|
||||
" log4 = logs_p4[fam]\n",
|
||||
" if log4:\n",
|
||||
" fid4 = log4[\"history\"][\"fid\"]\n",
|
||||
" eps4 = sorted(int(k) for k in fid4)\n",
|
||||
" fids4 = [fid4[str(e)] for e in eps4]\n",
|
||||
" axes[0].plot(eps4, fids4, \"--\", color=c, linewidth=1.2, alpha=0.55,\n",
|
||||
" label=f\"{fam} 100ep\")\n",
|
||||
"\n",
|
||||
"axes[0].set_xlabel(\"Epoch\")\n",
|
||||
"axes[0].set_ylabel(\"FID (lower is better)\")\n",
|
||||
"axes[0].set_title(\"FID Curves — Phase 5 (solid) vs best 100-ep (dashed)\")\n",
|
||||
"axes[0].legend(fontsize=8)\n",
|
||||
"\n",
|
||||
"# Bar chart: FID at 100 vs 200 epochs per family\n",
|
||||
"fams = list(FAMILIES.keys())\n",
|
||||
"fid100 = [get_fid(logs_p5[f], 100) for f in fams]\n",
|
||||
"fid200 = [get_fid(logs_p5[f], 200) for f in fams]\n",
|
||||
"x = np.arange(len(fams))\n",
|
||||
"w = 0.35\n",
|
||||
"bars1 = axes[1].bar(x - w/2, [v or 0 for v in fid100], w, label=\"FID@100\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=0.6)\n",
|
||||
"bars2 = axes[1].bar(x + w/2, [v or 0 for v in fid200], w, label=\"FID@200\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=1.0)\n",
|
||||
"axes[1].set_xticks(x); axes[1].set_xticklabels(fams)\n",
|
||||
"axes[1].set_ylabel(\"FID\"); axes[1].set_title(\"FID: 100ep vs 200ep per family\")\n",
|
||||
"axes[1].legend()\n",
|
||||
"for bar in list(bars1) + list(bars2):\n",
|
||||
" h = bar.get_height()\n",
|
||||
" if h > 0:\n",
|
||||
" axes[1].text(bar.get_x() + bar.get_width()/2, h + 1, f\"{h:.0f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000008"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Sample Grids — Epoch 200"
|
||||
],
|
||||
"id": "d0000009"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 3, figsize=(16, 6))\n",
|
||||
"\n",
|
||||
"for idx, (fam, info) in enumerate(FAMILIES.items()):\n",
|
||||
" ax = axes[idx]\n",
|
||||
" img_path = SAMPLES / info[\"p5\"] / \"epoch_0200.png\"\n",
|
||||
" log = logs_p5[fam]\n",
|
||||
" fid200 = get_fid(log, 200)\n",
|
||||
" if img_path.exists():\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" title = f\"{fam}: {info['label']}\\n\"\n",
|
||||
" if fid200:\n",
|
||||
" res = log[\"config\"].get(\"image_size\", \"?\")\n",
|
||||
" title += f\"FID@200={fid200:.1f} ({res}×{res})\"\n",
|
||||
" ax.set_title(title, fontsize=9)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes, fontsize=12)\n",
|
||||
" ax.set_title(f\"{fam}: {info['label']}\", fontsize=9)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
"\n",
|
||||
"fig.suptitle(\"Phase 5 — Epoch 200 Sample Grids (4×4, prior samples)\", fontsize=13, fontweight=\"bold\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000010"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Training Progression — Epoch 10 → 50 → 100 → 200"
|
||||
],
|
||||
"id": "d0000011"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"check_epochs = [10, 50, 100, 200]\n",
|
||||
"\n",
|
||||
"for fam, info in FAMILIES.items():\n",
|
||||
" run = info[\"p5\"]\n",
|
||||
" log = logs_p5[fam]\n",
|
||||
" if log is None:\n",
|
||||
" print(f\"{fam}: not yet run\")\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(16, 4))\n",
|
||||
" for ax, ep in zip(axes, check_epochs):\n",
|
||||
" img_path = SAMPLES / run / f\"epoch_{ep:04d}.png\"\n",
|
||||
" if img_path.exists():\n",
|
||||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||||
" fid = get_fid(log, ep)\n",
|
||||
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
|
||||
" else:\n",
|
||||
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
|
||||
" transform=ax.transAxes)\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
" fig.suptitle(f\"{fam} ({info['label']}) — 200-epoch progression\", fontsize=11, fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000012"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Extended Metrics — IS and LPIPS Diversity\n",
|
||||
"\n",
|
||||
"Requires loading trained model weights. Run after all phase 5 models have finished.\n",
|
||||
"Generates 5 000 samples per model and computes IS and LPIPS over 200 random pairs."
|
||||
],
|
||||
"id": "d0000013"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.insert(0, \"..\")\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from src.utils import load_config\n",
|
||||
"from src.models import get_model\n",
|
||||
"from src.training.metrics import compute_is, compute_lpips_diversity\n",
|
||||
"from src.training.diffusion import cosine_betas, make_alpha_bars, ddim_sample\n",
|
||||
"from src.training.ema import EMA\n",
|
||||
"\n",
|
||||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||
"N_SAMPLE = 5_000\n",
|
||||
"\n",
|
||||
"def load_ema_model(run_name, config_path):\n",
|
||||
" \"\"\"Load the best EMA weights for a given phase-5 run.\"\"\"\n",
|
||||
" cfg = load_config(str(Path(\"../configs/phase5\") / config_path))\n",
|
||||
" model_obj, kind = get_model(cfg)\n",
|
||||
" ema_path = Path(\"../outputs/models\") / f\"{run_name}_best_ema.pt\"\n",
|
||||
" if not ema_path.exists():\n",
|
||||
" ema_path = Path(\"../outputs/models\") / f\"{run_name}_final_ema.pt\"\n",
|
||||
" if isinstance(model_obj, tuple):\n",
|
||||
" model = model_obj[0] # generator for GAN\n",
|
||||
" else:\n",
|
||||
" model = model_obj\n",
|
||||
" model.load_state_dict(torch.load(ema_path, map_location=DEVICE))\n",
|
||||
" return model.to(DEVICE).eval(), cfg, kind\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@torch.no_grad()\n",
|
||||
"def generate_samples(run_name, config_path, n=N_SAMPLE):\n",
|
||||
" model, cfg, kind = load_ema_model(run_name, config_path)\n",
|
||||
" image_size = cfg.get(\"image_size\", 64)\n",
|
||||
"\n",
|
||||
" if kind == \"wgan\":\n",
|
||||
" latent_dim = cfg.get(\"latent_dim\", 128)\n",
|
||||
" imgs = torch.cat([\n",
|
||||
" model(torch.randn(min(64, n - i), latent_dim, 1, 1, device=DEVICE))\n",
|
||||
" for i in range(0, n, 64)\n",
|
||||
" ])[:n].cpu()\n",
|
||||
"\n",
|
||||
" elif kind == \"vae\":\n",
|
||||
" imgs = torch.cat([\n",
|
||||
" model.sample(min(64, n - i), DEVICE)\n",
|
||||
" for i in range(0, n, 64)\n",
|
||||
" ])[:n].cpu()\n",
|
||||
"\n",
|
||||
" elif kind == \"ddpm\":\n",
|
||||
" schedule = cfg.get(\"noise_schedule\", \"cosine\")\n",
|
||||
" pred_type = cfg.get(\"pred_type\", \"v\")\n",
|
||||
" T = cfg.get(\"T\", 1000)\n",
|
||||
" from src.training.diffusion import cosine_betas, linear_betas, make_alpha_bars, ddim_sample\n",
|
||||
" betas = (cosine_betas(T) if schedule == \"cosine\" else linear_betas(T)).to(DEVICE)\n",
|
||||
" ab = make_alpha_bars(betas)\n",
|
||||
" imgs = ddim_sample(model, n, image_size, ab, n_steps=100,\n",
|
||||
" pred_type=pred_type, device=DEVICE, batch_size=32)\n",
|
||||
" return imgs\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"Run this cell once all phase-5 models have completed.\")\n",
|
||||
"print(\"Expected to take ~5–10 minutes per family on an RTX 3090.\")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000014"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ── Compute IS and LPIPS per family ──────────────────────────────────────────\n",
|
||||
"# Uncomment when models are ready.\n",
|
||||
"\n",
|
||||
"extended_metrics = {}\n",
|
||||
"\n",
|
||||
"for fam, info in FAMILIES.items():\n",
|
||||
" run = info[\"p5\"]\n",
|
||||
" config = f\"{run}.json\"\n",
|
||||
" try:\n",
|
||||
" print(f\"\\n{fam}: generating {N_SAMPLE} samples...\")\n",
|
||||
" imgs = generate_samples(run, config)\n",
|
||||
"\n",
|
||||
" print(f\" Computing IS...\")\n",
|
||||
" is_mean, is_std = compute_is(imgs, device=DEVICE)\n",
|
||||
"\n",
|
||||
" print(f\" Computing LPIPS diversity...\")\n",
|
||||
" lpips = compute_lpips_diversity(imgs, n_pairs=200, device=DEVICE)\n",
|
||||
"\n",
|
||||
" extended_metrics[fam] = {\"IS_mean\": is_mean, \"IS_std\": is_std, \"LPIPS\": lpips}\n",
|
||||
" print(f\" IS = {is_mean:.2f} ± {is_std:.2f} LPIPS = {lpips:.4f}\")\n",
|
||||
"\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\" Skipped ({e})\")\n",
|
||||
" extended_metrics[fam] = {}\n",
|
||||
"\n",
|
||||
"# Merge with FID table\n",
|
||||
"for fam in FAMILIES:\n",
|
||||
" em = extended_metrics.get(fam, {})\n",
|
||||
" idx = list(FAMILIES.keys()).index(fam)\n",
|
||||
" df.loc[fam, \"IS ↑\"] = f\"{em['IS_mean']:.2f}±{em['IS_std']:.2f}\" if \"IS_mean\" in em else \"—\"\n",
|
||||
" df.loc[fam, \"LPIPS ↑\"] = f\"{em['LPIPS']:.4f}\" if \"LPIPS\" in em else \"—\"\n",
|
||||
"\n",
|
||||
"df.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df})"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000015"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. Latent Interpolation — GAN and VAE\n",
|
||||
"\n",
|
||||
"Smooth interpolation between two latent codes reveals whether the generator has learned a\n",
|
||||
"continuous manifold. DDPM has no encoder, so interpolation is done by different noise seeds."
|
||||
],
|
||||
"id": "d0000016"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Spherical linear interpolation (slerp)\n",
|
||||
"def slerp(z1, z2, t):\n",
|
||||
" z1_n = z1 / z1.norm()\n",
|
||||
" z2_n = z2 / z2.norm()\n",
|
||||
" omega = torch.acos((z1_n * z2_n).sum().clamp(-1, 1))\n",
|
||||
" if omega.abs() < 1e-6:\n",
|
||||
" return (1 - t) * z1 + t * z2\n",
|
||||
" return (torch.sin((1-t)*omega)/torch.sin(omega)) * z1 + \\\n",
|
||||
" (torch.sin(t*omega)/torch.sin(omega)) * z2\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def gan_interpolation(model, latent_dim, n_steps=10, device=DEVICE):\n",
|
||||
" z1 = torch.randn(1, latent_dim, 1, 1, device=device)\n",
|
||||
" z2 = torch.randn(1, latent_dim, 1, 1, device=device)\n",
|
||||
" alphas = torch.linspace(0, 1, n_steps)\n",
|
||||
" imgs = []\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for a in alphas:\n",
|
||||
" z = slerp(z1.flatten(), z2.flatten(), a.item()).view_as(z1)\n",
|
||||
" imgs.append(model(z).cpu())\n",
|
||||
" return torch.cat(imgs)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def vae_interpolation(model, real_imgs, n_steps=10, device=DEVICE):\n",
|
||||
" \"\"\"Encode two real images, interpolate in latent space, decode.\"\"\"\n",
|
||||
" img1, img2 = real_imgs[:1].to(device), real_imgs[1:2].to(device)\n",
|
||||
" with torch.no_grad():\n",
|
||||
" mu1, _ = model.encode(img1)\n",
|
||||
" mu2, _ = model.encode(img2)\n",
|
||||
" alphas = torch.linspace(0, 1, n_steps, device=device)\n",
|
||||
" imgs = [model.decode((1-a)*mu1 + a*mu2).cpu() for a in alphas]\n",
|
||||
" return torch.cat(imgs)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"Interpolation helpers defined. Run cells below after loading models.\")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000017"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ── GAN interpolation ─────────────────────────────────────────────────────────\n",
|
||||
"try:\n",
|
||||
" gan_model, gan_cfg, _ = load_ema_model(\"p5_gan\", \"p5_gan.json\")\n",
|
||||
" latent_dim = gan_cfg.get(\"latent_dim\", 128)\n",
|
||||
" interp_imgs = gan_interpolation(gan_model, latent_dim, n_steps=10)\n",
|
||||
" interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n",
|
||||
"\n",
|
||||
" fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n",
|
||||
" for ax, img in zip(axes, interp_imgs):\n",
|
||||
" ax.imshow(img.permute(1, 2, 0).numpy())\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
" fig.suptitle(\"GAN — Slerp latent interpolation (z₁ → z₂)\", fontsize=11, fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()\n",
|
||||
"except Exception as e:\n",
|
||||
" print(f\"GAN interpolation: {e}\")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000018"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ── VAE interpolation ─────────────────────────────────────────────────────────\n",
|
||||
"try:\n",
|
||||
" from src.data import GeneratorDataset, get_transform\n",
|
||||
"\n",
|
||||
" vae_model, vae_cfg, _ = load_ema_model(\"p5_vae\", \"p5_vae.json\")\n",
|
||||
" ds = GeneratorDataset(\"../../\" + vae_cfg[\"data_dir\"],\n",
|
||||
" sources=vae_cfg.get(\"sources\", [\"wiki\"]),\n",
|
||||
" transform=get_transform(vae_cfg[\"image_size\"], augment=False))\n",
|
||||
" sample_real = torch.stack([ds[i] for i in range(2)])\n",
|
||||
"\n",
|
||||
" interp_imgs = vae_interpolation(vae_model, sample_real, n_steps=10)\n",
|
||||
" interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n",
|
||||
"\n",
|
||||
" fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n",
|
||||
" for ax, img in zip(axes, interp_imgs):\n",
|
||||
" ax.imshow(img.permute(1, 2, 0).numpy())\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
" fig.suptitle(\"VAE — μ-space linear interpolation (image₁ → image₂)\", fontsize=11, fontweight=\"bold\")\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()\n",
|
||||
"except Exception as e:\n",
|
||||
" print(f\"VAE interpolation: {e}\")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000019"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 8. Failure Mode Analysis\n",
|
||||
"\n",
|
||||
"Identify the worst-generated images per model: those farthest from their nearest real neighbour in pixel space, or simply those with highest reconstruction error (VAE) or highest DDPM loss."
|
||||
],
|
||||
"id": "d0000020"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# For GAN and DDPM: generate 256 images, pick the 8 with lowest mean activation\n",
|
||||
"# (a proxy for less-coherent images — very rough)\n",
|
||||
"\n",
|
||||
"def worst_samples(imgs, n=8):\n",
|
||||
" \"\"\"Heuristic: pick samples with lowest per-pixel mean (often darker / less structured).\"\"\"\n",
|
||||
" scores = imgs.mean(dim=[1, 2, 3]) # mean brightness per image\n",
|
||||
" worst_idx = scores.argsort()[:n]\n",
|
||||
" return imgs[worst_idx]\n",
|
||||
"\n",
|
||||
"print(\"Failure mode analysis requires generated samples — run after model loading (Section 6).\")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000021"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Optionally run with already-generated `imgs` from Section 6:\n",
|
||||
"# worst = worst_samples(imgs.cpu())\n",
|
||||
"# worst = (worst.clamp(-1,1) + 1) / 2\n",
|
||||
"# fig, axes = plt.subplots(1, 8, figsize=(18, 2.5))\n",
|
||||
"# for ax, img in zip(axes, worst):\n",
|
||||
"# ax.imshow(img.permute(1,2,0)); ax.axis('off')\n",
|
||||
"# plt.suptitle('Failure modes (lowest-brightness samples)', fontsize=11)\n",
|
||||
"# plt.tight_layout(); plt.show()"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000022"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 9. Training Loss Overview (all families)"
|
||||
],
|
||||
"id": "d0000023"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 3, figsize=(18, 4))\n",
|
||||
"\n",
|
||||
"for ax, (fam, info) in zip(axes, FAMILIES.items()):\n",
|
||||
" log = logs_p5[fam]\n",
|
||||
" if log is None:\n",
|
||||
" ax.set_title(f\"{fam} (not yet run)\"); ax.axis(\"off\"); continue\n",
|
||||
" h = log[\"history\"]\n",
|
||||
" c = info[\"color\"]\n",
|
||||
"\n",
|
||||
" if fam == \"GAN\":\n",
|
||||
" ax.plot(h[\"g_loss\"], label=\"G loss\", color=c, linewidth=1.2)\n",
|
||||
" ax.plot(h[\"w_dist\"], label=\"W-dist\", color=c, linewidth=1.2, linestyle=\"--\")\n",
|
||||
" ax.set_ylabel(\"Loss / W-distance\")\n",
|
||||
" elif fam == \"VAE\":\n",
|
||||
" ax.plot(h[\"recon_loss\"], label=\"MSE\", color=c)\n",
|
||||
" ax2 = ax.twinx()\n",
|
||||
" ax2.plot(h[\"kl_loss\"], label=\"KL\", color=\"grey\", linestyle=\"--\")\n",
|
||||
" ax2.set_ylabel(\"KL\", color=\"grey\")\n",
|
||||
" ax.set_ylabel(\"MSE\")\n",
|
||||
" elif fam == \"DDPM\":\n",
|
||||
" ax.plot(h[\"loss\"], label=\"MSE (v-pred)\", color=c)\n",
|
||||
" ax.set_ylabel(\"MSE loss\")\n",
|
||||
"\n",
|
||||
" ax.set_xlabel(\"Epoch\")\n",
|
||||
" ax.set_title(f\"{fam}: {info['label']}\")\n",
|
||||
" ax.legend(fontsize=8)\n",
|
||||
"\n",
|
||||
"fig.suptitle(\"Phase 5 — Training Dynamics (200 epochs)\", fontsize=12, fontweight=\"bold\")\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000024"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 10. Conclusions"
|
||||
],
|
||||
"id": "d0000025"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"print(\"=\" * 72)\n",
|
||||
"print(\"PHASE 5 — CROSS-FAMILY COMPARISON (200 epochs)\")\n",
|
||||
"print(\"=\" * 72)\n",
|
||||
"\n",
|
||||
"for fam, info in FAMILIES.items():\n",
|
||||
" log = logs_p5[fam]\n",
|
||||
" em = extended_metrics.get(fam, {})\n",
|
||||
" print(f\"\\n ── {fam}: {info['label']} ──\")\n",
|
||||
" if log:\n",
|
||||
" res = log.get(\"config\", log).get(\"image_size\", \"?\")\n",
|
||||
" print(f\" Resolution : {res}×{res}\")\n",
|
||||
" n_p = log.get(\"n_params\")\n",
|
||||
" print(f\" Params : {n_p:,}\" if n_p else \" Params : ?\")\n",
|
||||
" tt = log.get(\"history\", {}).get(\"train_time_s\")\n",
|
||||
" print(f\" Train time : {tt / 60:.1f} min\" if tt else \" Train time : ?\")\n",
|
||||
" for ep in (100, 150, 200):\n",
|
||||
" fid = get_fid(log, ep)\n",
|
||||
" print(f\" FID@{ep:<3} : {fid:.1f}\" if fid else f\" FID@{ep:<3} : ?\")\n",
|
||||
" if em:\n",
|
||||
" print(f\" IS : {em.get('IS_mean','?'):.2f} ± {em.get('IS_std','?'):.2f}\" if 'IS_mean' in em else \" IS : ?\")\n",
|
||||
" print(f\" LPIPS div : {em.get('LPIPS','?'):.4f}\" if 'LPIPS' in em else \" LPIPS div : ?\")\n",
|
||||
" else:\n",
|
||||
" print(\" IS / LPIPS : (run Section 6 to compute)\")\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\" * 72)\n",
|
||||
"print(\"Narrative to fill in after results:\")\n",
|
||||
"print(\" - Which family achieves best FID?\")\n",
|
||||
"print(\" - GAN: fast convergence but mode collapse risk?\")\n",
|
||||
"print(\" - VAE: blurry priors improved by perceptual+adversarial loss?\")\n",
|
||||
"print(\" - DDPM: highest quality but slowest inference (100 DDIM steps)?\")\n",
|
||||
"print(\"=\" * 72)"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"id": "d0000026"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user