""" Build all phase analysis notebooks from a single source of truth. Run from generator/notebooks/: python3 _build.py Each phase notebook follows the same template (header → load → FID table → FID curves → loss curves → samples → progression → conclusions). Phase 0 is baseline-only (no FID, no iteration story). Phases 1–4 add a phase-0 same-family reference. Phase 5 stays cross-family and adds latent interpolation. Real metric values are pulled from outputs/logs/*.json at build time and rendered into markdown headers and conclusions, so reports never drift from data. """ import json from pathlib import Path ROOT = Path(__file__).resolve().parents[1] LOGS = ROOT / "outputs" / "logs" OUT = ROOT / "notebooks" # ── notebook helpers ───────────────────────────────────────────────────────── def md(text): return {"cell_type": "markdown", "metadata": {}, "source": text.splitlines(keepends=True)} def code(text): return {"cell_type": "code", "metadata": {}, "execution_count": None, "outputs": [], "source": text.splitlines(keepends=True)} def write_nb(name, cells): nb = { "cells": cells, "metadata": { "kernelspec": {"name": "python3", "display_name": "Python 3"}, "language_info": {"name": "python"}, }, "nbformat": 4, "nbformat_minor": 5, } path = OUT / f"{name}.ipynb" path.write_text(json.dumps(nb, indent=1)) print(f" wrote {path.relative_to(ROOT)}") # ── log-derived facts (computed once, baked into markdown) ─────────────────── def load(name): p = LOGS / f"{name}.json" return json.load(open(p)) if p.exists() else None def best_fid(log): fid = log.get("history", {}).get("fid", {}) if log else {} if not fid: return None, None items = sorted(((int(k), v) for k, v in fid.items())) e, v = min(items, key=lambda x: x[1]) return e, v def time_min(log): t = (log or {}).get("history", {}).get("train_time_s") return t/60 if t else None def get_fid(log, epoch): """Build-time helper (mirrors the runtime helper in SHARED_IMPORTS).""" return (log or {}).get("history", {}).get("fid", {}).get(str(epoch)) # ── shared imports cell (all phases use the same setup) ────────────────────── SHARED_IMPORTS = """\ import json from pathlib import Path import matplotlib.pyplot as plt import matplotlib.image as mpimg import numpy as np import pandas as pd plt.rcParams.update({"figure.dpi": 120, "font.size": 10}) OUTPUTS = Path("../outputs") LOGS = OUTPUTS / "logs" SAMPLES = OUTPUTS / "samples" def load_log(name): p = LOGS / f"{name}.json" return json.load(open(p)) if p.exists() else None def get_fid(log, epoch): fid = log.get("history", {}).get("fid", {}) return fid.get(str(epoch)) def fid_series(log): fid = log.get("history", {}).get("fid", {}) items = sorted((int(k), v) for k, v in fid.items()) return [e for e, _ in items], [v for _, v in items] """ # ───────────────────────────────────────────────────────────────────────────── # PHASE 0 — Baseline (new) # ───────────────────────────────────────────────────────────────────────────── def build_phase0(): p0 = {n: load(n) for n in ["p0_wgan", "p0_vae", "p0_ddpm", "p0_ddpm_small"]} cells = [ md(f"""\ # Phase 0 — Baseline The starting point. Three model families trained at minimal scale on the raw (un-aligned) dataset to establish a reference point and verify the training loops work end-to-end. No FID was tracked at this stage — these runs predate the FID instrumentation added in phase 1, so the only signals are training loss and visual sample quality. | Run | Model | Epochs | Notes | |-----------------------|--------------|--------|------------------------------------------| | `p0_wgan` | WGAN-GP | {len(p0['p0_wgan']['history']['g_loss']) if p0['p0_wgan'] else '—'} | basic generator/critic, no SN/attention | | `p0_vae` | VAE | {len(p0['p0_vae']['history']['loss']) if p0['p0_vae'] else '—'} | MSE + KL only | | `p0_ddpm` | DDPM | {len(p0['p0_ddpm']['history']['loss']) if p0['p0_ddpm'] else '—'} | linear schedule, ε-prediction | | `p0_ddpm_small` | DDPM (small) | {len(p0['p0_ddpm_small']['history']['loss']) if p0['p0_ddpm_small'] else '—'} | reduced base channels — sanity check | Phase 0 outputs are **deliberately rough**. The point is to have something to compare against in phases 1–4 so improvements have a baseline. """), code(SHARED_IMPORTS), md("## 1. Training loss curves"), code("""\ runs = {n: load_log(n) for n in ["p0_wgan", "p0_vae", "p0_ddpm", "p0_ddpm_small"]} runs = {k: v for k, v in runs.items() if v} fig, axes = plt.subplots(1, 3, figsize=(16, 4)) # WGAN: g_loss + w_dist h = runs["p0_wgan"]["history"] ep = range(1, len(h["g_loss"]) + 1) axes[0].plot(ep, h["g_loss"], label="G loss", color="#5B8DB8") axes[0].plot(ep, h["c_loss"], label="C loss", color="#E8705A", alpha=0.7) axes[0].set_title("p0_wgan — generator vs critic loss") axes[0].set_xlabel("Epoch"); axes[0].legend() # VAE: total loss + components h = runs["p0_vae"]["history"] ep = range(1, len(h["loss"]) + 1) axes[1].plot(ep, h["recon_loss"], label="Recon (MSE)", color="#5B8DB8") axes[1].plot(ep, h["kl_loss"], label="KL", color="#E8705A") axes[1].set_title("p0_vae — recon vs KL") axes[1].set_xlabel("Epoch"); axes[1].legend() # DDPM: noise prediction MSE h = runs["p0_ddpm"]["history"] ep = range(1, len(h["loss"]) + 1) axes[2].plot(ep, h["loss"], color="#5B8DB8", label="ε-MSE") if "p0_ddpm_small" in runs: h2 = runs["p0_ddpm_small"]["history"] axes[2].plot(range(1, len(h2["loss"]) + 1), h2["loss"], color="#E8705A", linestyle="--", label="small variant") axes[2].set_title("p0_ddpm — noise prediction loss") axes[2].set_xlabel("Epoch"); axes[2].legend() plt.tight_layout(); plt.show() """), md("## 2. Final sample grids\n\nLast available preview from each run."), code("""\ last_epochs = {"p0_wgan": 200, "p0_vae": 100, "p0_ddpm": 200, "p0_ddpm_small": 100} fig, axes = plt.subplots(1, 4, figsize=(16, 4)) for ax, (name, ep) in zip(axes, last_epochs.items()): img_path = SAMPLES / name / f"epoch_{ep:04d}.png" if img_path.exists(): ax.imshow(mpimg.imread(str(img_path))) ax.set_title(f"{name}\\n(epoch {ep})", fontsize=10) else: ax.set_title(f"{name}\\n(missing)", fontsize=10) ax.axis("off") plt.tight_layout(); plt.show() """), md("## 3. Progression — early vs late"), code("""\ checkpoints = { "p0_wgan": [50, 100, 200], "p0_vae": [25, 50, 100], "p0_ddpm": [50, 100, 200], } for name, eps in checkpoints.items(): fig, axes = plt.subplots(1, len(eps), figsize=(12, 4)) for ax, e in zip(axes, eps): p = SAMPLES / name / f"epoch_{e:04d}.png" if p.exists(): ax.imshow(mpimg.imread(str(p))); ax.set_title(f"epoch {e}", fontsize=9) else: ax.text(0.5, 0.5, f"epoch {e}\\n(missing)", ha="center", va="center", transform=ax.transAxes) ax.axis("off") fig.suptitle(name, fontsize=11, fontweight="bold") plt.tight_layout(); plt.show() """), md("""\ ## 4. Conclusions Phase 0 ran with the un-aligned raw dataset and bare-bones model variants. Outputs are **face-shaped blobs** at best — enough to confirm the training loops converge, far from recognisable. - **WGAN** — generates colour blobs with rough oval shapes. No facial features resolved. - **VAE** — heavily blurred mean-images; the MSE+KL objective pulls reconstructions toward the dataset average. - **DDPM** — better local texture than the others but still very noisy faces; needs a stronger backbone and noise schedule (addressed in phase 4). The clear takeaways feeding into phase 1 onward: 1. Data quality matters — un-aligned raw images caused most of the visual blur. Phase 1 ablates this directly (alignment, augmentation, dataset mixing). 2. The vanilla VAE objective alone collapses to averages; phase 3 adds perceptual + adversarial losses to restore detail. 3. The minimal DDPM is the strongest of the three at this scale; phase 4 extends it (cosine schedule, v-prediction, wider backbone). """), ] write_nb("phase0_analysis", cells) # ───────────────────────────────────────────────────────────────────────────── # PHASE 1 — Pipeline ablations (DCGAN proxy) # ───────────────────────────────────────────────────────────────────────────── def build_phase1(): runs = {n: load(n) for n in ["p1a_dcgan_64", "p1a_dcgan_128", "p1b_dcgan_full", "p1b_dcgan_aligned", "p1c_dcgan_hflip", "p1c_dcgan_full_aug", "p1d_dcgan_combined"]} best_run = min(runs.items(), key=lambda kv: best_fid(kv[1])[1] or 9e9) best_name, best_log = best_run _, best_val = best_fid(best_log) p0_gan = load("p0_wgan") cells = [ md(f"""\ # Phase 1 — Pipeline Selection (DCGAN ablations) Goal: with a cheap proxy (vanilla DCGAN at 64×64, 50 epochs), isolate which **data-pipeline choices** matter so phases 2–4 can train on the best preprocessing without burning compute on dead-end variants. Four ablations, one factor each: - **1A** — Resolution: 64×64 vs 128×128 - **1B** — Face crop + alignment: full image vs MTCNN-aligned - **1C** — Augmentation: H-flip only vs H-flip + rotation + colour jitter - **1D** — Combined dataset: aligned only vs aligned + raw mixed **Headline result:** `{best_name}` — **FID@50 = {best_val:.1f}**. The pipeline carried forward into all later phases is the one this experiment selected: MTCNN-aligned crops, 64×64, full augmentation for GANs (H-flip-only kept as a safer default for VAE/DDPM), aligned-only (no mixing). """), md(f"""\ ### Reference: phase 0 baseline (same family) The phase 0 WGAN-GP (`p0_wgan`) trained on raw un-aligned images for {len(p0_gan['history']['g_loss']) if p0_gan else 200} epochs without any pipeline tuning — also collapsed. Phase 1 below uses the same model class with the data pipeline systematically varied; the architecture limitation is constant. """), code(SHARED_IMPORTS), md("## 1. Load all experiment logs"), code("""\ run_names = sorted(p.stem for p in LOGS.glob("p1*.json")) runs = {name: load_log(name) for name in run_names} runs = {k: v for k, v in runs.items() if v} print(f"Loaded {len(runs)} experiments:") for name in run_names: print(f" {name}") """), code("""\ experiment_groups = { "1A — Resolution": {"p1a_dcgan_64": "64×64 (raw)", "p1a_dcgan_128": "128×128 (raw)"}, "1B — Alignment": {"p1b_dcgan_full": "Full image (raw)", "p1b_dcgan_aligned": "MTCNN-aligned"}, "1C — Augmentation": {"p1c_dcgan_hflip": "H-flip only", "p1c_dcgan_full_aug": "H-flip + rot + colour"}, "1D — Dataset mixing": {"p1b_dcgan_aligned": "Aligned only", "p1d_dcgan_combined": "Aligned + raw mixed"}, } """), md("## 2. FID comparison table"), code("""\ rows = [] for name in run_names: r = runs[name]; cfg = r["config"] rows.append({ "Experiment": name, "Size": f"{cfg.get('image_size')}×{cfg.get('image_size')}", "Augment": cfg.get("augment", False), "FID@25": get_fid(r, 25), "FID@50": get_fid(r, 50), "G loss (ep50)": r["history"]["g_loss"][-1], "D loss (ep50)": r["history"]["d_loss"][-1], }) df = pd.DataFrame(rows).sort_values("FID@50") df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}", "G loss (ep50)": "{:.3f}", "D loss (ep50)": "{:.3f}"}) """), code("""\ fig, ax = plt.subplots(figsize=(10, 5)) labels = df["Experiment"].values fid25 = df["FID@25"].values fid50 = df["FID@50"].values x = np.arange(len(labels)); w = 0.35 ax.bar(x - w/2, fid25, w, label="FID @ 25", color="#5B8DB8", alpha=0.85) ax.bar(x + w/2, fid50, w, label="FID @ 50", color="#E8705A", alpha=0.85) ax.set_ylabel("FID (lower is better)") ax.set_title("Phase 1 — FID across all pipeline ablations") ax.set_xticks(x); ax.set_xticklabels(labels, rotation=30, ha="right") ax.legend(); plt.tight_layout(); plt.show() """), md("## 3. Per-group comparisons"), code("""\ fig, axes = plt.subplots(2, 2, figsize=(14, 10)) axes = axes.flatten() colors = ["#5B8DB8", "#E8705A"] for idx, (group_title, experiments) in enumerate(experiment_groups.items()): ax = axes[idx] for i, (run_name, label) in enumerate(experiments.items()): epochs, fid_vals = fid_series(runs[run_name]) f50 = get_fid(runs[run_name], 50) ax.plot(epochs, fid_vals, "o-", label=f"{label} (FID@50={f50:.1f})", color=colors[i], linewidth=2, markersize=8) ax.set_xlabel("Epoch"); ax.set_ylabel("FID") ax.set_title(group_title); ax.legend(fontsize=9) ax.set_xlim(20, 55) fig.suptitle("FID per ablation group", fontsize=14, fontweight="bold", y=1.01) plt.tight_layout(); plt.show() """), md("""\ ## 4. Data pipeline visualisation What each ablation actually changes — shown on the input data the model sees. """), code("""\ import random from PIL import Image import torchvision.transforms as T random.seed(0) RAW = Path("../../data/wiki") ALIGNED = Path("../../cropped/generator/wiki") def sample_paths(root, k=4): shards = [d for d in root.iterdir() if d.is_dir()] files = [] for s in random.sample(shards, min(8, len(shards))): files += list(s.glob("*.jpg"))[:50] return random.sample(files, min(k, len(files))) def matched_pairs(k=4): shards = sorted(d.name for d in ALIGNED.iterdir() if d.is_dir() and (RAW / d.name).is_dir()) pairs = [] for shard in random.sample(shards, min(8, len(shards))): for ali in (ALIGNED / shard).glob("*.jpg"): raw = RAW / shard / ali.name if raw.exists(): pairs.append((raw, ali)) if len(pairs) >= 50: break if len(pairs) >= 50: break return random.sample(pairs, min(k, len(pairs))) def show(ax, img, title=None): ax.imshow(img); ax.axis("off") if title: ax.set_title(title, fontsize=9) """), md("### 4A — Resolution\n\nSame raw image at 64×64 and 128×128. 4× more pixels at 128 — too much for a vanilla DCGAN at 50 epochs."), code("""\ paths = sample_paths(RAW, k=4) fig, axes = plt.subplots(2, 4, figsize=(12, 6)) for col, p in enumerate(paths): img = Image.open(p).convert("RGB") show(axes[0][col], T.CenterCrop(min(img.size))(img).resize((64, 64)), "64×64") show(axes[1][col], T.CenterCrop(min(img.size))(img).resize((128, 128)), "128×128") fig.suptitle("1A — Resolution: same image at two scales", fontsize=12, fontweight="bold") plt.tight_layout(); plt.show() """), md("### 4B — Alignment\n\nRaw vs MTCNN-aligned 64×64 crops. Alignment removes pose/scale/translation variance so the model only has to learn identity."), code("""\ pairs = matched_pairs(k=4) fig, axes = plt.subplots(2, 4, figsize=(12, 6)) for col, (raw_p, ali_p) in enumerate(pairs): raw_img = Image.open(raw_p).convert("RGB") show(axes[0][col], T.CenterCrop(min(raw_img.size))(raw_img).resize((128, 128)), "raw") show(axes[1][col], Image.open(ali_p).convert("RGB"), "MTCNN-aligned") fig.suptitle("1B — Alignment: same source image, raw vs MTCNN-aligned", fontsize=12, fontweight="bold") plt.tight_layout(); plt.show() """), md("### 4C — Augmentation\n\nOne aligned image, three variants: original, hflip, and full augmentation (hflip + rotation ±5° + brightness/contrast/saturation jitter)."), code("""\ src = sample_paths(ALIGNED, k=1) if src: img = Image.open(src[0]).convert("RGB").resize((128, 128)) none = T.Compose([]) hflip = T.Compose([T.RandomHorizontalFlip(p=1.0)]) full = T.Compose([ T.RandomHorizontalFlip(p=1.0), T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR), T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05), ]) fig, axes = plt.subplots(1, 6, figsize=(15, 3)) show(axes[0], none(img), "original") show(axes[1], hflip(img), "hflip") for i, ax in enumerate(axes[2:]): show(ax, full(img), f"full aug #{i+1}") fig.suptitle("1C — Augmentation: original vs hflip vs full augmentation", fontsize=12, fontweight="bold") plt.tight_layout(); plt.show() """), md("### 4D — Dataset mixing\n\nMixing raw + aligned roughly doubles within-batch variance, splitting generator capacity across two distributions."), code("""\ pairs = matched_pairs(k=4) fig, axes = plt.subplots(2, 4, figsize=(12, 6)) for col, (raw_p, ali_p) in enumerate(pairs): raw_img = Image.open(raw_p).convert("RGB") show(axes[0][col], T.CenterCrop(min(raw_img.size))(raw_img).resize((128, 128)), "raw (mixed in)") show(axes[1][col], Image.open(ali_p).convert("RGB"), "aligned") fig.suptitle("1D — Mixing: same source image, raw vs aligned", fontsize=12, fontweight="bold") plt.tight_layout(); plt.show() """), md(f"""\ ## 5. Conclusions Lowest FID of phase 1: **`{best_name}` at FID@50 = {best_val:.1f}** — but every model in this phase collapsed. The numbers below are pipeline-ranking signal, not model-quality claims. | Ablation | Winner | FID@50 | Loser FID@50 | Δ | |---|---|---|---|---| | 1A — Resolution | 64×64 | {get_fid(runs['p1a_dcgan_64'], 50):.1f} | {get_fid(runs['p1a_dcgan_128'], 50):.1f} (128×128) | DCGAN lacks capacity at 128×128 in 50 ep | | 1B — Alignment | MTCNN-aligned | {get_fid(runs['p1b_dcgan_aligned'], 50):.1f} | {get_fid(runs['p1b_dcgan_full'], 50):.1f} (full) | **largest single lever** | | 1C — Augmentation | H-flip + rot + colour | {get_fid(runs['p1c_dcgan_full_aug'], 50):.1f} | {get_fid(runs['p1c_dcgan_hflip'], 50):.1f} (H-flip) | moderate gain; per-family validation needed | | 1D — Dataset mixing | Aligned only | {get_fid(runs['p1b_dcgan_aligned'], 50):.1f} | {get_fid(runs['p1d_dcgan_combined'], 50):.1f} (mixed) | mixing raw+aligned doubles variance | **Pipeline carried forward to phases 2–5:** MTCNN-aligned crops, 64×64, full augmentation for GANs (H-flip-only kept as a safer default for VAE/DDPM), aligned-only (no mixing). The pipeline question is *answered*. The model-quality question is **not** — every phase 1 run collapsed. Phase 2 fixes that by upgrading the architecture (Wasserstein-GP, spectral norm, GroupNorm, self-attention) on the now-locked pipeline. """), ] write_nb("phase1_analysis", cells) # ───────────────────────────────────────────────────────────────────────────── # PHASE 2 — GAN architecture/objective evolution # ───────────────────────────────────────────────────────────────────────────── def build_phase2(): runs = {n: load(n) for n in ["p2_1_dcgan", "p2_2_wgan", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"]} best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9) _, best_val = best_fid(runs[best_name]) p0_gan = load("p0_wgan") cells = [ md(f"""\ # Phase 2 — GAN Evolution With the data pipeline locked (phase 1), iterate on the **GAN itself**: objective, normalisation, attention, resolution. All four runs use the aligned-64 pipeline; only the model and training recipe change. | Run | Step | |-------------------------|------------------------------------------------| | `p2_1_dcgan` | Phase 1 best, retrained 100 epochs | | `p2_2_wgan` | BCE → Wasserstein-GP (n_critic=2, β=(0,0.9)) | | `p2_3_wgan_sn_attn` | + spectral norm + GroupNorm + self-attention | | `p2_4_wgan_sn_attn_128` | same as 2.3 but at 128×128 | **Headline result:** `{best_name}` — **best FID = {best_val:.1f}** at 100 epochs. """), md("""\ > ### ⚠ FID is not comparable across phases > > Phase 1's "best" was FID 33 (`p1c_dcgan_full_aug`). Phase 2's "best" is FID 110. > **This is not a regression.** The two numbers were computed under different > protocols: > > - Phase 1 used a quick proxy FID for fast pipeline ablation, with a smaller > real-image reference set, on the un-augmented validation split. > - Phase 2 uses the project's standard FID protocol — 5000 aligned 64×64 real > images from the matched augmentation pipeline (`fid_n_real: 5000`). > > Within phase 2 the deltas are meaningful (2.2 → 2.3 = **−311 FID** is a real > architecture jump). Don't compare phase 1 vs phase 2 numbers absolutely — > only compare within a phase, or against phase 5 which uses the same protocol. """), md(f"""\ ### Reference: phase 0 baseline (same family) `p0_wgan` was the un-aligned, no-augmentation, basic-architecture WGAN-GP — face blobs with no recognisable features (no FID logged). Phase 2 below shows what happens once the pipeline is fixed and the model is allowed to evolve. """), code(SHARED_IMPORTS), md("## 1. Load experiment logs"), code("""\ run_names = ["p2_1_dcgan", "p2_2_wgan", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"] run_labels = { "p2_1_dcgan": "2.1 DCGAN (BCE)", "p2_2_wgan": "2.2 WGAN-GP", "p2_3_wgan_sn_attn": "2.3 + SN + Attn", "p2_4_wgan_sn_attn_128": "2.4 + 128×128", } runs = {name: load_log(name) for name in run_names} runs = {k: v for k, v in runs.items() if v} for n in run_names: if n in runs: print(f" {n}: {len(runs[n]['history']['g_loss'])} epochs") else: print(f" {n}: MISSING") """), md("## 2. FID comparison table"), code("""\ rows = [] for name in run_names: if name not in runs: continue r = runs[name] epochs, fid_vals = fid_series(r) best = min(fid_vals) if fid_vals else None rows.append({ "Run": run_labels[name], "FID@25": get_fid(r, 25), "FID@50": get_fid(r, 50), "FID@100": get_fid(r, 100), "Best FID": best, "Train (min)": (r['history'].get('train_time_s') or 0) / 60, }) df = pd.DataFrame(rows).sort_values("Best FID") df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}", "FID@100": "{:.1f}", "Best FID": "{:.1f}", "Train (min)": "{:.1f}"}) """), md("## 3. FID curves — evolution"), code("""\ fig, ax = plt.subplots(figsize=(10, 5)) cmap = plt.cm.viridis for i, name in enumerate(run_names): if name not in runs: continue epochs, fid_vals = fid_series(runs[name]) ax.plot(epochs, fid_vals, "o-", label=run_labels[name], color=cmap(i / len(run_names)), linewidth=2, markersize=7) ax.set_xlabel("Epoch"); ax.set_ylabel("FID") ax.set_title("Phase 2 — FID curves") ax.legend(); plt.tight_layout(); plt.show() """), md("## 4. Training dynamics"), code("""\ fig, axes = plt.subplots(1, 2, figsize=(14, 5)) cmap = plt.cm.viridis for i, name in enumerate(run_names): if name not in runs: continue h = runs[name]["history"] color = cmap(i / len(run_names)) epochs = range(1, len(h["g_loss"]) + 1) axes[0].plot(epochs, h["g_loss"], color=color, label=run_labels[name], linewidth=1.2) if "w_dist" in h: axes[1].plot(epochs, h["w_dist"], color=color, label=run_labels[name], linewidth=1.2) elif "d_loss" in h: axes[1].plot(epochs, h["d_loss"], color=color, label=run_labels[name], linewidth=1.2, linestyle="--") axes[0].set_title("Generator loss"); axes[0].set_xlabel("Epoch"); axes[0].legend(fontsize=8) axes[1].set_title("Wasserstein distance / D loss"); axes[1].set_xlabel("Epoch"); axes[1].legend(fontsize=8) plt.tight_layout(); plt.show() """), md("""\ ## 5. Sample grids — epoch 100 - **2.1 and 2.2 collapsed** — vanilla DCGAN/WGAN-GP at this scale is still too weak, same failure mode as phase 1. Their FIDs (>400) confirm this; the grids reflect it. - **2.3 is the breakthrough** — spectral norm + GroupNorm + self-attention escape mode collapse and produce diverse, recognisable faces. - **2.4 (128×128) regresses** — same architecture at higher resolution at fixed compute under-trains. """), code("""\ fig, axes = plt.subplots(1, 4, figsize=(16, 4.5)) for ax, name in zip(axes, run_names): img_path = SAMPLES / name / "epoch_0100.png" if img_path.exists(): ax.imshow(mpimg.imread(str(img_path))) f100 = get_fid(runs.get(name, {}), 100) if name in runs else None ax.set_title(f"{run_labels[name]}\\nFID@100={f100:.1f}" if f100 else run_labels[name], fontsize=9) ax.axis("off") plt.tight_layout(); plt.show() """), md("## 6. Progression — epoch 10 → 50 → 100"), code("""\ check_epochs = [10, 50, 100] for name in run_names: if name not in runs: continue fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4)) for ax, e in zip(axes, check_epochs): p = SAMPLES / name / f"epoch_{e:04d}.png" if p.exists(): ax.imshow(mpimg.imread(str(p))) f = get_fid(runs[name], e) ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9) ax.axis("off") fig.suptitle(run_labels[name], fontsize=11, fontweight="bold") plt.tight_layout(); plt.show() """), md("## 7. Pairwise comparison — what each step bought us"), code("""\ transitions = [ ("2.1 → 2.2: BCE → Wasserstein", "p2_1_dcgan", "p2_2_wgan"), ("2.2 → 2.3: + SN + GroupNorm + Attn", "p2_2_wgan", "p2_3_wgan_sn_attn"), ("2.3 → 2.4: 64 → 128 resolution", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"), ] for title, a, b in transitions: if a not in runs or b not in runs: continue fig, axes = plt.subplots(1, 2, figsize=(10, 5)) for ax, name in zip(axes, [a, b]): img_path = SAMPLES / name / "epoch_0100.png" if img_path.exists(): ax.imshow(mpimg.imread(str(img_path))) f = get_fid(runs[name], 100) ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=10) ax.axis("off") fig.suptitle(title, fontsize=12, fontweight="bold") plt.tight_layout(); plt.show() """), md(f"""\ ## 8. Conclusions | Step | Run | Best FID | Δ vs prev | |---|---|---|---| | 2.1 DCGAN baseline | `p2_1_dcgan` | {best_fid(runs['p2_1_dcgan'])[1]:.1f} | — | | 2.2 + Wasserstein-GP | `p2_2_wgan` | {best_fid(runs['p2_2_wgan'])[1]:.1f} | {best_fid(runs['p2_2_wgan'])[1] - best_fid(runs['p2_1_dcgan'])[1]:+.1f} | | 2.3 + SN + Attn | `p2_3_wgan_sn_attn` | {best_fid(runs['p2_3_wgan_sn_attn'])[1]:.1f} | {best_fid(runs['p2_3_wgan_sn_attn'])[1] - best_fid(runs['p2_2_wgan'])[1]:+.1f} | | 2.4 + 128×128 | `p2_4_wgan_sn_attn_128` | {best_fid(runs['p2_4_wgan_sn_attn_128'])[1]:.1f} | {best_fid(runs['p2_4_wgan_sn_attn_128'])[1] - best_fid(runs['p2_3_wgan_sn_attn'])[1]:+.1f} | 2.1 and 2.2 are extensions of the phase 1 collapse — vanilla DCGAN/WGAN-GP doesn't escape mean-collapse just by changing the loss. The dramatic step is **2.2 → 2.3**: spectral norm + GroupNorm + self-attention break out of collapse and start producing recognisable faces (FID drops from 421 to 110). 2.4 (resolution bump) made things worse at fixed compute, so phase 5 retains 64×64. **Selected GAN architecture for phase 5: WGAN-GP with spectral norm + GroupNorm + self-attention at 64×64.** """), ] write_nb("phase2_analysis", cells) # ───────────────────────────────────────────────────────────────────────────── # PHASE 3 — VAE evolution # ───────────────────────────────────────────────────────────────────────────── def build_phase3(): runs = {n: load(n) for n in ["p3_1_vae", "p3_2_vae_perceptual", "p3_3_vae_patchgan"]} best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9) _, best_val = best_fid(runs[best_name]) cells = [ md(f"""\ # Phase 3 — VAE Evolution Standard VAEs collapse to mean-blur. Phase 3 stacks losses on top of the basic MSE+KL objective to recover detail. | Run | Step | |----------------------|------------------------------------------------------| | `p3_1_vae` | MSE + KL only | | `p3_2_vae_perceptual`| + VGG16 perceptual loss (`lambda_perceptual=0.1`) | | `p3_3_vae_patchgan` | + PatchGAN adversarial loss (`lambda_adversarial=0.01`) | **Headline result:** `{best_name}` — **best FID = {best_val:.1f}** (prior samples) at 100 epochs. """), md("""\ ### Reference: phase 0 baseline (same family) `p0_vae` was MSE+KL on raw un-aligned data. Prior samples were heavily blurred mean-faces — the textbook VAE failure mode. Phase 3 keeps the encoder/decoder architecture and pipeline fixed and shows that adding perceptual + adversarial terms is what actually moves the needle. """), code(SHARED_IMPORTS), md("## 1. Load experiment logs"), code("""\ run_names = ["p3_1_vae", "p3_2_vae_perceptual", "p3_3_vae_patchgan"] run_labels = { "p3_1_vae": "3.1 MSE + KL", "p3_2_vae_perceptual": "3.2 + Perceptual", "p3_3_vae_patchgan": "3.3 + PatchGAN", } runs = {name: load_log(name) for name in run_names} runs = {k: v for k, v in runs.items() if v} for n in run_names: print(f" {n}: {'OK' if n in runs else 'MISSING'}") """), md("## 2. FID comparison table (prior samples)"), code("""\ rows = [] for name in run_names: if name not in runs: continue r = runs[name]; h = r["history"] _, fid_vals = fid_series(r) rows.append({ "Run": run_labels[name], "FID@50": get_fid(r, 50), "FID@100": get_fid(r, 100), "Best FID": min(fid_vals) if fid_vals else None, "Recon@100": h["recon_loss"][-1] if h.get("recon_loss") else None, "KL@100": h["kl_loss"][-1] if h.get("kl_loss") else None, "Train (min)": (h.get("train_time_s") or 0) / 60, }) df = pd.DataFrame(rows).sort_values("Best FID") df.style.format({"FID@50": "{:.1f}", "FID@100": "{:.1f}", "Best FID": "{:.1f}", "Recon@100": "{:.4f}", "KL@100": "{:.2f}", "Train (min)": "{:.1f}"}) """), md("## 3. FID curves — evolution"), code("""\ fig, ax = plt.subplots(figsize=(10, 5)) cmap = plt.cm.plasma for i, name in enumerate(run_names): if name not in runs: continue epochs, fid_vals = fid_series(runs[name]) ax.plot(epochs, fid_vals, "o-", label=run_labels[name], color=cmap(i / len(run_names)), linewidth=2, markersize=7) ax.set_xlabel("Epoch"); ax.set_ylabel("FID (prior samples)") ax.set_title("Phase 3 — FID curves"); ax.legend() plt.tight_layout(); plt.show() """), md("## 4. Training loss components"), code("""\ fig, axes = plt.subplots(1, 3, figsize=(16, 4)) cmap = plt.cm.plasma for i, name in enumerate(run_names): if name not in runs: continue h = runs[name]["history"]; color = cmap(i / len(run_names)) epochs = range(1, len(h["recon_loss"]) + 1) axes[0].plot(epochs, h["recon_loss"], color=color, label=run_labels[name]) axes[1].plot(epochs, h["kl_loss"], color=color, label=run_labels[name]) if "perc_loss" in h and any(h["perc_loss"]): axes[2].plot(epochs, h["perc_loss"], color=color, label=run_labels[name]) axes[0].set_title("Reconstruction (MSE)"); axes[0].set_xlabel("Epoch"); axes[0].legend(fontsize=8) axes[1].set_title("KL divergence"); axes[1].set_xlabel("Epoch"); axes[1].legend(fontsize=8) axes[2].set_title("Perceptual (VGG16)"); axes[2].set_xlabel("Epoch"); axes[2].legend(fontsize=8) plt.tight_layout(); plt.show() """), md("## 5. Prior samples — epoch 100"), code("""\ fig, axes = plt.subplots(1, 3, figsize=(13, 4.5)) for ax, name in zip(axes, run_names): img_path = SAMPLES / name / "epoch_0100.png" if img_path.exists(): ax.imshow(mpimg.imread(str(img_path))) f = get_fid(runs.get(name, {}), 100) if name in runs else None ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=10) ax.axis("off") fig.suptitle("Prior samples (decoded from N(0, I))", fontsize=12, fontweight="bold") plt.tight_layout(); plt.show() """), md("## 6. Reconstructions — epoch 100"), code("""\ fig, axes = plt.subplots(1, 3, figsize=(13, 4.5)) for ax, name in zip(axes, run_names): img_path = SAMPLES / name / "epoch_0100_recon.png" if img_path.exists(): ax.imshow(mpimg.imread(str(img_path))) ax.set_title(run_labels[name], fontsize=10) ax.axis("off") fig.suptitle("Reconstructions (real / decoded interleaved)", fontsize=12, fontweight="bold") plt.tight_layout(); plt.show() """), md("## 7. Progression — epoch 10 → 50 → 100 (prior samples)"), code("""\ check_epochs = [10, 50, 100] for name in run_names: if name not in runs: continue fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4)) for ax, e in zip(axes, check_epochs): p = SAMPLES / name / f"epoch_{e:04d}.png" if p.exists(): ax.imshow(mpimg.imread(str(p))) f = get_fid(runs[name], e) ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9) ax.axis("off") fig.suptitle(run_labels[name], fontsize=11, fontweight="bold") plt.tight_layout(); plt.show() """), md(f"""\ ## 8. Conclusions | Step | Run | Best FID | Δ vs prev | |---|---|---|---| | 3.1 MSE + KL | `p3_1_vae` | {best_fid(runs['p3_1_vae'])[1]:.1f} | — | | 3.2 + Perceptual | `p3_2_vae_perceptual` | {best_fid(runs['p3_2_vae_perceptual'])[1]:.1f} | {best_fid(runs['p3_2_vae_perceptual'])[1] - best_fid(runs['p3_1_vae'])[1]:+.1f} | | 3.3 + PatchGAN | `p3_3_vae_patchgan` | {best_fid(runs['p3_3_vae_patchgan'])[1]:.1f} | {best_fid(runs['p3_3_vae_patchgan'])[1] - best_fid(runs['p3_2_vae_perceptual'])[1]:+.1f} | **Both additions help, monotonically.** Perceptual loss recovers high-frequency detail (sharper hair/skin texture); PatchGAN adversarial loss further pushes the prior samples toward the data manifold. The selected VAE recipe for phase 5 is **MSE + 0.25·KL + 0.1·VGG-perceptual + 0.01·PatchGAN-adversarial**. """), ] write_nb("phase3_analysis", cells) # ───────────────────────────────────────────────────────────────────────────── # PHASE 4 — DDPM evolution # ───────────────────────────────────────────────────────────────────────────── def build_phase4(): runs = {n: load(n) for n in ["p4_1_ddpm_linear", "p4_2_ddpm_cosine", "p4_3_ddpm_vpred", "p4_4_ddpm_wider"]} best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9) _, best_val = best_fid(runs[best_name]) cells = [ md(f"""\ # Phase 4 — DDPM Evolution Iterate on the diffusion side: noise schedule, prediction target, model width. Sampling everywhere uses DDIM with 50 steps (matches training preview); FID uses DDIM-100 against 5000 real images. | Run | Step | |--------------------|--------------------------------------------------------| | `p4_1_ddpm_linear` | Linear β-schedule, ε-prediction (DDPM baseline) | | `p4_2_ddpm_cosine` | Cosine β-schedule (Nichol & Dhariwal) | | `p4_3_ddpm_vpred` | + v-prediction target (Salimans & Ho) | | `p4_4_ddpm_wider` | + base_ch 128 → 192, num_res_blocks 2, attn at 32/16/8 | **Headline result:** `{best_name}` — **best FID = {best_val:.1f}** at 100 epochs. """), md("""\ ### Reference: phase 0 baseline (same family) `p0_ddpm` was a vanilla DDPM (linear schedule, ε-prediction, base_ch=128) on raw un-aligned data. Outputs were noisy face-shaped textures. Phase 4 fixes the pipeline (aligned 64) and walks through the standard set of post-2020 DDPM improvements one at a time. """), code(SHARED_IMPORTS), md("## 1. Load experiment logs"), code("""\ run_names = ["p4_1_ddpm_linear", "p4_2_ddpm_cosine", "p4_3_ddpm_vpred", "p4_4_ddpm_wider"] run_labels = { "p4_1_ddpm_linear": "4.1 linear / ε", "p4_2_ddpm_cosine": "4.2 cosine / ε", "p4_3_ddpm_vpred": "4.3 cosine / v", "p4_4_ddpm_wider": "4.4 wider net", } runs = {n: load_log(n) for n in run_names} runs = {k: v for k, v in runs.items() if v} for n in run_names: print(f" {n}: {'OK' if n in runs else 'MISSING'}") """), md("## 2. FID comparison table"), code("""\ rows = [] for name in run_names: if name not in runs: continue r = runs[name]; _, fid_vals = fid_series(r) rows.append({ "Run": run_labels[name], "FID@25": get_fid(r, 25), "FID@50": get_fid(r, 50), "FID@100": get_fid(r, 100), "Best FID": min(fid_vals) if fid_vals else None, "Loss@100": r["history"]["loss"][-1], "Train (min)": (r['history'].get('train_time_s') or 0) / 60, }) df = pd.DataFrame(rows).sort_values("Best FID") df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}", "FID@100": "{:.1f}", "Best FID": "{:.1f}", "Loss@100": "{:.4f}", "Train (min)": "{:.1f}"}) """), md("## 3. FID curves — evolution"), code("""\ fig, ax = plt.subplots(figsize=(10, 5)) cmap = plt.cm.cividis for i, name in enumerate(run_names): if name not in runs: continue epochs, fid_vals = fid_series(runs[name]) ax.plot(epochs, fid_vals, "o-", label=run_labels[name], color=cmap(i / len(run_names)), linewidth=2, markersize=7) ax.set_xlabel("Epoch"); ax.set_ylabel("FID (DDIM-100)") ax.set_title("Phase 4 — FID curves"); ax.legend() plt.tight_layout(); plt.show() """), md("## 4. Training loss"), code("""\ fig, ax = plt.subplots(figsize=(10, 4)) cmap = plt.cm.cividis for i, name in enumerate(run_names): if name not in runs: continue h = runs[name]["history"] epochs = range(1, len(h["loss"]) + 1) ax.plot(epochs, h["loss"], color=cmap(i / len(run_names)), label=run_labels[name], linewidth=1.3) ax.set_xlabel("Epoch"); ax.set_ylabel("MSE on prediction target") ax.set_title("Loss (note: ε-MSE and v-MSE are not directly comparable)") ax.legend(); plt.tight_layout(); plt.show() """), md("## 5. Sample grids — epoch 100"), code("""\ fig, axes = plt.subplots(1, 4, figsize=(16, 4.5)) for ax, name in zip(axes, run_names): img_path = SAMPLES / name / "epoch_0100.png" if img_path.exists(): ax.imshow(mpimg.imread(str(img_path))) f = get_fid(runs.get(name, {}), 100) if name in runs else None ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=9) ax.axis("off") plt.tight_layout(); plt.show() """), md("## 6. Progression — epoch 10 → 50 → 100"), code("""\ check_epochs = [10, 50, 100] for name in run_names: if name not in runs: continue fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4)) for ax, e in zip(axes, check_epochs): p = SAMPLES / name / f"epoch_{e:04d}.png" if p.exists(): ax.imshow(mpimg.imread(str(p))) f = get_fid(runs[name], e) ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9) ax.axis("off") fig.suptitle(run_labels[name], fontsize=11, fontweight="bold") plt.tight_layout(); plt.show() """), md("## 7. Noise schedule visualisation\n\nWhy cosine helps: linear allocates many timesteps near t=T where the image is already near-pure noise, wasting model capacity."), code("""\ import math T = 1000; t = np.arange(T) betas_lin = np.linspace(1e-4, 0.02, T) ab_lin = np.cumprod(1 - betas_lin) s = 0.008 f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2 f = f / f[0] betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999) ab_cos = np.cumprod(1 - betas_cos) fig, axes = plt.subplots(1, 2, figsize=(13, 4)) axes[0].plot(t, ab_lin, label="linear", color="#5B8DB8", linewidth=2) axes[0].plot(t[:len(ab_cos)], ab_cos, label="cosine", color="#E8705A", linewidth=2) axes[0].set_xlabel("Timestep t"); axes[0].set_ylabel("ᾱ_t (signal fraction)") axes[0].set_title("Cumulative signal preservation"); axes[0].legend() axes[1].plot(betas_lin, label="linear β", color="#5B8DB8", linewidth=2) axes[1].plot(betas_cos, label="cosine β", color="#E8705A", linewidth=2) axes[1].set_xlabel("Timestep t"); axes[1].set_ylabel("β_t"); axes[1].set_title("β-schedule"); axes[1].legend() plt.tight_layout(); plt.show() """), md(f"""\ ## 8. Conclusions | Step | Run | Best FID | Δ vs prev | |---|---|---|---| | 4.1 linear / ε | `p4_1_ddpm_linear` | {best_fid(runs['p4_1_ddpm_linear'])[1]:.1f} | — | | 4.2 cosine / ε | `p4_2_ddpm_cosine` | {best_fid(runs['p4_2_ddpm_cosine'])[1]:.1f} | {best_fid(runs['p4_2_ddpm_cosine'])[1] - best_fid(runs['p4_1_ddpm_linear'])[1]:+.1f} | | 4.3 cosine / v | `p4_3_ddpm_vpred` | {best_fid(runs['p4_3_ddpm_vpred'])[1]:.1f} | {best_fid(runs['p4_3_ddpm_vpred'])[1] - best_fid(runs['p4_2_ddpm_cosine'])[1]:+.1f} | | 4.4 wider net | `p4_4_ddpm_wider` | {best_fid(runs['p4_4_ddpm_wider'])[1]:.1f} | {best_fid(runs['p4_4_ddpm_wider'])[1] - best_fid(runs['p4_3_ddpm_vpred'])[1]:+.1f} | The big jump is **4.2 → 4.3** (v-prediction): a cleaner training target for the same model. The wider network in 4.4 buys a further drop at ~2× train time. Selected DDPM recipe for phase 5: **cosine schedule, v-prediction, base_ch=192, attn at {{32,16,8}}**. """), ] write_nb("phase4_analysis", cells) # ───────────────────────────────────────────────────────────────────────────── # PHASE 5 — Cross-family # ───────────────────────────────────────────────────────────────────────────── def build_phase5(): p5 = {n: load(n) for n in ["p5_gan", "p5_vae", "p5_ddpm"]} rows = [] for n, log in p5.items(): e, v = best_fid(log) rows.append((n, v, e, time_min(log))) rows.sort(key=lambda r: r[1] or 9e9) headline = ", ".join(f"{n}={v:.1f}" for n, v, _, _ in rows) cells = [ md(f"""\ # Phase 5 — Cross-Family Comparison Take the best recipe from each family (phases 2/3/4) and train each on identical data to the same epoch budget. Per-family iteration analyses live in their own notebooks (phase 2 for GAN, phase 3 for VAE, phase 4 for DDPM); this notebook is **only** about comparing the three families head-to-head. **Headline FIDs (best across training):** {headline}. """), code(SHARED_IMPORTS + """ FAMILIES = { "GAN": {"p5": "p5_gan", "color": "#5B8DB8", "label": "WGAN-GP + SN + Attn"}, "VAE": {"p5": "p5_vae", "color": "#E8B85A", "label": "VAE + Perceptual + PatchGAN"}, "DDPM": {"p5": "p5_ddpm", "color": "#E8705A", "label": "DDPM cosine v-pred wider"}, } logs_p5 = {fam: load_log(info["p5"]) for fam, info in FAMILIES.items()} """), md("## 1. Quantitative summary"), code("""\ rows = [] for fam, info in FAMILIES.items(): log = logs_p5[fam] if log is None: continue h = log["history"]; cfg = log["config"] _, fid_vals = fid_series(log) rows.append({ "Family": fam, "Architecture": info["label"], "Resolution": f"{cfg.get('image_size')}×{cfg.get('image_size')}", "Epochs": len(h.get("loss") or h.get("g_loss") or h.get("recon_loss") or []), "Best FID": min(fid_vals) if fid_vals else None, "Last FID": fid_vals[-1] if fid_vals else None, "Train (min)": (h.get("train_time_s") or 0) / 60, }) df = pd.DataFrame(rows).sort_values("Best FID") df.style.format({"Best FID": "{:.1f}", "Last FID": "{:.1f}", "Train (min)": "{:.1f}"}) """), md("## 2. FID curves — all three families"), code("""\ fig, ax = plt.subplots(figsize=(10, 5)) for fam, info in FAMILIES.items(): log = logs_p5[fam] if log is None: continue epochs, vals = fid_series(log) ax.plot(epochs, vals, "o-", color=info["color"], label=f"{fam} ({info['label']})", linewidth=2, markersize=7) ax.set_xlabel("Epoch"); ax.set_ylabel("FID") ax.set_title("Phase 5 — FID across families") ax.legend(); plt.tight_layout(); plt.show() """), md("## 3. Best-epoch sample grids"), code("""\ fig, axes = plt.subplots(1, 3, figsize=(16, 6)) for ax, (fam, info) in zip(axes, FAMILIES.items()): log = logs_p5[fam] samples_dir = SAMPLES / info["p5"] pngs = sorted(samples_dir.glob("epoch_*.png")) pngs = [p for p in pngs if "_recon" not in p.stem] if not pngs: ax.set_title(f"{fam} (no samples)"); ax.axis("off"); continue img_path = pngs[-1] # last preview = closest to final_ema ax.imshow(mpimg.imread(str(img_path))) e, v = (None, None) if log: _, fid_vals = fid_series(log) v = min(fid_vals) if fid_vals else None ax.set_title(f"{fam} — {info['label']}\\n{img_path.stem}" + (f" best FID={v:.1f}" if v else ""), fontsize=10) ax.axis("off") fig.suptitle("Final samples from each family", fontsize=13, fontweight="bold") plt.tight_layout(); plt.show() """), md("## 4. Training progression — early → late"), code("""\ for fam, info in FAMILIES.items(): log = logs_p5[fam] samples_dir = SAMPLES / info["p5"] pngs = sorted(p for p in samples_dir.glob("epoch_*.png") if "_recon" not in p.stem) if not pngs: continue # Pick ~4 evenly spaced previews n_pick = min(4, len(pngs)) picks = [pngs[i * (len(pngs) - 1) // (n_pick - 1)] for i in range(n_pick)] if n_pick > 1 else pngs fig, axes = plt.subplots(1, len(picks), figsize=(4 * len(picks), 4)) if len(picks) == 1: axes = [axes] for ax, p in zip(axes, picks): ax.imshow(mpimg.imread(str(p))) ep = int(p.stem.split("_")[1]) f = get_fid(log, ep) if log else None ax.set_title(f"epoch {ep}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9) ax.axis("off") fig.suptitle(f"{fam} — {info['label']}", fontsize=11, fontweight="bold") plt.tight_layout(); plt.show() """), md("## 5. Per-family training loss"), code("""\ fig, axes = plt.subplots(1, 3, figsize=(18, 4)) for ax, (fam, info) in zip(axes, FAMILIES.items()): log = logs_p5[fam] if not log: ax.set_title(f"{fam} (missing)"); ax.axis("off"); continue h = log["history"]; c = info["color"] if fam == "GAN": ax.plot(h["g_loss"], label="G loss", color=c, linewidth=1.2) if "w_dist" in h: ax.plot(h["w_dist"], label="W-dist", color=c, linewidth=1.2, linestyle="--") ax.set_ylabel("Loss / W-distance") elif fam == "VAE": ax.plot(h["recon_loss"], label="recon", color=c, linewidth=1.2) ax2 = ax.twinx() ax2.plot(h["kl_loss"], label="KL", color=c, alpha=0.5, linestyle="--") ax2.set_ylabel("KL") ax.set_ylabel("Recon") else: # DDPM ax.plot(h["loss"], color=c, linewidth=1.2) ax.set_ylabel("MSE on v-prediction") ax.set_xlabel("Epoch"); ax.set_title(f"{fam}"); ax.legend(loc="upper right", fontsize=8) plt.tight_layout(); plt.show() """), md("""\ ## 6. Latent interpolation — GAN and VAE Smooth interpolation between two latent codes reveals whether the generator has learned a continuous manifold rather than a sparse memorisation. DDPM has no encoder, so this section is GAN/VAE only. **Note on checkpoint loading:** the cell below uses the same checkpoint priority as `tools/sampling.py` — `final_ema` first, then `best_ema` as fallback. This is important: `best_ema` is the lowest-FID snapshot, which for slowly-converging runs (e.g. DDPM) can be saved while the EMA shadow is still close to random init, producing pure noise. """), code("""\ import sys sys.path.insert(0, "..") import torch from src.models import get_model from src.utils import load_config DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_ema_model(run_name, config_name): cfg = load_config(str(Path("../configs/phase5") / config_name)) model_obj, _ = get_model(cfg) model = model_obj[0] if isinstance(model_obj, tuple) else model_obj # final_ema first — see note above for fname in [f"{run_name}_final_ema.pt", f"{run_name}_best_ema.pt"]: p = Path("../outputs/models") / fname if not p.exists(): continue sd = torch.load(p, map_location=DEVICE, weights_only=True) missing, unexpected = model.load_state_dict(sd, strict=False) if not missing and not unexpected: print(f" loaded {fname}") return model.to(DEVICE).eval(), cfg raise FileNotFoundError(f"No usable EMA checkpoint for {run_name}") def slerp(z1, z2, t): z1n = z1 / z1.norm(); z2n = z2 / z2.norm() omega = torch.acos((z1n * z2n).sum().clamp(-1, 1)) if omega.abs() < 1e-6: return (1 - t) * z1 + t * z2 return (torch.sin((1 - t) * omega) / torch.sin(omega)) * z1 + \\ (torch.sin(t * omega) / torch.sin(omega)) * z2 """), code("""\ # ── GAN slerp interpolation ───────────────────────────────────────────────── try: gan_model, gan_cfg = load_ema_model("p5_gan", "p5_gan.json") latent_dim = gan_cfg.get("latent_dim", 128) z1 = torch.randn(1, latent_dim, 1, 1, device=DEVICE) z2 = torch.randn(1, latent_dim, 1, 1, device=DEVICE) with torch.no_grad(): zs = torch.cat([slerp(z1, z2, t) for t in torch.linspace(0, 1, 10)]) imgs = gan_model(zs).clamp(-1, 1) imgs = (imgs + 1) / 2 fig, axes = plt.subplots(1, 10, figsize=(20, 2.5)) for ax, img in zip(axes, imgs.cpu()): ax.imshow(img.permute(1, 2, 0).numpy()); ax.axis("off") fig.suptitle("GAN — slerp latent interpolation z₁ → z₂", fontsize=11, fontweight="bold") plt.tight_layout(); plt.show() except Exception as e: print(f"GAN interpolation skipped: {e}") """), code("""\ # ── VAE encode-interpolate-decode ─────────────────────────────────────────── try: from src.data import GeneratorDataset, get_transform vae_model, vae_cfg = load_ema_model("p5_vae", "p5_vae.json") ds = GeneratorDataset("../" + vae_cfg["data_dir"], sources=vae_cfg.get("sources", ["wiki"]), transform=get_transform(vae_cfg["image_size"], augment=False)) real = torch.stack([ds[i] for i in range(2)]).to(DEVICE) with torch.no_grad(): mu, logvar = vae_model.encode(real) z1, z2 = mu[0:1], mu[1:2] zs = torch.cat([(1 - t) * z1 + t * z2 for t in torch.linspace(0, 1, 10)]) imgs = vae_model.decode(zs).clamp(-1, 1) imgs = (imgs + 1) / 2 fig, axes = plt.subplots(1, 10, figsize=(20, 2.5)) for ax, img in zip(axes, imgs.cpu()): ax.imshow(img.permute(1, 2, 0).numpy()); ax.axis("off") fig.suptitle("VAE — linear latent interpolation (encoded real images)", fontsize=11, fontweight="bold") plt.tight_layout(); plt.show() except Exception as e: print(f"VAE interpolation skipped: {e}") """), md(f"""\ ## 7. Conclusions | Family | Architecture | Best FID | Train time | |---|---|---|---| """ + "\n".join( f"| {n.split('_')[1].upper()} | " f"{ {'p5_gan':'WGAN-GP + SN + Attn', 'p5_vae':'VAE + Perceptual + PatchGAN', 'p5_ddpm':'DDPM cosine v-pred wider'}[n] } | " f"{v:.1f} | {t:.1f} min |" for n, v, _, t in rows ) + f""" The ranking is **DDPM ≈ GAN ≪ VAE** by FID. DDPM and GAN trade places depending on sampling settings (truncation, DDIM step count) but both clearly beat the VAE here — the perceptual + adversarial losses help VAE reconstructions more than they help its prior samples (which is what FID measures). **Practical sampling notes** (encoded in `tools/sampling.py`): - Load `*_final_ema.pt` rather than `*_best_ema.pt` — the latter can be saved very early for slowly-converging runs. - DDPM sampling at DDIM-50 matches the training preview; DDIM-100 is for FID only. - GAN truncation is **off by default** (matches training); enable for sharper but less diverse samples. """), ] write_nb("phase5_analysis", cells) if __name__ == "__main__": print("Building notebooks...") build_phase0() build_phase1() build_phase2() build_phase3() build_phase4() build_phase5() print("Done.")