Files
DRL_PROJ/generator/notebooks/_build.py
T
2026-05-05 00:01:43 +01:00

1207 lines
54 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Build all phase analysis notebooks from a single source of truth.
Run from generator/notebooks/: python3 _build.py
Each phase notebook follows the same template (header → load → FID table →
FID curves → loss curves → samples → progression → conclusions). Phase 0 is
baseline-only (no FID, no iteration story). Phases 14 add a phase-0 same-family
reference. Phase 5 stays cross-family and adds latent interpolation.
Real metric values are pulled from outputs/logs/*.json at build time and
rendered into markdown headers and conclusions, so reports never drift from data.
"""
import json
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
LOGS = ROOT / "outputs" / "logs"
OUT = ROOT / "notebooks"
# ── notebook helpers ─────────────────────────────────────────────────────────
def md(text): return {"cell_type": "markdown", "metadata": {}, "source": text.splitlines(keepends=True)}
def code(text): return {"cell_type": "code", "metadata": {}, "execution_count": None, "outputs": [], "source": text.splitlines(keepends=True)}
def write_nb(name, cells):
nb = {
"cells": cells,
"metadata": {
"kernelspec": {"name": "python3", "display_name": "Python 3"},
"language_info": {"name": "python"},
},
"nbformat": 4,
"nbformat_minor": 5,
}
path = OUT / f"{name}.ipynb"
path.write_text(json.dumps(nb, indent=1))
print(f" wrote {path.relative_to(ROOT)}")
# ── log-derived facts (computed once, baked into markdown) ───────────────────
def load(name):
p = LOGS / f"{name}.json"
return json.load(open(p)) if p.exists() else None
def best_fid(log):
fid = log.get("history", {}).get("fid", {}) if log else {}
if not fid: return None, None
items = sorted(((int(k), v) for k, v in fid.items()))
e, v = min(items, key=lambda x: x[1])
return e, v
def time_min(log):
t = (log or {}).get("history", {}).get("train_time_s")
return t/60 if t else None
def get_fid(log, epoch):
"""Build-time helper (mirrors the runtime helper in SHARED_IMPORTS)."""
return (log or {}).get("history", {}).get("fid", {}).get(str(epoch))
# ── shared imports cell (all phases use the same setup) ──────────────────────
SHARED_IMPORTS = """\
import json
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import pandas as pd
plt.rcParams.update({"figure.dpi": 120, "font.size": 10})
OUTPUTS = Path("../outputs")
LOGS = OUTPUTS / "logs"
SAMPLES = OUTPUTS / "samples"
def load_log(name):
p = LOGS / f"{name}.json"
return json.load(open(p)) if p.exists() else None
def get_fid(log, epoch):
fid = log.get("history", {}).get("fid", {})
return fid.get(str(epoch))
def fid_series(log):
fid = log.get("history", {}).get("fid", {})
items = sorted((int(k), v) for k, v in fid.items())
return [e for e, _ in items], [v for _, v in items]
"""
# ─────────────────────────────────────────────────────────────────────────────
# PHASE 0 — Baseline (new)
# ─────────────────────────────────────────────────────────────────────────────
def build_phase0():
p0 = {n: load(n) for n in ["p0_wgan", "p0_vae", "p0_ddpm", "p0_ddpm_small"]}
cells = [
md(f"""\
# Phase 0 — Baseline
The starting point. Three model families trained at minimal scale on the raw (un-aligned)
dataset to establish a reference point and verify the training loops work end-to-end.
No FID was tracked at this stage — these runs predate the FID instrumentation added in
phase 1, so the only signals are training loss and visual sample quality.
| Run | Model | Epochs | Notes |
|-----------------------|--------------|--------|------------------------------------------|
| `p0_wgan` | WGAN-GP | {len(p0['p0_wgan']['history']['g_loss']) if p0['p0_wgan'] else '—'} | basic generator/critic, no SN/attention |
| `p0_vae` | VAE | {len(p0['p0_vae']['history']['loss']) if p0['p0_vae'] else '—'} | MSE + KL only |
| `p0_ddpm` | DDPM | {len(p0['p0_ddpm']['history']['loss']) if p0['p0_ddpm'] else '—'} | linear schedule, ε-prediction |
| `p0_ddpm_small` | DDPM (small) | {len(p0['p0_ddpm_small']['history']['loss']) if p0['p0_ddpm_small'] else '—'} | reduced base channels — sanity check |
Phase 0 outputs are **deliberately rough**. The point is to have something to compare
against in phases 14 so improvements have a baseline.
"""),
code(SHARED_IMPORTS),
md("## 1. Training loss curves"),
code("""\
runs = {n: load_log(n) for n in ["p0_wgan", "p0_vae", "p0_ddpm", "p0_ddpm_small"]}
runs = {k: v for k, v in runs.items() if v}
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
# WGAN: g_loss + w_dist
h = runs["p0_wgan"]["history"]
ep = range(1, len(h["g_loss"]) + 1)
axes[0].plot(ep, h["g_loss"], label="G loss", color="#5B8DB8")
axes[0].plot(ep, h["c_loss"], label="C loss", color="#E8705A", alpha=0.7)
axes[0].set_title("p0_wgan — generator vs critic loss")
axes[0].set_xlabel("Epoch"); axes[0].legend()
# VAE: total loss + components
h = runs["p0_vae"]["history"]
ep = range(1, len(h["loss"]) + 1)
axes[1].plot(ep, h["recon_loss"], label="Recon (MSE)", color="#5B8DB8")
axes[1].plot(ep, h["kl_loss"], label="KL", color="#E8705A")
axes[1].set_title("p0_vae — recon vs KL")
axes[1].set_xlabel("Epoch"); axes[1].legend()
# DDPM: noise prediction MSE
h = runs["p0_ddpm"]["history"]
ep = range(1, len(h["loss"]) + 1)
axes[2].plot(ep, h["loss"], color="#5B8DB8", label="ε-MSE")
if "p0_ddpm_small" in runs:
h2 = runs["p0_ddpm_small"]["history"]
axes[2].plot(range(1, len(h2["loss"]) + 1), h2["loss"], color="#E8705A", linestyle="--", label="small variant")
axes[2].set_title("p0_ddpm — noise prediction loss")
axes[2].set_xlabel("Epoch"); axes[2].legend()
plt.tight_layout(); plt.show()
"""),
md("## 2. Final sample grids\n\nLast available preview from each run."),
code("""\
last_epochs = {"p0_wgan": 200, "p0_vae": 100, "p0_ddpm": 200, "p0_ddpm_small": 100}
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for ax, (name, ep) in zip(axes, last_epochs.items()):
img_path = SAMPLES / name / f"epoch_{ep:04d}.png"
if img_path.exists():
ax.imshow(mpimg.imread(str(img_path)))
ax.set_title(f"{name}\\n(epoch {ep})", fontsize=10)
else:
ax.set_title(f"{name}\\n(missing)", fontsize=10)
ax.axis("off")
plt.tight_layout(); plt.show()
"""),
md("## 3. Progression — early vs late"),
code("""\
checkpoints = {
"p0_wgan": [50, 100, 200],
"p0_vae": [25, 50, 100],
"p0_ddpm": [50, 100, 200],
}
for name, eps in checkpoints.items():
fig, axes = plt.subplots(1, len(eps), figsize=(12, 4))
for ax, e in zip(axes, eps):
p = SAMPLES / name / f"epoch_{e:04d}.png"
if p.exists():
ax.imshow(mpimg.imread(str(p))); ax.set_title(f"epoch {e}", fontsize=9)
else:
ax.text(0.5, 0.5, f"epoch {e}\\n(missing)", ha="center", va="center", transform=ax.transAxes)
ax.axis("off")
fig.suptitle(name, fontsize=11, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md("""\
## 4. Conclusions
Phase 0 ran with the un-aligned raw dataset and bare-bones model variants. Outputs are
**face-shaped blobs** at best — enough to confirm the training loops converge, far from
recognisable.
- **WGAN** — generates colour blobs with rough oval shapes. No facial features resolved.
- **VAE** — heavily blurred mean-images; the MSE+KL objective pulls reconstructions toward
the dataset average.
- **DDPM** — better local texture than the others but still very noisy faces; needs a
stronger backbone and noise schedule (addressed in phase 4).
The clear takeaways feeding into phase 1 onward:
1. Data quality matters — un-aligned raw images caused most of the visual blur.
Phase 1 ablates this directly (alignment, augmentation, dataset mixing).
2. The vanilla VAE objective alone collapses to averages; phase 3 adds perceptual + adversarial
losses to restore detail.
3. The minimal DDPM is the strongest of the three at this scale; phase 4 extends it
(cosine schedule, v-prediction, wider backbone).
"""),
]
write_nb("phase0_analysis", cells)
# ─────────────────────────────────────────────────────────────────────────────
# PHASE 1 — Pipeline ablations (DCGAN proxy)
# ─────────────────────────────────────────────────────────────────────────────
def build_phase1():
runs = {n: load(n) for n in ["p1a_dcgan_64", "p1a_dcgan_128", "p1b_dcgan_full",
"p1b_dcgan_aligned", "p1c_dcgan_hflip",
"p1c_dcgan_full_aug", "p1d_dcgan_combined"]}
best_run = min(runs.items(), key=lambda kv: best_fid(kv[1])[1] or 9e9)
best_name, best_log = best_run
_, best_val = best_fid(best_log)
p0_gan = load("p0_wgan")
cells = [
md(f"""\
# Phase 1 — Pipeline Selection (DCGAN ablations)
Goal: with a cheap proxy (vanilla DCGAN at 64×64, 50 epochs), isolate which **data-pipeline
choices** matter so phases 24 can train on the best preprocessing without burning compute
on dead-end variants.
Four ablations, one factor each:
- **1A** — Resolution: 64×64 vs 128×128
- **1B** — Face crop + alignment: full image vs MTCNN-aligned
- **1C** — Augmentation: H-flip only vs H-flip + rotation + colour jitter
- **1D** — Combined dataset: aligned only vs aligned + raw mixed
**Headline result:** `{best_name}` — **FID@50 = {best_val:.1f}**. The pipeline carried
forward into all later phases is the one this experiment selected: MTCNN-aligned crops,
64×64, full augmentation for GANs (H-flip-only kept as a safer default for VAE/DDPM),
aligned-only (no mixing).
"""),
md(f"""\
### Reference: phase 0 baseline (same family)
The phase 0 WGAN-GP (`p0_wgan`) trained on raw un-aligned images for {len(p0_gan['history']['g_loss']) if p0_gan else 200} epochs
without any pipeline tuning — also collapsed. Phase 1 below uses the same model class
with the data pipeline systematically varied; the architecture limitation is constant.
"""),
code(SHARED_IMPORTS),
md("## 1. Load all experiment logs"),
code("""\
run_names = sorted(p.stem for p in LOGS.glob("p1*.json"))
runs = {name: load_log(name) for name in run_names}
runs = {k: v for k, v in runs.items() if v}
print(f"Loaded {len(runs)} experiments:")
for name in run_names: print(f" {name}")
"""),
code("""\
experiment_groups = {
"1A — Resolution": {"p1a_dcgan_64": "64×64 (raw)",
"p1a_dcgan_128": "128×128 (raw)"},
"1B — Alignment": {"p1b_dcgan_full": "Full image (raw)",
"p1b_dcgan_aligned": "MTCNN-aligned"},
"1C — Augmentation": {"p1c_dcgan_hflip": "H-flip only",
"p1c_dcgan_full_aug": "H-flip + rot + colour"},
"1D — Dataset mixing": {"p1b_dcgan_aligned": "Aligned only",
"p1d_dcgan_combined": "Aligned + raw mixed"},
}
"""),
md("## 2. FID comparison table"),
code("""\
rows = []
for name in run_names:
r = runs[name]; cfg = r["config"]
rows.append({
"Experiment": name,
"Size": f"{cfg.get('image_size')}×{cfg.get('image_size')}",
"Augment": cfg.get("augment", False),
"FID@25": get_fid(r, 25),
"FID@50": get_fid(r, 50),
"G loss (ep50)": r["history"]["g_loss"][-1],
"D loss (ep50)": r["history"]["d_loss"][-1],
})
df = pd.DataFrame(rows).sort_values("FID@50")
df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}",
"G loss (ep50)": "{:.3f}", "D loss (ep50)": "{:.3f}"})
"""),
code("""\
fig, ax = plt.subplots(figsize=(10, 5))
labels = df["Experiment"].values
fid25 = df["FID@25"].values
fid50 = df["FID@50"].values
x = np.arange(len(labels)); w = 0.35
ax.bar(x - w/2, fid25, w, label="FID @ 25", color="#5B8DB8", alpha=0.85)
ax.bar(x + w/2, fid50, w, label="FID @ 50", color="#E8705A", alpha=0.85)
ax.set_ylabel("FID (lower is better)")
ax.set_title("Phase 1 — FID across all pipeline ablations")
ax.set_xticks(x); ax.set_xticklabels(labels, rotation=30, ha="right")
ax.legend(); plt.tight_layout(); plt.show()
"""),
md("## 3. Per-group comparisons"),
code("""\
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()
colors = ["#5B8DB8", "#E8705A"]
for idx, (group_title, experiments) in enumerate(experiment_groups.items()):
ax = axes[idx]
for i, (run_name, label) in enumerate(experiments.items()):
epochs, fid_vals = fid_series(runs[run_name])
f50 = get_fid(runs[run_name], 50)
ax.plot(epochs, fid_vals, "o-",
label=f"{label} (FID@50={f50:.1f})",
color=colors[i], linewidth=2, markersize=8)
ax.set_xlabel("Epoch"); ax.set_ylabel("FID")
ax.set_title(group_title); ax.legend(fontsize=9)
ax.set_xlim(20, 55)
fig.suptitle("FID per ablation group", fontsize=14, fontweight="bold", y=1.01)
plt.tight_layout(); plt.show()
"""),
md("""\
## 4. Data pipeline visualisation
What each ablation actually changes — shown on the input data the model sees.
"""),
code("""\
import random
from PIL import Image
import torchvision.transforms as T
random.seed(0)
RAW = Path("../../data/wiki")
ALIGNED = Path("../../cropped/generator/wiki")
def sample_paths(root, k=4):
shards = [d for d in root.iterdir() if d.is_dir()]
files = []
for s in random.sample(shards, min(8, len(shards))):
files += list(s.glob("*.jpg"))[:50]
return random.sample(files, min(k, len(files)))
def matched_pairs(k=4):
shards = sorted(d.name for d in ALIGNED.iterdir() if d.is_dir() and (RAW / d.name).is_dir())
pairs = []
for shard in random.sample(shards, min(8, len(shards))):
for ali in (ALIGNED / shard).glob("*.jpg"):
raw = RAW / shard / ali.name
if raw.exists():
pairs.append((raw, ali))
if len(pairs) >= 50: break
if len(pairs) >= 50: break
return random.sample(pairs, min(k, len(pairs)))
def show(ax, img, title=None):
ax.imshow(img); ax.axis("off")
if title: ax.set_title(title, fontsize=9)
"""),
md("### 4A — Resolution\n\nSame raw image at 64×64 and 128×128. 4× more pixels at 128 — too much for a vanilla DCGAN at 50 epochs."),
code("""\
paths = sample_paths(RAW, k=4)
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for col, p in enumerate(paths):
img = Image.open(p).convert("RGB")
show(axes[0][col], T.CenterCrop(min(img.size))(img).resize((64, 64)), "64×64")
show(axes[1][col], T.CenterCrop(min(img.size))(img).resize((128, 128)), "128×128")
fig.suptitle("1A — Resolution: same image at two scales", fontsize=12, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md("### 4B — Alignment\n\nRaw vs MTCNN-aligned 64×64 crops. Alignment removes pose/scale/translation variance so the model only has to learn identity."),
code("""\
pairs = matched_pairs(k=4)
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for col, (raw_p, ali_p) in enumerate(pairs):
raw_img = Image.open(raw_p).convert("RGB")
show(axes[0][col], T.CenterCrop(min(raw_img.size))(raw_img).resize((128, 128)), "raw")
show(axes[1][col], Image.open(ali_p).convert("RGB"), "MTCNN-aligned")
fig.suptitle("1B — Alignment: same source image, raw vs MTCNN-aligned", fontsize=12, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md("### 4C — Augmentation\n\nOne aligned image, three variants: original, hflip, and full augmentation (hflip + rotation ±5° + brightness/contrast/saturation jitter)."),
code("""\
src = sample_paths(ALIGNED, k=1)
if src:
img = Image.open(src[0]).convert("RGB").resize((128, 128))
none = T.Compose([])
hflip = T.Compose([T.RandomHorizontalFlip(p=1.0)])
full = T.Compose([
T.RandomHorizontalFlip(p=1.0),
T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),
])
fig, axes = plt.subplots(1, 6, figsize=(15, 3))
show(axes[0], none(img), "original")
show(axes[1], hflip(img), "hflip")
for i, ax in enumerate(axes[2:]):
show(ax, full(img), f"full aug #{i+1}")
fig.suptitle("1C — Augmentation: original vs hflip vs full augmentation", fontsize=12, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md("### 4D — Dataset mixing\n\nMixing raw + aligned roughly doubles within-batch variance, splitting generator capacity across two distributions."),
code("""\
pairs = matched_pairs(k=4)
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for col, (raw_p, ali_p) in enumerate(pairs):
raw_img = Image.open(raw_p).convert("RGB")
show(axes[0][col], T.CenterCrop(min(raw_img.size))(raw_img).resize((128, 128)), "raw (mixed in)")
show(axes[1][col], Image.open(ali_p).convert("RGB"), "aligned")
fig.suptitle("1D — Mixing: same source image, raw vs aligned", fontsize=12, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md(f"""\
## 5. Conclusions
Lowest FID of phase 1: **`{best_name}` at FID@50 = {best_val:.1f}** — but every model
in this phase collapsed. The numbers below are pipeline-ranking signal, not model-quality
claims.
| Ablation | Winner | FID@50 | Loser FID@50 | Δ |
|---|---|---|---|---|
| 1A — Resolution | 64×64 | {get_fid(runs['p1a_dcgan_64'], 50):.1f} | {get_fid(runs['p1a_dcgan_128'], 50):.1f} (128×128) | DCGAN lacks capacity at 128×128 in 50 ep |
| 1B — Alignment | MTCNN-aligned | {get_fid(runs['p1b_dcgan_aligned'], 50):.1f} | {get_fid(runs['p1b_dcgan_full'], 50):.1f} (full) | **largest single lever** |
| 1C — Augmentation | H-flip + rot + colour | {get_fid(runs['p1c_dcgan_full_aug'], 50):.1f} | {get_fid(runs['p1c_dcgan_hflip'], 50):.1f} (H-flip) | moderate gain; per-family validation needed |
| 1D — Dataset mixing | Aligned only | {get_fid(runs['p1b_dcgan_aligned'], 50):.1f} | {get_fid(runs['p1d_dcgan_combined'], 50):.1f} (mixed) | mixing raw+aligned doubles variance |
**Pipeline carried forward to phases 25:** MTCNN-aligned crops, 64×64, full augmentation
for GANs (H-flip-only kept as a safer default for VAE/DDPM), aligned-only (no mixing).
The pipeline question is *answered*. The model-quality question is **not** — every
phase 1 run collapsed. Phase 2 fixes that by upgrading the architecture (Wasserstein-GP,
spectral norm, GroupNorm, self-attention) on the now-locked pipeline.
"""),
]
write_nb("phase1_analysis", cells)
# ─────────────────────────────────────────────────────────────────────────────
# PHASE 2 — GAN architecture/objective evolution
# ─────────────────────────────────────────────────────────────────────────────
def build_phase2():
runs = {n: load(n) for n in ["p2_1_dcgan", "p2_2_wgan", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"]}
best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9)
_, best_val = best_fid(runs[best_name])
p0_gan = load("p0_wgan")
cells = [
md(f"""\
# Phase 2 — GAN Evolution
With the data pipeline locked (phase 1), iterate on the **GAN itself**: objective,
normalisation, attention, resolution. All four runs use the aligned-64 pipeline; only
the model and training recipe change.
| Run | Step |
|-------------------------|------------------------------------------------|
| `p2_1_dcgan` | Phase 1 best, retrained 100 epochs |
| `p2_2_wgan` | BCE → Wasserstein-GP (n_critic=2, β=(0,0.9)) |
| `p2_3_wgan_sn_attn` | + spectral norm + GroupNorm + self-attention |
| `p2_4_wgan_sn_attn_128` | same as 2.3 but at 128×128 |
**Headline result:** `{best_name}` — **best FID = {best_val:.1f}** at 100 epochs.
"""),
md("""\
> ### ⚠ FID is not comparable across phases
>
> Phase 1's "best" was FID 33 (`p1c_dcgan_full_aug`). Phase 2's "best" is FID 110.
> **This is not a regression.** The two numbers were computed under different
> protocols:
>
> - Phase 1 used a quick proxy FID for fast pipeline ablation, with a smaller
> real-image reference set, on the un-augmented validation split.
> - Phase 2 uses the project's standard FID protocol — 5000 aligned 64×64 real
> images from the matched augmentation pipeline (`fid_n_real: 5000`).
>
> Within phase 2 the deltas are meaningful (2.2 → 2.3 = **311 FID** is a real
> architecture jump). Don't compare phase 1 vs phase 2 numbers absolutely —
> only compare within a phase, or against phase 5 which uses the same protocol.
"""),
md(f"""\
### Reference: phase 0 baseline (same family)
`p0_wgan` was the un-aligned, no-augmentation, basic-architecture WGAN-GP — face blobs
with no recognisable features (no FID logged). Phase 2 below shows what happens once
the pipeline is fixed and the model is allowed to evolve.
"""),
code(SHARED_IMPORTS),
md("## 1. Load experiment logs"),
code("""\
run_names = ["p2_1_dcgan", "p2_2_wgan", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"]
run_labels = {
"p2_1_dcgan": "2.1 DCGAN (BCE)",
"p2_2_wgan": "2.2 WGAN-GP",
"p2_3_wgan_sn_attn": "2.3 + SN + Attn",
"p2_4_wgan_sn_attn_128": "2.4 + 128×128",
}
runs = {name: load_log(name) for name in run_names}
runs = {k: v for k, v in runs.items() if v}
for n in run_names:
if n in runs: print(f" {n}: {len(runs[n]['history']['g_loss'])} epochs")
else: print(f" {n}: MISSING")
"""),
md("## 2. FID comparison table"),
code("""\
rows = []
for name in run_names:
if name not in runs: continue
r = runs[name]
epochs, fid_vals = fid_series(r)
best = min(fid_vals) if fid_vals else None
rows.append({
"Run": run_labels[name],
"FID@25": get_fid(r, 25),
"FID@50": get_fid(r, 50),
"FID@100": get_fid(r, 100),
"Best FID": best,
"Train (min)": (r['history'].get('train_time_s') or 0) / 60,
})
df = pd.DataFrame(rows).sort_values("Best FID")
df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}", "FID@100": "{:.1f}",
"Best FID": "{:.1f}", "Train (min)": "{:.1f}"})
"""),
md("## 3. FID curves — evolution"),
code("""\
fig, ax = plt.subplots(figsize=(10, 5))
cmap = plt.cm.viridis
for i, name in enumerate(run_names):
if name not in runs: continue
epochs, fid_vals = fid_series(runs[name])
ax.plot(epochs, fid_vals, "o-", label=run_labels[name],
color=cmap(i / len(run_names)), linewidth=2, markersize=7)
ax.set_xlabel("Epoch"); ax.set_ylabel("FID")
ax.set_title("Phase 2 — FID curves")
ax.legend(); plt.tight_layout(); plt.show()
"""),
md("## 4. Training dynamics"),
code("""\
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
cmap = plt.cm.viridis
for i, name in enumerate(run_names):
if name not in runs: continue
h = runs[name]["history"]
color = cmap(i / len(run_names))
epochs = range(1, len(h["g_loss"]) + 1)
axes[0].plot(epochs, h["g_loss"], color=color, label=run_labels[name], linewidth=1.2)
if "w_dist" in h:
axes[1].plot(epochs, h["w_dist"], color=color, label=run_labels[name], linewidth=1.2)
elif "d_loss" in h:
axes[1].plot(epochs, h["d_loss"], color=color, label=run_labels[name], linewidth=1.2, linestyle="--")
axes[0].set_title("Generator loss"); axes[0].set_xlabel("Epoch"); axes[0].legend(fontsize=8)
axes[1].set_title("Wasserstein distance / D loss"); axes[1].set_xlabel("Epoch"); axes[1].legend(fontsize=8)
plt.tight_layout(); plt.show()
"""),
md("""\
## 5. Sample grids — epoch 100
- **2.1 and 2.2 collapsed** — vanilla DCGAN/WGAN-GP at this scale is still too weak,
same failure mode as phase 1. Their FIDs (>400) confirm this; the grids reflect it.
- **2.3 is the breakthrough** — spectral norm + GroupNorm + self-attention escape
mode collapse and produce diverse, recognisable faces.
- **2.4 (128×128) regresses** — same architecture at higher resolution at fixed
compute under-trains.
"""),
code("""\
fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))
for ax, name in zip(axes, run_names):
img_path = SAMPLES / name / "epoch_0100.png"
if img_path.exists():
ax.imshow(mpimg.imread(str(img_path)))
f100 = get_fid(runs.get(name, {}), 100) if name in runs else None
ax.set_title(f"{run_labels[name]}\\nFID@100={f100:.1f}" if f100 else run_labels[name], fontsize=9)
ax.axis("off")
plt.tight_layout(); plt.show()
"""),
md("## 6. Progression — epoch 10 → 50 → 100"),
code("""\
check_epochs = [10, 50, 100]
for name in run_names:
if name not in runs: continue
fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))
for ax, e in zip(axes, check_epochs):
p = SAMPLES / name / f"epoch_{e:04d}.png"
if p.exists():
ax.imshow(mpimg.imread(str(p)))
f = get_fid(runs[name], e)
ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
ax.axis("off")
fig.suptitle(run_labels[name], fontsize=11, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md("## 7. Pairwise comparison — what each step bought us"),
code("""\
transitions = [
("2.1 → 2.2: BCE → Wasserstein", "p2_1_dcgan", "p2_2_wgan"),
("2.2 → 2.3: + SN + GroupNorm + Attn", "p2_2_wgan", "p2_3_wgan_sn_attn"),
("2.3 → 2.4: 64 → 128 resolution", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"),
]
for title, a, b in transitions:
if a not in runs or b not in runs: continue
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
for ax, name in zip(axes, [a, b]):
img_path = SAMPLES / name / "epoch_0100.png"
if img_path.exists():
ax.imshow(mpimg.imread(str(img_path)))
f = get_fid(runs[name], 100)
ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=10)
ax.axis("off")
fig.suptitle(title, fontsize=12, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md(f"""\
## 8. Conclusions
| Step | Run | Best FID | Δ vs prev |
|---|---|---|---|
| 2.1 DCGAN baseline | `p2_1_dcgan` | {best_fid(runs['p2_1_dcgan'])[1]:.1f} | — |
| 2.2 + Wasserstein-GP | `p2_2_wgan` | {best_fid(runs['p2_2_wgan'])[1]:.1f} | {best_fid(runs['p2_2_wgan'])[1] - best_fid(runs['p2_1_dcgan'])[1]:+.1f} |
| 2.3 + SN + Attn | `p2_3_wgan_sn_attn` | {best_fid(runs['p2_3_wgan_sn_attn'])[1]:.1f} | {best_fid(runs['p2_3_wgan_sn_attn'])[1] - best_fid(runs['p2_2_wgan'])[1]:+.1f} |
| 2.4 + 128×128 | `p2_4_wgan_sn_attn_128` | {best_fid(runs['p2_4_wgan_sn_attn_128'])[1]:.1f} | {best_fid(runs['p2_4_wgan_sn_attn_128'])[1] - best_fid(runs['p2_3_wgan_sn_attn'])[1]:+.1f} |
2.1 and 2.2 are extensions of the phase 1 collapse — vanilla DCGAN/WGAN-GP doesn't escape
mean-collapse just by changing the loss. The dramatic step is **2.2 → 2.3**: spectral norm
+ GroupNorm + self-attention break out of collapse and start producing recognisable faces
(FID drops from 421 to 110). 2.4 (resolution bump) made things worse at fixed compute, so
phase 5 retains 64×64. **Selected GAN architecture for phase 5: WGAN-GP with spectral norm
+ GroupNorm + self-attention at 64×64.**
"""),
]
write_nb("phase2_analysis", cells)
# ─────────────────────────────────────────────────────────────────────────────
# PHASE 3 — VAE evolution
# ─────────────────────────────────────────────────────────────────────────────
def build_phase3():
runs = {n: load(n) for n in ["p3_1_vae", "p3_2_vae_perceptual", "p3_3_vae_patchgan"]}
best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9)
_, best_val = best_fid(runs[best_name])
cells = [
md(f"""\
# Phase 3 — VAE Evolution
Standard VAEs collapse to mean-blur. Phase 3 stacks losses on top of the basic
MSE+KL objective to recover detail.
| Run | Step |
|----------------------|------------------------------------------------------|
| `p3_1_vae` | MSE + KL only |
| `p3_2_vae_perceptual`| + VGG16 perceptual loss (`lambda_perceptual=0.1`) |
| `p3_3_vae_patchgan` | + PatchGAN adversarial loss (`lambda_adversarial=0.01`) |
**Headline result:** `{best_name}` — **best FID = {best_val:.1f}** (prior samples) at 100 epochs.
"""),
md("""\
### Reference: phase 0 baseline (same family)
`p0_vae` was MSE+KL on raw un-aligned data. Prior samples were heavily blurred
mean-faces — the textbook VAE failure mode. Phase 3 keeps the encoder/decoder
architecture and pipeline fixed and shows that adding perceptual + adversarial
terms is what actually moves the needle.
"""),
code(SHARED_IMPORTS),
md("## 1. Load experiment logs"),
code("""\
run_names = ["p3_1_vae", "p3_2_vae_perceptual", "p3_3_vae_patchgan"]
run_labels = {
"p3_1_vae": "3.1 MSE + KL",
"p3_2_vae_perceptual": "3.2 + Perceptual",
"p3_3_vae_patchgan": "3.3 + PatchGAN",
}
runs = {name: load_log(name) for name in run_names}
runs = {k: v for k, v in runs.items() if v}
for n in run_names: print(f" {n}: {'OK' if n in runs else 'MISSING'}")
"""),
md("## 2. FID comparison table (prior samples)"),
code("""\
rows = []
for name in run_names:
if name not in runs: continue
r = runs[name]; h = r["history"]
_, fid_vals = fid_series(r)
rows.append({
"Run": run_labels[name],
"FID@50": get_fid(r, 50),
"FID@100": get_fid(r, 100),
"Best FID": min(fid_vals) if fid_vals else None,
"Recon@100": h["recon_loss"][-1] if h.get("recon_loss") else None,
"KL@100": h["kl_loss"][-1] if h.get("kl_loss") else None,
"Train (min)": (h.get("train_time_s") or 0) / 60,
})
df = pd.DataFrame(rows).sort_values("Best FID")
df.style.format({"FID@50": "{:.1f}", "FID@100": "{:.1f}", "Best FID": "{:.1f}",
"Recon@100": "{:.4f}", "KL@100": "{:.2f}", "Train (min)": "{:.1f}"})
"""),
md("## 3. FID curves — evolution"),
code("""\
fig, ax = plt.subplots(figsize=(10, 5))
cmap = plt.cm.plasma
for i, name in enumerate(run_names):
if name not in runs: continue
epochs, fid_vals = fid_series(runs[name])
ax.plot(epochs, fid_vals, "o-", label=run_labels[name],
color=cmap(i / len(run_names)), linewidth=2, markersize=7)
ax.set_xlabel("Epoch"); ax.set_ylabel("FID (prior samples)")
ax.set_title("Phase 3 — FID curves"); ax.legend()
plt.tight_layout(); plt.show()
"""),
md("## 4. Training loss components"),
code("""\
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
cmap = plt.cm.plasma
for i, name in enumerate(run_names):
if name not in runs: continue
h = runs[name]["history"]; color = cmap(i / len(run_names))
epochs = range(1, len(h["recon_loss"]) + 1)
axes[0].plot(epochs, h["recon_loss"], color=color, label=run_labels[name])
axes[1].plot(epochs, h["kl_loss"], color=color, label=run_labels[name])
if "perc_loss" in h and any(h["perc_loss"]):
axes[2].plot(epochs, h["perc_loss"], color=color, label=run_labels[name])
axes[0].set_title("Reconstruction (MSE)"); axes[0].set_xlabel("Epoch"); axes[0].legend(fontsize=8)
axes[1].set_title("KL divergence"); axes[1].set_xlabel("Epoch"); axes[1].legend(fontsize=8)
axes[2].set_title("Perceptual (VGG16)"); axes[2].set_xlabel("Epoch"); axes[2].legend(fontsize=8)
plt.tight_layout(); plt.show()
"""),
md("## 5. Prior samples — epoch 100"),
code("""\
fig, axes = plt.subplots(1, 3, figsize=(13, 4.5))
for ax, name in zip(axes, run_names):
img_path = SAMPLES / name / "epoch_0100.png"
if img_path.exists():
ax.imshow(mpimg.imread(str(img_path)))
f = get_fid(runs.get(name, {}), 100) if name in runs else None
ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=10)
ax.axis("off")
fig.suptitle("Prior samples (decoded from N(0, I))", fontsize=12, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md("## 6. Reconstructions — epoch 100"),
code("""\
fig, axes = plt.subplots(1, 3, figsize=(13, 4.5))
for ax, name in zip(axes, run_names):
img_path = SAMPLES / name / "epoch_0100_recon.png"
if img_path.exists():
ax.imshow(mpimg.imread(str(img_path)))
ax.set_title(run_labels[name], fontsize=10)
ax.axis("off")
fig.suptitle("Reconstructions (real / decoded interleaved)", fontsize=12, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md("## 7. Progression — epoch 10 → 50 → 100 (prior samples)"),
code("""\
check_epochs = [10, 50, 100]
for name in run_names:
if name not in runs: continue
fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))
for ax, e in zip(axes, check_epochs):
p = SAMPLES / name / f"epoch_{e:04d}.png"
if p.exists():
ax.imshow(mpimg.imread(str(p)))
f = get_fid(runs[name], e)
ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
ax.axis("off")
fig.suptitle(run_labels[name], fontsize=11, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md(f"""\
## 8. Conclusions
| Step | Run | Best FID | Δ vs prev |
|---|---|---|---|
| 3.1 MSE + KL | `p3_1_vae` | {best_fid(runs['p3_1_vae'])[1]:.1f} | — |
| 3.2 + Perceptual | `p3_2_vae_perceptual` | {best_fid(runs['p3_2_vae_perceptual'])[1]:.1f} | {best_fid(runs['p3_2_vae_perceptual'])[1] - best_fid(runs['p3_1_vae'])[1]:+.1f} |
| 3.3 + PatchGAN | `p3_3_vae_patchgan` | {best_fid(runs['p3_3_vae_patchgan'])[1]:.1f} | {best_fid(runs['p3_3_vae_patchgan'])[1] - best_fid(runs['p3_2_vae_perceptual'])[1]:+.1f} |
**Both additions help, monotonically.** Perceptual loss recovers high-frequency detail
(sharper hair/skin texture); PatchGAN adversarial loss further pushes the prior samples
toward the data manifold. The selected VAE recipe for phase 5 is
**MSE + 0.25·KL + 0.1·VGG-perceptual + 0.01·PatchGAN-adversarial**.
"""),
]
write_nb("phase3_analysis", cells)
# ─────────────────────────────────────────────────────────────────────────────
# PHASE 4 — DDPM evolution
# ─────────────────────────────────────────────────────────────────────────────
def build_phase4():
runs = {n: load(n) for n in ["p4_1_ddpm_linear", "p4_2_ddpm_cosine",
"p4_3_ddpm_vpred", "p4_4_ddpm_wider"]}
best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9)
_, best_val = best_fid(runs[best_name])
cells = [
md(f"""\
# Phase 4 — DDPM Evolution
Iterate on the diffusion side: noise schedule, prediction target, model width.
Sampling everywhere uses DDIM with 50 steps (matches training preview); FID
uses DDIM-100 against 5000 real images.
| Run | Step |
|--------------------|--------------------------------------------------------|
| `p4_1_ddpm_linear` | Linear β-schedule, ε-prediction (DDPM baseline) |
| `p4_2_ddpm_cosine` | Cosine β-schedule (Nichol & Dhariwal) |
| `p4_3_ddpm_vpred` | + v-prediction target (Salimans & Ho) |
| `p4_4_ddpm_wider` | + base_ch 128 → 192, num_res_blocks 2, attn at 32/16/8 |
**Headline result:** `{best_name}` — **best FID = {best_val:.1f}** at 100 epochs.
"""),
md("""\
### Reference: phase 0 baseline (same family)
`p0_ddpm` was a vanilla DDPM (linear schedule, ε-prediction, base_ch=128) on raw
un-aligned data. Outputs were noisy face-shaped textures. Phase 4 fixes the
pipeline (aligned 64) and walks through the standard set of post-2020 DDPM
improvements one at a time.
"""),
code(SHARED_IMPORTS),
md("## 1. Load experiment logs"),
code("""\
run_names = ["p4_1_ddpm_linear", "p4_2_ddpm_cosine", "p4_3_ddpm_vpred", "p4_4_ddpm_wider"]
run_labels = {
"p4_1_ddpm_linear": "4.1 linear / ε",
"p4_2_ddpm_cosine": "4.2 cosine / ε",
"p4_3_ddpm_vpred": "4.3 cosine / v",
"p4_4_ddpm_wider": "4.4 wider net",
}
runs = {n: load_log(n) for n in run_names}
runs = {k: v for k, v in runs.items() if v}
for n in run_names: print(f" {n}: {'OK' if n in runs else 'MISSING'}")
"""),
md("## 2. FID comparison table"),
code("""\
rows = []
for name in run_names:
if name not in runs: continue
r = runs[name]; _, fid_vals = fid_series(r)
rows.append({
"Run": run_labels[name],
"FID@25": get_fid(r, 25),
"FID@50": get_fid(r, 50),
"FID@100": get_fid(r, 100),
"Best FID": min(fid_vals) if fid_vals else None,
"Loss@100": r["history"]["loss"][-1],
"Train (min)": (r['history'].get('train_time_s') or 0) / 60,
})
df = pd.DataFrame(rows).sort_values("Best FID")
df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}", "FID@100": "{:.1f}",
"Best FID": "{:.1f}", "Loss@100": "{:.4f}", "Train (min)": "{:.1f}"})
"""),
md("## 3. FID curves — evolution"),
code("""\
fig, ax = plt.subplots(figsize=(10, 5))
cmap = plt.cm.cividis
for i, name in enumerate(run_names):
if name not in runs: continue
epochs, fid_vals = fid_series(runs[name])
ax.plot(epochs, fid_vals, "o-", label=run_labels[name],
color=cmap(i / len(run_names)), linewidth=2, markersize=7)
ax.set_xlabel("Epoch"); ax.set_ylabel("FID (DDIM-100)")
ax.set_title("Phase 4 — FID curves"); ax.legend()
plt.tight_layout(); plt.show()
"""),
md("## 4. Training loss"),
code("""\
fig, ax = plt.subplots(figsize=(10, 4))
cmap = plt.cm.cividis
for i, name in enumerate(run_names):
if name not in runs: continue
h = runs[name]["history"]
epochs = range(1, len(h["loss"]) + 1)
ax.plot(epochs, h["loss"], color=cmap(i / len(run_names)), label=run_labels[name], linewidth=1.3)
ax.set_xlabel("Epoch"); ax.set_ylabel("MSE on prediction target")
ax.set_title("Loss (note: ε-MSE and v-MSE are not directly comparable)")
ax.legend(); plt.tight_layout(); plt.show()
"""),
md("## 5. Sample grids — epoch 100"),
code("""\
fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))
for ax, name in zip(axes, run_names):
img_path = SAMPLES / name / "epoch_0100.png"
if img_path.exists():
ax.imshow(mpimg.imread(str(img_path)))
f = get_fid(runs.get(name, {}), 100) if name in runs else None
ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=9)
ax.axis("off")
plt.tight_layout(); plt.show()
"""),
md("## 6. Progression — epoch 10 → 50 → 100"),
code("""\
check_epochs = [10, 50, 100]
for name in run_names:
if name not in runs: continue
fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))
for ax, e in zip(axes, check_epochs):
p = SAMPLES / name / f"epoch_{e:04d}.png"
if p.exists():
ax.imshow(mpimg.imread(str(p)))
f = get_fid(runs[name], e)
ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
ax.axis("off")
fig.suptitle(run_labels[name], fontsize=11, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md("## 7. Noise schedule visualisation\n\nWhy cosine helps: linear allocates many timesteps near t=T where the image is already near-pure noise, wasting model capacity."),
code("""\
import math
T = 1000; t = np.arange(T)
betas_lin = np.linspace(1e-4, 0.02, T)
ab_lin = np.cumprod(1 - betas_lin)
s = 0.008
f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
f = f / f[0]
betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)
ab_cos = np.cumprod(1 - betas_cos)
fig, axes = plt.subplots(1, 2, figsize=(13, 4))
axes[0].plot(t, ab_lin, label="linear", color="#5B8DB8", linewidth=2)
axes[0].plot(t[:len(ab_cos)], ab_cos, label="cosine", color="#E8705A", linewidth=2)
axes[0].set_xlabel("Timestep t"); axes[0].set_ylabel("ᾱ_t (signal fraction)")
axes[0].set_title("Cumulative signal preservation"); axes[0].legend()
axes[1].plot(betas_lin, label="linear β", color="#5B8DB8", linewidth=2)
axes[1].plot(betas_cos, label="cosine β", color="#E8705A", linewidth=2)
axes[1].set_xlabel("Timestep t"); axes[1].set_ylabel("β_t"); axes[1].set_title("β-schedule"); axes[1].legend()
plt.tight_layout(); plt.show()
"""),
md(f"""\
## 8. Conclusions
| Step | Run | Best FID | Δ vs prev |
|---|---|---|---|
| 4.1 linear / ε | `p4_1_ddpm_linear` | {best_fid(runs['p4_1_ddpm_linear'])[1]:.1f} | — |
| 4.2 cosine / ε | `p4_2_ddpm_cosine` | {best_fid(runs['p4_2_ddpm_cosine'])[1]:.1f} | {best_fid(runs['p4_2_ddpm_cosine'])[1] - best_fid(runs['p4_1_ddpm_linear'])[1]:+.1f} |
| 4.3 cosine / v | `p4_3_ddpm_vpred` | {best_fid(runs['p4_3_ddpm_vpred'])[1]:.1f} | {best_fid(runs['p4_3_ddpm_vpred'])[1] - best_fid(runs['p4_2_ddpm_cosine'])[1]:+.1f} |
| 4.4 wider net | `p4_4_ddpm_wider` | {best_fid(runs['p4_4_ddpm_wider'])[1]:.1f} | {best_fid(runs['p4_4_ddpm_wider'])[1] - best_fid(runs['p4_3_ddpm_vpred'])[1]:+.1f} |
The big jump is **4.2 → 4.3** (v-prediction): a cleaner training target for the same
model. The wider network in 4.4 buys a further drop at ~2× train time. Selected DDPM
recipe for phase 5: **cosine schedule, v-prediction, base_ch=192, attn at {{32,16,8}}**.
"""),
]
write_nb("phase4_analysis", cells)
# ─────────────────────────────────────────────────────────────────────────────
# PHASE 5 — Cross-family
# ─────────────────────────────────────────────────────────────────────────────
def build_phase5():
p5 = {n: load(n) for n in ["p5_gan", "p5_vae", "p5_ddpm"]}
rows = []
for n, log in p5.items():
e, v = best_fid(log)
rows.append((n, v, e, time_min(log)))
rows.sort(key=lambda r: r[1] or 9e9)
headline = ", ".join(f"{n}={v:.1f}" for n, v, _, _ in rows)
cells = [
md(f"""\
# Phase 5 — Cross-Family Comparison
Take the best recipe from each family (phases 2/3/4) and train each on identical data
to the same epoch budget. Per-family iteration analyses live in their own notebooks
(phase 2 for GAN, phase 3 for VAE, phase 4 for DDPM); this notebook is **only** about
comparing the three families head-to-head.
**Headline FIDs (best across training):** {headline}.
"""),
code(SHARED_IMPORTS + """
FAMILIES = {
"GAN": {"p5": "p5_gan", "color": "#5B8DB8", "label": "WGAN-GP + SN + Attn"},
"VAE": {"p5": "p5_vae", "color": "#E8B85A", "label": "VAE + Perceptual + PatchGAN"},
"DDPM": {"p5": "p5_ddpm", "color": "#E8705A", "label": "DDPM cosine v-pred wider"},
}
logs_p5 = {fam: load_log(info["p5"]) for fam, info in FAMILIES.items()}
"""),
md("## 1. Quantitative summary"),
code("""\
rows = []
for fam, info in FAMILIES.items():
log = logs_p5[fam]
if log is None: continue
h = log["history"]; cfg = log["config"]
_, fid_vals = fid_series(log)
rows.append({
"Family": fam,
"Architecture": info["label"],
"Resolution": f"{cfg.get('image_size')}×{cfg.get('image_size')}",
"Epochs": len(h.get("loss") or h.get("g_loss") or h.get("recon_loss") or []),
"Best FID": min(fid_vals) if fid_vals else None,
"Last FID": fid_vals[-1] if fid_vals else None,
"Train (min)": (h.get("train_time_s") or 0) / 60,
})
df = pd.DataFrame(rows).sort_values("Best FID")
df.style.format({"Best FID": "{:.1f}", "Last FID": "{:.1f}", "Train (min)": "{:.1f}"})
"""),
md("## 2. FID curves — all three families"),
code("""\
fig, ax = plt.subplots(figsize=(10, 5))
for fam, info in FAMILIES.items():
log = logs_p5[fam]
if log is None: continue
epochs, vals = fid_series(log)
ax.plot(epochs, vals, "o-", color=info["color"], label=f"{fam} ({info['label']})",
linewidth=2, markersize=7)
ax.set_xlabel("Epoch"); ax.set_ylabel("FID")
ax.set_title("Phase 5 — FID across families")
ax.legend(); plt.tight_layout(); plt.show()
"""),
md("## 3. Best-epoch sample grids"),
code("""\
fig, axes = plt.subplots(1, 3, figsize=(16, 6))
for ax, (fam, info) in zip(axes, FAMILIES.items()):
log = logs_p5[fam]
samples_dir = SAMPLES / info["p5"]
pngs = sorted(samples_dir.glob("epoch_*.png"))
pngs = [p for p in pngs if "_recon" not in p.stem]
if not pngs:
ax.set_title(f"{fam} (no samples)"); ax.axis("off"); continue
img_path = pngs[-1] # last preview = closest to final_ema
ax.imshow(mpimg.imread(str(img_path)))
e, v = (None, None)
if log:
_, fid_vals = fid_series(log)
v = min(fid_vals) if fid_vals else None
ax.set_title(f"{fam}{info['label']}\\n{img_path.stem}" + (f" best FID={v:.1f}" if v else ""), fontsize=10)
ax.axis("off")
fig.suptitle("Final samples from each family", fontsize=13, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md("## 4. Training progression — early → late"),
code("""\
for fam, info in FAMILIES.items():
log = logs_p5[fam]
samples_dir = SAMPLES / info["p5"]
pngs = sorted(p for p in samples_dir.glob("epoch_*.png") if "_recon" not in p.stem)
if not pngs: continue
# Pick ~4 evenly spaced previews
n_pick = min(4, len(pngs))
picks = [pngs[i * (len(pngs) - 1) // (n_pick - 1)] for i in range(n_pick)] if n_pick > 1 else pngs
fig, axes = plt.subplots(1, len(picks), figsize=(4 * len(picks), 4))
if len(picks) == 1: axes = [axes]
for ax, p in zip(axes, picks):
ax.imshow(mpimg.imread(str(p)))
ep = int(p.stem.split("_")[1])
f = get_fid(log, ep) if log else None
ax.set_title(f"epoch {ep}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
ax.axis("off")
fig.suptitle(f"{fam}{info['label']}", fontsize=11, fontweight="bold")
plt.tight_layout(); plt.show()
"""),
md("## 5. Per-family training loss"),
code("""\
fig, axes = plt.subplots(1, 3, figsize=(18, 4))
for ax, (fam, info) in zip(axes, FAMILIES.items()):
log = logs_p5[fam]
if not log: ax.set_title(f"{fam} (missing)"); ax.axis("off"); continue
h = log["history"]; c = info["color"]
if fam == "GAN":
ax.plot(h["g_loss"], label="G loss", color=c, linewidth=1.2)
if "w_dist" in h:
ax.plot(h["w_dist"], label="W-dist", color=c, linewidth=1.2, linestyle="--")
ax.set_ylabel("Loss / W-distance")
elif fam == "VAE":
ax.plot(h["recon_loss"], label="recon", color=c, linewidth=1.2)
ax2 = ax.twinx()
ax2.plot(h["kl_loss"], label="KL", color=c, alpha=0.5, linestyle="--")
ax2.set_ylabel("KL")
ax.set_ylabel("Recon")
else: # DDPM
ax.plot(h["loss"], color=c, linewidth=1.2)
ax.set_ylabel("MSE on v-prediction")
ax.set_xlabel("Epoch"); ax.set_title(f"{fam}"); ax.legend(loc="upper right", fontsize=8)
plt.tight_layout(); plt.show()
"""),
md("""\
## 6. Latent interpolation — GAN and VAE
Smooth interpolation between two latent codes reveals whether the generator has
learned a continuous manifold rather than a sparse memorisation. DDPM has no
encoder, so this section is GAN/VAE only.
**Note on checkpoint loading:** the cell below uses the same checkpoint priority
as `tools/sampling.py` — `final_ema` first, then `best_ema` as fallback. This is
important: `best_ema` is the lowest-FID snapshot, which for slowly-converging
runs (e.g. DDPM) can be saved while the EMA shadow is still close to random init,
producing pure noise.
"""),
code("""\
import sys
sys.path.insert(0, "..")
import torch
from src.models import get_model
from src.utils import load_config
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_ema_model(run_name, config_name):
cfg = load_config(str(Path("../configs/phase5") / config_name))
model_obj, _ = get_model(cfg)
model = model_obj[0] if isinstance(model_obj, tuple) else model_obj
# final_ema first — see note above
for fname in [f"{run_name}_final_ema.pt", f"{run_name}_best_ema.pt"]:
p = Path("../outputs/models") / fname
if not p.exists(): continue
sd = torch.load(p, map_location=DEVICE, weights_only=True)
missing, unexpected = model.load_state_dict(sd, strict=False)
if not missing and not unexpected:
print(f" loaded {fname}")
return model.to(DEVICE).eval(), cfg
raise FileNotFoundError(f"No usable EMA checkpoint for {run_name}")
def slerp(z1, z2, t):
z1n = z1 / z1.norm(); z2n = z2 / z2.norm()
omega = torch.acos((z1n * z2n).sum().clamp(-1, 1))
if omega.abs() < 1e-6: return (1 - t) * z1 + t * z2
return (torch.sin((1 - t) * omega) / torch.sin(omega)) * z1 + \\
(torch.sin(t * omega) / torch.sin(omega)) * z2
"""),
code("""\
# ── GAN slerp interpolation ─────────────────────────────────────────────────
try:
gan_model, gan_cfg = load_ema_model("p5_gan", "p5_gan.json")
latent_dim = gan_cfg.get("latent_dim", 128)
z1 = torch.randn(1, latent_dim, 1, 1, device=DEVICE)
z2 = torch.randn(1, latent_dim, 1, 1, device=DEVICE)
with torch.no_grad():
zs = torch.cat([slerp(z1, z2, t) for t in torch.linspace(0, 1, 10)])
imgs = gan_model(zs).clamp(-1, 1)
imgs = (imgs + 1) / 2
fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))
for ax, img in zip(axes, imgs.cpu()):
ax.imshow(img.permute(1, 2, 0).numpy()); ax.axis("off")
fig.suptitle("GAN — slerp latent interpolation z₁ → z₂", fontsize=11, fontweight="bold")
plt.tight_layout(); plt.show()
except Exception as e:
print(f"GAN interpolation skipped: {e}")
"""),
code("""\
# ── VAE encode-interpolate-decode ───────────────────────────────────────────
try:
from src.data import GeneratorDataset, get_transform
vae_model, vae_cfg = load_ema_model("p5_vae", "p5_vae.json")
ds = GeneratorDataset("../" + vae_cfg["data_dir"],
sources=vae_cfg.get("sources", ["wiki"]),
transform=get_transform(vae_cfg["image_size"], augment=False))
real = torch.stack([ds[i] for i in range(2)]).to(DEVICE)
with torch.no_grad():
mu, logvar = vae_model.encode(real)
z1, z2 = mu[0:1], mu[1:2]
zs = torch.cat([(1 - t) * z1 + t * z2 for t in torch.linspace(0, 1, 10)])
imgs = vae_model.decode(zs).clamp(-1, 1)
imgs = (imgs + 1) / 2
fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))
for ax, img in zip(axes, imgs.cpu()):
ax.imshow(img.permute(1, 2, 0).numpy()); ax.axis("off")
fig.suptitle("VAE — linear latent interpolation (encoded real images)", fontsize=11, fontweight="bold")
plt.tight_layout(); plt.show()
except Exception as e:
print(f"VAE interpolation skipped: {e}")
"""),
md(f"""\
## 7. Conclusions
| Family | Architecture | Best FID | Train time |
|---|---|---|---|
""" + "\n".join(
f"| {n.split('_')[1].upper()} | "
f"{ {'p5_gan':'WGAN-GP + SN + Attn', 'p5_vae':'VAE + Perceptual + PatchGAN', 'p5_ddpm':'DDPM cosine v-pred wider'}[n] } | "
f"{v:.1f} | {t:.1f} min |"
for n, v, _, t in rows
) + f"""
The ranking is **DDPM ≈ GAN ≪ VAE** by FID. DDPM and GAN trade places depending on
sampling settings (truncation, DDIM step count) but both clearly beat the VAE here —
the perceptual + adversarial losses help VAE reconstructions more than they help its
prior samples (which is what FID measures).
**Practical sampling notes** (encoded in `tools/sampling.py`):
- Load `*_final_ema.pt` rather than `*_best_ema.pt` — the latter can be saved very early
for slowly-converging runs.
- DDPM sampling at DDIM-50 matches the training preview; DDIM-100 is for FID only.
- GAN truncation is **off by default** (matches training); enable for sharper but less
diverse samples.
"""),
]
write_nb("phase5_analysis", cells)
if __name__ == "__main__":
print("Building notebooks...")
build_phase0()
build_phase1()
build_phase2()
build_phase3()
build_phase4()
build_phase5()
print("Done.")