358 lines
13 KiB
Plaintext
358 lines
13 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Phase 4 - DDPM Progression\n",
|
|
"\n",
|
|
"Phase 4 applies the same report logic to diffusion models. The pipeline is\n",
|
|
"already fixed, so this notebook isolates the DDPM recipe: schedule, prediction\n",
|
|
"target, and backbone width.\n",
|
|
"\n",
|
|
"The story is stepwise. A cosine schedule helps, v-prediction is the major gain,\n",
|
|
"and the wider backbone becomes useful only after the target and schedule are\n",
|
|
"improved.\n",
|
|
"\n",
|
|
"## What this phase changes\n",
|
|
"\n",
|
|
"| Run | Recipe change |\n",
|
|
"|---|---|\n",
|
|
"| `p4_1_ddpm_linear` | Linear noise schedule, epsilon-prediction |\n",
|
|
"| `p4_2_ddpm_cosine` | Cosine noise schedule |\n",
|
|
"| `p4_3_ddpm_vpred` | v-prediction target |\n",
|
|
"| `p4_4_ddpm_wider` | Wider U-Net: base channels 192 with attention at 32/16/8 |\n",
|
|
"\n",
|
|
"Sampling previews use DDIM-50. Logged FID uses DDIM-100 against the saved real\n",
|
|
"reference set.\n",
|
|
"\n",
|
|
"**Headline result:** `p4_4_ddpm_wider` reaches **best FID = 30.0**.\n",
|
|
"\n",
|
|
"## How to read DDPM sample grids\n",
|
|
"\n",
|
|
"The DDPM grids should not be read as the same faces improving from epoch to\n",
|
|
"epoch. GAN and VAE previews can reuse a fixed latent grid, so each position can\n",
|
|
"look like the same latent code becoming sharper over training. A DDPM preview\n",
|
|
"starts from noise and runs a stochastic reverse-diffusion sampler. Unless the\n",
|
|
"exact initial noise and sampler randomness are fixed and stored, each epoch\n",
|
|
"preview is a fresh draw from the model.\n",
|
|
"\n",
|
|
"So for DDPM, the progression panels show distribution-level improvement:\n",
|
|
"cleaner faces, fewer artifacts, and better global structure. They are not\n",
|
|
"identity-by-identity refinements of the same preview images.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Reference: Phase 0 baseline from the same family\n",
|
|
"\n",
|
|
"`p0_ddpm` was a vanilla DDPM (linear schedule, epsilon-prediction, base_ch=128) on raw\n",
|
|
"un-aligned data. Outputs were noisy face-shaped textures. Phase 4 fixes the\n",
|
|
"pipeline (aligned 64) and walks through the standard set of post-2020 DDPM\n",
|
|
"improvements one at a time.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import matplotlib.image as mpimg\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
|
|
"\n",
|
|
"try:\n",
|
|
" display\n",
|
|
"except NameError:\n",
|
|
" def display(obj):\n",
|
|
" print(obj)\n",
|
|
"\n",
|
|
"def find_generator_root():\n",
|
|
" for base in [Path.cwd(), *Path.cwd().parents]:\n",
|
|
" for candidate in [base, base / \"generator\"]:\n",
|
|
" if (candidate / \"outputs\" / \"logs\").exists() and (candidate / \"outputs\" / \"samples\").exists():\n",
|
|
" return candidate.resolve()\n",
|
|
" raise FileNotFoundError(\"Could not locate generator/outputs from the current working directory\")\n",
|
|
"\n",
|
|
"GENERATOR_ROOT = find_generator_root()\n",
|
|
"PROJECT_ROOT = GENERATOR_ROOT.parent\n",
|
|
"OUTPUTS = GENERATOR_ROOT / \"outputs\"\n",
|
|
"LOGS = OUTPUTS / \"logs\"\n",
|
|
"SAMPLES = OUTPUTS / \"samples\"\n",
|
|
"\n",
|
|
"\n",
|
|
"def load_log(name):\n",
|
|
" p = LOGS / f\"{name}.json\"\n",
|
|
" return json.load(open(p)) if p.exists() else None\n",
|
|
"\n",
|
|
"def get_fid(log, epoch):\n",
|
|
" fid = log.get(\"history\", {}).get(\"fid\", {})\n",
|
|
" return fid.get(str(epoch))\n",
|
|
"\n",
|
|
"def fid_series(log):\n",
|
|
" fid = log.get(\"history\", {}).get(\"fid\", {})\n",
|
|
" items = sorted((int(k), v) for k, v in fid.items())\n",
|
|
" return [e for e, _ in items], [v for _, v in items]\n",
|
|
"\n",
|
|
"def show_image_or_missing(ax, path, title=None):\n",
|
|
" if path.exists():\n",
|
|
" ax.imshow(mpimg.imread(str(path)))\n",
|
|
" else:\n",
|
|
" ax.text(0.5, 0.5, f\"missing artifact\\n{path.name}\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
|
" if title:\n",
|
|
" ax.set_title(title, fontsize=9)\n",
|
|
" ax.axis(\"off\")\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1. Load experiment logs\n",
|
|
"\n",
|
|
"The notebook reads existing DDPM logs only. Sampling and FID values are already saved."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"run_names = [\"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"]\n",
|
|
"run_labels = {\n",
|
|
" \"p4_1_ddpm_linear\": \"4.1 linear / epsilon\",\n",
|
|
" \"p4_2_ddpm_cosine\": \"4.2 cosine / epsilon\",\n",
|
|
" \"p4_3_ddpm_vpred\": \"4.3 cosine / v\",\n",
|
|
" \"p4_4_ddpm_wider\": \"4.4 wider net\",\n",
|
|
"}\n",
|
|
"runs = {n: load_log(n) for n in run_names}\n",
|
|
"runs = {k: v for k, v in runs.items() if v}\n",
|
|
"for n in run_names: print(f\" {n}: {'OK' if n in runs else 'MISSING'}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. FID comparison table\n",
|
|
"\n",
|
|
"The table shows whether each recipe change improves generation quality under the saved DDIM-100 FID protocol."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"rows = []\n",
|
|
"for name in run_names:\n",
|
|
" if name not in runs: continue\n",
|
|
" r = runs[name]; _, fid_vals = fid_series(r)\n",
|
|
" rows.append({\n",
|
|
" \"Run\": run_labels[name],\n",
|
|
" \"FID@25\": get_fid(r, 25),\n",
|
|
" \"FID@50\": get_fid(r, 50),\n",
|
|
" \"FID@100\": get_fid(r, 100),\n",
|
|
" \"Best FID\": min(fid_vals) if fid_vals else None,\n",
|
|
" \"Loss@100\": r[\"history\"][\"loss\"][-1],\n",
|
|
" \"Train (min)\": (r['history'].get('train_time_s') or 0) / 60,\n",
|
|
" })\n",
|
|
"df = pd.DataFrame(rows).sort_values(\"Best FID\")\n",
|
|
"df.style.format({\"FID@25\": \"{:.1f}\", \"FID@50\": \"{:.1f}\", \"FID@100\": \"{:.1f}\",\n",
|
|
" \"Best FID\": \"{:.1f}\", \"Loss@100\": \"{:.4f}\", \"Train (min)\": \"{:.1f}\"})\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3. FID curves - progression"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"fig, ax = plt.subplots(figsize=(10, 5))\n",
|
|
"cmap = plt.cm.cividis\n",
|
|
"for i, name in enumerate(run_names):\n",
|
|
" if name not in runs: continue\n",
|
|
" epochs, fid_vals = fid_series(runs[name])\n",
|
|
" ax.plot(epochs, fid_vals, \"o-\", label=run_labels[name],\n",
|
|
" color=cmap(i / len(run_names)), linewidth=2, markersize=7)\n",
|
|
"ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"FID (DDIM-100)\")\n",
|
|
"ax.set_title(\"Phase 4 - FID curves\"); ax.legend()\n",
|
|
"plt.tight_layout(); plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4. Training loss\n",
|
|
"\n",
|
|
"The loss plot is diagnostic, but epsilon-MSE and v-MSE are different targets. FID and sample grids carry the decision."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"fig, ax = plt.subplots(figsize=(10, 4))\n",
|
|
"cmap = plt.cm.cividis\n",
|
|
"for i, name in enumerate(run_names):\n",
|
|
" if name not in runs: continue\n",
|
|
" h = runs[name][\"history\"]\n",
|
|
" epochs = range(1, len(h[\"loss\"]) + 1)\n",
|
|
" ax.plot(epochs, h[\"loss\"], color=cmap(i / len(run_names)), label=run_labels[name], linewidth=1.3)\n",
|
|
"ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"MSE on prediction target\")\n",
|
|
"ax.set_title(\"Loss (epsilon-MSE and v-MSE are not directly comparable)\")\n",
|
|
"ax.legend(); plt.tight_layout(); plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 5. Sample grids - epoch 100"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))\n",
|
|
"for ax, name in zip(axes, run_names):\n",
|
|
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
|
" if img_path.exists():\n",
|
|
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
|
" f = get_fid(runs.get(name, {}), 100) if name in runs else None\n",
|
|
" ax.set_title(f\"{run_labels[name]}\\nFID@100={f:.1f}\" if f else run_labels[name], fontsize=9)\n",
|
|
" ax.axis(\"off\")\n",
|
|
"plt.tight_layout(); plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 6. Progression - epoch 10 -> 50 -> 100\n",
|
|
"\n",
|
|
"Read these as fresh samples from each checkpoint, not the same DDPM images being refined over time."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"check_epochs = [10, 50, 100]\n",
|
|
"for name in run_names:\n",
|
|
" if name not in runs: continue\n",
|
|
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
|
|
" for ax, e in zip(axes, check_epochs):\n",
|
|
" p = SAMPLES / name / f\"epoch_{e:04d}.png\"\n",
|
|
" if p.exists():\n",
|
|
" ax.imshow(mpimg.imread(str(p)))\n",
|
|
" f = get_fid(runs[name], e)\n",
|
|
" ax.set_title(f\"epoch {e}\" + (f\"\\nFID={f:.1f}\" if f else \"\"), fontsize=9)\n",
|
|
" ax.axis(\"off\")\n",
|
|
" fig.suptitle(run_labels[name], fontsize=11, fontweight=\"bold\")\n",
|
|
" plt.tight_layout(); plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 7. Noise schedule visualization\n",
|
|
"\n",
|
|
"The cosine schedule preserves useful signal more smoothly across timesteps. That gives the model a better learning problem before v-prediction and width are added."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import math\n",
|
|
"T = 1000; t = np.arange(T)\n",
|
|
"betas_lin = np.linspace(1e-4, 0.02, T)\n",
|
|
"ab_lin = np.cumprod(1 - betas_lin)\n",
|
|
"s = 0.008\n",
|
|
"f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2\n",
|
|
"f = f / f[0]\n",
|
|
"betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)\n",
|
|
"ab_cos = np.cumprod(1 - betas_cos)\n",
|
|
"\n",
|
|
"fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
|
|
"axes[0].plot(t, ab_lin, label=\"linear\", color=\"#5B8DB8\", linewidth=2)\n",
|
|
"axes[0].plot(t[:len(ab_cos)], ab_cos, label=\"cosine\", color=\"#E8705A\", linewidth=2)\n",
|
|
"axes[0].set_xlabel(\"Timestep t\"); axes[0].set_ylabel(\"alpha_bar_t (signal fraction)\")\n",
|
|
"axes[0].set_title(\"Cumulative signal preservation\"); axes[0].legend()\n",
|
|
"axes[1].plot(betas_lin, label=\"linear beta\", color=\"#5B8DB8\", linewidth=2)\n",
|
|
"axes[1].plot(betas_cos, label=\"cosine beta\", color=\"#E8705A\", linewidth=2)\n",
|
|
"axes[1].set_xlabel(\"Timestep t\"); axes[1].set_ylabel(\"beta_t\"); axes[1].set_title(\"beta schedule\"); axes[1].legend()\n",
|
|
"plt.tight_layout(); plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 8. What this phase proves\n",
|
|
"\n",
|
|
"| Step | Run | Best FID | Delta vs previous |\n",
|
|
"|---|---|---:|---:|\n",
|
|
"| 4.1 linear / epsilon | `p4_1_ddpm_linear` | 134.5 | n/a |\n",
|
|
"| 4.2 cosine / epsilon | `p4_2_ddpm_cosine` | 132.3 | -2.2 |\n",
|
|
"| 4.3 cosine / v | `p4_3_ddpm_vpred` | 34.5 | -97.8 |\n",
|
|
"| 4.4 wider net | `p4_4_ddpm_wider` | 30.0 | -4.5 |\n",
|
|
"\n",
|
|
"The largest improvement is v-prediction. The wider network then helps because\n",
|
|
"the schedule and prediction target have already made the learning problem\n",
|
|
"better aligned with sample quality.\n",
|
|
"\n",
|
|
"**Decision:** select the DDPM recipe with cosine schedule, v-prediction,\n",
|
|
"base_ch=192, and attention at 32/16/8 for the final comparison.\n",
|
|
"\n",
|
|
"**Report conclusion:** Phase 4 turns DDPM from the textured but noisy baseline\n",
|
|
"into the strongest quality candidate for Phase 5.\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python",
|
|
"version": "3.12.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|