669 lines
28 KiB
Plaintext
669 lines
28 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Phase 5 — Cross-Family Comparison\n",
|
||
"\n",
|
||
"Best-of-each finalist retrained for **200 epochs** under identical data conditions.\n",
|
||
"\n",
|
||
"| Family | Config | Resolution | Key design |\n",
|
||
"|--------|--------|-----------|------------|\n",
|
||
"| GAN | `p5_gan` | 128×128 | WGAN-GP + SpectralNorm + GroupNorm + Self-Attention |\n",
|
||
"| VAE | `p5_vae` | 64×64 | Convolutional VAE + VGG perceptual + PatchGAN |\n",
|
||
"| DDPM | `p5_ddpm` | 64×64 | Wider U-Net (192ch) + cosine schedule + v-prediction |\n",
|
||
"\n",
|
||
"> **Resolution note**: GAN runs at 128×128 (best architecture from Phase 2.4), while VAE and DDPM run at 64×64. FID is measured at each model's native resolution against real images at that resolution, so scores are not directly numerically comparable across families — they are indicators of within-family improvement."
|
||
],
|
||
"id": "d0000001"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"import json\n",
|
||
"import sys\n",
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import matplotlib.image as mpimg\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
|
||
"\n",
|
||
"OUTPUTS = Path(\"../outputs\")\n",
|
||
"LOGS = OUTPUTS / \"logs\"\n",
|
||
"SAMPLES = OUTPUTS / \"samples\"\n",
|
||
"\n",
|
||
"# Phase 5 finalists and their best phase-2/3/4 counterparts (100-epoch comparison)\n",
|
||
"FAMILIES = {\n",
|
||
" \"GAN\": {\"p5\": \"p5_gan\", \"p4ep\": \"p2_4_wgan_sn_attn_128\", \"label\": \"WGAN-GP+SN+Attn\", \"color\": \"#5B8DB8\"},\n",
|
||
" \"VAE\": {\"p5\": \"p5_vae\", \"p4ep\": \"p3_3_vae_patchgan\", \"label\": \"VAE+Perc+PatchGAN\",\"color\": \"#E8705A\"},\n",
|
||
" \"DDPM\": {\"p5\": \"p5_ddpm\", \"p4ep\": \"p4_4_ddpm_wider\", \"label\": \"DDPM wider 192ch\", \"color\": \"#6ABF69\"},\n",
|
||
"}"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000002"
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1. Load logs"
|
||
],
|
||
"id": "d0000003"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"def load_log(run_name):\n",
|
||
" p = LOGS / f\"{run_name}.json\"\n",
|
||
" if p.exists():\n",
|
||
" with open(p) as f:\n",
|
||
" return json.load(f)\n",
|
||
" return None\n",
|
||
"\n",
|
||
"def get_fid(log, epoch):\n",
|
||
" if log is None:\n",
|
||
" return None\n",
|
||
" fid = log[\"history\"][\"fid\"]\n",
|
||
" return fid.get(str(epoch), fid.get(epoch, None))\n",
|
||
"\n",
|
||
"logs_p5 = {fam: load_log(info[\"p5\"]) for fam, info in FAMILIES.items()}\n",
|
||
"logs_p4 = {fam: load_log(info[\"p4ep\"]) for fam, info in FAMILIES.items()}\n",
|
||
"\n",
|
||
"for fam in FAMILIES:\n",
|
||
" p5_ok = \"✓\" if logs_p5[fam] else \"✗\"\n",
|
||
" p4_ok = \"✓\" if logs_p4[fam] else \"✗\"\n",
|
||
" print(f\" {fam}: 200ep={p5_ok} 100ep={p4_ok}\")"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000004"
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2. Quantitative Summary Table"
|
||
],
|
||
"id": "d0000005"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"rows = []\n",
|
||
"for fam, info in FAMILIES.items():\n",
|
||
" log = logs_p5[fam]\n",
|
||
" train_time = log.get(\"history\", {}).get(\"train_time_s\") if log else None\n",
|
||
" train_min = f\"{train_time / 60:.1f}\" if train_time else \"?\"\n",
|
||
" rows.append({\n",
|
||
" \"Family\": fam,\n",
|
||
" \"Model\": info[\"label\"],\n",
|
||
" \"Res\": log.get(\"config\", log).get(\"image_size\", \"?\") if log else \"?\",\n",
|
||
" \"Params\": log.get(\"n_params\") if log else None,\n",
|
||
" \"Train (min)\": train_min,\n",
|
||
" \"FID@100\": get_fid(log, 100),\n",
|
||
" \"FID@150\": get_fid(log, 150),\n",
|
||
" \"FID@200\": get_fid(log, 200),\n",
|
||
" # IS and LPIPS filled in by Section 6\n",
|
||
" \"IS ↑\": None,\n",
|
||
" \"LPIPS ↑\": None,\n",
|
||
" })\n",
|
||
"\n",
|
||
"df = pd.DataFrame(rows).set_index(\"Family\")\n",
|
||
"\n",
|
||
"\n",
|
||
"def fmt_params(v):\n",
|
||
" if v is None:\n",
|
||
" return \"?\"\n",
|
||
" if v >= 1_000_000:\n",
|
||
" return f\"{v / 1_000_000:.1f}M\"\n",
|
||
" if v >= 1_000:\n",
|
||
" return f\"{v / 1_000:.0f}K\"\n",
|
||
" return str(v)\n",
|
||
"\n",
|
||
"\n",
|
||
"df_display = df.copy()\n",
|
||
"df_display[\"Params\"] = df_display[\"Params\"].apply(fmt_params)\n",
|
||
"df_display.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df_display})"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000006"
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3. FID Curves — All Three Families"
|
||
],
|
||
"id": "d0000007"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
|
||
"\n",
|
||
"for fam, info in FAMILIES.items():\n",
|
||
" c = info[\"color\"]\n",
|
||
" # Phase 5 (200ep) — solid\n",
|
||
" log = logs_p5[fam]\n",
|
||
" if log:\n",
|
||
" fid_dict = log[\"history\"][\"fid\"]\n",
|
||
" eps = sorted(int(k) for k in fid_dict)\n",
|
||
" fids = [fid_dict[str(e)] for e in eps]\n",
|
||
" axes[0].plot(eps, fids, \"-o\", color=c, linewidth=2, markersize=6,\n",
|
||
" label=f\"{fam} 200ep (FID@200={fid_dict.get('200','?'):.1f})\" if isinstance(fid_dict.get('200'), float) else f\"{fam} 200ep\")\n",
|
||
"\n",
|
||
" # Phase 4/3/2 best (100ep) — dashed, same colour\n",
|
||
" log4 = logs_p4[fam]\n",
|
||
" if log4:\n",
|
||
" fid4 = log4[\"history\"][\"fid\"]\n",
|
||
" eps4 = sorted(int(k) for k in fid4)\n",
|
||
" fids4 = [fid4[str(e)] for e in eps4]\n",
|
||
" axes[0].plot(eps4, fids4, \"--\", color=c, linewidth=1.2, alpha=0.55,\n",
|
||
" label=f\"{fam} 100ep\")\n",
|
||
"\n",
|
||
"axes[0].set_xlabel(\"Epoch\")\n",
|
||
"axes[0].set_ylabel(\"FID (lower is better)\")\n",
|
||
"axes[0].set_title(\"FID Curves — Phase 5 (solid) vs best 100-ep (dashed)\")\n",
|
||
"axes[0].legend(fontsize=8)\n",
|
||
"\n",
|
||
"# Bar chart: FID at 100 vs 200 epochs per family\n",
|
||
"fams = list(FAMILIES.keys())\n",
|
||
"fid100 = [get_fid(logs_p5[f], 100) for f in fams]\n",
|
||
"fid200 = [get_fid(logs_p5[f], 200) for f in fams]\n",
|
||
"x = np.arange(len(fams))\n",
|
||
"w = 0.35\n",
|
||
"bars1 = axes[1].bar(x - w/2, [v or 0 for v in fid100], w, label=\"FID@100\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=0.6)\n",
|
||
"bars2 = axes[1].bar(x + w/2, [v or 0 for v in fid200], w, label=\"FID@200\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=1.0)\n",
|
||
"axes[1].set_xticks(x); axes[1].set_xticklabels(fams)\n",
|
||
"axes[1].set_ylabel(\"FID\"); axes[1].set_title(\"FID: 100ep vs 200ep per family\")\n",
|
||
"axes[1].legend()\n",
|
||
"for bar in list(bars1) + list(bars2):\n",
|
||
" h = bar.get_height()\n",
|
||
" if h > 0:\n",
|
||
" axes[1].text(bar.get_x() + bar.get_width()/2, h + 1, f\"{h:.0f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n",
|
||
"\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000008"
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4. Sample Grids — Epoch 200"
|
||
],
|
||
"id": "d0000009"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"fig, axes = plt.subplots(1, 3, figsize=(16, 6))\n",
|
||
"\n",
|
||
"for idx, (fam, info) in enumerate(FAMILIES.items()):\n",
|
||
" ax = axes[idx]\n",
|
||
" img_path = SAMPLES / info[\"p5\"] / \"epoch_0200.png\"\n",
|
||
" log = logs_p5[fam]\n",
|
||
" fid200 = get_fid(log, 200)\n",
|
||
" if img_path.exists():\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" title = f\"{fam}: {info['label']}\\n\"\n",
|
||
" if fid200:\n",
|
||
" res = log[\"config\"].get(\"image_size\", \"?\")\n",
|
||
" title += f\"FID@200={fid200:.1f} ({res}×{res})\"\n",
|
||
" ax.set_title(title, fontsize=9)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes, fontsize=12)\n",
|
||
" ax.set_title(f\"{fam}: {info['label']}\", fontsize=9)\n",
|
||
" ax.axis(\"off\")\n",
|
||
"\n",
|
||
"fig.suptitle(\"Phase 5 — Epoch 200 Sample Grids (4×4, prior samples)\", fontsize=13, fontweight=\"bold\")\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000010"
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 5. Training Progression — Epoch 10 → 50 → 100 → 200"
|
||
],
|
||
"id": "d0000011"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"check_epochs = [10, 50, 100, 200]\n",
|
||
"\n",
|
||
"for fam, info in FAMILIES.items():\n",
|
||
" run = info[\"p5\"]\n",
|
||
" log = logs_p5[fam]\n",
|
||
" if log is None:\n",
|
||
" print(f\"{fam}: not yet run\")\n",
|
||
" continue\n",
|
||
"\n",
|
||
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(16, 4))\n",
|
||
" for ax, ep in zip(axes, check_epochs):\n",
|
||
" img_path = SAMPLES / run / f\"epoch_{ep:04d}.png\"\n",
|
||
" if img_path.exists():\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" fid = get_fid(log, ep)\n",
|
||
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
|
||
" transform=ax.transAxes)\n",
|
||
" ax.axis(\"off\")\n",
|
||
" fig.suptitle(f\"{fam} ({info['label']}) — 200-epoch progression\", fontsize=11, fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000012"
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 6. Extended Metrics — IS and LPIPS Diversity\n",
|
||
"\n",
|
||
"Requires loading trained model weights. Run after all phase 5 models have finished.\n",
|
||
"Generates 5 000 samples per model and computes IS and LPIPS over 200 random pairs."
|
||
],
|
||
"id": "d0000013"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"import sys\n",
|
||
"sys.path.insert(0, \"..\")\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"from src.utils import load_config\n",
|
||
"from src.models import get_model\n",
|
||
"from src.training.metrics import compute_is, compute_lpips_diversity\n",
|
||
"from src.training.diffusion import cosine_betas, make_alpha_bars, ddim_sample\n",
|
||
"from src.training.ema import EMA\n",
|
||
"\n",
|
||
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||
"N_SAMPLE = 5_000\n",
|
||
"\n",
|
||
"def load_ema_model(run_name, config_path):\n",
|
||
" \"\"\"Load the best EMA weights for a given phase-5 run.\"\"\"\n",
|
||
" cfg = load_config(str(Path(\"../configs/phase5\") / config_path))\n",
|
||
" model_obj, kind = get_model(cfg)\n",
|
||
" ema_path = Path(\"../outputs/models\") / f\"{run_name}_best_ema.pt\"\n",
|
||
" if not ema_path.exists():\n",
|
||
" ema_path = Path(\"../outputs/models\") / f\"{run_name}_final_ema.pt\"\n",
|
||
" if isinstance(model_obj, tuple):\n",
|
||
" model = model_obj[0] # generator for GAN\n",
|
||
" else:\n",
|
||
" model = model_obj\n",
|
||
" model.load_state_dict(torch.load(ema_path, map_location=DEVICE))\n",
|
||
" return model.to(DEVICE).eval(), cfg, kind\n",
|
||
"\n",
|
||
"\n",
|
||
"@torch.no_grad()\n",
|
||
"def generate_samples(run_name, config_path, n=N_SAMPLE):\n",
|
||
" model, cfg, kind = load_ema_model(run_name, config_path)\n",
|
||
" image_size = cfg.get(\"image_size\", 64)\n",
|
||
"\n",
|
||
" if kind == \"wgan\":\n",
|
||
" latent_dim = cfg.get(\"latent_dim\", 128)\n",
|
||
" imgs = torch.cat([\n",
|
||
" model(torch.randn(min(64, n - i), latent_dim, 1, 1, device=DEVICE))\n",
|
||
" for i in range(0, n, 64)\n",
|
||
" ])[:n].cpu()\n",
|
||
"\n",
|
||
" elif kind == \"vae\":\n",
|
||
" imgs = torch.cat([\n",
|
||
" model.sample(min(64, n - i), DEVICE)\n",
|
||
" for i in range(0, n, 64)\n",
|
||
" ])[:n].cpu()\n",
|
||
"\n",
|
||
" elif kind == \"ddpm\":\n",
|
||
" schedule = cfg.get(\"noise_schedule\", \"cosine\")\n",
|
||
" pred_type = cfg.get(\"pred_type\", \"v\")\n",
|
||
" T = cfg.get(\"T\", 1000)\n",
|
||
" from src.training.diffusion import cosine_betas, linear_betas, make_alpha_bars, ddim_sample\n",
|
||
" betas = (cosine_betas(T) if schedule == \"cosine\" else linear_betas(T)).to(DEVICE)\n",
|
||
" ab = make_alpha_bars(betas)\n",
|
||
" imgs = ddim_sample(model, n, image_size, ab, n_steps=100,\n",
|
||
" pred_type=pred_type, device=DEVICE, batch_size=32)\n",
|
||
" return imgs\n",
|
||
"\n",
|
||
"\n",
|
||
"print(\"Run this cell once all phase-5 models have completed.\")\n",
|
||
"print(\"Expected to take ~5–10 minutes per family on an RTX 3090.\")"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000014"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"# ── Compute IS and LPIPS per family ──────────────────────────────────────────\n",
|
||
"# Uncomment when models are ready.\n",
|
||
"\n",
|
||
"extended_metrics = {}\n",
|
||
"\n",
|
||
"for fam, info in FAMILIES.items():\n",
|
||
" run = info[\"p5\"]\n",
|
||
" config = f\"{run}.json\"\n",
|
||
" try:\n",
|
||
" print(f\"\\n{fam}: generating {N_SAMPLE} samples...\")\n",
|
||
" imgs = generate_samples(run, config)\n",
|
||
"\n",
|
||
" print(f\" Computing IS...\")\n",
|
||
" is_mean, is_std = compute_is(imgs, device=DEVICE)\n",
|
||
"\n",
|
||
" print(f\" Computing LPIPS diversity...\")\n",
|
||
" lpips = compute_lpips_diversity(imgs, n_pairs=200, device=DEVICE)\n",
|
||
"\n",
|
||
" extended_metrics[fam] = {\"IS_mean\": is_mean, \"IS_std\": is_std, \"LPIPS\": lpips}\n",
|
||
" print(f\" IS = {is_mean:.2f} ± {is_std:.2f} LPIPS = {lpips:.4f}\")\n",
|
||
"\n",
|
||
" except Exception as e:\n",
|
||
" print(f\" Skipped ({e})\")\n",
|
||
" extended_metrics[fam] = {}\n",
|
||
"\n",
|
||
"# Merge with FID table\n",
|
||
"for fam in FAMILIES:\n",
|
||
" em = extended_metrics.get(fam, {})\n",
|
||
" idx = list(FAMILIES.keys()).index(fam)\n",
|
||
" df.loc[fam, \"IS ↑\"] = f\"{em['IS_mean']:.2f}±{em['IS_std']:.2f}\" if \"IS_mean\" in em else \"—\"\n",
|
||
" df.loc[fam, \"LPIPS ↑\"] = f\"{em['LPIPS']:.4f}\" if \"LPIPS\" in em else \"—\"\n",
|
||
"\n",
|
||
"df.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df})"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000015"
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 7. Latent Interpolation — GAN and VAE\n",
|
||
"\n",
|
||
"Smooth interpolation between two latent codes reveals whether the generator has learned a\n",
|
||
"continuous manifold. DDPM has no encoder, so interpolation is done by different noise seeds."
|
||
],
|
||
"id": "d0000016"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Spherical linear interpolation (slerp)\n",
|
||
"def slerp(z1, z2, t):\n",
|
||
" z1_n = z1 / z1.norm()\n",
|
||
" z2_n = z2 / z2.norm()\n",
|
||
" omega = torch.acos((z1_n * z2_n).sum().clamp(-1, 1))\n",
|
||
" if omega.abs() < 1e-6:\n",
|
||
" return (1 - t) * z1 + t * z2\n",
|
||
" return (torch.sin((1-t)*omega)/torch.sin(omega)) * z1 + \\\n",
|
||
" (torch.sin(t*omega)/torch.sin(omega)) * z2\n",
|
||
"\n",
|
||
"\n",
|
||
"def gan_interpolation(model, latent_dim, n_steps=10, device=DEVICE):\n",
|
||
" z1 = torch.randn(1, latent_dim, 1, 1, device=device)\n",
|
||
" z2 = torch.randn(1, latent_dim, 1, 1, device=device)\n",
|
||
" alphas = torch.linspace(0, 1, n_steps)\n",
|
||
" imgs = []\n",
|
||
" with torch.no_grad():\n",
|
||
" for a in alphas:\n",
|
||
" z = slerp(z1.flatten(), z2.flatten(), a.item()).view_as(z1)\n",
|
||
" imgs.append(model(z).cpu())\n",
|
||
" return torch.cat(imgs)\n",
|
||
"\n",
|
||
"\n",
|
||
"def vae_interpolation(model, real_imgs, n_steps=10, device=DEVICE):\n",
|
||
" \"\"\"Encode two real images, interpolate in latent space, decode.\"\"\"\n",
|
||
" img1, img2 = real_imgs[:1].to(device), real_imgs[1:2].to(device)\n",
|
||
" with torch.no_grad():\n",
|
||
" mu1, _ = model.encode(img1)\n",
|
||
" mu2, _ = model.encode(img2)\n",
|
||
" alphas = torch.linspace(0, 1, n_steps, device=device)\n",
|
||
" imgs = [model.decode((1-a)*mu1 + a*mu2).cpu() for a in alphas]\n",
|
||
" return torch.cat(imgs)\n",
|
||
"\n",
|
||
"\n",
|
||
"print(\"Interpolation helpers defined. Run cells below after loading models.\")"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000017"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"# ── GAN interpolation ─────────────────────────────────────────────────────────\n",
|
||
"try:\n",
|
||
" gan_model, gan_cfg, _ = load_ema_model(\"p5_gan\", \"p5_gan.json\")\n",
|
||
" latent_dim = gan_cfg.get(\"latent_dim\", 128)\n",
|
||
" interp_imgs = gan_interpolation(gan_model, latent_dim, n_steps=10)\n",
|
||
" interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n",
|
||
"\n",
|
||
" fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n",
|
||
" for ax, img in zip(axes, interp_imgs):\n",
|
||
" ax.imshow(img.permute(1, 2, 0).numpy())\n",
|
||
" ax.axis(\"off\")\n",
|
||
" fig.suptitle(\"GAN — Slerp latent interpolation (z₁ → z₂)\", fontsize=11, fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()\n",
|
||
"except Exception as e:\n",
|
||
" print(f\"GAN interpolation: {e}\")"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000018"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"# ── VAE interpolation ─────────────────────────────────────────────────────────\n",
|
||
"try:\n",
|
||
" from src.data import GeneratorDataset, get_transform\n",
|
||
"\n",
|
||
" vae_model, vae_cfg, _ = load_ema_model(\"p5_vae\", \"p5_vae.json\")\n",
|
||
" ds = GeneratorDataset(\"../../\" + vae_cfg[\"data_dir\"],\n",
|
||
" sources=vae_cfg.get(\"sources\", [\"wiki\"]),\n",
|
||
" transform=get_transform(vae_cfg[\"image_size\"], augment=False))\n",
|
||
" sample_real = torch.stack([ds[i] for i in range(2)])\n",
|
||
"\n",
|
||
" interp_imgs = vae_interpolation(vae_model, sample_real, n_steps=10)\n",
|
||
" interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n",
|
||
"\n",
|
||
" fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n",
|
||
" for ax, img in zip(axes, interp_imgs):\n",
|
||
" ax.imshow(img.permute(1, 2, 0).numpy())\n",
|
||
" ax.axis(\"off\")\n",
|
||
" fig.suptitle(\"VAE — μ-space linear interpolation (image₁ → image₂)\", fontsize=11, fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()\n",
|
||
"except Exception as e:\n",
|
||
" print(f\"VAE interpolation: {e}\")"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000019"
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 8. Failure Mode Analysis\n",
|
||
"\n",
|
||
"Identify the worst-generated images per model: those farthest from their nearest real neighbour in pixel space, or simply those with highest reconstruction error (VAE) or highest DDPM loss."
|
||
],
|
||
"id": "d0000020"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"# For GAN and DDPM: generate 256 images, pick the 8 with lowest mean activation\n",
|
||
"# (a proxy for less-coherent images — very rough)\n",
|
||
"\n",
|
||
"def worst_samples(imgs, n=8):\n",
|
||
" \"\"\"Heuristic: pick samples with lowest per-pixel mean (often darker / less structured).\"\"\"\n",
|
||
" scores = imgs.mean(dim=[1, 2, 3]) # mean brightness per image\n",
|
||
" worst_idx = scores.argsort()[:n]\n",
|
||
" return imgs[worst_idx]\n",
|
||
"\n",
|
||
"print(\"Failure mode analysis requires generated samples — run after model loading (Section 6).\")"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000021"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Optionally run with already-generated `imgs` from Section 6:\n",
|
||
"# worst = worst_samples(imgs.cpu())\n",
|
||
"# worst = (worst.clamp(-1,1) + 1) / 2\n",
|
||
"# fig, axes = plt.subplots(1, 8, figsize=(18, 2.5))\n",
|
||
"# for ax, img in zip(axes, worst):\n",
|
||
"# ax.imshow(img.permute(1,2,0)); ax.axis('off')\n",
|
||
"# plt.suptitle('Failure modes (lowest-brightness samples)', fontsize=11)\n",
|
||
"# plt.tight_layout(); plt.show()"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000022"
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 9. Training Loss Overview (all families)"
|
||
],
|
||
"id": "d0000023"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"fig, axes = plt.subplots(1, 3, figsize=(18, 4))\n",
|
||
"\n",
|
||
"for ax, (fam, info) in zip(axes, FAMILIES.items()):\n",
|
||
" log = logs_p5[fam]\n",
|
||
" if log is None:\n",
|
||
" ax.set_title(f\"{fam} (not yet run)\"); ax.axis(\"off\"); continue\n",
|
||
" h = log[\"history\"]\n",
|
||
" c = info[\"color\"]\n",
|
||
"\n",
|
||
" if fam == \"GAN\":\n",
|
||
" ax.plot(h[\"g_loss\"], label=\"G loss\", color=c, linewidth=1.2)\n",
|
||
" ax.plot(h[\"w_dist\"], label=\"W-dist\", color=c, linewidth=1.2, linestyle=\"--\")\n",
|
||
" ax.set_ylabel(\"Loss / W-distance\")\n",
|
||
" elif fam == \"VAE\":\n",
|
||
" ax.plot(h[\"recon_loss\"], label=\"MSE\", color=c)\n",
|
||
" ax2 = ax.twinx()\n",
|
||
" ax2.plot(h[\"kl_loss\"], label=\"KL\", color=\"grey\", linestyle=\"--\")\n",
|
||
" ax2.set_ylabel(\"KL\", color=\"grey\")\n",
|
||
" ax.set_ylabel(\"MSE\")\n",
|
||
" elif fam == \"DDPM\":\n",
|
||
" ax.plot(h[\"loss\"], label=\"MSE (v-pred)\", color=c)\n",
|
||
" ax.set_ylabel(\"MSE loss\")\n",
|
||
"\n",
|
||
" ax.set_xlabel(\"Epoch\")\n",
|
||
" ax.set_title(f\"{fam}: {info['label']}\")\n",
|
||
" ax.legend(fontsize=8)\n",
|
||
"\n",
|
||
"fig.suptitle(\"Phase 5 — Training Dynamics (200 epochs)\", fontsize=12, fontweight=\"bold\")\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000024"
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 10. Conclusions"
|
||
],
|
||
"id": "d0000025"
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"metadata": {},
|
||
"source": [
|
||
"print(\"=\" * 72)\n",
|
||
"print(\"PHASE 5 — CROSS-FAMILY COMPARISON (200 epochs)\")\n",
|
||
"print(\"=\" * 72)\n",
|
||
"\n",
|
||
"for fam, info in FAMILIES.items():\n",
|
||
" log = logs_p5[fam]\n",
|
||
" em = extended_metrics.get(fam, {})\n",
|
||
" print(f\"\\n ── {fam}: {info['label']} ──\")\n",
|
||
" if log:\n",
|
||
" res = log.get(\"config\", log).get(\"image_size\", \"?\")\n",
|
||
" print(f\" Resolution : {res}×{res}\")\n",
|
||
" n_p = log.get(\"n_params\")\n",
|
||
" print(f\" Params : {n_p:,}\" if n_p else \" Params : ?\")\n",
|
||
" tt = log.get(\"history\", {}).get(\"train_time_s\")\n",
|
||
" print(f\" Train time : {tt / 60:.1f} min\" if tt else \" Train time : ?\")\n",
|
||
" for ep in (100, 150, 200):\n",
|
||
" fid = get_fid(log, ep)\n",
|
||
" print(f\" FID@{ep:<3} : {fid:.1f}\" if fid else f\" FID@{ep:<3} : ?\")\n",
|
||
" if em:\n",
|
||
" print(f\" IS : {em.get('IS_mean','?'):.2f} ± {em.get('IS_std','?'):.2f}\" if 'IS_mean' in em else \" IS : ?\")\n",
|
||
" print(f\" LPIPS div : {em.get('LPIPS','?'):.4f}\" if 'LPIPS' in em else \" LPIPS div : ?\")\n",
|
||
" else:\n",
|
||
" print(\" IS / LPIPS : (run Section 6 to compute)\")\n",
|
||
"\n",
|
||
"print(\"\\n\" + \"=\" * 72)\n",
|
||
"print(\"Narrative to fill in after results:\")\n",
|
||
"print(\" - Which family achieves best FID?\")\n",
|
||
"print(\" - GAN: fast convergence but mode collapse risk?\")\n",
|
||
"print(\" - VAE: blurry priors improved by perceptual+adversarial loss?\")\n",
|
||
"print(\" - DDPM: highest quality but slowest inference (100 DDIM steps)?\")\n",
|
||
"print(\"=\" * 72)"
|
||
],
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"id": "d0000026"
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"name": "python",
|
||
"version": "3.10.0"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
} |