Files
DRL_PROJ/generator/notebooks/phase4_analysis.ipynb
T
2026-04-30 13:10:33 +01:00

404 lines
13 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "c0000001",
"metadata": {},
"source": [
"# Phase 4 — DDPM Evolution Analysis\n",
"\n",
"Traces the DDPM improvement story:\n",
"\n",
"| Step | Model | Key change | Expected failure |\n",
"|------|-------|------------|------------------|\n",
"| 4.1 | DDPM linear + ε-pred | Baseline | Noise prediction unstable at very low t (linear schedule over-denoises) |\n",
"| 4.2 | + cosine schedule | Less noise wasted at low timesteps | Residual instability from ε parameterisation |\n",
"| 4.3 | + v-prediction | Numerically stable across full trajectory | Possible underfitting at 64×64 |\n",
"| 4.4 | + wider U-Net (192ch) + 32×32 attention | More capacity and longer-range context | — |\n",
"\n",
"FID is computed via DDIM (100 steps, deterministic) from the EMA model.\n",
"H-flip-only augmentation, MTCNN-aligned 64×64 crops, T=1000."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000002",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\""
]
},
{
"cell_type": "markdown",
"id": "c0000003",
"metadata": {},
"source": [
"## 1. Load experiment logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000004",
"metadata": {},
"outputs": [],
"source": [
"run_names = [\"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"]\n",
"run_labels = {\n",
" \"p4_1_ddpm_linear\": \"4.1 linear + ε\",\n",
" \"p4_2_ddpm_cosine\": \"4.2 cosine + ε\",\n",
" \"p4_3_ddpm_vpred\": \"4.3 cosine + v\",\n",
" \"p4_4_ddpm_wider\": \"4.4 wider + 32×32 attn\",\n",
"}\n",
"\n",
"runs = {}\n",
"for name in run_names:\n",
" log_path = LOGS / f\"{name}.json\"\n",
" if log_path.exists():\n",
" with open(log_path) as f:\n",
" runs[name] = json.load(f)\n",
" else:\n",
" print(f\" Missing: {log_path}\")\n",
"\n",
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
"for name in run_names:\n",
" print(f\" {'✓' if name in runs else '✗'} {name}\")"
]
},
{
"cell_type": "markdown",
"id": "c0000005",
"metadata": {},
"source": [
"## 2. FID Comparison Table"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000006",
"metadata": {},
"outputs": [],
"source": [
"def get_fid(run, epoch):\n",
" fid = run[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"rows = []\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" r = runs[name]\n",
" cfg = r[\"config\"]\n",
" rows.append({\n",
" \"Step\": run_labels[name],\n",
" \"Schedule\": cfg.get(\"noise_schedule\"),\n",
" \"Pred\": cfg.get(\"pred_type\"),\n",
" \"base_ch\": cfg.get(\"base_ch\"),\n",
" \"FID@25\": get_fid(r, 25),\n",
" \"FID@50\": get_fid(r, 50),\n",
" \"FID@75\": get_fid(r, 75),\n",
" \"FID@100\": get_fid(r, 100),\n",
" })\n",
"\n",
"df = pd.DataFrame(rows)\n",
"df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})"
]
},
{
"cell_type": "markdown",
"id": "c0000007",
"metadata": {},
"source": [
"## 3. FID Curves — Evolution Story"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000008",
"metadata": {},
"outputs": [],
"source": [
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n",
"\n",
"fig, ax = plt.subplots(figsize=(11, 5))\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
" epochs = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in epochs]\n",
" fid100 = fid_dict.get(\"100\", \"?\")\n",
" label = f\"{run_labels[name]} (FID@100={fid100:.1f})\" if isinstance(fid100, float) else run_labels[name]\n",
" ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"FID (DDIM 100 steps, lower is better)\")\n",
"ax.set_title(\"Phase 4 — FID Curves: linear·ε → cosine·ε → cosine·v → wider\")\n",
"ax.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000009",
"metadata": {},
"source": [
"## 4. Training Loss Curves (MSE on predicted target)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000010",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(11, 4))\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" losses = runs[name][\"history\"][\"loss\"]\n",
" ax.plot(range(1, len(losses)+1), losses, label=run_labels[name],\n",
" color=colors[i], linewidth=1.2, alpha=0.9)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"MSE (noise / v prediction loss)\")\n",
"ax.set_title(\"Phase 4 — Training Loss\")\n",
"ax.legend(fontsize=9)\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000011",
"metadata": {},
"source": [
"## 5. Sample Grids — Epoch 100 (DDIM 50 steps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000012",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=8)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 4 — Epoch 100 DDIM Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000013",
"metadata": {},
"source": [
"## 6. Step-by-step Pairwise Comparisons"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000014",
"metadata": {},
"outputs": [],
"source": [
"transitions = [\n",
" (\"4.1→4.2: linear→cosine schedule\", \"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\"),\n",
" (\"4.2→4.3: ε-pred→v-pred\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\"),\n",
" (\"4.3→4.4: 128ch→192ch + 32×32 attn\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"),\n",
"]\n",
"\n",
"for title, name_a, name_b in transitions:\n",
" fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
" for ax, name in zip(axes, [name_a, name_b]):\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=9)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000015",
"metadata": {},
"source": [
"## 7. Progression: Epoch 10 → 50 → 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000016",
"metadata": {},
"outputs": [],
"source": [
"check_epochs = [10, 50, 100]\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(runs[name], ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
" transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{run_labels[name]} — Progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000017",
"metadata": {},
"source": [
"## 8. Noise Schedule Visualisation\n",
"\n",
"Illustrates why cosine outperforms linear: the linear schedule allocates many timesteps near t=T where the image is already near-pure noise, wasting model capacity."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000018",
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import numpy as np\n",
"\n",
"T = 1000\n",
"t = np.arange(T)\n",
"\n",
"# Linear betas\n",
"betas_lin = np.linspace(1e-4, 0.02, T)\n",
"ab_lin = np.cumprod(1 - betas_lin)\n",
"\n",
"# Cosine betas\n",
"s = 0.008\n",
"f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2\n",
"f = f / f[0]\n",
"betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)\n",
"ab_cos = np.cumprod(1 - betas_cos)\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
"\n",
"axes[0].plot(t, ab_lin, label=\"linear\", color=\"#5B8DB8\", linewidth=2)\n",
"axes[0].plot(t, ab_cos, label=\"cosine\", color=\"#E8705A\", linewidth=2)\n",
"axes[0].set_xlabel(\"Timestep t\")\n",
"axes[0].set_ylabel(\"ᾱ_t (signal fraction)\")\n",
"axes[0].set_title(\"ᾱ_t vs t — cosine stays informative longer\")\n",
"axes[0].legend()\n",
"\n",
"axes[1].plot(t, np.sqrt(ab_lin), label=\"linear √ᾱ\", color=\"#5B8DB8\", linewidth=2)\n",
"axes[1].plot(t, np.sqrt(ab_cos), label=\"cosine √ᾱ\", color=\"#E8705A\", linewidth=2)\n",
"axes[1].set_xlabel(\"Timestep t\")\n",
"axes[1].set_ylabel(\"√ᾱ_t (signal amplitude)\")\n",
"axes[1].set_title(\"Signal amplitude — cosine is more uniform\")\n",
"axes[1].legend()\n",
"\n",
"plt.suptitle(\"Noise Schedule Comparison\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000019",
"metadata": {},
"source": [
"## 9. Conclusions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000020",
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 70)\n",
"print(\"PHASE 4 — DDPM EVOLUTION SUMMARY\")\n",
"print(\"=\" * 70)\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
" continue\n",
" fid100 = get_fid(runs[name], 100)\n",
" fid50 = get_fid(runs[name], 50)\n",
" h = runs[name][\"history\"]\n",
" loss50 = h[\"loss\"][49] if len(h[\"loss\"]) > 49 else None\n",
" print(f\"\\n {run_labels[name]}:\")\n",
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
" print(f\" Loss@50 = {loss50:.5f}\" if loss50 else \" Loss@50 = ?\")\n",
"\n",
"print(\"\\n\" + \"=\" * 70)\n",
"print(\"Best DDPM model for Phase 5 comparison: fill in after runs complete.\")\n",
"print(\"=\" * 70)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}