Files
DRL_PROJ/generator/notebooks/phase4_analysis.ipynb
T
2026-05-11 17:36:08 +01:00

358 lines
13 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Phase 4 - DDPM Progression\n",
"\n",
"Phase 4 applies the same report logic to diffusion models. The pipeline is\n",
"already fixed, so this notebook isolates the DDPM recipe: schedule, prediction\n",
"target, and backbone width.\n",
"\n",
"The story is stepwise. A cosine schedule helps, v-prediction is the major gain,\n",
"and the wider backbone becomes useful only after the target and schedule are\n",
"improved.\n",
"\n",
"## What this phase changes\n",
"\n",
"| Run | Recipe change |\n",
"|---|---|\n",
"| `p4_1_ddpm_linear` | Linear noise schedule, epsilon-prediction |\n",
"| `p4_2_ddpm_cosine` | Cosine noise schedule |\n",
"| `p4_3_ddpm_vpred` | v-prediction target |\n",
"| `p4_4_ddpm_wider` | Wider U-Net: base channels 192 with attention at 32/16/8 |\n",
"\n",
"Sampling previews use DDIM-50. Logged FID uses DDIM-100 against the saved real\n",
"reference set.\n",
"\n",
"**Headline result:** `p4_4_ddpm_wider` reaches **best FID = 30.0**.\n",
"\n",
"## How to read DDPM sample grids\n",
"\n",
"The DDPM grids should not be read as the same faces improving from epoch to\n",
"epoch. GAN and VAE previews can reuse a fixed latent grid, so each position can\n",
"look like the same latent code becoming sharper over training. A DDPM preview\n",
"starts from noise and runs a stochastic reverse-diffusion sampler. Unless the\n",
"exact initial noise and sampler randomness are fixed and stored, each epoch\n",
"preview is a fresh draw from the model.\n",
"\n",
"So for DDPM, the progression panels show distribution-level improvement:\n",
"cleaner faces, fewer artifacts, and better global structure. They are not\n",
"identity-by-identity refinements of the same preview images.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reference: Phase 0 baseline from the same family\n",
"\n",
"`p0_ddpm` was a vanilla DDPM (linear schedule, epsilon-prediction, base_ch=128) on raw\n",
"un-aligned data. Outputs were noisy face-shaped textures. Phase 4 fixes the\n",
"pipeline (aligned 64) and walks through the standard set of post-2020 DDPM\n",
"improvements one at a time.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"try:\n",
" display\n",
"except NameError:\n",
" def display(obj):\n",
" print(obj)\n",
"\n",
"def find_generator_root():\n",
" for base in [Path.cwd(), *Path.cwd().parents]:\n",
" for candidate in [base, base / \"generator\"]:\n",
" if (candidate / \"outputs\" / \"logs\").exists() and (candidate / \"outputs\" / \"samples\").exists():\n",
" return candidate.resolve()\n",
" raise FileNotFoundError(\"Could not locate generator/outputs from the current working directory\")\n",
"\n",
"GENERATOR_ROOT = find_generator_root()\n",
"PROJECT_ROOT = GENERATOR_ROOT.parent\n",
"OUTPUTS = GENERATOR_ROOT / \"outputs\"\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\"\n",
"\n",
"\n",
"def load_log(name):\n",
" p = LOGS / f\"{name}.json\"\n",
" return json.load(open(p)) if p.exists() else None\n",
"\n",
"def get_fid(log, epoch):\n",
" fid = log.get(\"history\", {}).get(\"fid\", {})\n",
" return fid.get(str(epoch))\n",
"\n",
"def fid_series(log):\n",
" fid = log.get(\"history\", {}).get(\"fid\", {})\n",
" items = sorted((int(k), v) for k, v in fid.items())\n",
" return [e for e, _ in items], [v for _, v in items]\n",
"\n",
"def show_image_or_missing(ax, path, title=None):\n",
" if path.exists():\n",
" ax.imshow(mpimg.imread(str(path)))\n",
" else:\n",
" ax.text(0.5, 0.5, f\"missing artifact\\n{path.name}\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" if title:\n",
" ax.set_title(title, fontsize=9)\n",
" ax.axis(\"off\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Load experiment logs\n",
"\n",
"The notebook reads existing DDPM logs only. Sampling and FID values are already saved."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run_names = [\"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"]\n",
"run_labels = {\n",
" \"p4_1_ddpm_linear\": \"4.1 linear / epsilon\",\n",
" \"p4_2_ddpm_cosine\": \"4.2 cosine / epsilon\",\n",
" \"p4_3_ddpm_vpred\": \"4.3 cosine / v\",\n",
" \"p4_4_ddpm_wider\": \"4.4 wider net\",\n",
"}\n",
"runs = {n: load_log(n) for n in run_names}\n",
"runs = {k: v for k, v in runs.items() if v}\n",
"for n in run_names: print(f\" {n}: {'OK' if n in runs else 'MISSING'}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. FID comparison table\n",
"\n",
"The table shows whether each recipe change improves generation quality under the saved DDIM-100 FID protocol."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rows = []\n",
"for name in run_names:\n",
" if name not in runs: continue\n",
" r = runs[name]; _, fid_vals = fid_series(r)\n",
" rows.append({\n",
" \"Run\": run_labels[name],\n",
" \"FID@25\": get_fid(r, 25),\n",
" \"FID@50\": get_fid(r, 50),\n",
" \"FID@100\": get_fid(r, 100),\n",
" \"Best FID\": min(fid_vals) if fid_vals else None,\n",
" \"Loss@100\": r[\"history\"][\"loss\"][-1],\n",
" \"Train (min)\": (r['history'].get('train_time_s') or 0) / 60,\n",
" })\n",
"df = pd.DataFrame(rows).sort_values(\"Best FID\")\n",
"df.style.format({\"FID@25\": \"{:.1f}\", \"FID@50\": \"{:.1f}\", \"FID@100\": \"{:.1f}\",\n",
" \"Best FID\": \"{:.1f}\", \"Loss@100\": \"{:.4f}\", \"Train (min)\": \"{:.1f}\"})\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. FID curves - progression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(10, 5))\n",
"cmap = plt.cm.cividis\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs: continue\n",
" epochs, fid_vals = fid_series(runs[name])\n",
" ax.plot(epochs, fid_vals, \"o-\", label=run_labels[name],\n",
" color=cmap(i / len(run_names)), linewidth=2, markersize=7)\n",
"ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"FID (DDIM-100)\")\n",
"ax.set_title(\"Phase 4 - FID curves\"); ax.legend()\n",
"plt.tight_layout(); plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Training loss\n",
"\n",
"The loss plot is diagnostic, but epsilon-MSE and v-MSE are different targets. FID and sample grids carry the decision."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(10, 4))\n",
"cmap = plt.cm.cividis\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs: continue\n",
" h = runs[name][\"history\"]\n",
" epochs = range(1, len(h[\"loss\"]) + 1)\n",
" ax.plot(epochs, h[\"loss\"], color=cmap(i / len(run_names)), label=run_labels[name], linewidth=1.3)\n",
"ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"MSE on prediction target\")\n",
"ax.set_title(\"Loss (epsilon-MSE and v-MSE are not directly comparable)\")\n",
"ax.legend(); plt.tight_layout(); plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Sample grids - epoch 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))\n",
"for ax, name in zip(axes, run_names):\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" f = get_fid(runs.get(name, {}), 100) if name in runs else None\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={f:.1f}\" if f else run_labels[name], fontsize=9)\n",
" ax.axis(\"off\")\n",
"plt.tight_layout(); plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Progression - epoch 10 -> 50 -> 100\n",
"\n",
"Read these as fresh samples from each checkpoint, not the same DDPM images being refined over time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"check_epochs = [10, 50, 100]\n",
"for name in run_names:\n",
" if name not in runs: continue\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
" for ax, e in zip(axes, check_epochs):\n",
" p = SAMPLES / name / f\"epoch_{e:04d}.png\"\n",
" if p.exists():\n",
" ax.imshow(mpimg.imread(str(p)))\n",
" f = get_fid(runs[name], e)\n",
" ax.set_title(f\"epoch {e}\" + (f\"\\nFID={f:.1f}\" if f else \"\"), fontsize=9)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(run_labels[name], fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout(); plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Noise schedule visualization\n",
"\n",
"The cosine schedule preserves useful signal more smoothly across timesteps. That gives the model a better learning problem before v-prediction and width are added."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"T = 1000; t = np.arange(T)\n",
"betas_lin = np.linspace(1e-4, 0.02, T)\n",
"ab_lin = np.cumprod(1 - betas_lin)\n",
"s = 0.008\n",
"f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2\n",
"f = f / f[0]\n",
"betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)\n",
"ab_cos = np.cumprod(1 - betas_cos)\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
"axes[0].plot(t, ab_lin, label=\"linear\", color=\"#5B8DB8\", linewidth=2)\n",
"axes[0].plot(t[:len(ab_cos)], ab_cos, label=\"cosine\", color=\"#E8705A\", linewidth=2)\n",
"axes[0].set_xlabel(\"Timestep t\"); axes[0].set_ylabel(\"alpha_bar_t (signal fraction)\")\n",
"axes[0].set_title(\"Cumulative signal preservation\"); axes[0].legend()\n",
"axes[1].plot(betas_lin, label=\"linear beta\", color=\"#5B8DB8\", linewidth=2)\n",
"axes[1].plot(betas_cos, label=\"cosine beta\", color=\"#E8705A\", linewidth=2)\n",
"axes[1].set_xlabel(\"Timestep t\"); axes[1].set_ylabel(\"beta_t\"); axes[1].set_title(\"beta schedule\"); axes[1].legend()\n",
"plt.tight_layout(); plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. What this phase proves\n",
"\n",
"| Step | Run | Best FID | Delta vs previous |\n",
"|---|---|---:|---:|\n",
"| 4.1 linear / epsilon | `p4_1_ddpm_linear` | 134.5 | n/a |\n",
"| 4.2 cosine / epsilon | `p4_2_ddpm_cosine` | 132.3 | -2.2 |\n",
"| 4.3 cosine / v | `p4_3_ddpm_vpred` | 34.5 | -97.8 |\n",
"| 4.4 wider net | `p4_4_ddpm_wider` | 30.0 | -4.5 |\n",
"\n",
"The largest improvement is v-prediction. The wider network then helps because\n",
"the schedule and prediction target have already made the learning problem\n",
"better aligned with sample quality.\n",
"\n",
"**Decision:** select the DDPM recipe with cosine schedule, v-prediction,\n",
"base_ch=192, and attention at 32/16/8 for the final comparison.\n",
"\n",
"**Report conclusion:** Phase 4 turns DDPM from the textured but noisy baseline\n",
"into the strongest quality candidate for Phase 5.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}