404 lines
13 KiB
Plaintext
404 lines
13 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c0000001",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Phase 4 — DDPM Evolution Analysis\n",
|
||
"\n",
|
||
"Traces the DDPM improvement story:\n",
|
||
"\n",
|
||
"| Step | Model | Key change | Expected failure |\n",
|
||
"|------|-------|------------|------------------|\n",
|
||
"| 4.1 | DDPM linear + ε-pred | Baseline | Noise prediction unstable at very low t (linear schedule over-denoises) |\n",
|
||
"| 4.2 | + cosine schedule | Less noise wasted at low timesteps | Residual instability from ε parameterisation |\n",
|
||
"| 4.3 | + v-prediction | Numerically stable across full trajectory | Possible underfitting at 64×64 |\n",
|
||
"| 4.4 | + wider U-Net (192ch) + 32×32 attention | More capacity and longer-range context | — |\n",
|
||
"\n",
|
||
"FID is computed via DDIM (100 steps, deterministic) from the EMA model.\n",
|
||
"H-flip-only augmentation, MTCNN-aligned 64×64 crops, T=1000."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c0000002",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import json\n",
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import matplotlib.image as mpimg\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
|
||
"\n",
|
||
"OUTPUTS = Path(\"../outputs\")\n",
|
||
"LOGS = OUTPUTS / \"logs\"\n",
|
||
"SAMPLES = OUTPUTS / \"samples\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c0000003",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1. Load experiment logs"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c0000004",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"run_names = [\"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"]\n",
|
||
"run_labels = {\n",
|
||
" \"p4_1_ddpm_linear\": \"4.1 linear + ε\",\n",
|
||
" \"p4_2_ddpm_cosine\": \"4.2 cosine + ε\",\n",
|
||
" \"p4_3_ddpm_vpred\": \"4.3 cosine + v\",\n",
|
||
" \"p4_4_ddpm_wider\": \"4.4 wider + 32×32 attn\",\n",
|
||
"}\n",
|
||
"\n",
|
||
"runs = {}\n",
|
||
"for name in run_names:\n",
|
||
" log_path = LOGS / f\"{name}.json\"\n",
|
||
" if log_path.exists():\n",
|
||
" with open(log_path) as f:\n",
|
||
" runs[name] = json.load(f)\n",
|
||
" else:\n",
|
||
" print(f\" Missing: {log_path}\")\n",
|
||
"\n",
|
||
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
|
||
"for name in run_names:\n",
|
||
" print(f\" {'✓' if name in runs else '✗'} {name}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c0000005",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2. FID Comparison Table"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c0000006",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get_fid(run, epoch):\n",
|
||
" fid = run[\"history\"][\"fid\"]\n",
|
||
" return fid.get(str(epoch), fid.get(epoch, None))\n",
|
||
"\n",
|
||
"rows = []\n",
|
||
"for name in run_names:\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" r = runs[name]\n",
|
||
" cfg = r[\"config\"]\n",
|
||
" rows.append({\n",
|
||
" \"Step\": run_labels[name],\n",
|
||
" \"Schedule\": cfg.get(\"noise_schedule\"),\n",
|
||
" \"Pred\": cfg.get(\"pred_type\"),\n",
|
||
" \"base_ch\": cfg.get(\"base_ch\"),\n",
|
||
" \"FID@25\": get_fid(r, 25),\n",
|
||
" \"FID@50\": get_fid(r, 50),\n",
|
||
" \"FID@75\": get_fid(r, 75),\n",
|
||
" \"FID@100\": get_fid(r, 100),\n",
|
||
" })\n",
|
||
"\n",
|
||
"df = pd.DataFrame(rows)\n",
|
||
"df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c0000007",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3. FID Curves — Evolution Story"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c0000008",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n",
|
||
"\n",
|
||
"fig, ax = plt.subplots(figsize=(11, 5))\n",
|
||
"for i, name in enumerate(run_names):\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
|
||
" epochs = sorted(int(k) for k in fid_dict)\n",
|
||
" fids = [fid_dict[str(e)] for e in epochs]\n",
|
||
" fid100 = fid_dict.get(\"100\", \"?\")\n",
|
||
" label = f\"{run_labels[name]} (FID@100={fid100:.1f})\" if isinstance(fid100, float) else run_labels[name]\n",
|
||
" ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n",
|
||
"\n",
|
||
"ax.set_xlabel(\"Epoch\")\n",
|
||
"ax.set_ylabel(\"FID (DDIM 100 steps, lower is better)\")\n",
|
||
"ax.set_title(\"Phase 4 — FID Curves: linear·ε → cosine·ε → cosine·v → wider\")\n",
|
||
"ax.legend()\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c0000009",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4. Training Loss Curves (MSE on predicted target)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c0000010",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fig, ax = plt.subplots(figsize=(11, 4))\n",
|
||
"for i, name in enumerate(run_names):\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" losses = runs[name][\"history\"][\"loss\"]\n",
|
||
" ax.plot(range(1, len(losses)+1), losses, label=run_labels[name],\n",
|
||
" color=colors[i], linewidth=1.2, alpha=0.9)\n",
|
||
"\n",
|
||
"ax.set_xlabel(\"Epoch\")\n",
|
||
"ax.set_ylabel(\"MSE (noise / v prediction loss)\")\n",
|
||
"ax.set_title(\"Phase 4 — Training Loss\")\n",
|
||
"ax.legend(fontsize=9)\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c0000011",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 5. Sample Grids — Epoch 100 (DDIM 50 steps)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c0000012",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n",
|
||
"\n",
|
||
"for idx, name in enumerate(run_names):\n",
|
||
" ax = axes[idx]\n",
|
||
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
||
" if img_path.exists():\n",
|
||
" fid = get_fid(runs[name], 100) if name in runs else None\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||
" ax.set_title(run_labels[name], fontsize=8)\n",
|
||
" ax.axis(\"off\")\n",
|
||
"\n",
|
||
"fig.suptitle(\"Phase 4 — Epoch 100 DDIM Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c0000013",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 6. Step-by-step Pairwise Comparisons"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c0000014",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"transitions = [\n",
|
||
" (\"4.1→4.2: linear→cosine schedule\", \"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\"),\n",
|
||
" (\"4.2→4.3: ε-pred→v-pred\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\"),\n",
|
||
" (\"4.3→4.4: 128ch→192ch + 32×32 attn\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"),\n",
|
||
"]\n",
|
||
"\n",
|
||
"for title, name_a, name_b in transitions:\n",
|
||
" fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
|
||
" for ax, name in zip(axes, [name_a, name_b]):\n",
|
||
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
||
" if img_path.exists():\n",
|
||
" fid = get_fid(runs[name], 100) if name in runs else None\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||
" ax.set_title(run_labels[name], fontsize=9)\n",
|
||
" ax.axis(\"off\")\n",
|
||
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c0000015",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 7. Progression: Epoch 10 → 50 → 100"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c0000016",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"check_epochs = [10, 50, 100]\n",
|
||
"\n",
|
||
"for name in run_names:\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
|
||
" for ax, ep in zip(axes, check_epochs):\n",
|
||
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
|
||
" if img_path.exists():\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" fid = get_fid(runs[name], ep)\n",
|
||
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
|
||
" transform=ax.transAxes)\n",
|
||
" ax.axis(\"off\")\n",
|
||
" fig.suptitle(f\"{run_labels[name]} — Progression\", fontsize=11, fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c0000017",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 8. Noise Schedule Visualisation\n",
|
||
"\n",
|
||
"Illustrates why cosine outperforms linear: the linear schedule allocates many timesteps near t=T where the image is already near-pure noise, wasting model capacity."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c0000018",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import math\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"T = 1000\n",
|
||
"t = np.arange(T)\n",
|
||
"\n",
|
||
"# Linear betas\n",
|
||
"betas_lin = np.linspace(1e-4, 0.02, T)\n",
|
||
"ab_lin = np.cumprod(1 - betas_lin)\n",
|
||
"\n",
|
||
"# Cosine betas\n",
|
||
"s = 0.008\n",
|
||
"f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2\n",
|
||
"f = f / f[0]\n",
|
||
"betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)\n",
|
||
"ab_cos = np.cumprod(1 - betas_cos)\n",
|
||
"\n",
|
||
"fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
|
||
"\n",
|
||
"axes[0].plot(t, ab_lin, label=\"linear\", color=\"#5B8DB8\", linewidth=2)\n",
|
||
"axes[0].plot(t, ab_cos, label=\"cosine\", color=\"#E8705A\", linewidth=2)\n",
|
||
"axes[0].set_xlabel(\"Timestep t\")\n",
|
||
"axes[0].set_ylabel(\"ᾱ_t (signal fraction)\")\n",
|
||
"axes[0].set_title(\"ᾱ_t vs t — cosine stays informative longer\")\n",
|
||
"axes[0].legend()\n",
|
||
"\n",
|
||
"axes[1].plot(t, np.sqrt(ab_lin), label=\"linear √ᾱ\", color=\"#5B8DB8\", linewidth=2)\n",
|
||
"axes[1].plot(t, np.sqrt(ab_cos), label=\"cosine √ᾱ\", color=\"#E8705A\", linewidth=2)\n",
|
||
"axes[1].set_xlabel(\"Timestep t\")\n",
|
||
"axes[1].set_ylabel(\"√ᾱ_t (signal amplitude)\")\n",
|
||
"axes[1].set_title(\"Signal amplitude — cosine is more uniform\")\n",
|
||
"axes[1].legend()\n",
|
||
"\n",
|
||
"plt.suptitle(\"Noise Schedule Comparison\", fontsize=12, fontweight=\"bold\")\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c0000019",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 9. Conclusions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c0000020",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print(\"=\" * 70)\n",
|
||
"print(\"PHASE 4 — DDPM EVOLUTION SUMMARY\")\n",
|
||
"print(\"=\" * 70)\n",
|
||
"\n",
|
||
"for name in run_names:\n",
|
||
" if name not in runs:\n",
|
||
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
|
||
" continue\n",
|
||
" fid100 = get_fid(runs[name], 100)\n",
|
||
" fid50 = get_fid(runs[name], 50)\n",
|
||
" h = runs[name][\"history\"]\n",
|
||
" loss50 = h[\"loss\"][49] if len(h[\"loss\"]) > 49 else None\n",
|
||
" print(f\"\\n {run_labels[name]}:\")\n",
|
||
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
|
||
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
|
||
" print(f\" Loss@50 = {loss50:.5f}\" if loss50 else \" Loss@50 = ?\")\n",
|
||
"\n",
|
||
"print(\"\\n\" + \"=\" * 70)\n",
|
||
"print(\"Best DDPM model for Phase 5 comparison: fill in after runs complete.\")\n",
|
||
"print(\"=\" * 70)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"name": "python",
|
||
"version": "3.10.0"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|