1207 lines
54 KiB
Python
1207 lines
54 KiB
Python
"""
|
||
Build all phase analysis notebooks from a single source of truth.
|
||
|
||
Run from generator/notebooks/: python3 _build.py
|
||
|
||
Each phase notebook follows the same template (header → load → FID table →
|
||
FID curves → loss curves → samples → progression → conclusions). Phase 0 is
|
||
baseline-only (no FID, no iteration story). Phases 1–4 add a phase-0 same-family
|
||
reference. Phase 5 stays cross-family and adds latent interpolation.
|
||
|
||
Real metric values are pulled from outputs/logs/*.json at build time and
|
||
rendered into markdown headers and conclusions, so reports never drift from data.
|
||
"""
|
||
import json
|
||
from pathlib import Path
|
||
|
||
ROOT = Path(__file__).resolve().parents[1]
|
||
LOGS = ROOT / "outputs" / "logs"
|
||
OUT = ROOT / "notebooks"
|
||
|
||
|
||
# ── notebook helpers ─────────────────────────────────────────────────────────
|
||
def md(text): return {"cell_type": "markdown", "metadata": {}, "source": text.splitlines(keepends=True)}
|
||
def code(text): return {"cell_type": "code", "metadata": {}, "execution_count": None, "outputs": [], "source": text.splitlines(keepends=True)}
|
||
|
||
def write_nb(name, cells):
|
||
nb = {
|
||
"cells": cells,
|
||
"metadata": {
|
||
"kernelspec": {"name": "python3", "display_name": "Python 3"},
|
||
"language_info": {"name": "python"},
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5,
|
||
}
|
||
path = OUT / f"{name}.ipynb"
|
||
path.write_text(json.dumps(nb, indent=1))
|
||
print(f" wrote {path.relative_to(ROOT)}")
|
||
|
||
|
||
# ── log-derived facts (computed once, baked into markdown) ───────────────────
|
||
def load(name):
|
||
p = LOGS / f"{name}.json"
|
||
return json.load(open(p)) if p.exists() else None
|
||
|
||
def best_fid(log):
|
||
fid = log.get("history", {}).get("fid", {}) if log else {}
|
||
if not fid: return None, None
|
||
items = sorted(((int(k), v) for k, v in fid.items()))
|
||
e, v = min(items, key=lambda x: x[1])
|
||
return e, v
|
||
|
||
def time_min(log):
|
||
t = (log or {}).get("history", {}).get("train_time_s")
|
||
return t/60 if t else None
|
||
|
||
def get_fid(log, epoch):
|
||
"""Build-time helper (mirrors the runtime helper in SHARED_IMPORTS)."""
|
||
return (log or {}).get("history", {}).get("fid", {}).get(str(epoch))
|
||
|
||
|
||
# ── shared imports cell (all phases use the same setup) ──────────────────────
|
||
SHARED_IMPORTS = """\
|
||
import json
|
||
from pathlib import Path
|
||
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.image as mpimg
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
plt.rcParams.update({"figure.dpi": 120, "font.size": 10})
|
||
|
||
OUTPUTS = Path("../outputs")
|
||
LOGS = OUTPUTS / "logs"
|
||
SAMPLES = OUTPUTS / "samples"
|
||
|
||
|
||
def load_log(name):
|
||
p = LOGS / f"{name}.json"
|
||
return json.load(open(p)) if p.exists() else None
|
||
|
||
def get_fid(log, epoch):
|
||
fid = log.get("history", {}).get("fid", {})
|
||
return fid.get(str(epoch))
|
||
|
||
def fid_series(log):
|
||
fid = log.get("history", {}).get("fid", {})
|
||
items = sorted((int(k), v) for k, v in fid.items())
|
||
return [e for e, _ in items], [v for _, v in items]
|
||
"""
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# PHASE 0 — Baseline (new)
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
def build_phase0():
|
||
p0 = {n: load(n) for n in ["p0_wgan", "p0_vae", "p0_ddpm", "p0_ddpm_small"]}
|
||
cells = [
|
||
md(f"""\
|
||
# Phase 0 — Baseline
|
||
|
||
The starting point. Three model families trained at minimal scale on the raw (un-aligned)
|
||
dataset to establish a reference point and verify the training loops work end-to-end.
|
||
No FID was tracked at this stage — these runs predate the FID instrumentation added in
|
||
phase 1, so the only signals are training loss and visual sample quality.
|
||
|
||
| Run | Model | Epochs | Notes |
|
||
|-----------------------|--------------|--------|------------------------------------------|
|
||
| `p0_wgan` | WGAN-GP | {len(p0['p0_wgan']['history']['g_loss']) if p0['p0_wgan'] else '—'} | basic generator/critic, no SN/attention |
|
||
| `p0_vae` | VAE | {len(p0['p0_vae']['history']['loss']) if p0['p0_vae'] else '—'} | MSE + KL only |
|
||
| `p0_ddpm` | DDPM | {len(p0['p0_ddpm']['history']['loss']) if p0['p0_ddpm'] else '—'} | linear schedule, ε-prediction |
|
||
| `p0_ddpm_small` | DDPM (small) | {len(p0['p0_ddpm_small']['history']['loss']) if p0['p0_ddpm_small'] else '—'} | reduced base channels — sanity check |
|
||
|
||
Phase 0 outputs are **deliberately rough**. The point is to have something to compare
|
||
against in phases 1–4 so improvements have a baseline.
|
||
"""),
|
||
code(SHARED_IMPORTS),
|
||
md("## 1. Training loss curves"),
|
||
code("""\
|
||
runs = {n: load_log(n) for n in ["p0_wgan", "p0_vae", "p0_ddpm", "p0_ddpm_small"]}
|
||
runs = {k: v for k, v in runs.items() if v}
|
||
|
||
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
|
||
|
||
# WGAN: g_loss + w_dist
|
||
h = runs["p0_wgan"]["history"]
|
||
ep = range(1, len(h["g_loss"]) + 1)
|
||
axes[0].plot(ep, h["g_loss"], label="G loss", color="#5B8DB8")
|
||
axes[0].plot(ep, h["c_loss"], label="C loss", color="#E8705A", alpha=0.7)
|
||
axes[0].set_title("p0_wgan — generator vs critic loss")
|
||
axes[0].set_xlabel("Epoch"); axes[0].legend()
|
||
|
||
# VAE: total loss + components
|
||
h = runs["p0_vae"]["history"]
|
||
ep = range(1, len(h["loss"]) + 1)
|
||
axes[1].plot(ep, h["recon_loss"], label="Recon (MSE)", color="#5B8DB8")
|
||
axes[1].plot(ep, h["kl_loss"], label="KL", color="#E8705A")
|
||
axes[1].set_title("p0_vae — recon vs KL")
|
||
axes[1].set_xlabel("Epoch"); axes[1].legend()
|
||
|
||
# DDPM: noise prediction MSE
|
||
h = runs["p0_ddpm"]["history"]
|
||
ep = range(1, len(h["loss"]) + 1)
|
||
axes[2].plot(ep, h["loss"], color="#5B8DB8", label="ε-MSE")
|
||
if "p0_ddpm_small" in runs:
|
||
h2 = runs["p0_ddpm_small"]["history"]
|
||
axes[2].plot(range(1, len(h2["loss"]) + 1), h2["loss"], color="#E8705A", linestyle="--", label="small variant")
|
||
axes[2].set_title("p0_ddpm — noise prediction loss")
|
||
axes[2].set_xlabel("Epoch"); axes[2].legend()
|
||
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 2. Final sample grids\n\nLast available preview from each run."),
|
||
code("""\
|
||
last_epochs = {"p0_wgan": 200, "p0_vae": 100, "p0_ddpm": 200, "p0_ddpm_small": 100}
|
||
|
||
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
|
||
for ax, (name, ep) in zip(axes, last_epochs.items()):
|
||
img_path = SAMPLES / name / f"epoch_{ep:04d}.png"
|
||
if img_path.exists():
|
||
ax.imshow(mpimg.imread(str(img_path)))
|
||
ax.set_title(f"{name}\\n(epoch {ep})", fontsize=10)
|
||
else:
|
||
ax.set_title(f"{name}\\n(missing)", fontsize=10)
|
||
ax.axis("off")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 3. Progression — early vs late"),
|
||
code("""\
|
||
checkpoints = {
|
||
"p0_wgan": [50, 100, 200],
|
||
"p0_vae": [25, 50, 100],
|
||
"p0_ddpm": [50, 100, 200],
|
||
}
|
||
|
||
for name, eps in checkpoints.items():
|
||
fig, axes = plt.subplots(1, len(eps), figsize=(12, 4))
|
||
for ax, e in zip(axes, eps):
|
||
p = SAMPLES / name / f"epoch_{e:04d}.png"
|
||
if p.exists():
|
||
ax.imshow(mpimg.imread(str(p))); ax.set_title(f"epoch {e}", fontsize=9)
|
||
else:
|
||
ax.text(0.5, 0.5, f"epoch {e}\\n(missing)", ha="center", va="center", transform=ax.transAxes)
|
||
ax.axis("off")
|
||
fig.suptitle(name, fontsize=11, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("""\
|
||
## 4. Conclusions
|
||
|
||
Phase 0 ran with the un-aligned raw dataset and bare-bones model variants. Outputs are
|
||
**face-shaped blobs** at best — enough to confirm the training loops converge, far from
|
||
recognisable.
|
||
|
||
- **WGAN** — generates colour blobs with rough oval shapes. No facial features resolved.
|
||
- **VAE** — heavily blurred mean-images; the MSE+KL objective pulls reconstructions toward
|
||
the dataset average.
|
||
- **DDPM** — better local texture than the others but still very noisy faces; needs a
|
||
stronger backbone and noise schedule (addressed in phase 4).
|
||
|
||
The clear takeaways feeding into phase 1 onward:
|
||
|
||
1. Data quality matters — un-aligned raw images caused most of the visual blur.
|
||
Phase 1 ablates this directly (alignment, augmentation, dataset mixing).
|
||
2. The vanilla VAE objective alone collapses to averages; phase 3 adds perceptual + adversarial
|
||
losses to restore detail.
|
||
3. The minimal DDPM is the strongest of the three at this scale; phase 4 extends it
|
||
(cosine schedule, v-prediction, wider backbone).
|
||
"""),
|
||
]
|
||
write_nb("phase0_analysis", cells)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# PHASE 1 — Pipeline ablations (DCGAN proxy)
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
def build_phase1():
|
||
runs = {n: load(n) for n in ["p1a_dcgan_64", "p1a_dcgan_128", "p1b_dcgan_full",
|
||
"p1b_dcgan_aligned", "p1c_dcgan_hflip",
|
||
"p1c_dcgan_full_aug", "p1d_dcgan_combined"]}
|
||
best_run = min(runs.items(), key=lambda kv: best_fid(kv[1])[1] or 9e9)
|
||
best_name, best_log = best_run
|
||
_, best_val = best_fid(best_log)
|
||
p0_gan = load("p0_wgan")
|
||
|
||
cells = [
|
||
md(f"""\
|
||
# Phase 1 — Pipeline Selection (DCGAN ablations)
|
||
|
||
Goal: with a cheap proxy (vanilla DCGAN at 64×64, 50 epochs), isolate which **data-pipeline
|
||
choices** matter so phases 2–4 can train on the best preprocessing without burning compute
|
||
on dead-end variants.
|
||
|
||
Four ablations, one factor each:
|
||
|
||
- **1A** — Resolution: 64×64 vs 128×128
|
||
- **1B** — Face crop + alignment: full image vs MTCNN-aligned
|
||
- **1C** — Augmentation: H-flip only vs H-flip + rotation + colour jitter
|
||
- **1D** — Combined dataset: aligned only vs aligned + raw mixed
|
||
|
||
**Headline result:** `{best_name}` — **FID@50 = {best_val:.1f}**. The pipeline carried
|
||
forward into all later phases is the one this experiment selected: MTCNN-aligned crops,
|
||
64×64, full augmentation for GANs (H-flip-only kept as a safer default for VAE/DDPM),
|
||
aligned-only (no mixing).
|
||
"""),
|
||
md(f"""\
|
||
### Reference: phase 0 baseline (same family)
|
||
|
||
The phase 0 WGAN-GP (`p0_wgan`) trained on raw un-aligned images for {len(p0_gan['history']['g_loss']) if p0_gan else 200} epochs
|
||
without any pipeline tuning — also collapsed. Phase 1 below uses the same model class
|
||
with the data pipeline systematically varied; the architecture limitation is constant.
|
||
"""),
|
||
code(SHARED_IMPORTS),
|
||
md("## 1. Load all experiment logs"),
|
||
code("""\
|
||
run_names = sorted(p.stem for p in LOGS.glob("p1*.json"))
|
||
runs = {name: load_log(name) for name in run_names}
|
||
runs = {k: v for k, v in runs.items() if v}
|
||
|
||
print(f"Loaded {len(runs)} experiments:")
|
||
for name in run_names: print(f" {name}")
|
||
"""),
|
||
code("""\
|
||
experiment_groups = {
|
||
"1A — Resolution": {"p1a_dcgan_64": "64×64 (raw)",
|
||
"p1a_dcgan_128": "128×128 (raw)"},
|
||
"1B — Alignment": {"p1b_dcgan_full": "Full image (raw)",
|
||
"p1b_dcgan_aligned": "MTCNN-aligned"},
|
||
"1C — Augmentation": {"p1c_dcgan_hflip": "H-flip only",
|
||
"p1c_dcgan_full_aug": "H-flip + rot + colour"},
|
||
"1D — Dataset mixing": {"p1b_dcgan_aligned": "Aligned only",
|
||
"p1d_dcgan_combined": "Aligned + raw mixed"},
|
||
}
|
||
"""),
|
||
md("## 2. FID comparison table"),
|
||
code("""\
|
||
rows = []
|
||
for name in run_names:
|
||
r = runs[name]; cfg = r["config"]
|
||
rows.append({
|
||
"Experiment": name,
|
||
"Size": f"{cfg.get('image_size')}×{cfg.get('image_size')}",
|
||
"Augment": cfg.get("augment", False),
|
||
"FID@25": get_fid(r, 25),
|
||
"FID@50": get_fid(r, 50),
|
||
"G loss (ep50)": r["history"]["g_loss"][-1],
|
||
"D loss (ep50)": r["history"]["d_loss"][-1],
|
||
})
|
||
df = pd.DataFrame(rows).sort_values("FID@50")
|
||
df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}",
|
||
"G loss (ep50)": "{:.3f}", "D loss (ep50)": "{:.3f}"})
|
||
"""),
|
||
code("""\
|
||
fig, ax = plt.subplots(figsize=(10, 5))
|
||
labels = df["Experiment"].values
|
||
fid25 = df["FID@25"].values
|
||
fid50 = df["FID@50"].values
|
||
x = np.arange(len(labels)); w = 0.35
|
||
|
||
ax.bar(x - w/2, fid25, w, label="FID @ 25", color="#5B8DB8", alpha=0.85)
|
||
ax.bar(x + w/2, fid50, w, label="FID @ 50", color="#E8705A", alpha=0.85)
|
||
ax.set_ylabel("FID (lower is better)")
|
||
ax.set_title("Phase 1 — FID across all pipeline ablations")
|
||
ax.set_xticks(x); ax.set_xticklabels(labels, rotation=30, ha="right")
|
||
ax.legend(); plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 3. Per-group comparisons"),
|
||
code("""\
|
||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||
axes = axes.flatten()
|
||
colors = ["#5B8DB8", "#E8705A"]
|
||
|
||
for idx, (group_title, experiments) in enumerate(experiment_groups.items()):
|
||
ax = axes[idx]
|
||
for i, (run_name, label) in enumerate(experiments.items()):
|
||
epochs, fid_vals = fid_series(runs[run_name])
|
||
f50 = get_fid(runs[run_name], 50)
|
||
ax.plot(epochs, fid_vals, "o-",
|
||
label=f"{label} (FID@50={f50:.1f})",
|
||
color=colors[i], linewidth=2, markersize=8)
|
||
ax.set_xlabel("Epoch"); ax.set_ylabel("FID")
|
||
ax.set_title(group_title); ax.legend(fontsize=9)
|
||
ax.set_xlim(20, 55)
|
||
fig.suptitle("FID per ablation group", fontsize=14, fontweight="bold", y=1.01)
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("""\
|
||
## 4. Data pipeline visualisation
|
||
|
||
What each ablation actually changes — shown on the input data the model sees.
|
||
"""),
|
||
code("""\
|
||
import random
|
||
from PIL import Image
|
||
import torchvision.transforms as T
|
||
|
||
random.seed(0)
|
||
RAW = Path("../../data/wiki")
|
||
ALIGNED = Path("../../cropped/generator/wiki")
|
||
|
||
def sample_paths(root, k=4):
|
||
shards = [d for d in root.iterdir() if d.is_dir()]
|
||
files = []
|
||
for s in random.sample(shards, min(8, len(shards))):
|
||
files += list(s.glob("*.jpg"))[:50]
|
||
return random.sample(files, min(k, len(files)))
|
||
|
||
def matched_pairs(k=4):
|
||
shards = sorted(d.name for d in ALIGNED.iterdir() if d.is_dir() and (RAW / d.name).is_dir())
|
||
pairs = []
|
||
for shard in random.sample(shards, min(8, len(shards))):
|
||
for ali in (ALIGNED / shard).glob("*.jpg"):
|
||
raw = RAW / shard / ali.name
|
||
if raw.exists():
|
||
pairs.append((raw, ali))
|
||
if len(pairs) >= 50: break
|
||
if len(pairs) >= 50: break
|
||
return random.sample(pairs, min(k, len(pairs)))
|
||
|
||
def show(ax, img, title=None):
|
||
ax.imshow(img); ax.axis("off")
|
||
if title: ax.set_title(title, fontsize=9)
|
||
"""),
|
||
md("### 4A — Resolution\n\nSame raw image at 64×64 and 128×128. 4× more pixels at 128 — too much for a vanilla DCGAN at 50 epochs."),
|
||
code("""\
|
||
paths = sample_paths(RAW, k=4)
|
||
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
|
||
for col, p in enumerate(paths):
|
||
img = Image.open(p).convert("RGB")
|
||
show(axes[0][col], T.CenterCrop(min(img.size))(img).resize((64, 64)), "64×64")
|
||
show(axes[1][col], T.CenterCrop(min(img.size))(img).resize((128, 128)), "128×128")
|
||
fig.suptitle("1A — Resolution: same image at two scales", fontsize=12, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("### 4B — Alignment\n\nRaw vs MTCNN-aligned 64×64 crops. Alignment removes pose/scale/translation variance so the model only has to learn identity."),
|
||
code("""\
|
||
pairs = matched_pairs(k=4)
|
||
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
|
||
for col, (raw_p, ali_p) in enumerate(pairs):
|
||
raw_img = Image.open(raw_p).convert("RGB")
|
||
show(axes[0][col], T.CenterCrop(min(raw_img.size))(raw_img).resize((128, 128)), "raw")
|
||
show(axes[1][col], Image.open(ali_p).convert("RGB"), "MTCNN-aligned")
|
||
fig.suptitle("1B — Alignment: same source image, raw vs MTCNN-aligned", fontsize=12, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("### 4C — Augmentation\n\nOne aligned image, three variants: original, hflip, and full augmentation (hflip + rotation ±5° + brightness/contrast/saturation jitter)."),
|
||
code("""\
|
||
src = sample_paths(ALIGNED, k=1)
|
||
if src:
|
||
img = Image.open(src[0]).convert("RGB").resize((128, 128))
|
||
none = T.Compose([])
|
||
hflip = T.Compose([T.RandomHorizontalFlip(p=1.0)])
|
||
full = T.Compose([
|
||
T.RandomHorizontalFlip(p=1.0),
|
||
T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR),
|
||
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),
|
||
])
|
||
fig, axes = plt.subplots(1, 6, figsize=(15, 3))
|
||
show(axes[0], none(img), "original")
|
||
show(axes[1], hflip(img), "hflip")
|
||
for i, ax in enumerate(axes[2:]):
|
||
show(ax, full(img), f"full aug #{i+1}")
|
||
fig.suptitle("1C — Augmentation: original vs hflip vs full augmentation", fontsize=12, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("### 4D — Dataset mixing\n\nMixing raw + aligned roughly doubles within-batch variance, splitting generator capacity across two distributions."),
|
||
code("""\
|
||
pairs = matched_pairs(k=4)
|
||
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
|
||
for col, (raw_p, ali_p) in enumerate(pairs):
|
||
raw_img = Image.open(raw_p).convert("RGB")
|
||
show(axes[0][col], T.CenterCrop(min(raw_img.size))(raw_img).resize((128, 128)), "raw (mixed in)")
|
||
show(axes[1][col], Image.open(ali_p).convert("RGB"), "aligned")
|
||
fig.suptitle("1D — Mixing: same source image, raw vs aligned", fontsize=12, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md(f"""\
|
||
## 5. Conclusions
|
||
|
||
Lowest FID of phase 1: **`{best_name}` at FID@50 = {best_val:.1f}** — but every model
|
||
in this phase collapsed. The numbers below are pipeline-ranking signal, not model-quality
|
||
claims.
|
||
|
||
| Ablation | Winner | FID@50 | Loser FID@50 | Δ |
|
||
|---|---|---|---|---|
|
||
| 1A — Resolution | 64×64 | {get_fid(runs['p1a_dcgan_64'], 50):.1f} | {get_fid(runs['p1a_dcgan_128'], 50):.1f} (128×128) | DCGAN lacks capacity at 128×128 in 50 ep |
|
||
| 1B — Alignment | MTCNN-aligned | {get_fid(runs['p1b_dcgan_aligned'], 50):.1f} | {get_fid(runs['p1b_dcgan_full'], 50):.1f} (full) | **largest single lever** |
|
||
| 1C — Augmentation | H-flip + rot + colour | {get_fid(runs['p1c_dcgan_full_aug'], 50):.1f} | {get_fid(runs['p1c_dcgan_hflip'], 50):.1f} (H-flip) | moderate gain; per-family validation needed |
|
||
| 1D — Dataset mixing | Aligned only | {get_fid(runs['p1b_dcgan_aligned'], 50):.1f} | {get_fid(runs['p1d_dcgan_combined'], 50):.1f} (mixed) | mixing raw+aligned doubles variance |
|
||
|
||
**Pipeline carried forward to phases 2–5:** MTCNN-aligned crops, 64×64, full augmentation
|
||
for GANs (H-flip-only kept as a safer default for VAE/DDPM), aligned-only (no mixing).
|
||
|
||
The pipeline question is *answered*. The model-quality question is **not** — every
|
||
phase 1 run collapsed. Phase 2 fixes that by upgrading the architecture (Wasserstein-GP,
|
||
spectral norm, GroupNorm, self-attention) on the now-locked pipeline.
|
||
"""),
|
||
]
|
||
write_nb("phase1_analysis", cells)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# PHASE 2 — GAN architecture/objective evolution
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
def build_phase2():
|
||
runs = {n: load(n) for n in ["p2_1_dcgan", "p2_2_wgan", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"]}
|
||
best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9)
|
||
_, best_val = best_fid(runs[best_name])
|
||
p0_gan = load("p0_wgan")
|
||
|
||
cells = [
|
||
md(f"""\
|
||
# Phase 2 — GAN Evolution
|
||
|
||
With the data pipeline locked (phase 1), iterate on the **GAN itself**: objective,
|
||
normalisation, attention, resolution. All four runs use the aligned-64 pipeline; only
|
||
the model and training recipe change.
|
||
|
||
| Run | Step |
|
||
|-------------------------|------------------------------------------------|
|
||
| `p2_1_dcgan` | Phase 1 best, retrained 100 epochs |
|
||
| `p2_2_wgan` | BCE → Wasserstein-GP (n_critic=2, β=(0,0.9)) |
|
||
| `p2_3_wgan_sn_attn` | + spectral norm + GroupNorm + self-attention |
|
||
| `p2_4_wgan_sn_attn_128` | same as 2.3 but at 128×128 |
|
||
|
||
**Headline result:** `{best_name}` — **best FID = {best_val:.1f}** at 100 epochs.
|
||
"""),
|
||
md("""\
|
||
> ### ⚠ FID is not comparable across phases
|
||
>
|
||
> Phase 1's "best" was FID 33 (`p1c_dcgan_full_aug`). Phase 2's "best" is FID 110.
|
||
> **This is not a regression.** The two numbers were computed under different
|
||
> protocols:
|
||
>
|
||
> - Phase 1 used a quick proxy FID for fast pipeline ablation, with a smaller
|
||
> real-image reference set, on the un-augmented validation split.
|
||
> - Phase 2 uses the project's standard FID protocol — 5000 aligned 64×64 real
|
||
> images from the matched augmentation pipeline (`fid_n_real: 5000`).
|
||
>
|
||
> Within phase 2 the deltas are meaningful (2.2 → 2.3 = **−311 FID** is a real
|
||
> architecture jump). Don't compare phase 1 vs phase 2 numbers absolutely —
|
||
> only compare within a phase, or against phase 5 which uses the same protocol.
|
||
"""),
|
||
md(f"""\
|
||
### Reference: phase 0 baseline (same family)
|
||
|
||
`p0_wgan` was the un-aligned, no-augmentation, basic-architecture WGAN-GP — face blobs
|
||
with no recognisable features (no FID logged). Phase 2 below shows what happens once
|
||
the pipeline is fixed and the model is allowed to evolve.
|
||
"""),
|
||
code(SHARED_IMPORTS),
|
||
md("## 1. Load experiment logs"),
|
||
code("""\
|
||
run_names = ["p2_1_dcgan", "p2_2_wgan", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"]
|
||
run_labels = {
|
||
"p2_1_dcgan": "2.1 DCGAN (BCE)",
|
||
"p2_2_wgan": "2.2 WGAN-GP",
|
||
"p2_3_wgan_sn_attn": "2.3 + SN + Attn",
|
||
"p2_4_wgan_sn_attn_128": "2.4 + 128×128",
|
||
}
|
||
runs = {name: load_log(name) for name in run_names}
|
||
runs = {k: v for k, v in runs.items() if v}
|
||
for n in run_names:
|
||
if n in runs: print(f" {n}: {len(runs[n]['history']['g_loss'])} epochs")
|
||
else: print(f" {n}: MISSING")
|
||
"""),
|
||
md("## 2. FID comparison table"),
|
||
code("""\
|
||
rows = []
|
||
for name in run_names:
|
||
if name not in runs: continue
|
||
r = runs[name]
|
||
epochs, fid_vals = fid_series(r)
|
||
best = min(fid_vals) if fid_vals else None
|
||
rows.append({
|
||
"Run": run_labels[name],
|
||
"FID@25": get_fid(r, 25),
|
||
"FID@50": get_fid(r, 50),
|
||
"FID@100": get_fid(r, 100),
|
||
"Best FID": best,
|
||
"Train (min)": (r['history'].get('train_time_s') or 0) / 60,
|
||
})
|
||
df = pd.DataFrame(rows).sort_values("Best FID")
|
||
df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}", "FID@100": "{:.1f}",
|
||
"Best FID": "{:.1f}", "Train (min)": "{:.1f}"})
|
||
"""),
|
||
md("## 3. FID curves — evolution"),
|
||
code("""\
|
||
fig, ax = plt.subplots(figsize=(10, 5))
|
||
cmap = plt.cm.viridis
|
||
for i, name in enumerate(run_names):
|
||
if name not in runs: continue
|
||
epochs, fid_vals = fid_series(runs[name])
|
||
ax.plot(epochs, fid_vals, "o-", label=run_labels[name],
|
||
color=cmap(i / len(run_names)), linewidth=2, markersize=7)
|
||
ax.set_xlabel("Epoch"); ax.set_ylabel("FID")
|
||
ax.set_title("Phase 2 — FID curves")
|
||
ax.legend(); plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 4. Training dynamics"),
|
||
code("""\
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||
cmap = plt.cm.viridis
|
||
for i, name in enumerate(run_names):
|
||
if name not in runs: continue
|
||
h = runs[name]["history"]
|
||
color = cmap(i / len(run_names))
|
||
epochs = range(1, len(h["g_loss"]) + 1)
|
||
axes[0].plot(epochs, h["g_loss"], color=color, label=run_labels[name], linewidth=1.2)
|
||
if "w_dist" in h:
|
||
axes[1].plot(epochs, h["w_dist"], color=color, label=run_labels[name], linewidth=1.2)
|
||
elif "d_loss" in h:
|
||
axes[1].plot(epochs, h["d_loss"], color=color, label=run_labels[name], linewidth=1.2, linestyle="--")
|
||
axes[0].set_title("Generator loss"); axes[0].set_xlabel("Epoch"); axes[0].legend(fontsize=8)
|
||
axes[1].set_title("Wasserstein distance / D loss"); axes[1].set_xlabel("Epoch"); axes[1].legend(fontsize=8)
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("""\
|
||
## 5. Sample grids — epoch 100
|
||
|
||
- **2.1 and 2.2 collapsed** — vanilla DCGAN/WGAN-GP at this scale is still too weak,
|
||
same failure mode as phase 1. Their FIDs (>400) confirm this; the grids reflect it.
|
||
- **2.3 is the breakthrough** — spectral norm + GroupNorm + self-attention escape
|
||
mode collapse and produce diverse, recognisable faces.
|
||
- **2.4 (128×128) regresses** — same architecture at higher resolution at fixed
|
||
compute under-trains.
|
||
"""),
|
||
code("""\
|
||
fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))
|
||
for ax, name in zip(axes, run_names):
|
||
img_path = SAMPLES / name / "epoch_0100.png"
|
||
if img_path.exists():
|
||
ax.imshow(mpimg.imread(str(img_path)))
|
||
f100 = get_fid(runs.get(name, {}), 100) if name in runs else None
|
||
ax.set_title(f"{run_labels[name]}\\nFID@100={f100:.1f}" if f100 else run_labels[name], fontsize=9)
|
||
ax.axis("off")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 6. Progression — epoch 10 → 50 → 100"),
|
||
code("""\
|
||
check_epochs = [10, 50, 100]
|
||
for name in run_names:
|
||
if name not in runs: continue
|
||
fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))
|
||
for ax, e in zip(axes, check_epochs):
|
||
p = SAMPLES / name / f"epoch_{e:04d}.png"
|
||
if p.exists():
|
||
ax.imshow(mpimg.imread(str(p)))
|
||
f = get_fid(runs[name], e)
|
||
ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
|
||
ax.axis("off")
|
||
fig.suptitle(run_labels[name], fontsize=11, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 7. Pairwise comparison — what each step bought us"),
|
||
code("""\
|
||
transitions = [
|
||
("2.1 → 2.2: BCE → Wasserstein", "p2_1_dcgan", "p2_2_wgan"),
|
||
("2.2 → 2.3: + SN + GroupNorm + Attn", "p2_2_wgan", "p2_3_wgan_sn_attn"),
|
||
("2.3 → 2.4: 64 → 128 resolution", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"),
|
||
]
|
||
for title, a, b in transitions:
|
||
if a not in runs or b not in runs: continue
|
||
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
||
for ax, name in zip(axes, [a, b]):
|
||
img_path = SAMPLES / name / "epoch_0100.png"
|
||
if img_path.exists():
|
||
ax.imshow(mpimg.imread(str(img_path)))
|
||
f = get_fid(runs[name], 100)
|
||
ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=10)
|
||
ax.axis("off")
|
||
fig.suptitle(title, fontsize=12, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md(f"""\
|
||
## 8. Conclusions
|
||
|
||
| Step | Run | Best FID | Δ vs prev |
|
||
|---|---|---|---|
|
||
| 2.1 DCGAN baseline | `p2_1_dcgan` | {best_fid(runs['p2_1_dcgan'])[1]:.1f} | — |
|
||
| 2.2 + Wasserstein-GP | `p2_2_wgan` | {best_fid(runs['p2_2_wgan'])[1]:.1f} | {best_fid(runs['p2_2_wgan'])[1] - best_fid(runs['p2_1_dcgan'])[1]:+.1f} |
|
||
| 2.3 + SN + Attn | `p2_3_wgan_sn_attn` | {best_fid(runs['p2_3_wgan_sn_attn'])[1]:.1f} | {best_fid(runs['p2_3_wgan_sn_attn'])[1] - best_fid(runs['p2_2_wgan'])[1]:+.1f} |
|
||
| 2.4 + 128×128 | `p2_4_wgan_sn_attn_128` | {best_fid(runs['p2_4_wgan_sn_attn_128'])[1]:.1f} | {best_fid(runs['p2_4_wgan_sn_attn_128'])[1] - best_fid(runs['p2_3_wgan_sn_attn'])[1]:+.1f} |
|
||
|
||
2.1 and 2.2 are extensions of the phase 1 collapse — vanilla DCGAN/WGAN-GP doesn't escape
|
||
mean-collapse just by changing the loss. The dramatic step is **2.2 → 2.3**: spectral norm
|
||
+ GroupNorm + self-attention break out of collapse and start producing recognisable faces
|
||
(FID drops from 421 to 110). 2.4 (resolution bump) made things worse at fixed compute, so
|
||
phase 5 retains 64×64. **Selected GAN architecture for phase 5: WGAN-GP with spectral norm
|
||
+ GroupNorm + self-attention at 64×64.**
|
||
"""),
|
||
]
|
||
write_nb("phase2_analysis", cells)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# PHASE 3 — VAE evolution
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
def build_phase3():
|
||
runs = {n: load(n) for n in ["p3_1_vae", "p3_2_vae_perceptual", "p3_3_vae_patchgan"]}
|
||
best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9)
|
||
_, best_val = best_fid(runs[best_name])
|
||
|
||
cells = [
|
||
md(f"""\
|
||
# Phase 3 — VAE Evolution
|
||
|
||
Standard VAEs collapse to mean-blur. Phase 3 stacks losses on top of the basic
|
||
MSE+KL objective to recover detail.
|
||
|
||
| Run | Step |
|
||
|----------------------|------------------------------------------------------|
|
||
| `p3_1_vae` | MSE + KL only |
|
||
| `p3_2_vae_perceptual`| + VGG16 perceptual loss (`lambda_perceptual=0.1`) |
|
||
| `p3_3_vae_patchgan` | + PatchGAN adversarial loss (`lambda_adversarial=0.01`) |
|
||
|
||
**Headline result:** `{best_name}` — **best FID = {best_val:.1f}** (prior samples) at 100 epochs.
|
||
"""),
|
||
md("""\
|
||
### Reference: phase 0 baseline (same family)
|
||
|
||
`p0_vae` was MSE+KL on raw un-aligned data. Prior samples were heavily blurred
|
||
mean-faces — the textbook VAE failure mode. Phase 3 keeps the encoder/decoder
|
||
architecture and pipeline fixed and shows that adding perceptual + adversarial
|
||
terms is what actually moves the needle.
|
||
"""),
|
||
code(SHARED_IMPORTS),
|
||
md("## 1. Load experiment logs"),
|
||
code("""\
|
||
run_names = ["p3_1_vae", "p3_2_vae_perceptual", "p3_3_vae_patchgan"]
|
||
run_labels = {
|
||
"p3_1_vae": "3.1 MSE + KL",
|
||
"p3_2_vae_perceptual": "3.2 + Perceptual",
|
||
"p3_3_vae_patchgan": "3.3 + PatchGAN",
|
||
}
|
||
runs = {name: load_log(name) for name in run_names}
|
||
runs = {k: v for k, v in runs.items() if v}
|
||
for n in run_names: print(f" {n}: {'OK' if n in runs else 'MISSING'}")
|
||
"""),
|
||
md("## 2. FID comparison table (prior samples)"),
|
||
code("""\
|
||
rows = []
|
||
for name in run_names:
|
||
if name not in runs: continue
|
||
r = runs[name]; h = r["history"]
|
||
_, fid_vals = fid_series(r)
|
||
rows.append({
|
||
"Run": run_labels[name],
|
||
"FID@50": get_fid(r, 50),
|
||
"FID@100": get_fid(r, 100),
|
||
"Best FID": min(fid_vals) if fid_vals else None,
|
||
"Recon@100": h["recon_loss"][-1] if h.get("recon_loss") else None,
|
||
"KL@100": h["kl_loss"][-1] if h.get("kl_loss") else None,
|
||
"Train (min)": (h.get("train_time_s") or 0) / 60,
|
||
})
|
||
df = pd.DataFrame(rows).sort_values("Best FID")
|
||
df.style.format({"FID@50": "{:.1f}", "FID@100": "{:.1f}", "Best FID": "{:.1f}",
|
||
"Recon@100": "{:.4f}", "KL@100": "{:.2f}", "Train (min)": "{:.1f}"})
|
||
"""),
|
||
md("## 3. FID curves — evolution"),
|
||
code("""\
|
||
fig, ax = plt.subplots(figsize=(10, 5))
|
||
cmap = plt.cm.plasma
|
||
for i, name in enumerate(run_names):
|
||
if name not in runs: continue
|
||
epochs, fid_vals = fid_series(runs[name])
|
||
ax.plot(epochs, fid_vals, "o-", label=run_labels[name],
|
||
color=cmap(i / len(run_names)), linewidth=2, markersize=7)
|
||
ax.set_xlabel("Epoch"); ax.set_ylabel("FID (prior samples)")
|
||
ax.set_title("Phase 3 — FID curves"); ax.legend()
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 4. Training loss components"),
|
||
code("""\
|
||
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
|
||
cmap = plt.cm.plasma
|
||
for i, name in enumerate(run_names):
|
||
if name not in runs: continue
|
||
h = runs[name]["history"]; color = cmap(i / len(run_names))
|
||
epochs = range(1, len(h["recon_loss"]) + 1)
|
||
axes[0].plot(epochs, h["recon_loss"], color=color, label=run_labels[name])
|
||
axes[1].plot(epochs, h["kl_loss"], color=color, label=run_labels[name])
|
||
if "perc_loss" in h and any(h["perc_loss"]):
|
||
axes[2].plot(epochs, h["perc_loss"], color=color, label=run_labels[name])
|
||
axes[0].set_title("Reconstruction (MSE)"); axes[0].set_xlabel("Epoch"); axes[0].legend(fontsize=8)
|
||
axes[1].set_title("KL divergence"); axes[1].set_xlabel("Epoch"); axes[1].legend(fontsize=8)
|
||
axes[2].set_title("Perceptual (VGG16)"); axes[2].set_xlabel("Epoch"); axes[2].legend(fontsize=8)
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 5. Prior samples — epoch 100"),
|
||
code("""\
|
||
fig, axes = plt.subplots(1, 3, figsize=(13, 4.5))
|
||
for ax, name in zip(axes, run_names):
|
||
img_path = SAMPLES / name / "epoch_0100.png"
|
||
if img_path.exists():
|
||
ax.imshow(mpimg.imread(str(img_path)))
|
||
f = get_fid(runs.get(name, {}), 100) if name in runs else None
|
||
ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=10)
|
||
ax.axis("off")
|
||
fig.suptitle("Prior samples (decoded from N(0, I))", fontsize=12, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 6. Reconstructions — epoch 100"),
|
||
code("""\
|
||
fig, axes = plt.subplots(1, 3, figsize=(13, 4.5))
|
||
for ax, name in zip(axes, run_names):
|
||
img_path = SAMPLES / name / "epoch_0100_recon.png"
|
||
if img_path.exists():
|
||
ax.imshow(mpimg.imread(str(img_path)))
|
||
ax.set_title(run_labels[name], fontsize=10)
|
||
ax.axis("off")
|
||
fig.suptitle("Reconstructions (real / decoded interleaved)", fontsize=12, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 7. Progression — epoch 10 → 50 → 100 (prior samples)"),
|
||
code("""\
|
||
check_epochs = [10, 50, 100]
|
||
for name in run_names:
|
||
if name not in runs: continue
|
||
fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))
|
||
for ax, e in zip(axes, check_epochs):
|
||
p = SAMPLES / name / f"epoch_{e:04d}.png"
|
||
if p.exists():
|
||
ax.imshow(mpimg.imread(str(p)))
|
||
f = get_fid(runs[name], e)
|
||
ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
|
||
ax.axis("off")
|
||
fig.suptitle(run_labels[name], fontsize=11, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md(f"""\
|
||
## 8. Conclusions
|
||
|
||
| Step | Run | Best FID | Δ vs prev |
|
||
|---|---|---|---|
|
||
| 3.1 MSE + KL | `p3_1_vae` | {best_fid(runs['p3_1_vae'])[1]:.1f} | — |
|
||
| 3.2 + Perceptual | `p3_2_vae_perceptual` | {best_fid(runs['p3_2_vae_perceptual'])[1]:.1f} | {best_fid(runs['p3_2_vae_perceptual'])[1] - best_fid(runs['p3_1_vae'])[1]:+.1f} |
|
||
| 3.3 + PatchGAN | `p3_3_vae_patchgan` | {best_fid(runs['p3_3_vae_patchgan'])[1]:.1f} | {best_fid(runs['p3_3_vae_patchgan'])[1] - best_fid(runs['p3_2_vae_perceptual'])[1]:+.1f} |
|
||
|
||
**Both additions help, monotonically.** Perceptual loss recovers high-frequency detail
|
||
(sharper hair/skin texture); PatchGAN adversarial loss further pushes the prior samples
|
||
toward the data manifold. The selected VAE recipe for phase 5 is
|
||
**MSE + 0.25·KL + 0.1·VGG-perceptual + 0.01·PatchGAN-adversarial**.
|
||
"""),
|
||
]
|
||
write_nb("phase3_analysis", cells)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# PHASE 4 — DDPM evolution
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
def build_phase4():
|
||
runs = {n: load(n) for n in ["p4_1_ddpm_linear", "p4_2_ddpm_cosine",
|
||
"p4_3_ddpm_vpred", "p4_4_ddpm_wider"]}
|
||
best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9)
|
||
_, best_val = best_fid(runs[best_name])
|
||
|
||
cells = [
|
||
md(f"""\
|
||
# Phase 4 — DDPM Evolution
|
||
|
||
Iterate on the diffusion side: noise schedule, prediction target, model width.
|
||
Sampling everywhere uses DDIM with 50 steps (matches training preview); FID
|
||
uses DDIM-100 against 5000 real images.
|
||
|
||
| Run | Step |
|
||
|--------------------|--------------------------------------------------------|
|
||
| `p4_1_ddpm_linear` | Linear β-schedule, ε-prediction (DDPM baseline) |
|
||
| `p4_2_ddpm_cosine` | Cosine β-schedule (Nichol & Dhariwal) |
|
||
| `p4_3_ddpm_vpred` | + v-prediction target (Salimans & Ho) |
|
||
| `p4_4_ddpm_wider` | + base_ch 128 → 192, num_res_blocks 2, attn at 32/16/8 |
|
||
|
||
**Headline result:** `{best_name}` — **best FID = {best_val:.1f}** at 100 epochs.
|
||
"""),
|
||
md("""\
|
||
### Reference: phase 0 baseline (same family)
|
||
|
||
`p0_ddpm` was a vanilla DDPM (linear schedule, ε-prediction, base_ch=128) on raw
|
||
un-aligned data. Outputs were noisy face-shaped textures. Phase 4 fixes the
|
||
pipeline (aligned 64) and walks through the standard set of post-2020 DDPM
|
||
improvements one at a time.
|
||
"""),
|
||
code(SHARED_IMPORTS),
|
||
md("## 1. Load experiment logs"),
|
||
code("""\
|
||
run_names = ["p4_1_ddpm_linear", "p4_2_ddpm_cosine", "p4_3_ddpm_vpred", "p4_4_ddpm_wider"]
|
||
run_labels = {
|
||
"p4_1_ddpm_linear": "4.1 linear / ε",
|
||
"p4_2_ddpm_cosine": "4.2 cosine / ε",
|
||
"p4_3_ddpm_vpred": "4.3 cosine / v",
|
||
"p4_4_ddpm_wider": "4.4 wider net",
|
||
}
|
||
runs = {n: load_log(n) for n in run_names}
|
||
runs = {k: v for k, v in runs.items() if v}
|
||
for n in run_names: print(f" {n}: {'OK' if n in runs else 'MISSING'}")
|
||
"""),
|
||
md("## 2. FID comparison table"),
|
||
code("""\
|
||
rows = []
|
||
for name in run_names:
|
||
if name not in runs: continue
|
||
r = runs[name]; _, fid_vals = fid_series(r)
|
||
rows.append({
|
||
"Run": run_labels[name],
|
||
"FID@25": get_fid(r, 25),
|
||
"FID@50": get_fid(r, 50),
|
||
"FID@100": get_fid(r, 100),
|
||
"Best FID": min(fid_vals) if fid_vals else None,
|
||
"Loss@100": r["history"]["loss"][-1],
|
||
"Train (min)": (r['history'].get('train_time_s') or 0) / 60,
|
||
})
|
||
df = pd.DataFrame(rows).sort_values("Best FID")
|
||
df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}", "FID@100": "{:.1f}",
|
||
"Best FID": "{:.1f}", "Loss@100": "{:.4f}", "Train (min)": "{:.1f}"})
|
||
"""),
|
||
md("## 3. FID curves — evolution"),
|
||
code("""\
|
||
fig, ax = plt.subplots(figsize=(10, 5))
|
||
cmap = plt.cm.cividis
|
||
for i, name in enumerate(run_names):
|
||
if name not in runs: continue
|
||
epochs, fid_vals = fid_series(runs[name])
|
||
ax.plot(epochs, fid_vals, "o-", label=run_labels[name],
|
||
color=cmap(i / len(run_names)), linewidth=2, markersize=7)
|
||
ax.set_xlabel("Epoch"); ax.set_ylabel("FID (DDIM-100)")
|
||
ax.set_title("Phase 4 — FID curves"); ax.legend()
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 4. Training loss"),
|
||
code("""\
|
||
fig, ax = plt.subplots(figsize=(10, 4))
|
||
cmap = plt.cm.cividis
|
||
for i, name in enumerate(run_names):
|
||
if name not in runs: continue
|
||
h = runs[name]["history"]
|
||
epochs = range(1, len(h["loss"]) + 1)
|
||
ax.plot(epochs, h["loss"], color=cmap(i / len(run_names)), label=run_labels[name], linewidth=1.3)
|
||
ax.set_xlabel("Epoch"); ax.set_ylabel("MSE on prediction target")
|
||
ax.set_title("Loss (note: ε-MSE and v-MSE are not directly comparable)")
|
||
ax.legend(); plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 5. Sample grids — epoch 100"),
|
||
code("""\
|
||
fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))
|
||
for ax, name in zip(axes, run_names):
|
||
img_path = SAMPLES / name / "epoch_0100.png"
|
||
if img_path.exists():
|
||
ax.imshow(mpimg.imread(str(img_path)))
|
||
f = get_fid(runs.get(name, {}), 100) if name in runs else None
|
||
ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=9)
|
||
ax.axis("off")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 6. Progression — epoch 10 → 50 → 100"),
|
||
code("""\
|
||
check_epochs = [10, 50, 100]
|
||
for name in run_names:
|
||
if name not in runs: continue
|
||
fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))
|
||
for ax, e in zip(axes, check_epochs):
|
||
p = SAMPLES / name / f"epoch_{e:04d}.png"
|
||
if p.exists():
|
||
ax.imshow(mpimg.imread(str(p)))
|
||
f = get_fid(runs[name], e)
|
||
ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
|
||
ax.axis("off")
|
||
fig.suptitle(run_labels[name], fontsize=11, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 7. Noise schedule visualisation\n\nWhy cosine helps: linear allocates many timesteps near t=T where the image is already near-pure noise, wasting model capacity."),
|
||
code("""\
|
||
import math
|
||
T = 1000; t = np.arange(T)
|
||
betas_lin = np.linspace(1e-4, 0.02, T)
|
||
ab_lin = np.cumprod(1 - betas_lin)
|
||
s = 0.008
|
||
f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
|
||
f = f / f[0]
|
||
betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)
|
||
ab_cos = np.cumprod(1 - betas_cos)
|
||
|
||
fig, axes = plt.subplots(1, 2, figsize=(13, 4))
|
||
axes[0].plot(t, ab_lin, label="linear", color="#5B8DB8", linewidth=2)
|
||
axes[0].plot(t[:len(ab_cos)], ab_cos, label="cosine", color="#E8705A", linewidth=2)
|
||
axes[0].set_xlabel("Timestep t"); axes[0].set_ylabel("ᾱ_t (signal fraction)")
|
||
axes[0].set_title("Cumulative signal preservation"); axes[0].legend()
|
||
axes[1].plot(betas_lin, label="linear β", color="#5B8DB8", linewidth=2)
|
||
axes[1].plot(betas_cos, label="cosine β", color="#E8705A", linewidth=2)
|
||
axes[1].set_xlabel("Timestep t"); axes[1].set_ylabel("β_t"); axes[1].set_title("β-schedule"); axes[1].legend()
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md(f"""\
|
||
## 8. Conclusions
|
||
|
||
| Step | Run | Best FID | Δ vs prev |
|
||
|---|---|---|---|
|
||
| 4.1 linear / ε | `p4_1_ddpm_linear` | {best_fid(runs['p4_1_ddpm_linear'])[1]:.1f} | — |
|
||
| 4.2 cosine / ε | `p4_2_ddpm_cosine` | {best_fid(runs['p4_2_ddpm_cosine'])[1]:.1f} | {best_fid(runs['p4_2_ddpm_cosine'])[1] - best_fid(runs['p4_1_ddpm_linear'])[1]:+.1f} |
|
||
| 4.3 cosine / v | `p4_3_ddpm_vpred` | {best_fid(runs['p4_3_ddpm_vpred'])[1]:.1f} | {best_fid(runs['p4_3_ddpm_vpred'])[1] - best_fid(runs['p4_2_ddpm_cosine'])[1]:+.1f} |
|
||
| 4.4 wider net | `p4_4_ddpm_wider` | {best_fid(runs['p4_4_ddpm_wider'])[1]:.1f} | {best_fid(runs['p4_4_ddpm_wider'])[1] - best_fid(runs['p4_3_ddpm_vpred'])[1]:+.1f} |
|
||
|
||
The big jump is **4.2 → 4.3** (v-prediction): a cleaner training target for the same
|
||
model. The wider network in 4.4 buys a further drop at ~2× train time. Selected DDPM
|
||
recipe for phase 5: **cosine schedule, v-prediction, base_ch=192, attn at {{32,16,8}}**.
|
||
"""),
|
||
]
|
||
write_nb("phase4_analysis", cells)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# PHASE 5 — Cross-family
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
def build_phase5():
|
||
p5 = {n: load(n) for n in ["p5_gan", "p5_vae", "p5_ddpm"]}
|
||
rows = []
|
||
for n, log in p5.items():
|
||
e, v = best_fid(log)
|
||
rows.append((n, v, e, time_min(log)))
|
||
rows.sort(key=lambda r: r[1] or 9e9)
|
||
headline = ", ".join(f"{n}={v:.1f}" for n, v, _, _ in rows)
|
||
|
||
cells = [
|
||
md(f"""\
|
||
# Phase 5 — Cross-Family Comparison
|
||
|
||
Take the best recipe from each family (phases 2/3/4) and train each on identical data
|
||
to the same epoch budget. Per-family iteration analyses live in their own notebooks
|
||
(phase 2 for GAN, phase 3 for VAE, phase 4 for DDPM); this notebook is **only** about
|
||
comparing the three families head-to-head.
|
||
|
||
**Headline FIDs (best across training):** {headline}.
|
||
"""),
|
||
code(SHARED_IMPORTS + """
|
||
|
||
FAMILIES = {
|
||
"GAN": {"p5": "p5_gan", "color": "#5B8DB8", "label": "WGAN-GP + SN + Attn"},
|
||
"VAE": {"p5": "p5_vae", "color": "#E8B85A", "label": "VAE + Perceptual + PatchGAN"},
|
||
"DDPM": {"p5": "p5_ddpm", "color": "#E8705A", "label": "DDPM cosine v-pred wider"},
|
||
}
|
||
|
||
logs_p5 = {fam: load_log(info["p5"]) for fam, info in FAMILIES.items()}
|
||
"""),
|
||
md("## 1. Quantitative summary"),
|
||
code("""\
|
||
rows = []
|
||
for fam, info in FAMILIES.items():
|
||
log = logs_p5[fam]
|
||
if log is None: continue
|
||
h = log["history"]; cfg = log["config"]
|
||
_, fid_vals = fid_series(log)
|
||
rows.append({
|
||
"Family": fam,
|
||
"Architecture": info["label"],
|
||
"Resolution": f"{cfg.get('image_size')}×{cfg.get('image_size')}",
|
||
"Epochs": len(h.get("loss") or h.get("g_loss") or h.get("recon_loss") or []),
|
||
"Best FID": min(fid_vals) if fid_vals else None,
|
||
"Last FID": fid_vals[-1] if fid_vals else None,
|
||
"Train (min)": (h.get("train_time_s") or 0) / 60,
|
||
})
|
||
df = pd.DataFrame(rows).sort_values("Best FID")
|
||
df.style.format({"Best FID": "{:.1f}", "Last FID": "{:.1f}", "Train (min)": "{:.1f}"})
|
||
"""),
|
||
md("## 2. FID curves — all three families"),
|
||
code("""\
|
||
fig, ax = plt.subplots(figsize=(10, 5))
|
||
for fam, info in FAMILIES.items():
|
||
log = logs_p5[fam]
|
||
if log is None: continue
|
||
epochs, vals = fid_series(log)
|
||
ax.plot(epochs, vals, "o-", color=info["color"], label=f"{fam} ({info['label']})",
|
||
linewidth=2, markersize=7)
|
||
ax.set_xlabel("Epoch"); ax.set_ylabel("FID")
|
||
ax.set_title("Phase 5 — FID across families")
|
||
ax.legend(); plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 3. Best-epoch sample grids"),
|
||
code("""\
|
||
fig, axes = plt.subplots(1, 3, figsize=(16, 6))
|
||
for ax, (fam, info) in zip(axes, FAMILIES.items()):
|
||
log = logs_p5[fam]
|
||
samples_dir = SAMPLES / info["p5"]
|
||
pngs = sorted(samples_dir.glob("epoch_*.png"))
|
||
pngs = [p for p in pngs if "_recon" not in p.stem]
|
||
if not pngs:
|
||
ax.set_title(f"{fam} (no samples)"); ax.axis("off"); continue
|
||
img_path = pngs[-1] # last preview = closest to final_ema
|
||
ax.imshow(mpimg.imread(str(img_path)))
|
||
e, v = (None, None)
|
||
if log:
|
||
_, fid_vals = fid_series(log)
|
||
v = min(fid_vals) if fid_vals else None
|
||
ax.set_title(f"{fam} — {info['label']}\\n{img_path.stem}" + (f" best FID={v:.1f}" if v else ""), fontsize=10)
|
||
ax.axis("off")
|
||
fig.suptitle("Final samples from each family", fontsize=13, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 4. Training progression — early → late"),
|
||
code("""\
|
||
for fam, info in FAMILIES.items():
|
||
log = logs_p5[fam]
|
||
samples_dir = SAMPLES / info["p5"]
|
||
pngs = sorted(p for p in samples_dir.glob("epoch_*.png") if "_recon" not in p.stem)
|
||
if not pngs: continue
|
||
# Pick ~4 evenly spaced previews
|
||
n_pick = min(4, len(pngs))
|
||
picks = [pngs[i * (len(pngs) - 1) // (n_pick - 1)] for i in range(n_pick)] if n_pick > 1 else pngs
|
||
fig, axes = plt.subplots(1, len(picks), figsize=(4 * len(picks), 4))
|
||
if len(picks) == 1: axes = [axes]
|
||
for ax, p in zip(axes, picks):
|
||
ax.imshow(mpimg.imread(str(p)))
|
||
ep = int(p.stem.split("_")[1])
|
||
f = get_fid(log, ep) if log else None
|
||
ax.set_title(f"epoch {ep}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
|
||
ax.axis("off")
|
||
fig.suptitle(f"{fam} — {info['label']}", fontsize=11, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("## 5. Per-family training loss"),
|
||
code("""\
|
||
fig, axes = plt.subplots(1, 3, figsize=(18, 4))
|
||
for ax, (fam, info) in zip(axes, FAMILIES.items()):
|
||
log = logs_p5[fam]
|
||
if not log: ax.set_title(f"{fam} (missing)"); ax.axis("off"); continue
|
||
h = log["history"]; c = info["color"]
|
||
if fam == "GAN":
|
||
ax.plot(h["g_loss"], label="G loss", color=c, linewidth=1.2)
|
||
if "w_dist" in h:
|
||
ax.plot(h["w_dist"], label="W-dist", color=c, linewidth=1.2, linestyle="--")
|
||
ax.set_ylabel("Loss / W-distance")
|
||
elif fam == "VAE":
|
||
ax.plot(h["recon_loss"], label="recon", color=c, linewidth=1.2)
|
||
ax2 = ax.twinx()
|
||
ax2.plot(h["kl_loss"], label="KL", color=c, alpha=0.5, linestyle="--")
|
||
ax2.set_ylabel("KL")
|
||
ax.set_ylabel("Recon")
|
||
else: # DDPM
|
||
ax.plot(h["loss"], color=c, linewidth=1.2)
|
||
ax.set_ylabel("MSE on v-prediction")
|
||
ax.set_xlabel("Epoch"); ax.set_title(f"{fam}"); ax.legend(loc="upper right", fontsize=8)
|
||
plt.tight_layout(); plt.show()
|
||
"""),
|
||
md("""\
|
||
## 6. Latent interpolation — GAN and VAE
|
||
|
||
Smooth interpolation between two latent codes reveals whether the generator has
|
||
learned a continuous manifold rather than a sparse memorisation. DDPM has no
|
||
encoder, so this section is GAN/VAE only.
|
||
|
||
**Note on checkpoint loading:** the cell below uses the same checkpoint priority
|
||
as `tools/sampling.py` — `final_ema` first, then `best_ema` as fallback. This is
|
||
important: `best_ema` is the lowest-FID snapshot, which for slowly-converging
|
||
runs (e.g. DDPM) can be saved while the EMA shadow is still close to random init,
|
||
producing pure noise.
|
||
"""),
|
||
code("""\
|
||
import sys
|
||
sys.path.insert(0, "..")
|
||
import torch
|
||
from src.models import get_model
|
||
from src.utils import load_config
|
||
|
||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
def load_ema_model(run_name, config_name):
|
||
cfg = load_config(str(Path("../configs/phase5") / config_name))
|
||
model_obj, _ = get_model(cfg)
|
||
model = model_obj[0] if isinstance(model_obj, tuple) else model_obj
|
||
# final_ema first — see note above
|
||
for fname in [f"{run_name}_final_ema.pt", f"{run_name}_best_ema.pt"]:
|
||
p = Path("../outputs/models") / fname
|
||
if not p.exists(): continue
|
||
sd = torch.load(p, map_location=DEVICE, weights_only=True)
|
||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||
if not missing and not unexpected:
|
||
print(f" loaded {fname}")
|
||
return model.to(DEVICE).eval(), cfg
|
||
raise FileNotFoundError(f"No usable EMA checkpoint for {run_name}")
|
||
|
||
|
||
def slerp(z1, z2, t):
|
||
z1n = z1 / z1.norm(); z2n = z2 / z2.norm()
|
||
omega = torch.acos((z1n * z2n).sum().clamp(-1, 1))
|
||
if omega.abs() < 1e-6: return (1 - t) * z1 + t * z2
|
||
return (torch.sin((1 - t) * omega) / torch.sin(omega)) * z1 + \\
|
||
(torch.sin(t * omega) / torch.sin(omega)) * z2
|
||
"""),
|
||
code("""\
|
||
# ── GAN slerp interpolation ─────────────────────────────────────────────────
|
||
try:
|
||
gan_model, gan_cfg = load_ema_model("p5_gan", "p5_gan.json")
|
||
latent_dim = gan_cfg.get("latent_dim", 128)
|
||
z1 = torch.randn(1, latent_dim, 1, 1, device=DEVICE)
|
||
z2 = torch.randn(1, latent_dim, 1, 1, device=DEVICE)
|
||
with torch.no_grad():
|
||
zs = torch.cat([slerp(z1, z2, t) for t in torch.linspace(0, 1, 10)])
|
||
imgs = gan_model(zs).clamp(-1, 1)
|
||
imgs = (imgs + 1) / 2
|
||
fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))
|
||
for ax, img in zip(axes, imgs.cpu()):
|
||
ax.imshow(img.permute(1, 2, 0).numpy()); ax.axis("off")
|
||
fig.suptitle("GAN — slerp latent interpolation z₁ → z₂", fontsize=11, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
except Exception as e:
|
||
print(f"GAN interpolation skipped: {e}")
|
||
"""),
|
||
code("""\
|
||
# ── VAE encode-interpolate-decode ───────────────────────────────────────────
|
||
try:
|
||
from src.data import GeneratorDataset, get_transform
|
||
vae_model, vae_cfg = load_ema_model("p5_vae", "p5_vae.json")
|
||
ds = GeneratorDataset("../" + vae_cfg["data_dir"],
|
||
sources=vae_cfg.get("sources", ["wiki"]),
|
||
transform=get_transform(vae_cfg["image_size"], augment=False))
|
||
real = torch.stack([ds[i] for i in range(2)]).to(DEVICE)
|
||
with torch.no_grad():
|
||
mu, logvar = vae_model.encode(real)
|
||
z1, z2 = mu[0:1], mu[1:2]
|
||
zs = torch.cat([(1 - t) * z1 + t * z2 for t in torch.linspace(0, 1, 10)])
|
||
imgs = vae_model.decode(zs).clamp(-1, 1)
|
||
imgs = (imgs + 1) / 2
|
||
fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))
|
||
for ax, img in zip(axes, imgs.cpu()):
|
||
ax.imshow(img.permute(1, 2, 0).numpy()); ax.axis("off")
|
||
fig.suptitle("VAE — linear latent interpolation (encoded real images)", fontsize=11, fontweight="bold")
|
||
plt.tight_layout(); plt.show()
|
||
except Exception as e:
|
||
print(f"VAE interpolation skipped: {e}")
|
||
"""),
|
||
md(f"""\
|
||
## 7. Conclusions
|
||
|
||
| Family | Architecture | Best FID | Train time |
|
||
|---|---|---|---|
|
||
""" + "\n".join(
|
||
f"| {n.split('_')[1].upper()} | "
|
||
f"{ {'p5_gan':'WGAN-GP + SN + Attn', 'p5_vae':'VAE + Perceptual + PatchGAN', 'p5_ddpm':'DDPM cosine v-pred wider'}[n] } | "
|
||
f"{v:.1f} | {t:.1f} min |"
|
||
for n, v, _, t in rows
|
||
) + f"""
|
||
|
||
The ranking is **DDPM ≈ GAN ≪ VAE** by FID. DDPM and GAN trade places depending on
|
||
sampling settings (truncation, DDIM step count) but both clearly beat the VAE here —
|
||
the perceptual + adversarial losses help VAE reconstructions more than they help its
|
||
prior samples (which is what FID measures).
|
||
|
||
**Practical sampling notes** (encoded in `tools/sampling.py`):
|
||
- Load `*_final_ema.pt` rather than `*_best_ema.pt` — the latter can be saved very early
|
||
for slowly-converging runs.
|
||
- DDPM sampling at DDIM-50 matches the training preview; DDIM-100 is for FID only.
|
||
- GAN truncation is **off by default** (matches training); enable for sharper but less
|
||
diverse samples.
|
||
"""),
|
||
]
|
||
write_nb("phase5_analysis", cells)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("Building notebooks...")
|
||
build_phase0()
|
||
build_phase1()
|
||
build_phase2()
|
||
build_phase3()
|
||
build_phase4()
|
||
build_phase5()
|
||
print("Done.")
|