{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Phase 4 - DDPM Progression\n", "\n", "Phase 4 applies the same report logic to diffusion models. The pipeline is\n", "already fixed, so this notebook isolates the DDPM recipe: schedule, prediction\n", "target, and backbone width.\n", "\n", "The story is stepwise. A cosine schedule helps, v-prediction is the major gain,\n", "and the wider backbone becomes useful only after the target and schedule are\n", "improved.\n", "\n", "## What this phase changes\n", "\n", "| Run | Recipe change |\n", "|---|---|\n", "| `p4_1_ddpm_linear` | Linear noise schedule, epsilon-prediction |\n", "| `p4_2_ddpm_cosine` | Cosine noise schedule |\n", "| `p4_3_ddpm_vpred` | v-prediction target |\n", "| `p4_4_ddpm_wider` | Wider U-Net: base channels 192 with attention at 32/16/8 |\n", "\n", "Sampling previews use DDIM-50. Logged FID uses DDIM-100 against the saved real\n", "reference set.\n", "\n", "**Headline result:** `p4_4_ddpm_wider` reaches **best FID = 30.0**.\n", "\n", "## How to read DDPM sample grids\n", "\n", "The DDPM grids should not be read as the same faces improving from epoch to\n", "epoch. GAN and VAE previews can reuse a fixed latent grid, so each position can\n", "look like the same latent code becoming sharper over training. A DDPM preview\n", "starts from noise and runs a stochastic reverse-diffusion sampler. Unless the\n", "exact initial noise and sampler randomness are fixed and stored, each epoch\n", "preview is a fresh draw from the model.\n", "\n", "So for DDPM, the progression panels show distribution-level improvement:\n", "cleaner faces, fewer artifacts, and better global structure. They are not\n", "identity-by-identity refinements of the same preview images.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reference: Phase 0 baseline from the same family\n", "\n", "`p0_ddpm` was a vanilla DDPM (linear schedule, epsilon-prediction, base_ch=128) on raw\n", "un-aligned data. Outputs were noisy face-shaped textures. Phase 4 fixes the\n", "pipeline (aligned 64) and walks through the standard set of post-2020 DDPM\n", "improvements one at a time.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import numpy as np\n", "import pandas as pd\n", "\n", "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", "\n", "try:\n", " display\n", "except NameError:\n", " def display(obj):\n", " print(obj)\n", "\n", "def find_generator_root():\n", " for base in [Path.cwd(), *Path.cwd().parents]:\n", " for candidate in [base, base / \"generator\"]:\n", " if (candidate / \"outputs\" / \"logs\").exists() and (candidate / \"outputs\" / \"samples\").exists():\n", " return candidate.resolve()\n", " raise FileNotFoundError(\"Could not locate generator/outputs from the current working directory\")\n", "\n", "GENERATOR_ROOT = find_generator_root()\n", "PROJECT_ROOT = GENERATOR_ROOT.parent\n", "OUTPUTS = GENERATOR_ROOT / \"outputs\"\n", "LOGS = OUTPUTS / \"logs\"\n", "SAMPLES = OUTPUTS / \"samples\"\n", "\n", "\n", "def load_log(name):\n", " p = LOGS / f\"{name}.json\"\n", " return json.load(open(p)) if p.exists() else None\n", "\n", "def get_fid(log, epoch):\n", " fid = log.get(\"history\", {}).get(\"fid\", {})\n", " return fid.get(str(epoch))\n", "\n", "def fid_series(log):\n", " fid = log.get(\"history\", {}).get(\"fid\", {})\n", " items = sorted((int(k), v) for k, v in fid.items())\n", " return [e for e, _ in items], [v for _, v in items]\n", "\n", "def show_image_or_missing(ax, path, title=None):\n", " if path.exists():\n", " ax.imshow(mpimg.imread(str(path)))\n", " else:\n", " ax.text(0.5, 0.5, f\"missing artifact\\n{path.name}\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " if title:\n", " ax.set_title(title, fontsize=9)\n", " ax.axis(\"off\")\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load experiment logs\n", "\n", "The notebook reads existing DDPM logs only. Sampling and FID values are already saved." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "run_names = [\"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"]\n", "run_labels = {\n", " \"p4_1_ddpm_linear\": \"4.1 linear / epsilon\",\n", " \"p4_2_ddpm_cosine\": \"4.2 cosine / epsilon\",\n", " \"p4_3_ddpm_vpred\": \"4.3 cosine / v\",\n", " \"p4_4_ddpm_wider\": \"4.4 wider net\",\n", "}\n", "runs = {n: load_log(n) for n in run_names}\n", "runs = {k: v for k, v in runs.items() if v}\n", "for n in run_names: print(f\" {n}: {'OK' if n in runs else 'MISSING'}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. FID comparison table\n", "\n", "The table shows whether each recipe change improves generation quality under the saved DDIM-100 FID protocol." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rows = []\n", "for name in run_names:\n", " if name not in runs: continue\n", " r = runs[name]; _, fid_vals = fid_series(r)\n", " rows.append({\n", " \"Run\": run_labels[name],\n", " \"FID@25\": get_fid(r, 25),\n", " \"FID@50\": get_fid(r, 50),\n", " \"FID@100\": get_fid(r, 100),\n", " \"Best FID\": min(fid_vals) if fid_vals else None,\n", " \"Loss@100\": r[\"history\"][\"loss\"][-1],\n", " \"Train (min)\": (r['history'].get('train_time_s') or 0) / 60,\n", " })\n", "df = pd.DataFrame(rows).sort_values(\"Best FID\")\n", "df.style.format({\"FID@25\": \"{:.1f}\", \"FID@50\": \"{:.1f}\", \"FID@100\": \"{:.1f}\",\n", " \"Best FID\": \"{:.1f}\", \"Loss@100\": \"{:.4f}\", \"Train (min)\": \"{:.1f}\"})\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. FID curves - progression" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(10, 5))\n", "cmap = plt.cm.cividis\n", "for i, name in enumerate(run_names):\n", " if name not in runs: continue\n", " epochs, fid_vals = fid_series(runs[name])\n", " ax.plot(epochs, fid_vals, \"o-\", label=run_labels[name],\n", " color=cmap(i / len(run_names)), linewidth=2, markersize=7)\n", "ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"FID (DDIM-100)\")\n", "ax.set_title(\"Phase 4 - FID curves\"); ax.legend()\n", "plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Training loss\n", "\n", "The loss plot is diagnostic, but epsilon-MSE and v-MSE are different targets. FID and sample grids carry the decision." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(10, 4))\n", "cmap = plt.cm.cividis\n", "for i, name in enumerate(run_names):\n", " if name not in runs: continue\n", " h = runs[name][\"history\"]\n", " epochs = range(1, len(h[\"loss\"]) + 1)\n", " ax.plot(epochs, h[\"loss\"], color=cmap(i / len(run_names)), label=run_labels[name], linewidth=1.3)\n", "ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"MSE on prediction target\")\n", "ax.set_title(\"Loss (epsilon-MSE and v-MSE are not directly comparable)\")\n", "ax.legend(); plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Sample grids - epoch 100" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))\n", "for ax, name in zip(axes, run_names):\n", " img_path = SAMPLES / name / \"epoch_0100.png\"\n", " if img_path.exists():\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " f = get_fid(runs.get(name, {}), 100) if name in runs else None\n", " ax.set_title(f\"{run_labels[name]}\\nFID@100={f:.1f}\" if f else run_labels[name], fontsize=9)\n", " ax.axis(\"off\")\n", "plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Progression - epoch 10 -> 50 -> 100\n", "\n", "Read these as fresh samples from each checkpoint, not the same DDPM images being refined over time." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "check_epochs = [10, 50, 100]\n", "for name in run_names:\n", " if name not in runs: continue\n", " fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n", " for ax, e in zip(axes, check_epochs):\n", " p = SAMPLES / name / f\"epoch_{e:04d}.png\"\n", " if p.exists():\n", " ax.imshow(mpimg.imread(str(p)))\n", " f = get_fid(runs[name], e)\n", " ax.set_title(f\"epoch {e}\" + (f\"\\nFID={f:.1f}\" if f else \"\"), fontsize=9)\n", " ax.axis(\"off\")\n", " fig.suptitle(run_labels[name], fontsize=11, fontweight=\"bold\")\n", " plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Noise schedule visualization\n", "\n", "The cosine schedule preserves useful signal more smoothly across timesteps. That gives the model a better learning problem before v-prediction and width are added." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "T = 1000; t = np.arange(T)\n", "betas_lin = np.linspace(1e-4, 0.02, T)\n", "ab_lin = np.cumprod(1 - betas_lin)\n", "s = 0.008\n", "f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2\n", "f = f / f[0]\n", "betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)\n", "ab_cos = np.cumprod(1 - betas_cos)\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n", "axes[0].plot(t, ab_lin, label=\"linear\", color=\"#5B8DB8\", linewidth=2)\n", "axes[0].plot(t[:len(ab_cos)], ab_cos, label=\"cosine\", color=\"#E8705A\", linewidth=2)\n", "axes[0].set_xlabel(\"Timestep t\"); axes[0].set_ylabel(\"alpha_bar_t (signal fraction)\")\n", "axes[0].set_title(\"Cumulative signal preservation\"); axes[0].legend()\n", "axes[1].plot(betas_lin, label=\"linear beta\", color=\"#5B8DB8\", linewidth=2)\n", "axes[1].plot(betas_cos, label=\"cosine beta\", color=\"#E8705A\", linewidth=2)\n", "axes[1].set_xlabel(\"Timestep t\"); axes[1].set_ylabel(\"beta_t\"); axes[1].set_title(\"beta schedule\"); axes[1].legend()\n", "plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. What this phase proves\n", "\n", "| Step | Run | Best FID | Delta vs previous |\n", "|---|---|---:|---:|\n", "| 4.1 linear / epsilon | `p4_1_ddpm_linear` | 134.5 | n/a |\n", "| 4.2 cosine / epsilon | `p4_2_ddpm_cosine` | 132.3 | -2.2 |\n", "| 4.3 cosine / v | `p4_3_ddpm_vpred` | 34.5 | -97.8 |\n", "| 4.4 wider net | `p4_4_ddpm_wider` | 30.0 | -4.5 |\n", "\n", "The largest improvement is v-prediction. The wider network then helps because\n", "the schedule and prediction target have already made the learning problem\n", "better aligned with sample quality.\n", "\n", "**Decision:** select the DDPM recipe with cosine schedule, v-prediction,\n", "base_ch=192, and attention at 32/16/8 for the final comparison.\n", "\n", "**Report conclusion:** Phase 4 turns DDPM from the textured but noisy baseline\n", "into the strongest quality candidate for Phase 5.\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 5 }