1581 lines
67 KiB
Python
1581 lines
67 KiB
Python
"""
|
|
Build all generator analysis notebooks from a single source of truth.
|
|
|
|
Run from generator/notebooks/: python _build.py
|
|
|
|
The notebooks are report chapters, not experiment launchers. They load saved
|
|
logs, samples, checkpoints, and figures only. Each phase follows a consistent
|
|
story structure: goal, what changed, evidence, decision, and conclusion.
|
|
|
|
Real metric values are pulled from outputs/logs/*.json at build time and
|
|
rendered into markdown headers and conclusions, so reports never drift from data.
|
|
"""
|
|
import json
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
LOGS = ROOT / "outputs" / "logs"
|
|
OUT = ROOT / "notebooks"
|
|
|
|
|
|
# notebook helpers
|
|
def md(text): return {"cell_type": "markdown", "metadata": {}, "source": text.splitlines(keepends=True)}
|
|
def code(text): return {"cell_type": "code", "metadata": {}, "execution_count": None, "outputs": [], "source": text.splitlines(keepends=True)}
|
|
|
|
def write_nb(name, cells):
|
|
nb = {
|
|
"cells": cells,
|
|
"metadata": {
|
|
"kernelspec": {"name": "python3", "display_name": "Python 3"},
|
|
"language_info": {"name": "python"},
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5,
|
|
}
|
|
path = OUT / f"{name}.ipynb"
|
|
path.write_text(json.dumps(nb, indent=1))
|
|
print(f" wrote {path.relative_to(ROOT)}")
|
|
|
|
|
|
# log-derived facts, computed once and baked into markdown
|
|
def load(name):
|
|
p = LOGS / f"{name}.json"
|
|
return json.load(open(p)) if p.exists() else None
|
|
|
|
def best_fid(log):
|
|
fid = log.get("history", {}).get("fid", {}) if log else {}
|
|
if not fid: return None, None
|
|
items = sorted(((int(k), v) for k, v in fid.items()))
|
|
e, v = min(items, key=lambda x: x[1])
|
|
return e, v
|
|
|
|
def time_min(log):
|
|
t = (log or {}).get("history", {}).get("train_time_s")
|
|
return t/60 if t else None
|
|
|
|
def get_fid(log, epoch):
|
|
"""Build-time helper (mirrors the runtime helper in SHARED_IMPORTS)."""
|
|
return (log or {}).get("history", {}).get("fid", {}).get(str(epoch))
|
|
|
|
|
|
# shared imports cell, used by all phase notebooks
|
|
SHARED_IMPORTS = """\
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.image as mpimg
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
plt.rcParams.update({"figure.dpi": 120, "font.size": 10})
|
|
|
|
try:
|
|
display
|
|
except NameError:
|
|
def display(obj):
|
|
print(obj)
|
|
|
|
def find_generator_root():
|
|
for base in [Path.cwd(), *Path.cwd().parents]:
|
|
for candidate in [base, base / "generator"]:
|
|
if (candidate / "outputs" / "logs").exists() and (candidate / "outputs" / "samples").exists():
|
|
return candidate.resolve()
|
|
raise FileNotFoundError("Could not locate generator/outputs from the current working directory")
|
|
|
|
GENERATOR_ROOT = find_generator_root()
|
|
PROJECT_ROOT = GENERATOR_ROOT.parent
|
|
OUTPUTS = GENERATOR_ROOT / "outputs"
|
|
LOGS = OUTPUTS / "logs"
|
|
SAMPLES = OUTPUTS / "samples"
|
|
|
|
|
|
def load_log(name):
|
|
p = LOGS / f"{name}.json"
|
|
return json.load(open(p)) if p.exists() else None
|
|
|
|
def get_fid(log, epoch):
|
|
fid = log.get("history", {}).get("fid", {})
|
|
return fid.get(str(epoch))
|
|
|
|
def fid_series(log):
|
|
fid = log.get("history", {}).get("fid", {})
|
|
items = sorted((int(k), v) for k, v in fid.items())
|
|
return [e for e, _ in items], [v for _, v in items]
|
|
|
|
def show_image_or_missing(ax, path, title=None):
|
|
if path.exists():
|
|
ax.imshow(mpimg.imread(str(path)))
|
|
else:
|
|
ax.text(0.5, 0.5, f"missing artifact\\n{path.name}", ha="center", va="center", transform=ax.transAxes)
|
|
if title:
|
|
ax.set_title(title, fontsize=9)
|
|
ax.axis("off")
|
|
|
|
"""
|
|
|
|
|
|
# PHASE 0 - Baseline sanity check
|
|
def build_phase0():
|
|
p0 = {n: load(n) for n in ["p0_wgan", "p0_vae", "p0_ddpm", "p0_ddpm_small"]}
|
|
cells = [
|
|
md(f"""\
|
|
# Phase 0 - Baseline Sanity Check
|
|
|
|
Phase 0 is the starting point of the generator story. It uses the raw, un-aligned
|
|
images and very plain versions of each model family so we can confirm that the
|
|
training code runs end-to-end before making any stronger claims.
|
|
|
|
The goal is not to choose a final model here. The goal is to expose the first
|
|
failure modes: rough WGAN blobs, blurry VAE averages, and noisy DDPM textures.
|
|
Those failures motivate the pipeline work in Phase 1.
|
|
|
|
## What this phase changes
|
|
|
|
Nothing is optimized yet. This phase keeps the input pipeline rough and uses the
|
|
minimal available recipes:
|
|
|
|
| Run | Family | Logged epochs | Purpose |
|
|
|---|---|---:|---|
|
|
| `p0_wgan` | WGAN-GP | {len(p0['p0_wgan']['history']['g_loss']) if p0['p0_wgan'] else 'n/a'} | Basic generator/critic sanity check |
|
|
| `p0_vae` | VAE | {len(p0['p0_vae']['history']['loss']) if p0['p0_vae'] else 'n/a'} | MSE + KL reconstruction baseline |
|
|
| `p0_ddpm` | DDPM | {len(p0['p0_ddpm']['history']['loss']) if p0['p0_ddpm'] else 'n/a'} | Linear schedule, epsilon-prediction baseline |
|
|
| `p0_ddpm_small` | DDPM small | {len(p0['p0_ddpm_small']['history']['loss']) if p0['p0_ddpm_small'] else 'n/a'} | Reduced-capacity sanity variant |
|
|
|
|
FID was not logged in Phase 0. The evidence here is loss behavior plus saved
|
|
sample grids.
|
|
"""),
|
|
code(SHARED_IMPORTS),
|
|
md("## 1. Training loss curves\n\nThese curves check that the loops ran and produced stable logs. They are not enough to prove visual quality."),
|
|
code("""\
|
|
runs = {n: load_log(n) for n in ["p0_wgan", "p0_vae", "p0_ddpm", "p0_ddpm_small"]}
|
|
runs = {k: v for k, v in runs.items() if v}
|
|
|
|
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
|
|
|
|
# WGAN: g_loss + w_dist
|
|
h = runs["p0_wgan"]["history"]
|
|
ep = range(1, len(h["g_loss"]) + 1)
|
|
axes[0].plot(ep, h["g_loss"], label="G loss", color="#5B8DB8")
|
|
axes[0].plot(ep, h["c_loss"], label="C loss", color="#E8705A", alpha=0.7)
|
|
axes[0].set_title("p0_wgan - generator vs critic loss")
|
|
axes[0].set_xlabel("Epoch"); axes[0].legend()
|
|
|
|
# VAE: total loss + components
|
|
h = runs["p0_vae"]["history"]
|
|
ep = range(1, len(h["loss"]) + 1)
|
|
axes[1].plot(ep, h["recon_loss"], label="Recon (MSE)", color="#5B8DB8")
|
|
axes[1].plot(ep, h["kl_loss"], label="KL", color="#E8705A")
|
|
axes[1].set_title("p0_vae - recon vs KL")
|
|
axes[1].set_xlabel("Epoch"); axes[1].legend()
|
|
|
|
# DDPM: noise prediction MSE
|
|
h = runs["p0_ddpm"]["history"]
|
|
ep = range(1, len(h["loss"]) + 1)
|
|
axes[2].plot(ep, h["loss"], color="#5B8DB8", label="epsilon-MSE")
|
|
if "p0_ddpm_small" in runs:
|
|
h2 = runs["p0_ddpm_small"]["history"]
|
|
axes[2].plot(range(1, len(h2["loss"]) + 1), h2["loss"], color="#E8705A", linestyle="--", label="small variant")
|
|
axes[2].set_title("p0_ddpm - noise prediction loss")
|
|
axes[2].set_xlabel("Epoch"); axes[2].legend()
|
|
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 2. Final sample grids\n\nThe final previews show the practical failure mode of the raw pipeline: the samples have some face-like structure, but identity, alignment, and detail are not under control. These PNGs are displayed exactly as saved, so older Phase 0 matrices keep their original layout instead of being forced into 4x4."),
|
|
code("""\
|
|
last_epochs = {"p0_wgan": 200, "p0_vae": 100, "p0_ddpm": 200, "p0_ddpm_small": 100}
|
|
|
|
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
|
|
for ax, (name, ep) in zip(axes, last_epochs.items()):
|
|
img_path = SAMPLES / name / f"epoch_{ep:04d}.png"
|
|
if img_path.exists():
|
|
ax.imshow(mpimg.imread(str(img_path)))
|
|
ax.set_title(f"{name}\\n(epoch {ep})", fontsize=10)
|
|
else:
|
|
ax.set_title(f"{name}\\n(missing)", fontsize=10)
|
|
ax.axis("off")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 3. Progression - early vs late\n\nThe progression grids make the baseline failure visible over time. Later samples improve slightly, but the raw input distribution keeps the task too broad. The saved matrices are shown in their original layout."),
|
|
code("""\
|
|
checkpoints = {
|
|
"p0_wgan": [50, 100, 200],
|
|
"p0_vae": [25, 50, 100],
|
|
"p0_ddpm": [50, 100, 200],
|
|
}
|
|
|
|
for name, eps in checkpoints.items():
|
|
fig, axes = plt.subplots(1, len(eps), figsize=(12, 4))
|
|
for ax, e in zip(axes, eps):
|
|
p = SAMPLES / name / f"epoch_{e:04d}.png"
|
|
if p.exists():
|
|
ax.imshow(mpimg.imread(str(p))); ax.set_title(f"epoch {e}", fontsize=9)
|
|
else:
|
|
ax.text(0.5, 0.5, f"epoch {e}\\n(missing)", ha="center", va="center", transform=ax.transAxes)
|
|
ax.axis("off")
|
|
fig.suptitle(name, fontsize=11, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("""\
|
|
## 4. What this phase proves
|
|
|
|
Phase 0 proves that the code path works, but it also proves that raw images are
|
|
too noisy a starting point for the rest of the project. The WGAN produces rough
|
|
color blobs, the VAE averages faces into blur, and the DDPM is the most textured
|
|
but still noisy.
|
|
|
|
**Decision:** treat data quality as the first bottleneck. Phase 1 therefore
|
|
locks the pipeline before the project spends more compute on stronger recipes.
|
|
|
|
**Report conclusion:** Phase 0 is a sanity check, not a competitive result. It
|
|
establishes the baseline failure and motivates the move to aligned face crops.
|
|
"""),
|
|
]
|
|
write_nb("phase0_analysis", cells)
|
|
|
|
|
|
# PHASE 1 - Pipeline ablations with a DCGAN proxy
|
|
def build_phase1():
|
|
runs = {n: load(n) for n in ["p1a_dcgan_64", "p1a_dcgan_128", "p1b_dcgan_full",
|
|
"p1b_dcgan_aligned", "p1c_dcgan_hflip",
|
|
"p1c_dcgan_full_aug", "p1d_dcgan_combined"]}
|
|
best_run = min(runs.items(), key=lambda kv: best_fid(kv[1])[1] or 9e9)
|
|
best_name, best_log = best_run
|
|
_, best_val = best_fid(best_log)
|
|
p0_gan = load("p0_wgan")
|
|
|
|
cells = [
|
|
md(f"""\
|
|
# Phase 1 - Pipeline Selection
|
|
|
|
Phase 1 answers the data-handling question left open by the baseline. Instead
|
|
of changing the model family, it uses a cheap DCGAN proxy and varies one
|
|
pipeline choice at a time.
|
|
|
|
This phase is deliberately controlled. The output quality is still limited, but
|
|
the relative differences tell us which input pipeline gives later recipes the
|
|
best chance.
|
|
|
|
## What this phase changes
|
|
|
|
Four pipeline choices are tested as ablations:
|
|
|
|
| Ablation | Question | Choices |
|
|
|---|---|---|
|
|
| 1A | How much resolution can the proxy handle? | 64x64 vs 128x128 |
|
|
| 1B | Does alignment matter? | Full raw image vs MTCNN-aligned crop |
|
|
| 1C | Does augmentation help the proxy? | H-flip only vs H-flip + rotation + color jitter |
|
|
| 1D | Should raw and aligned images be mixed? | Aligned only vs aligned + raw mixed |
|
|
|
|
**Headline result:** `{best_name}` reaches **FID@50 = {best_val:.1f}**. The
|
|
locked pipeline for the following phases is aligned face crops at 64x64, no
|
|
raw/aligned mixing, with augmentation choices following the saved family configs.
|
|
"""),
|
|
md(f"""\
|
|
### Reference: Phase 0 baseline from the same family
|
|
|
|
The phase 0 WGAN-GP (`p0_wgan`) trained on raw un-aligned images for {len(p0_gan['history']['g_loss']) if p0_gan else 200} epochs
|
|
without any pipeline tuning, and it also collapsed. Phase 1 below uses the same model class
|
|
with the data pipeline systematically varied; the architecture limitation is constant.
|
|
"""),
|
|
code(SHARED_IMPORTS),
|
|
md("## 1. Load all experiment logs\n\nAll evidence in this notebook comes from the existing Phase 1 logs and sample folders."),
|
|
code("""\
|
|
run_names = sorted(p.stem for p in LOGS.glob("p1*.json"))
|
|
runs = {name: load_log(name) for name in run_names}
|
|
runs = {k: v for k, v in runs.items() if v}
|
|
|
|
print(f"Loaded {len(runs)} experiments:")
|
|
for name in run_names: print(f" {name}")
|
|
"""),
|
|
code("""\
|
|
experiment_groups = {
|
|
"1A - Resolution": {"p1a_dcgan_64": "64x64 (raw)",
|
|
"p1a_dcgan_128": "128x128 (raw)"},
|
|
"1B - Alignment": {"p1b_dcgan_full": "Full image (raw)",
|
|
"p1b_dcgan_aligned": "MTCNN-aligned"},
|
|
"1C - Augmentation": {"p1c_dcgan_hflip": "H-flip only",
|
|
"p1c_dcgan_full_aug": "H-flip + rot + colour"},
|
|
"1D - Dataset mixing": {"p1b_dcgan_aligned": "Aligned only",
|
|
"p1d_dcgan_combined": "Aligned + raw mixed"},
|
|
}
|
|
"""),
|
|
md("## 2. FID comparison table\n\nThe table ranks the proxy runs. The values are useful within Phase 1, but they should not be compared directly with later FID protocols."),
|
|
code("""\
|
|
rows = []
|
|
for name in run_names:
|
|
r = runs[name]; cfg = r["config"]
|
|
rows.append({
|
|
"Experiment": name,
|
|
"Size": f"{cfg.get('image_size')}x{cfg.get('image_size')}",
|
|
"Augment": cfg.get("augment", False),
|
|
"FID@25": get_fid(r, 25),
|
|
"FID@50": get_fid(r, 50),
|
|
"G loss (ep50)": r["history"]["g_loss"][-1],
|
|
"D loss (ep50)": r["history"]["d_loss"][-1],
|
|
})
|
|
df = pd.DataFrame(rows).sort_values("FID@50")
|
|
df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}",
|
|
"G loss (ep50)": "{:.3f}", "D loss (ep50)": "{:.3f}"})
|
|
"""),
|
|
code("""\
|
|
fig, ax = plt.subplots(figsize=(10, 5))
|
|
labels = df["Experiment"].values
|
|
fid25 = df["FID@25"].values
|
|
fid50 = df["FID@50"].values
|
|
x = np.arange(len(labels)); w = 0.35
|
|
|
|
ax.bar(x - w/2, fid25, w, label="FID @ 25", color="#5B8DB8", alpha=0.85)
|
|
ax.bar(x + w/2, fid50, w, label="FID @ 50", color="#E8705A", alpha=0.85)
|
|
ax.set_ylabel("FID (lower is better)")
|
|
ax.set_title("Phase 1 - FID across all pipeline ablations")
|
|
ax.set_xticks(x); ax.set_xticklabels(labels, rotation=30, ha="right")
|
|
ax.legend(); plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 3. Controlled ablation results\n\nEach subplot holds the model approximately fixed and changes one pipeline factor. This is the decision evidence for the rest of the generator suite."),
|
|
code("""\
|
|
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
|
axes = axes.flatten()
|
|
colors = ["#5B8DB8", "#E8705A"]
|
|
|
|
for idx, (group_title, experiments) in enumerate(experiment_groups.items()):
|
|
ax = axes[idx]
|
|
for i, (run_name, label) in enumerate(experiments.items()):
|
|
epochs, fid_vals = fid_series(runs[run_name])
|
|
f50 = get_fid(runs[run_name], 50)
|
|
ax.plot(epochs, fid_vals, "o-",
|
|
label=f"{label} (FID@50={f50:.1f})",
|
|
color=colors[i], linewidth=2, markersize=8)
|
|
ax.set_xlabel("Epoch"); ax.set_ylabel("FID")
|
|
ax.set_title(group_title); ax.legend(fontsize=9)
|
|
ax.set_xlim(20, 55)
|
|
fig.suptitle("FID per ablation group", fontsize=14, fontweight="bold", y=1.01)
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("""\
|
|
## 4. Data pipeline visualization
|
|
|
|
What each ablation actually changes, shown on the input data the model sees.
|
|
"""),
|
|
code("""\
|
|
import random
|
|
from PIL import Image
|
|
import torchvision.transforms as T
|
|
|
|
random.seed(0)
|
|
RAW = PROJECT_ROOT / "data" / "wiki"
|
|
ALIGNED = PROJECT_ROOT / "cropped" / "generator" / "wiki"
|
|
|
|
def sample_paths(root, k=4):
|
|
if not root.exists():
|
|
print(f"Missing image directory: {root}")
|
|
return []
|
|
shards = [d for d in root.iterdir() if d.is_dir()]
|
|
files = []
|
|
for s in random.sample(shards, min(8, len(shards))):
|
|
files += list(s.glob("*.jpg"))[:50]
|
|
return random.sample(files, min(k, len(files)))
|
|
|
|
def matched_pairs(k=4):
|
|
if not RAW.exists() or not ALIGNED.exists():
|
|
print(f"Missing raw/aligned image directory: RAW={RAW.exists()} ALIGNED={ALIGNED.exists()}")
|
|
return []
|
|
shards = sorted(d.name for d in ALIGNED.iterdir() if d.is_dir() and (RAW / d.name).is_dir())
|
|
pairs = []
|
|
for shard in random.sample(shards, min(8, len(shards))):
|
|
for ali in (ALIGNED / shard).glob("*.jpg"):
|
|
raw = RAW / shard / ali.name
|
|
if raw.exists():
|
|
pairs.append((raw, ali))
|
|
if len(pairs) >= 50: break
|
|
if len(pairs) >= 50: break
|
|
return random.sample(pairs, min(k, len(pairs)))
|
|
|
|
def show(ax, img, title=None):
|
|
ax.imshow(img); ax.axis("off")
|
|
if title: ax.set_title(title, fontsize=9)
|
|
|
|
def show_unavailable(ax, message):
|
|
ax.text(0.5, 0.5, message, ha="center", va="center", wrap=True, transform=ax.transAxes)
|
|
ax.axis("off")
|
|
"""),
|
|
md("### 4A - Resolution\n\nSame raw image at 64x64 and 128x128. This is a paired comparison layout, so it keeps the original 2x4 format instead of being forced into a 4x4 sample grid."),
|
|
code("""\
|
|
paths = sample_paths(RAW, k=4)
|
|
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
|
|
if not paths:
|
|
for ax in axes.ravel():
|
|
show_unavailable(ax, "raw images unavailable")
|
|
else:
|
|
for col, p in enumerate(paths):
|
|
img = Image.open(p).convert("RGB")
|
|
show(axes[0][col], T.CenterCrop(min(img.size))(img).resize((64, 64)), "64x64")
|
|
show(axes[1][col], T.CenterCrop(min(img.size))(img).resize((128, 128)), "128x128")
|
|
fig.suptitle("1A - Resolution: same image at two scales", fontsize=12, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("### 4B - Alignment\n\nRaw vs MTCNN-aligned 64x64 crops. This paired layout keeps the original 2x4 format so each raw image is directly above its aligned crop."),
|
|
code("""\
|
|
pairs = matched_pairs(k=4)
|
|
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
|
|
if not pairs:
|
|
for ax in axes.ravel():
|
|
show_unavailable(ax, "matched raw/aligned crops unavailable")
|
|
else:
|
|
for col, (raw_p, ali_p) in enumerate(pairs):
|
|
raw_img = Image.open(raw_p).convert("RGB")
|
|
show(axes[0][col], T.CenterCrop(min(raw_img.size))(raw_img).resize((128, 128)), "raw")
|
|
show(axes[1][col], Image.open(ali_p).convert("RGB"), "MTCNN-aligned")
|
|
fig.suptitle("1B - Alignment: same source image, raw vs MTCNN-aligned", fontsize=12, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("### 4C - Augmentation\n\nOne aligned image, then deterministic examples of the saved augmentation idea. This keeps the original compact strip because the point is to compare transforms on one image, not to make a generated 4x4 sample matrix."),
|
|
code("""\
|
|
src = sample_paths(ALIGNED, k=1)
|
|
if src:
|
|
img = Image.open(src[0]).convert("RGB").resize((128, 128))
|
|
none = T.Compose([])
|
|
hflip = T.Compose([T.RandomHorizontalFlip(p=1.0)])
|
|
full = T.Compose([
|
|
T.RandomHorizontalFlip(p=1.0),
|
|
T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR),
|
|
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),
|
|
])
|
|
fig, axes = plt.subplots(1, 6, figsize=(15, 3))
|
|
show(axes[0], none(img), "original")
|
|
show(axes[1], hflip(img), "hflip")
|
|
for i, ax in enumerate(axes[2:]):
|
|
show(ax, full(img), f"full aug #{i+1}")
|
|
fig.suptitle("1C - Augmentation: original vs hflip vs full augmentation", fontsize=12, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
else:
|
|
fig, ax = plt.subplots(figsize=(6, 2))
|
|
show_unavailable(ax, "aligned crop directory unavailable")
|
|
fig.suptitle("1C - Augmentation", fontsize=12, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("### 4D - Dataset mixing\n\nMixing raw and aligned images asks one generator to model two different input distributions. This keeps the original paired 2x4 layout so the contrast is easy to read."),
|
|
code("""\
|
|
pairs = matched_pairs(k=4)
|
|
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
|
|
if not pairs:
|
|
for ax in axes.ravel():
|
|
show_unavailable(ax, "matched raw/aligned crops unavailable")
|
|
else:
|
|
for col, (raw_p, ali_p) in enumerate(pairs):
|
|
raw_img = Image.open(raw_p).convert("RGB")
|
|
show(axes[0][col], T.CenterCrop(min(raw_img.size))(raw_img).resize((128, 128)), "raw (mixed in)")
|
|
show(axes[1][col], Image.open(ali_p).convert("RGB"), "aligned")
|
|
fig.suptitle("1D - Mixing: same source image, raw vs aligned", fontsize=12, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md(f"""\
|
|
## 5. What this phase proves
|
|
|
|
Lowest FID of Phase 1: **`{best_name}` at FID@50 = {best_val:.1f}**. These
|
|
numbers rank pipeline choices only; the proxy generator is still not the final
|
|
quality target.
|
|
|
|
| Ablation | Winner | Winner FID@50 | Comparison | Interpretation |
|
|
|---|---|---:|---|---|
|
|
| 1A - Resolution | 64x64 | {get_fid(runs['p1a_dcgan_64'], 50):.1f} | 128x128: {get_fid(runs['p1a_dcgan_128'], 50):.1f} | 128x128 is too hard for this proxy budget |
|
|
| 1B - Alignment | MTCNN-aligned | {get_fid(runs['p1b_dcgan_aligned'], 50):.1f} | full image: {get_fid(runs['p1b_dcgan_full'], 50):.1f} | alignment is the strongest lever |
|
|
| 1C - Augmentation | H-flip + rotation + color | {get_fid(runs['p1c_dcgan_full_aug'], 50):.1f} | H-flip: {get_fid(runs['p1c_dcgan_hflip'], 50):.1f} | richer augmentation helps this proxy |
|
|
| 1D - Dataset mixing | aligned only | {get_fid(runs['p1b_dcgan_aligned'], 50):.1f} | mixed: {get_fid(runs['p1d_dcgan_combined'], 50):.1f} | raw+aligned mixing increases distribution variance |
|
|
|
|
**Decision:** carry forward MTCNN-aligned crops, 64x64 images, and aligned-only
|
|
data. The saved later configs keep this pipeline and choose augmentation per
|
|
model family.
|
|
|
|
**Report conclusion:** Phase 1 turns the Phase 0 failure into a pipeline
|
|
decision. Alignment is the main fix; Phase 2 can now focus on the GAN recipe
|
|
instead of fighting raw-image variance.
|
|
"""),
|
|
]
|
|
write_nb("phase1_analysis", cells)
|
|
|
|
|
|
# PHASE 2 - GAN architecture and objective evolution
|
|
def build_phase2():
|
|
runs = {n: load(n) for n in ["p2_1_dcgan", "p2_2_wgan", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"]}
|
|
best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9)
|
|
_, best_val = best_fid(runs[best_name])
|
|
p0_gan = load("p0_wgan")
|
|
|
|
cells = [
|
|
md(f"""\
|
|
# Phase 2 - GAN Progression
|
|
|
|
Phase 2 keeps the Phase 1 pipeline fixed and changes the GAN recipe. This makes
|
|
the question narrow: once the data is aligned, what model changes are needed to
|
|
escape collapse?
|
|
|
|
The progression moves from the DCGAN proxy to Wasserstein training, then to the
|
|
stability package that finally makes the samples recognizable.
|
|
|
|
## What this phase changes
|
|
|
|
| Run | Recipe change |
|
|
|---|---|
|
|
| `p2_1_dcgan` | DCGAN baseline under the Phase 2 protocol |
|
|
| `p2_2_wgan` | BCE objective replaced by Wasserstein-GP |
|
|
| `p2_3_wgan_sn_attn` | Spectral norm + GroupNorm + self-attention |
|
|
| `p2_4_wgan_sn_attn_128` | Same stabilized recipe at 128x128 |
|
|
|
|
**Headline result:** `{best_name}` reaches **best FID = {best_val:.1f}**. The
|
|
critical step is not the objective change alone; it is the stabilized 64x64
|
|
recipe with spectral normalization, GroupNorm, and self-attention.
|
|
"""),
|
|
md("""\
|
|
> ### FID is not comparable across phases
|
|
>
|
|
> Phase 1's "best" was FID 33 (`p1c_dcgan_full_aug`). Phase 2's "best" is FID 110.
|
|
> **This is not a regression.** The two numbers were computed under different
|
|
> protocols:
|
|
>
|
|
> - Phase 1 used a quick proxy FID for fast pipeline ablation, with a smaller
|
|
> real-image reference set, on the un-augmented validation split.
|
|
> - Phase 2 uses the project's standard FID protocol: 5000 aligned 64x64 real
|
|
> images from the matched augmentation pipeline (`fid_n_real: 5000`).
|
|
>
|
|
> Within Phase 2 the deltas are meaningful: 2.2 -> 2.3 is about **-311 FID**,
|
|
> which is a real architecture jump. Don't compare phase 1 vs phase 2 numbers absolutely;
|
|
> only compare within a phase, or against phase 5 which uses the same protocol.
|
|
"""),
|
|
md(f"""\
|
|
### Reference: Phase 0 baseline from the same family
|
|
|
|
`p0_wgan` was the un-aligned, no-augmentation, basic-architecture WGAN-GP: face blobs
|
|
with no recognisable features (no FID logged). Phase 2 below shows what happens once
|
|
the pipeline is fixed and the model is allowed to evolve.
|
|
"""),
|
|
code(SHARED_IMPORTS),
|
|
md("## 1. Load experiment logs\n\nOnly existing Phase 2 logs are loaded here. No training or re-evaluation is launched."),
|
|
code("""\
|
|
run_names = ["p2_1_dcgan", "p2_2_wgan", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"]
|
|
run_labels = {
|
|
"p2_1_dcgan": "2.1 DCGAN (BCE)",
|
|
"p2_2_wgan": "2.2 WGAN-GP",
|
|
"p2_3_wgan_sn_attn": "2.3 + SN + Attn",
|
|
"p2_4_wgan_sn_attn_128": "2.4 + 128x128",
|
|
}
|
|
visual_notes = {
|
|
"p2_1_dcgan": "collapsed gray output",
|
|
"p2_2_wgan": "collapsed gray output",
|
|
"p2_3_wgan_sn_attn": "recognizable faces",
|
|
"p2_4_wgan_sn_attn_128": "under-trained 128x128",
|
|
}
|
|
runs = {name: load_log(name) for name in run_names}
|
|
runs = {k: v for k, v in runs.items() if v}
|
|
for n in run_names:
|
|
if n in runs: print(f" {n}: {len(runs[n]['history']['g_loss'])} epochs")
|
|
else: print(f" {n}: MISSING")
|
|
"""),
|
|
md("## 2. FID comparison table\n\nThis table is the quantitative spine of the GAN progression: lower FID means generated samples are closer to the saved real reference distribution."),
|
|
code("""\
|
|
rows = []
|
|
for name in run_names:
|
|
if name not in runs: continue
|
|
r = runs[name]
|
|
epochs, fid_vals = fid_series(r)
|
|
best = min(fid_vals) if fid_vals else None
|
|
rows.append({
|
|
"Run": run_labels[name],
|
|
"FID@25": get_fid(r, 25),
|
|
"FID@50": get_fid(r, 50),
|
|
"FID@100": get_fid(r, 100),
|
|
"Best FID": best,
|
|
"Train (min)": (r['history'].get('train_time_s') or 0) / 60,
|
|
})
|
|
df = pd.DataFrame(rows).sort_values("Best FID")
|
|
df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}", "FID@100": "{:.1f}",
|
|
"Best FID": "{:.1f}", "Train (min)": "{:.1f}"})
|
|
"""),
|
|
md("## 3. FID curves - progression"),
|
|
code("""\
|
|
fig, ax = plt.subplots(figsize=(10, 5))
|
|
cmap = plt.cm.viridis
|
|
for i, name in enumerate(run_names):
|
|
if name not in runs: continue
|
|
epochs, fid_vals = fid_series(runs[name])
|
|
ax.plot(epochs, fid_vals, "o-", label=run_labels[name],
|
|
color=cmap(i / len(run_names)), linewidth=2, markersize=7)
|
|
ax.set_xlabel("Epoch"); ax.set_ylabel("FID")
|
|
ax.set_title("Phase 2 - FID curves")
|
|
ax.legend(); plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 4. Training dynamics\n\nThe loss curves help explain why the visual jump happens. The objective alone is unstable; the normalized, attention-equipped recipe is where training becomes useful."),
|
|
code("""\
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
|
cmap = plt.cm.viridis
|
|
for i, name in enumerate(run_names):
|
|
if name not in runs: continue
|
|
h = runs[name]["history"]
|
|
color = cmap(i / len(run_names))
|
|
epochs = range(1, len(h["g_loss"]) + 1)
|
|
axes[0].plot(epochs, h["g_loss"], color=color, label=run_labels[name], linewidth=1.2)
|
|
if "w_dist" in h:
|
|
axes[1].plot(epochs, h["w_dist"], color=color, label=run_labels[name], linewidth=1.2)
|
|
elif "d_loss" in h:
|
|
axes[1].plot(epochs, h["d_loss"], color=color, label=run_labels[name], linewidth=1.2, linestyle="--")
|
|
axes[0].set_title("Generator loss"); axes[0].set_xlabel("Epoch"); axes[0].legend(fontsize=8)
|
|
axes[1].set_title("Wasserstein distance / D loss"); axes[1].set_xlabel("Epoch"); axes[1].legend(fontsize=8)
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("""\
|
|
## 5. Sample grids - epoch 100
|
|
|
|
- **2.1 and 2.2 collapsed** - the gray cells are the actual saved generator outputs,
|
|
not missing images. The black lines are just grid separators. Vanilla DCGAN/WGAN-GP
|
|
at this scale is still too weak, and their FIDs (>400) confirm the failure.
|
|
- **2.3 is the breakthrough** - spectral norm + GroupNorm + self-attention escape
|
|
mode collapse and produce diverse, recognisable faces.
|
|
- **2.4 (128x128) regresses** - same architecture at higher resolution at fixed
|
|
compute under-trains.
|
|
"""),
|
|
code("""\
|
|
fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))
|
|
for ax, name in zip(axes, run_names):
|
|
img_path = SAMPLES / name / "epoch_0100.png"
|
|
f100 = get_fid(runs.get(name, {}), 100) if name in runs else None
|
|
title = f"{run_labels[name]}\\nFID@100={f100:.1f}" if f100 else run_labels[name]
|
|
title += f"\\n{visual_notes.get(name, '')}"
|
|
show_image_or_missing(ax, img_path, title)
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 6. Progression - epoch 10 -> 50 -> 100"),
|
|
code("""\
|
|
check_epochs = [10, 50, 100]
|
|
for name in run_names:
|
|
if name not in runs: continue
|
|
fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))
|
|
for ax, e in zip(axes, check_epochs):
|
|
p = SAMPLES / name / f"epoch_{e:04d}.png"
|
|
f = get_fid(runs[name], e)
|
|
title = f"epoch {e}" + (f"\\nFID={f:.1f}" if f else "")
|
|
if name in {"p2_1_dcgan", "p2_2_wgan"}:
|
|
title += "\\ncollapsed output"
|
|
show_image_or_missing(ax, p, title)
|
|
fig.suptitle(run_labels[name], fontsize=11, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 7. Pairwise comparison - what each step bought us"),
|
|
code("""\
|
|
transitions = [
|
|
("2.1 -> 2.2: BCE -> Wasserstein", "p2_1_dcgan", "p2_2_wgan"),
|
|
("2.2 -> 2.3: + SN + GroupNorm + Attn", "p2_2_wgan", "p2_3_wgan_sn_attn"),
|
|
("2.3 -> 2.4: 64 -> 128 resolution", "p2_3_wgan_sn_attn", "p2_4_wgan_sn_attn_128"),
|
|
]
|
|
for title, a, b in transitions:
|
|
if a not in runs or b not in runs: continue
|
|
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
|
for ax, name in zip(axes, [a, b]):
|
|
img_path = SAMPLES / name / "epoch_0100.png"
|
|
f = get_fid(runs[name], 100)
|
|
title_text = f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name]
|
|
title_text += f"\\n{visual_notes.get(name, '')}"
|
|
show_image_or_missing(ax, img_path, title_text)
|
|
fig.suptitle(title, fontsize=12, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md(f"""\
|
|
## 8. What this phase proves
|
|
|
|
| Step | Run | Best FID | Delta vs previous |
|
|
|---|---|---:|---:|
|
|
| 2.1 DCGAN baseline | `p2_1_dcgan` | {best_fid(runs['p2_1_dcgan'])[1]:.1f} | n/a |
|
|
| 2.2 + Wasserstein-GP | `p2_2_wgan` | {best_fid(runs['p2_2_wgan'])[1]:.1f} | {best_fid(runs['p2_2_wgan'])[1] - best_fid(runs['p2_1_dcgan'])[1]:+.1f} |
|
|
| 2.3 + SN + GroupNorm + attention | `p2_3_wgan_sn_attn` | {best_fid(runs['p2_3_wgan_sn_attn'])[1]:.1f} | {best_fid(runs['p2_3_wgan_sn_attn'])[1] - best_fid(runs['p2_2_wgan'])[1]:+.1f} |
|
|
| 2.4 + 128x128 | `p2_4_wgan_sn_attn_128` | {best_fid(runs['p2_4_wgan_sn_attn_128'])[1]:.1f} | {best_fid(runs['p2_4_wgan_sn_attn_128'])[1] - best_fid(runs['p2_3_wgan_sn_attn'])[1]:+.1f} |
|
|
|
|
Changing the loss from BCE to Wasserstein-GP is not enough by itself. The
|
|
breakthrough is the combined stability recipe in 2.3: spectral normalization,
|
|
GroupNorm, and self-attention. The 128x128 run then regresses under the saved
|
|
compute budget.
|
|
|
|
**Decision:** select the 64x64 WGAN-GP recipe with spectral normalization,
|
|
GroupNorm, and self-attention as the GAN representative for the final
|
|
comparison.
|
|
|
|
**Report conclusion:** Phase 2 turns the GAN from a collapsing proxy into a
|
|
usable generator recipe, but it also shows that higher resolution is not helpful
|
|
without enough training budget.
|
|
"""),
|
|
]
|
|
write_nb("phase2_analysis", cells)
|
|
|
|
|
|
# PHASE 3 - VAE composite-loss evolution
|
|
def build_phase3():
|
|
runs = {n: load(n) for n in ["p3_1_vae", "p3_2_vae_perceptual", "p3_3_vae_patchgan"]}
|
|
best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9)
|
|
_, best_val = best_fid(runs[best_name])
|
|
|
|
cells = [
|
|
md(f"""\
|
|
# Phase 3 - VAE Progression
|
|
|
|
Phase 3 studies the VAE family after the pipeline has been locked. The baseline
|
|
VAE is fast and stable, but its MSE + KL objective tends to average away facial
|
|
detail.
|
|
|
|
This phase asks whether extra loss terms can recover sharper, more realistic
|
|
samples. The saved runs stack the losses one step at a time.
|
|
|
|
## What this phase changes
|
|
|
|
| Run | Recipe change |
|
|
|---|---|
|
|
| `p3_1_vae` | Baseline MSE + KL |
|
|
| `p3_2_vae_perceptual` | Adds VGG16 perceptual loss (`lambda_perceptual=0.1`) |
|
|
| `p3_3_vae_patchgan` | Adds PatchGAN adversarial loss (`lambda_adversarial=0.01`) |
|
|
|
|
**Headline result:** `{best_name}` reaches **best FID = {best_val:.1f}** on
|
|
prior samples. The important result is the sequence: the added losses are
|
|
complementary, not interchangeable.
|
|
"""),
|
|
md("""\
|
|
### Reference: Phase 0 baseline from the same family
|
|
|
|
`p0_vae` was MSE+KL on raw un-aligned data. Prior samples were heavily blurred
|
|
mean-faces, the textbook VAE failure mode. Phase 3 keeps the encoder/decoder
|
|
architecture and pipeline fixed and shows that adding perceptual + adversarial
|
|
terms is what actually moves the needle.
|
|
"""),
|
|
code(SHARED_IMPORTS),
|
|
md("## 1. Load experiment logs\n\nThe notebook loads the existing VAE logs and saved previews only."),
|
|
code("""\
|
|
run_names = ["p3_1_vae", "p3_2_vae_perceptual", "p3_3_vae_patchgan"]
|
|
run_labels = {
|
|
"p3_1_vae": "3.1 MSE + KL",
|
|
"p3_2_vae_perceptual": "3.2 + Perceptual",
|
|
"p3_3_vae_patchgan": "3.3 + PatchGAN",
|
|
}
|
|
runs = {name: load_log(name) for name in run_names}
|
|
runs = {k: v for k, v in runs.items() if v}
|
|
for n in run_names: print(f" {n}: {'OK' if n in runs else 'MISSING'}")
|
|
"""),
|
|
md("## 2. FID comparison table (prior samples)\n\nFID is computed on samples decoded from the prior, so it evaluates generation quality rather than only reconstruction quality."),
|
|
code("""\
|
|
rows = []
|
|
for name in run_names:
|
|
if name not in runs: continue
|
|
r = runs[name]; h = r["history"]
|
|
_, fid_vals = fid_series(r)
|
|
rows.append({
|
|
"Run": run_labels[name],
|
|
"FID@50": get_fid(r, 50),
|
|
"FID@100": get_fid(r, 100),
|
|
"Best FID": min(fid_vals) if fid_vals else None,
|
|
"Recon@100": h["recon_loss"][-1] if h.get("recon_loss") else None,
|
|
"KL@100": h["kl_loss"][-1] if h.get("kl_loss") else None,
|
|
"Train (min)": (h.get("train_time_s") or 0) / 60,
|
|
})
|
|
df = pd.DataFrame(rows).sort_values("Best FID")
|
|
df.style.format({"FID@50": "{:.1f}", "FID@100": "{:.1f}", "Best FID": "{:.1f}",
|
|
"Recon@100": "{:.4f}", "KL@100": "{:.2f}", "Train (min)": "{:.1f}"})
|
|
"""),
|
|
md("## 3. FID curves - progression"),
|
|
code("""\
|
|
fig, ax = plt.subplots(figsize=(10, 5))
|
|
cmap = plt.cm.plasma
|
|
for i, name in enumerate(run_names):
|
|
if name not in runs: continue
|
|
epochs, fid_vals = fid_series(runs[name])
|
|
ax.plot(epochs, fid_vals, "o-", label=run_labels[name],
|
|
color=cmap(i / len(run_names)), linewidth=2, markersize=7)
|
|
ax.set_xlabel("Epoch"); ax.set_ylabel("FID (prior samples)")
|
|
ax.set_title("Phase 3 - FID curves"); ax.legend()
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 4. Training loss components\n\nThe component losses explain the tradeoff: reconstruction and KL preserve the VAE structure, perceptual loss encourages visual detail, and the PatchGAN term pushes samples toward the face manifold."),
|
|
code("""\
|
|
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
|
|
cmap = plt.cm.plasma
|
|
for i, name in enumerate(run_names):
|
|
if name not in runs: continue
|
|
h = runs[name]["history"]; color = cmap(i / len(run_names))
|
|
epochs = range(1, len(h["recon_loss"]) + 1)
|
|
axes[0].plot(epochs, h["recon_loss"], color=color, label=run_labels[name])
|
|
axes[1].plot(epochs, h["kl_loss"], color=color, label=run_labels[name])
|
|
if "perc_loss" in h and any(h["perc_loss"]):
|
|
axes[2].plot(epochs, h["perc_loss"], color=color, label=run_labels[name])
|
|
axes[0].set_title("Reconstruction (MSE)"); axes[0].set_xlabel("Epoch"); axes[0].legend(fontsize=8)
|
|
axes[1].set_title("KL divergence"); axes[1].set_xlabel("Epoch"); axes[1].legend(fontsize=8)
|
|
axes[2].set_title("Perceptual (VGG16)"); axes[2].set_xlabel("Epoch"); axes[2].legend(fontsize=8)
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 5. Prior samples - epoch 100"),
|
|
code("""\
|
|
fig, axes = plt.subplots(1, 3, figsize=(13, 4.5))
|
|
for ax, name in zip(axes, run_names):
|
|
img_path = SAMPLES / name / "epoch_0100.png"
|
|
if img_path.exists():
|
|
ax.imshow(mpimg.imread(str(img_path)))
|
|
f = get_fid(runs.get(name, {}), 100) if name in runs else None
|
|
ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=10)
|
|
ax.axis("off")
|
|
fig.suptitle("Prior samples (decoded from N(0, I))", fontsize=12, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 6. Reconstructions - epoch 100"),
|
|
code("""\
|
|
fig, axes = plt.subplots(1, 3, figsize=(13, 4.5))
|
|
for ax, name in zip(axes, run_names):
|
|
img_path = SAMPLES / name / "epoch_0100_recon.png"
|
|
if img_path.exists():
|
|
ax.imshow(mpimg.imread(str(img_path)))
|
|
ax.set_title(run_labels[name], fontsize=10)
|
|
ax.axis("off")
|
|
fig.suptitle("Reconstructions (real / decoded interleaved)", fontsize=12, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 7. Progression - epoch 10 -> 50 -> 100 (prior samples)"),
|
|
code("""\
|
|
check_epochs = [10, 50, 100]
|
|
for name in run_names:
|
|
if name not in runs: continue
|
|
fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))
|
|
for ax, e in zip(axes, check_epochs):
|
|
p = SAMPLES / name / f"epoch_{e:04d}.png"
|
|
if p.exists():
|
|
ax.imshow(mpimg.imread(str(p)))
|
|
f = get_fid(runs[name], e)
|
|
ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
|
|
ax.axis("off")
|
|
fig.suptitle(run_labels[name], fontsize=11, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md(f"""\
|
|
## 8. What this phase proves
|
|
|
|
| Step | Run | Best FID | Delta vs previous |
|
|
|---|---|---:|---:|
|
|
| 3.1 MSE + KL | `p3_1_vae` | {best_fid(runs['p3_1_vae'])[1]:.1f} | n/a |
|
|
| 3.2 + perceptual | `p3_2_vae_perceptual` | {best_fid(runs['p3_2_vae_perceptual'])[1]:.1f} | {best_fid(runs['p3_2_vae_perceptual'])[1] - best_fid(runs['p3_1_vae'])[1]:+.1f} |
|
|
| 3.3 + PatchGAN | `p3_3_vae_patchgan` | {best_fid(runs['p3_3_vae_patchgan'])[1]:.1f} | {best_fid(runs['p3_3_vae_patchgan'])[1] - best_fid(runs['p3_2_vae_perceptual'])[1]:+.1f} |
|
|
|
|
The losses add different kinds of pressure. MSE + KL keeps the latent model
|
|
stable but blurry. Perceptual loss restores more visual texture. PatchGAN adds
|
|
an adversarial signal that makes the prior samples look more face-like.
|
|
|
|
**Decision:** select the composite VAE recipe for the final comparison:
|
|
MSE + 0.25 KL + 0.1 VGG perceptual + 0.01 PatchGAN adversarial.
|
|
|
|
**Report conclusion:** Phase 3 shows that the VAE can be improved through
|
|
complementary losses, but even the selected recipe remains a speed-oriented
|
|
family rather than the strongest quality candidate.
|
|
"""),
|
|
]
|
|
write_nb("phase3_analysis", cells)
|
|
|
|
|
|
# PHASE 4 - DDPM schedule, target, and width evolution
|
|
def build_phase4():
|
|
runs = {n: load(n) for n in ["p4_1_ddpm_linear", "p4_2_ddpm_cosine",
|
|
"p4_3_ddpm_vpred", "p4_4_ddpm_wider"]}
|
|
best_name = min(runs, key=lambda n: best_fid(runs[n])[1] or 9e9)
|
|
_, best_val = best_fid(runs[best_name])
|
|
|
|
cells = [
|
|
md(f"""\
|
|
# Phase 4 - DDPM Progression
|
|
|
|
Phase 4 applies the same report logic to diffusion models. The pipeline is
|
|
already fixed, so this notebook isolates the DDPM recipe: schedule, prediction
|
|
target, and backbone width.
|
|
|
|
The story is stepwise. A cosine schedule helps, v-prediction is the major gain,
|
|
and the wider backbone becomes useful only after the target and schedule are
|
|
improved.
|
|
|
|
## What this phase changes
|
|
|
|
| Run | Recipe change |
|
|
|---|---|
|
|
| `p4_1_ddpm_linear` | Linear noise schedule, epsilon-prediction |
|
|
| `p4_2_ddpm_cosine` | Cosine noise schedule |
|
|
| `p4_3_ddpm_vpred` | v-prediction target |
|
|
| `p4_4_ddpm_wider` | Wider U-Net: base channels 192 with attention at 32/16/8 |
|
|
|
|
Sampling previews use DDIM-50. Logged FID uses DDIM-100 against the saved real
|
|
reference set.
|
|
|
|
**Headline result:** `{best_name}` reaches **best FID = {best_val:.1f}**.
|
|
|
|
## How to read DDPM sample grids
|
|
|
|
The DDPM grids should not be read as the same faces improving from epoch to
|
|
epoch. GAN and VAE previews can reuse a fixed latent grid, so each position can
|
|
look like the same latent code becoming sharper over training. A DDPM preview
|
|
starts from noise and runs a stochastic reverse-diffusion sampler. Unless the
|
|
exact initial noise and sampler randomness are fixed and stored, each epoch
|
|
preview is a fresh draw from the model.
|
|
|
|
So for DDPM, the progression panels show distribution-level improvement:
|
|
cleaner faces, fewer artifacts, and better global structure. They are not
|
|
identity-by-identity refinements of the same preview images.
|
|
"""),
|
|
md("""\
|
|
### Reference: Phase 0 baseline from the same family
|
|
|
|
`p0_ddpm` was a vanilla DDPM (linear schedule, epsilon-prediction, base_ch=128) on raw
|
|
un-aligned data. Outputs were noisy face-shaped textures. Phase 4 fixes the
|
|
pipeline (aligned 64) and walks through the standard set of post-2020 DDPM
|
|
improvements one at a time.
|
|
"""),
|
|
code(SHARED_IMPORTS),
|
|
md("## 1. Load experiment logs\n\nThe notebook reads existing DDPM logs only. Sampling and FID values are already saved."),
|
|
code("""\
|
|
run_names = ["p4_1_ddpm_linear", "p4_2_ddpm_cosine", "p4_3_ddpm_vpred", "p4_4_ddpm_wider"]
|
|
run_labels = {
|
|
"p4_1_ddpm_linear": "4.1 linear / epsilon",
|
|
"p4_2_ddpm_cosine": "4.2 cosine / epsilon",
|
|
"p4_3_ddpm_vpred": "4.3 cosine / v",
|
|
"p4_4_ddpm_wider": "4.4 wider net",
|
|
}
|
|
runs = {n: load_log(n) for n in run_names}
|
|
runs = {k: v for k, v in runs.items() if v}
|
|
for n in run_names: print(f" {n}: {'OK' if n in runs else 'MISSING'}")
|
|
"""),
|
|
md("## 2. FID comparison table\n\nThe table shows whether each recipe change improves generation quality under the saved DDIM-100 FID protocol."),
|
|
code("""\
|
|
rows = []
|
|
for name in run_names:
|
|
if name not in runs: continue
|
|
r = runs[name]; _, fid_vals = fid_series(r)
|
|
rows.append({
|
|
"Run": run_labels[name],
|
|
"FID@25": get_fid(r, 25),
|
|
"FID@50": get_fid(r, 50),
|
|
"FID@100": get_fid(r, 100),
|
|
"Best FID": min(fid_vals) if fid_vals else None,
|
|
"Loss@100": r["history"]["loss"][-1],
|
|
"Train (min)": (r['history'].get('train_time_s') or 0) / 60,
|
|
})
|
|
df = pd.DataFrame(rows).sort_values("Best FID")
|
|
df.style.format({"FID@25": "{:.1f}", "FID@50": "{:.1f}", "FID@100": "{:.1f}",
|
|
"Best FID": "{:.1f}", "Loss@100": "{:.4f}", "Train (min)": "{:.1f}"})
|
|
"""),
|
|
md("## 3. FID curves - progression"),
|
|
code("""\
|
|
fig, ax = plt.subplots(figsize=(10, 5))
|
|
cmap = plt.cm.cividis
|
|
for i, name in enumerate(run_names):
|
|
if name not in runs: continue
|
|
epochs, fid_vals = fid_series(runs[name])
|
|
ax.plot(epochs, fid_vals, "o-", label=run_labels[name],
|
|
color=cmap(i / len(run_names)), linewidth=2, markersize=7)
|
|
ax.set_xlabel("Epoch"); ax.set_ylabel("FID (DDIM-100)")
|
|
ax.set_title("Phase 4 - FID curves"); ax.legend()
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 4. Training loss\n\nThe loss plot is diagnostic, but epsilon-MSE and v-MSE are different targets. FID and sample grids carry the decision."),
|
|
code("""\
|
|
fig, ax = plt.subplots(figsize=(10, 4))
|
|
cmap = plt.cm.cividis
|
|
for i, name in enumerate(run_names):
|
|
if name not in runs: continue
|
|
h = runs[name]["history"]
|
|
epochs = range(1, len(h["loss"]) + 1)
|
|
ax.plot(epochs, h["loss"], color=cmap(i / len(run_names)), label=run_labels[name], linewidth=1.3)
|
|
ax.set_xlabel("Epoch"); ax.set_ylabel("MSE on prediction target")
|
|
ax.set_title("Loss (epsilon-MSE and v-MSE are not directly comparable)")
|
|
ax.legend(); plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 5. Sample grids - epoch 100"),
|
|
code("""\
|
|
fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))
|
|
for ax, name in zip(axes, run_names):
|
|
img_path = SAMPLES / name / "epoch_0100.png"
|
|
if img_path.exists():
|
|
ax.imshow(mpimg.imread(str(img_path)))
|
|
f = get_fid(runs.get(name, {}), 100) if name in runs else None
|
|
ax.set_title(f"{run_labels[name]}\\nFID@100={f:.1f}" if f else run_labels[name], fontsize=9)
|
|
ax.axis("off")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 6. Progression - epoch 10 -> 50 -> 100\n\nRead these as fresh samples from each checkpoint, not the same DDPM images being refined over time."),
|
|
code("""\
|
|
check_epochs = [10, 50, 100]
|
|
for name in run_names:
|
|
if name not in runs: continue
|
|
fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))
|
|
for ax, e in zip(axes, check_epochs):
|
|
p = SAMPLES / name / f"epoch_{e:04d}.png"
|
|
if p.exists():
|
|
ax.imshow(mpimg.imread(str(p)))
|
|
f = get_fid(runs[name], e)
|
|
ax.set_title(f"epoch {e}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
|
|
ax.axis("off")
|
|
fig.suptitle(run_labels[name], fontsize=11, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 7. Noise schedule visualization\n\nThe cosine schedule preserves useful signal more smoothly across timesteps. That gives the model a better learning problem before v-prediction and width are added."),
|
|
code("""\
|
|
import math
|
|
T = 1000; t = np.arange(T)
|
|
betas_lin = np.linspace(1e-4, 0.02, T)
|
|
ab_lin = np.cumprod(1 - betas_lin)
|
|
s = 0.008
|
|
f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
|
|
f = f / f[0]
|
|
betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)
|
|
ab_cos = np.cumprod(1 - betas_cos)
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(13, 4))
|
|
axes[0].plot(t, ab_lin, label="linear", color="#5B8DB8", linewidth=2)
|
|
axes[0].plot(t[:len(ab_cos)], ab_cos, label="cosine", color="#E8705A", linewidth=2)
|
|
axes[0].set_xlabel("Timestep t"); axes[0].set_ylabel("alpha_bar_t (signal fraction)")
|
|
axes[0].set_title("Cumulative signal preservation"); axes[0].legend()
|
|
axes[1].plot(betas_lin, label="linear beta", color="#5B8DB8", linewidth=2)
|
|
axes[1].plot(betas_cos, label="cosine beta", color="#E8705A", linewidth=2)
|
|
axes[1].set_xlabel("Timestep t"); axes[1].set_ylabel("beta_t"); axes[1].set_title("beta schedule"); axes[1].legend()
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md(f"""\
|
|
## 8. What this phase proves
|
|
|
|
| Step | Run | Best FID | Delta vs previous |
|
|
|---|---|---:|---:|
|
|
| 4.1 linear / epsilon | `p4_1_ddpm_linear` | {best_fid(runs['p4_1_ddpm_linear'])[1]:.1f} | n/a |
|
|
| 4.2 cosine / epsilon | `p4_2_ddpm_cosine` | {best_fid(runs['p4_2_ddpm_cosine'])[1]:.1f} | {best_fid(runs['p4_2_ddpm_cosine'])[1] - best_fid(runs['p4_1_ddpm_linear'])[1]:+.1f} |
|
|
| 4.3 cosine / v | `p4_3_ddpm_vpred` | {best_fid(runs['p4_3_ddpm_vpred'])[1]:.1f} | {best_fid(runs['p4_3_ddpm_vpred'])[1] - best_fid(runs['p4_2_ddpm_cosine'])[1]:+.1f} |
|
|
| 4.4 wider net | `p4_4_ddpm_wider` | {best_fid(runs['p4_4_ddpm_wider'])[1]:.1f} | {best_fid(runs['p4_4_ddpm_wider'])[1] - best_fid(runs['p4_3_ddpm_vpred'])[1]:+.1f} |
|
|
|
|
The largest improvement is v-prediction. The wider network then helps because
|
|
the schedule and prediction target have already made the learning problem
|
|
better aligned with sample quality.
|
|
|
|
**Decision:** select the DDPM recipe with cosine schedule, v-prediction,
|
|
base_ch=192, and attention at 32/16/8 for the final comparison.
|
|
|
|
**Report conclusion:** Phase 4 turns DDPM from the textured but noisy baseline
|
|
into the strongest quality candidate for Phase 5.
|
|
"""),
|
|
]
|
|
write_nb("phase4_analysis", cells)
|
|
|
|
|
|
# PHASE 5 - Cross-family final comparison
|
|
def build_phase5():
|
|
p5 = {n: load(n) for n in ["p5_gan", "p5_vae", "p5_ddpm"]}
|
|
rows = []
|
|
for n, log in p5.items():
|
|
e, v = best_fid(log)
|
|
rows.append((n, v, e, time_min(log)))
|
|
rows.sort(key=lambda r: r[1] or 9e9)
|
|
headline = ", ".join(f"{n}={v:.1f}" for n, v, _, _ in rows)
|
|
|
|
cells = [
|
|
md(f"""\
|
|
# Phase 5 - Final Comparison
|
|
|
|
Phase 5 is the project-level comparison. It loads the already trained best
|
|
recipes from the GAN, VAE, and DDPM branches and compares their saved logs,
|
|
sample grids, and checkpoint-based interpolation diagnostics.
|
|
|
|
The earlier notebooks explain how each recipe was selected. This notebook asks
|
|
the practical question: which family gives the best quality, which is fastest,
|
|
and which one should the project recommend overall?
|
|
|
|
## What this phase changes
|
|
|
|
Nothing new is trained here. The comparison uses the saved Phase 5 artifacts:
|
|
|
|
| Family | Selected recipe |
|
|
|---|---|
|
|
| GAN | WGAN-GP with spectral norm, GroupNorm, and self-attention |
|
|
| VAE | MSE + KL + perceptual + PatchGAN losses |
|
|
| DDPM | Cosine schedule + v-prediction + wider U-Net |
|
|
|
|
**Headline FIDs from saved logs:** {headline}.
|
|
"""),
|
|
code(SHARED_IMPORTS + """
|
|
|
|
FAMILIES = {
|
|
"GAN": {"p5": "p5_gan", "color": "#5B8DB8", "label": "WGAN-GP + SN + Attn"},
|
|
"VAE": {"p5": "p5_vae", "color": "#E8B85A", "label": "VAE + Perceptual + PatchGAN"},
|
|
"DDPM": {"p5": "p5_ddpm", "color": "#E8705A", "label": "DDPM cosine v-pred wider"},
|
|
}
|
|
|
|
logs_p5 = {fam: load_log(info["p5"]) for fam, info in FAMILIES.items()}
|
|
"""),
|
|
md("## 1. Quantitative summary\n\nThe table compares the saved best-of-family runs under the same Phase 5 setup. Lower FID is better; training time gives the speed side of the tradeoff."),
|
|
code("""\
|
|
rows = []
|
|
for fam, info in FAMILIES.items():
|
|
log = logs_p5[fam]
|
|
if log is None: continue
|
|
h = log["history"]; cfg = log["config"]
|
|
_, fid_vals = fid_series(log)
|
|
rows.append({
|
|
"Family": fam,
|
|
"Architecture": info["label"],
|
|
"Resolution": f"{cfg.get('image_size')}x{cfg.get('image_size')}",
|
|
"Epochs": len(h.get("loss") or h.get("g_loss") or h.get("recon_loss") or []),
|
|
"Best FID": min(fid_vals) if fid_vals else None,
|
|
"Last FID": fid_vals[-1] if fid_vals else None,
|
|
"Train (min)": (h.get("train_time_s") or 0) / 60,
|
|
})
|
|
df = pd.DataFrame(rows).sort_values("Best FID")
|
|
df.style.format({"Best FID": "{:.1f}", "Last FID": "{:.1f}", "Train (min)": "{:.1f}"})
|
|
"""),
|
|
md("## 2. Quality vs training time\n\nThis post-hoc plot uses only the existing log summaries. It makes the practical decision visible: DDPM is best by FID, GAN is the stronger speed-quality compromise, and VAE is fastest but behind in sample quality."),
|
|
code("""\
|
|
fig, ax = plt.subplots(figsize=(7, 4.8))
|
|
plot_df = df.copy()
|
|
for _, row in plot_df.iterrows():
|
|
fam = row["Family"]
|
|
info = FAMILIES[fam]
|
|
ax.scatter(row["Train (min)"], row["Best FID"], s=120, color=info["color"], label=fam)
|
|
ax.text(row["Train (min)"] + 1.0, row["Best FID"], fam, va="center", fontsize=10)
|
|
ax.set_xlabel("Training time (min)")
|
|
ax.set_ylabel("Best FID (lower is better)")
|
|
ax.set_title("Final comparison: quality vs training time")
|
|
ax.grid(alpha=0.25)
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 3. FID curves - all three families"),
|
|
code("""\
|
|
fig, ax = plt.subplots(figsize=(10, 5))
|
|
for fam, info in FAMILIES.items():
|
|
log = logs_p5[fam]
|
|
if log is None: continue
|
|
epochs, vals = fid_series(log)
|
|
ax.plot(epochs, vals, "o-", color=info["color"], label=f"{fam} ({info['label']})",
|
|
linewidth=2, markersize=7)
|
|
ax.set_xlabel("Epoch"); ax.set_ylabel("FID")
|
|
ax.set_title("Phase 5 - FID across families")
|
|
ax.legend(); plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 4. Best-epoch sample grids\n\nThe grids support the numeric ranking with visible sample quality. DDPM and GAN produce sharper, more plausible faces than the VAE prior samples."),
|
|
code("""\
|
|
fig, axes = plt.subplots(1, 3, figsize=(16, 6))
|
|
for ax, (fam, info) in zip(axes, FAMILIES.items()):
|
|
log = logs_p5[fam]
|
|
samples_dir = SAMPLES / info["p5"]
|
|
pngs = sorted(samples_dir.glob("epoch_*.png"))
|
|
pngs = [p for p in pngs if "_recon" not in p.stem]
|
|
if not pngs:
|
|
ax.set_title(f"{fam} (no samples)"); ax.axis("off"); continue
|
|
img_path = pngs[-1] # last preview = closest to final_ema
|
|
ax.imshow(mpimg.imread(str(img_path)))
|
|
e, v = (None, None)
|
|
if log:
|
|
_, fid_vals = fid_series(log)
|
|
v = min(fid_vals) if fid_vals else None
|
|
ax.set_title(f"{fam} - {info['label']}\\n{img_path.stem}" + (f" best FID={v:.1f}" if v else ""), fontsize=10)
|
|
ax.axis("off")
|
|
fig.suptitle("Final samples from each family", fontsize=13, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 5. Training progression - early -> late\n\nFor GAN and VAE runs, preview grids may reuse a fixed latent layout, so positions can feel like they improve over time. DDPM previews are different: each grid is a fresh stochastic sample from that checkpoint unless the exact starting noise was fixed and stored. Read the DDPM row as distribution quality improving, not the same faces being polished."),
|
|
code("""\
|
|
for fam, info in FAMILIES.items():
|
|
log = logs_p5[fam]
|
|
samples_dir = SAMPLES / info["p5"]
|
|
pngs = sorted(p for p in samples_dir.glob("epoch_*.png") if "_recon" not in p.stem)
|
|
if not pngs: continue
|
|
# Pick ~4 evenly spaced previews
|
|
n_pick = min(4, len(pngs))
|
|
picks = [pngs[i * (len(pngs) - 1) // (n_pick - 1)] for i in range(n_pick)] if n_pick > 1 else pngs
|
|
fig, axes = plt.subplots(1, len(picks), figsize=(4 * len(picks), 4))
|
|
if len(picks) == 1: axes = [axes]
|
|
for ax, p in zip(axes, picks):
|
|
ax.imshow(mpimg.imread(str(p)))
|
|
ep = int(p.stem.split("_")[1])
|
|
f = get_fid(log, ep) if log else None
|
|
ax.set_title(f"epoch {ep}" + (f"\\nFID={f:.1f}" if f else ""), fontsize=9)
|
|
ax.axis("off")
|
|
fig.suptitle(f"{fam} - {info['label']}", fontsize=11, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("## 6. Per-family training loss\n\nThe losses are not directly comparable across families, but they confirm that each saved recipe ran through its expected optimization path."),
|
|
code("""\
|
|
fig, axes = plt.subplots(1, 3, figsize=(18, 4))
|
|
for ax, (fam, info) in zip(axes, FAMILIES.items()):
|
|
log = logs_p5[fam]
|
|
if not log: ax.set_title(f"{fam} (missing)"); ax.axis("off"); continue
|
|
h = log["history"]; c = info["color"]
|
|
if fam == "GAN":
|
|
ax.plot(h["g_loss"], label="G loss", color=c, linewidth=1.2)
|
|
if "w_dist" in h:
|
|
ax.plot(h["w_dist"], label="W-dist", color=c, linewidth=1.2, linestyle="--")
|
|
ax.set_ylabel("Loss / W-distance")
|
|
elif fam == "VAE":
|
|
ax.plot(h["recon_loss"], label="recon", color=c, linewidth=1.2)
|
|
ax2 = ax.twinx()
|
|
ax2.plot(h["kl_loss"], label="KL", color=c, alpha=0.5, linestyle="--")
|
|
ax2.set_ylabel("KL")
|
|
ax.set_ylabel("Recon")
|
|
else: # DDPM
|
|
ax.plot(h["loss"], color=c, linewidth=1.2, label="loss")
|
|
ax.set_ylabel("MSE on v-prediction")
|
|
ax.set_xlabel("Epoch"); ax.set_title(f"{fam}"); ax.legend(loc="upper right", fontsize=8)
|
|
plt.tight_layout(); plt.show()
|
|
"""),
|
|
md("""\
|
|
## 7. Latent interpolation - GAN and VAE
|
|
|
|
Smooth interpolation between two latent codes reveals whether the generator has
|
|
learned a continuous manifold rather than a sparse memorisation. DDPM has no
|
|
encoder, so this section is GAN/VAE only.
|
|
|
|
**Checkpoint loading note:** the cell below uses the same priority as
|
|
`tools/sampling.py`: `final_ema` first, then `best_ema` as fallback. This avoids
|
|
using a best-FID EMA snapshot that may have been saved very early for a
|
|
slowly-converging run.
|
|
"""),
|
|
code("""\
|
|
import sys
|
|
sys.path.insert(0, str(GENERATOR_ROOT))
|
|
import torch
|
|
from src.models import get_model
|
|
from src.utils import load_config
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
def load_ema_model(run_name, config_name):
|
|
cfg = load_config(str(GENERATOR_ROOT / "configs" / "phase5" / config_name))
|
|
model_obj, _ = get_model(cfg)
|
|
model = model_obj[0] if isinstance(model_obj, tuple) else model_obj
|
|
# final_ema first; see the checkpoint loading note above.
|
|
for fname in [f"{run_name}_final_ema.pt", f"{run_name}_best_ema.pt"]:
|
|
p = GENERATOR_ROOT / "outputs" / "models" / fname
|
|
if not p.exists(): continue
|
|
sd = torch.load(p, map_location=DEVICE, weights_only=True)
|
|
missing, unexpected = model.load_state_dict(sd, strict=False)
|
|
if not missing and not unexpected:
|
|
print(f" loaded {fname}")
|
|
return model.to(DEVICE).eval(), cfg
|
|
raise FileNotFoundError(f"No usable EMA checkpoint for {run_name}")
|
|
|
|
|
|
def slerp(z1, z2, t):
|
|
z1n = z1 / z1.norm(); z2n = z2 / z2.norm()
|
|
omega = torch.acos((z1n * z2n).sum().clamp(-1, 1))
|
|
if omega.abs() < 1e-6: return (1 - t) * z1 + t * z2
|
|
return (torch.sin((1 - t) * omega) / torch.sin(omega)) * z1 + \\
|
|
(torch.sin(t * omega) / torch.sin(omega)) * z2
|
|
"""),
|
|
code("""\
|
|
# GAN slerp interpolation
|
|
try:
|
|
gan_model, gan_cfg = load_ema_model("p5_gan", "p5_gan.json")
|
|
latent_dim = gan_cfg.get("latent_dim", 128)
|
|
z1 = torch.randn(1, latent_dim, 1, 1, device=DEVICE)
|
|
z2 = torch.randn(1, latent_dim, 1, 1, device=DEVICE)
|
|
with torch.no_grad():
|
|
ts = torch.linspace(0, 1, 10)
|
|
zs = torch.cat([slerp(z1, z2, t) for t in ts])
|
|
imgs = gan_model(zs).clamp(-1, 1)
|
|
imgs = (imgs + 1) / 2
|
|
fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))
|
|
for ax, img, t in zip(axes, imgs.cpu(), ts):
|
|
ax.imshow(img.permute(1, 2, 0).numpy())
|
|
ax.set_title(f"t={float(t):.2f}", fontsize=8)
|
|
ax.axis("off")
|
|
fig.suptitle("GAN - slerp latent interpolation z1 -> z2", fontsize=11, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
except Exception as e:
|
|
print(f"GAN interpolation skipped: {e}")
|
|
"""),
|
|
code("""\
|
|
# VAE encode-interpolate-decode
|
|
try:
|
|
from src.data import GeneratorDataset, get_transform
|
|
vae_model, vae_cfg = load_ema_model("p5_vae", "p5_vae.json")
|
|
ds = GeneratorDataset(str(PROJECT_ROOT / vae_cfg["data_dir"]),
|
|
sources=vae_cfg.get("sources", ["wiki"]),
|
|
transform=get_transform(vae_cfg["image_size"], augment=False))
|
|
real = torch.stack([ds[i] for i in range(2)]).to(DEVICE)
|
|
with torch.no_grad():
|
|
mu, logvar = vae_model.encode(real)
|
|
z1, z2 = mu[0:1], mu[1:2]
|
|
ts = torch.linspace(0, 1, 10)
|
|
zs = torch.cat([(1 - t) * z1 + t * z2 for t in ts])
|
|
imgs = vae_model.decode(zs).clamp(-1, 1)
|
|
imgs = (imgs + 1) / 2
|
|
fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))
|
|
for ax, img, t in zip(axes, imgs.cpu(), ts):
|
|
ax.imshow(img.permute(1, 2, 0).numpy())
|
|
ax.set_title(f"t={float(t):.2f}", fontsize=8)
|
|
ax.axis("off")
|
|
fig.suptitle("VAE - linear latent interpolation (encoded real images)", fontsize=11, fontweight="bold")
|
|
plt.tight_layout(); plt.show()
|
|
except Exception as e:
|
|
print(f"VAE interpolation skipped: {e}")
|
|
"""),
|
|
md(f"""\
|
|
## 8. Final decision
|
|
|
|
| Family | Architecture | Best FID | Train time |
|
|
|---|---|---:|---:|
|
|
""" + "\n".join(
|
|
f"| {n.split('_')[1].upper()} | "
|
|
f"{ {'p5_gan':'WGAN-GP + SN + Attn', 'p5_vae':'VAE + Perceptual + PatchGAN', 'p5_ddpm':'DDPM cosine v-pred wider'}[n] } | "
|
|
f"{v:.1f} | {t:.1f} min |"
|
|
for n, v, _, t in rows
|
|
) + f"""
|
|
|
|
The saved logs rank DDPM best by FID, GAN close enough to be the practical
|
|
speed-quality alternative, and VAE clearly behind for prior-sample quality. The
|
|
VAE remains useful when fast iteration or reconstruction behavior matters, but
|
|
it is not the strongest final generator.
|
|
|
|
**Practical sampling notes** (encoded in `tools/sampling.py`):
|
|
- Prefer `*_final_ema.pt` before `*_best_ema.pt`.
|
|
- DDPM sampling at DDIM-50 matches the training preview; DDIM-100 is for FID only.
|
|
- GAN truncation is **off by default** (matches training); enable for sharper but less
|
|
diverse samples.
|
|
|
|
**Final recommendation:** choose DDPM when maximum visual quality is the priority.
|
|
Choose GAN when speed and quality both matter. Use VAE for fast prototyping or
|
|
reconstruction-focused analysis.
|
|
|
|
**Report conclusion:** Across the full pipeline, the project moves from raw
|
|
baseline failure to a locked aligned pipeline, then to family-specific recipes.
|
|
The final comparison supports DDPM as the best-quality generator and GAN as the
|
|
best practical compromise.
|
|
"""),
|
|
]
|
|
write_nb("phase5_analysis", cells)
|
|
|
|
|
|
# PHASE 6 - Final selected sample showcase
|
|
def build_phase6():
|
|
cells = [
|
|
md("""\
|
|
# Phase 6 - Final Selected Samples
|
|
|
|
This final notebook is a small showcase chapter. Phase 5 compared the model
|
|
families quantitatively; this notebook selects the three strongest individual
|
|
images from a large generated pool for each final Phase 5 recipe.
|
|
|
|
The candidate pool is the saved final-comparison output:
|
|
20 grids per architecture x 16 images per grid = 320 candidates for each model
|
|
family, or 960 individual generated images total.
|
|
|
|
## What this phase changes
|
|
|
|
No model is trained or fine-tuned. No FID is recomputed. The notebook only
|
|
splits already generated Phase 5 grids into individual images, scores them with
|
|
a deterministic visual-quality heuristic, and saves the top three examples per
|
|
architecture.
|
|
|
|
The score is useful for reproducible curation, but it is not a scientific
|
|
quality metric. The Phase 5 FID ranking remains the main quantitative result.
|
|
"""),
|
|
code(SHARED_IMPORTS + """\
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
from skimage.color import rgb2gray, rgb2hsv
|
|
from skimage.filters import laplace
|
|
|
|
SHOWCASE_ROOT = OUTPUTS / "samples" / "final_showcase"
|
|
FINAL_COMPARISON_ROOT = OUTPUTS / "samples" / "final_comparison"
|
|
SHOWCASE_ROOT.mkdir(parents=True, exist_ok=True)
|
|
(SHOWCASE_ROOT / "top_tiles").mkdir(parents=True, exist_ok=True)
|
|
|
|
RUNS = {
|
|
"p5_gan": "GAN - WGAN-GP + SN + Attn",
|
|
"p5_vae": "VAE - perceptual + PatchGAN",
|
|
"p5_ddpm": "DDPM - cosine v-pred wider",
|
|
}
|
|
"""),
|
|
md("""\
|
|
## 1. Candidate pool
|
|
|
|
Each `grid_*.png` is a 4x4 generated sample matrix. The cell below confirms
|
|
how many candidates are available per architecture.
|
|
"""),
|
|
code("""\
|
|
rows = []
|
|
for run, label in RUNS.items():
|
|
grids = sorted((FINAL_COMPARISON_ROOT / run).glob("grid_*.png"))
|
|
rows.append({
|
|
"run": run,
|
|
"architecture": label,
|
|
"grids": len(grids),
|
|
"candidate_images": len(grids) * 16,
|
|
"folder": str((FINAL_COMPARISON_ROOT / run).relative_to(GENERATOR_ROOT)),
|
|
})
|
|
candidate_pool_df = pd.DataFrame(rows)
|
|
display(candidate_pool_df)
|
|
"""),
|
|
md("""\
|
|
## 2. Selection method
|
|
|
|
The selector scores every tile using four simple image properties:
|
|
|
|
- exposure: avoids images that are too dark or too bright
|
|
- contrast: avoids flat gray outputs
|
|
- detail: favors sharper structure
|
|
- color: avoids very washed-out samples
|
|
|
|
This is a report curation tool, not a replacement for FID or visual judgment.
|
|
It simply makes the final showcase deterministic and auditable.
|
|
"""),
|
|
code("""\
|
|
def split_grid(path, nrow=4, padding=2):
|
|
img = Image.open(path).convert("RGB")
|
|
arr = np.asarray(img)
|
|
h, w = arr.shape[:2]
|
|
tile_w = (w - (nrow + 1) * padding) // nrow
|
|
nrows = (h - padding) // (tile_w + padding)
|
|
tiles = []
|
|
for r in range(nrows):
|
|
for c in range(nrow):
|
|
x0 = padding + c * (tile_w + padding)
|
|
y0 = padding + r * (tile_w + padding)
|
|
tile = arr[y0:y0 + tile_w, x0:x0 + tile_w]
|
|
if tile.shape[0] == tile_w and tile.shape[1] == tile_w:
|
|
tiles.append((r, c, tile))
|
|
return tiles
|
|
|
|
|
|
def score_tile(tile):
|
|
x = tile.astype("float32") / 255.0
|
|
gray = rgb2gray(x)
|
|
hsv = rgb2hsv(x)
|
|
mean = float(gray.mean())
|
|
std = float(gray.std())
|
|
saturation = float(hsv[..., 1].mean())
|
|
sharp = float(np.var(laplace(gray)))
|
|
exposure = max(0.0, 1.0 - abs(mean - 0.48) / 0.32)
|
|
contrast = min(std / 0.24, 1.0)
|
|
detail = min(np.log1p(sharp * 6000.0) / 4.0, 1.0)
|
|
color = min(saturation / 0.38, 1.0)
|
|
score = 0.30 * exposure + 0.30 * contrast + 0.25 * detail + 0.15 * color
|
|
if std < 0.035:
|
|
score *= 0.15
|
|
if mean < 0.12 or mean > 0.88:
|
|
score *= 0.4
|
|
return {
|
|
"score": float(score),
|
|
"mean": mean,
|
|
"std": std,
|
|
"saturation": saturation,
|
|
"sharpness": sharp,
|
|
"exposure_score": exposure,
|
|
"contrast_score": contrast,
|
|
"detail_score": detail,
|
|
"color_score": color,
|
|
}
|
|
|
|
|
|
records = []
|
|
for run, label in RUNS.items():
|
|
grids = sorted((FINAL_COMPARISON_ROOT / run).glob("grid_*.png"))
|
|
for grid_path in grids:
|
|
grid_index = int(grid_path.stem.split("_")[-1])
|
|
for tile_index, (row, col, tile) in enumerate(split_grid(grid_path), start=1):
|
|
records.append({
|
|
"run": run,
|
|
"architecture": label,
|
|
"grid": grid_path.name,
|
|
"grid_index": grid_index,
|
|
"tile_index": tile_index,
|
|
"row": row,
|
|
"col": col,
|
|
"source_path": str(grid_path.relative_to(GENERATOR_ROOT)),
|
|
**score_tile(tile),
|
|
})
|
|
|
|
candidate_scores_df = pd.DataFrame(records)
|
|
candidate_scores_df.to_csv(SHOWCASE_ROOT / "candidate_scores.csv", index=False)
|
|
display(candidate_scores_df.groupby("run")["score"].agg(["count", "mean", "max"]).reset_index())
|
|
"""),
|
|
md("""\
|
|
## 3. Top three per architecture
|
|
|
|
The cell below saves the selected individual images and a combined showcase
|
|
panel under `generator/outputs/samples/final_showcase`.
|
|
"""),
|
|
code("""\
|
|
selected_parts = []
|
|
for run, group in candidate_scores_df.groupby("run", sort=False):
|
|
selected = group.sort_values("score", ascending=False).head(3).copy()
|
|
selected["rank"] = range(1, len(selected) + 1)
|
|
out_dir = SHOWCASE_ROOT / "top_tiles" / run
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
tile_paths = []
|
|
for _, row in selected.iterrows():
|
|
grid_path = GENERATOR_ROOT / row["source_path"]
|
|
tile = split_grid(grid_path)[int(row["tile_index"]) - 1][2]
|
|
tile_path = out_dir / f"rank_{int(row['rank']):02d}_{grid_path.stem}_tile_{int(row['tile_index']):02d}.png"
|
|
Image.fromarray(tile).save(tile_path)
|
|
tile_paths.append(str(tile_path.relative_to(GENERATOR_ROOT)))
|
|
selected["tile_path"] = tile_paths
|
|
selected_parts.append(selected)
|
|
|
|
selected_top3_df = pd.concat(selected_parts, ignore_index=True).sort_values(["run", "rank"])
|
|
selected_top3_df.to_csv(SHOWCASE_ROOT / "selected_top3.csv", index=False)
|
|
selected_top3_df.to_json(SHOWCASE_ROOT / "selected_top3.json", orient="records", indent=2)
|
|
|
|
tile_size = 128
|
|
label_h = 46
|
|
cols = 3
|
|
rows = len(RUNS)
|
|
panel = Image.new("RGB", (cols * tile_size, rows * (tile_size + label_h)), "white")
|
|
draw = ImageDraw.Draw(panel)
|
|
try:
|
|
font = ImageFont.truetype("arial.ttf", 13)
|
|
font_small = ImageFont.truetype("arial.ttf", 11)
|
|
except Exception:
|
|
font = ImageFont.load_default()
|
|
font_small = ImageFont.load_default()
|
|
|
|
for r, (run, label) in enumerate(RUNS.items()):
|
|
group = selected_top3_df[selected_top3_df["run"] == run].sort_values("rank")
|
|
y_base = r * (tile_size + label_h)
|
|
draw.text((4, y_base + 2), label, fill="black", font=font)
|
|
for c, (_, row) in enumerate(group.iterrows()):
|
|
tile = Image.open(GENERATOR_ROOT / row["tile_path"]).convert("RGB").resize((tile_size, tile_size), Image.Resampling.BICUBIC)
|
|
x = c * tile_size
|
|
y = y_base + label_h
|
|
panel.paste(tile, (x, y))
|
|
draw.text((x + 4, y_base + 22), f"rank {int(row['rank'])} score {row['score']:.3f}", fill="black", font=font_small)
|
|
|
|
panel_path = SHOWCASE_ROOT / "phase5_top3_panel.png"
|
|
panel.save(panel_path)
|
|
display(selected_top3_df[["run", "rank", "score", "grid", "tile_index", "tile_path"]])
|
|
print(f"Saved panel: {panel_path.relative_to(GENERATOR_ROOT)}")
|
|
"""),
|
|
md("""\
|
|
## 4. Final selected images
|
|
|
|
These are the top three selected images for each Phase 5 architecture.
|
|
"""),
|
|
code("""\
|
|
panel_path = SHOWCASE_ROOT / "phase5_top3_panel.png"
|
|
plt.figure(figsize=(8, 10))
|
|
plt.imshow(mpimg.imread(str(panel_path)))
|
|
plt.axis("off")
|
|
plt.title("Phase 5 selected top-3 images per architecture")
|
|
plt.show()
|
|
"""),
|
|
md("""\
|
|
## 5. Report conclusion
|
|
|
|
The showcase supports the Phase 5 conclusion visually: DDPM gives the cleanest
|
|
best-case samples, GAN is close and sharper than the VAE in many examples, and
|
|
the VAE remains smoother and more conservative. These images are curated from a
|
|
large generated pool, so they should be used as final qualitative examples, not
|
|
as a replacement for the full distribution-level metrics.
|
|
"""),
|
|
]
|
|
write_nb("phase6_final_showcase", cells)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("Building notebooks...")
|
|
build_phase0()
|
|
build_phase1()
|
|
build_phase2()
|
|
build_phase3()
|
|
build_phase4()
|
|
build_phase5()
|
|
build_phase6()
|
|
print("Done.")
|