{ "cells": [ { "cell_type": "markdown", "id": "c0000001", "metadata": {}, "source": [ "# Phase 4 — DDPM Evolution Analysis\n", "\n", "Traces the DDPM improvement story:\n", "\n", "| Step | Model | Key change | Expected failure |\n", "|------|-------|------------|------------------|\n", "| 4.1 | DDPM linear + ε-pred | Baseline | Noise prediction unstable at very low t (linear schedule over-denoises) |\n", "| 4.2 | + cosine schedule | Less noise wasted at low timesteps | Residual instability from ε parameterisation |\n", "| 4.3 | + v-prediction | Numerically stable across full trajectory | Possible underfitting at 64×64 |\n", "| 4.4 | + wider U-Net (192ch) + 32×32 attention | More capacity and longer-range context | — |\n", "\n", "FID is computed via DDIM (100 steps, deterministic) from the EMA model.\n", "H-flip-only augmentation, MTCNN-aligned 64×64 crops, T=1000." ] }, { "cell_type": "code", "execution_count": null, "id": "c0000002", "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import numpy as np\n", "import pandas as pd\n", "\n", "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", "\n", "OUTPUTS = Path(\"../outputs\")\n", "LOGS = OUTPUTS / \"logs\"\n", "SAMPLES = OUTPUTS / \"samples\"" ] }, { "cell_type": "markdown", "id": "c0000003", "metadata": {}, "source": [ "## 1. Load experiment logs" ] }, { "cell_type": "code", "execution_count": null, "id": "c0000004", "metadata": {}, "outputs": [], "source": [ "run_names = [\"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"]\n", "run_labels = {\n", " \"p4_1_ddpm_linear\": \"4.1 linear + ε\",\n", " \"p4_2_ddpm_cosine\": \"4.2 cosine + ε\",\n", " \"p4_3_ddpm_vpred\": \"4.3 cosine + v\",\n", " \"p4_4_ddpm_wider\": \"4.4 wider + 32×32 attn\",\n", "}\n", "\n", "runs = {}\n", "for name in run_names:\n", " log_path = LOGS / f\"{name}.json\"\n", " if log_path.exists():\n", " with open(log_path) as f:\n", " runs[name] = json.load(f)\n", " else:\n", " print(f\" Missing: {log_path}\")\n", "\n", "print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n", "for name in run_names:\n", " print(f\" {'✓' if name in runs else '✗'} {name}\")" ] }, { "cell_type": "markdown", "id": "c0000005", "metadata": {}, "source": [ "## 2. FID Comparison Table" ] }, { "cell_type": "code", "execution_count": null, "id": "c0000006", "metadata": {}, "outputs": [], "source": [ "def get_fid(run, epoch):\n", " fid = run[\"history\"][\"fid\"]\n", " return fid.get(str(epoch), fid.get(epoch, None))\n", "\n", "rows = []\n", "for name in run_names:\n", " if name not in runs:\n", " continue\n", " r = runs[name]\n", " cfg = r[\"config\"]\n", " rows.append({\n", " \"Step\": run_labels[name],\n", " \"Schedule\": cfg.get(\"noise_schedule\"),\n", " \"Pred\": cfg.get(\"pred_type\"),\n", " \"base_ch\": cfg.get(\"base_ch\"),\n", " \"FID@25\": get_fid(r, 25),\n", " \"FID@50\": get_fid(r, 50),\n", " \"FID@75\": get_fid(r, 75),\n", " \"FID@100\": get_fid(r, 100),\n", " })\n", "\n", "df = pd.DataFrame(rows)\n", "df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})" ] }, { "cell_type": "markdown", "id": "c0000007", "metadata": {}, "source": [ "## 3. FID Curves — Evolution Story" ] }, { "cell_type": "code", "execution_count": null, "id": "c0000008", "metadata": {}, "outputs": [], "source": [ "colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n", "\n", "fig, ax = plt.subplots(figsize=(11, 5))\n", "for i, name in enumerate(run_names):\n", " if name not in runs:\n", " continue\n", " fid_dict = runs[name][\"history\"][\"fid\"]\n", " epochs = sorted(int(k) for k in fid_dict)\n", " fids = [fid_dict[str(e)] for e in epochs]\n", " fid100 = fid_dict.get(\"100\", \"?\")\n", " label = f\"{run_labels[name]} (FID@100={fid100:.1f})\" if isinstance(fid100, float) else run_labels[name]\n", " ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n", "\n", "ax.set_xlabel(\"Epoch\")\n", "ax.set_ylabel(\"FID (DDIM 100 steps, lower is better)\")\n", "ax.set_title(\"Phase 4 — FID Curves: linear·ε → cosine·ε → cosine·v → wider\")\n", "ax.legend()\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "c0000009", "metadata": {}, "source": [ "## 4. Training Loss Curves (MSE on predicted target)" ] }, { "cell_type": "code", "execution_count": null, "id": "c0000010", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(11, 4))\n", "for i, name in enumerate(run_names):\n", " if name not in runs:\n", " continue\n", " losses = runs[name][\"history\"][\"loss\"]\n", " ax.plot(range(1, len(losses)+1), losses, label=run_labels[name],\n", " color=colors[i], linewidth=1.2, alpha=0.9)\n", "\n", "ax.set_xlabel(\"Epoch\")\n", "ax.set_ylabel(\"MSE (noise / v prediction loss)\")\n", "ax.set_title(\"Phase 4 — Training Loss\")\n", "ax.legend(fontsize=9)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "c0000011", "metadata": {}, "source": [ "## 5. Sample Grids — Epoch 100 (DDIM 50 steps)" ] }, { "cell_type": "code", "execution_count": null, "id": "c0000012", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n", "\n", "for idx, name in enumerate(run_names):\n", " ax = axes[idx]\n", " img_path = SAMPLES / name / \"epoch_0100.png\"\n", " if img_path.exists():\n", " fid = get_fid(runs[name], 100) if name in runs else None\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n", " else:\n", " ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " ax.set_title(run_labels[name], fontsize=8)\n", " ax.axis(\"off\")\n", "\n", "fig.suptitle(\"Phase 4 — Epoch 100 DDIM Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "c0000013", "metadata": {}, "source": [ "## 6. Step-by-step Pairwise Comparisons" ] }, { "cell_type": "code", "execution_count": null, "id": "c0000014", "metadata": {}, "outputs": [], "source": [ "transitions = [\n", " (\"4.1→4.2: linear→cosine schedule\", \"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\"),\n", " (\"4.2→4.3: ε-pred→v-pred\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\"),\n", " (\"4.3→4.4: 128ch→192ch + 32×32 attn\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"),\n", "]\n", "\n", "for title, name_a, name_b in transitions:\n", " fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", " for ax, name in zip(axes, [name_a, name_b]):\n", " img_path = SAMPLES / name / \"epoch_0100.png\"\n", " if img_path.exists():\n", " fid = get_fid(runs[name], 100) if name in runs else None\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n", " else:\n", " ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " ax.set_title(run_labels[name], fontsize=9)\n", " ax.axis(\"off\")\n", " fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "c0000015", "metadata": {}, "source": [ "## 7. Progression: Epoch 10 → 50 → 100" ] }, { "cell_type": "code", "execution_count": null, "id": "c0000016", "metadata": {}, "outputs": [], "source": [ "check_epochs = [10, 50, 100]\n", "\n", "for name in run_names:\n", " if name not in runs:\n", " continue\n", " fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n", " for ax, ep in zip(axes, check_epochs):\n", " img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n", " if img_path.exists():\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " fid = get_fid(runs[name], ep)\n", " ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n", " else:\n", " ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n", " transform=ax.transAxes)\n", " ax.axis(\"off\")\n", " fig.suptitle(f\"{run_labels[name]} — Progression\", fontsize=11, fontweight=\"bold\")\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "c0000017", "metadata": {}, "source": [ "## 8. Noise Schedule Visualisation\n", "\n", "Illustrates why cosine outperforms linear: the linear schedule allocates many timesteps near t=T where the image is already near-pure noise, wasting model capacity." ] }, { "cell_type": "code", "execution_count": null, "id": "c0000018", "metadata": {}, "outputs": [], "source": [ "import math\n", "import numpy as np\n", "\n", "T = 1000\n", "t = np.arange(T)\n", "\n", "# Linear betas\n", "betas_lin = np.linspace(1e-4, 0.02, T)\n", "ab_lin = np.cumprod(1 - betas_lin)\n", "\n", "# Cosine betas\n", "s = 0.008\n", "f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2\n", "f = f / f[0]\n", "betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)\n", "ab_cos = np.cumprod(1 - betas_cos)\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n", "\n", "axes[0].plot(t, ab_lin, label=\"linear\", color=\"#5B8DB8\", linewidth=2)\n", "axes[0].plot(t, ab_cos, label=\"cosine\", color=\"#E8705A\", linewidth=2)\n", "axes[0].set_xlabel(\"Timestep t\")\n", "axes[0].set_ylabel(\"ᾱ_t (signal fraction)\")\n", "axes[0].set_title(\"ᾱ_t vs t — cosine stays informative longer\")\n", "axes[0].legend()\n", "\n", "axes[1].plot(t, np.sqrt(ab_lin), label=\"linear √ᾱ\", color=\"#5B8DB8\", linewidth=2)\n", "axes[1].plot(t, np.sqrt(ab_cos), label=\"cosine √ᾱ\", color=\"#E8705A\", linewidth=2)\n", "axes[1].set_xlabel(\"Timestep t\")\n", "axes[1].set_ylabel(\"√ᾱ_t (signal amplitude)\")\n", "axes[1].set_title(\"Signal amplitude — cosine is more uniform\")\n", "axes[1].legend()\n", "\n", "plt.suptitle(\"Noise Schedule Comparison\", fontsize=12, fontweight=\"bold\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "c0000019", "metadata": {}, "source": [ "## 9. Conclusions" ] }, { "cell_type": "code", "execution_count": null, "id": "c0000020", "metadata": {}, "outputs": [], "source": [ "print(\"=\" * 70)\n", "print(\"PHASE 4 — DDPM EVOLUTION SUMMARY\")\n", "print(\"=\" * 70)\n", "\n", "for name in run_names:\n", " if name not in runs:\n", " print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n", " continue\n", " fid100 = get_fid(runs[name], 100)\n", " fid50 = get_fid(runs[name], 50)\n", " h = runs[name][\"history\"]\n", " loss50 = h[\"loss\"][49] if len(h[\"loss\"]) > 49 else None\n", " print(f\"\\n {run_labels[name]}:\")\n", " print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n", " print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n", " print(f\" Loss@50 = {loss50:.5f}\" if loss50 else \" Loss@50 = ?\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"Best DDPM model for Phase 5 comparison: fill in after runs complete.\")\n", "print(\"=\" * 70)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 5 }