Files
DRL_PROJ/generator/notebooks/phase5_analysis.ipynb
T
2026-04-30 13:10:33 +01:00

669 lines
28 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Phase 5 — Cross-Family Comparison\n",
"\n",
"Best-of-each finalist retrained for **200 epochs** under identical data conditions.\n",
"\n",
"| Family | Config | Resolution | Key design |\n",
"|--------|--------|-----------|------------|\n",
"| GAN | `p5_gan` | 128×128 | WGAN-GP + SpectralNorm + GroupNorm + Self-Attention |\n",
"| VAE | `p5_vae` | 64×64 | Convolutional VAE + VGG perceptual + PatchGAN |\n",
"| DDPM | `p5_ddpm` | 64×64 | Wider U-Net (192ch) + cosine schedule + v-prediction |\n",
"\n",
"> **Resolution note**: GAN runs at 128×128 (best architecture from Phase 2.4), while VAE and DDPM run at 64×64. FID is measured at each model's native resolution against real images at that resolution, so scores are not directly numerically comparable across families — they are indicators of within-family improvement."
],
"id": "d0000001"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import json\n",
"import sys\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\"\n",
"\n",
"# Phase 5 finalists and their best phase-2/3/4 counterparts (100-epoch comparison)\n",
"FAMILIES = {\n",
" \"GAN\": {\"p5\": \"p5_gan\", \"p4ep\": \"p2_4_wgan_sn_attn_128\", \"label\": \"WGAN-GP+SN+Attn\", \"color\": \"#5B8DB8\"},\n",
" \"VAE\": {\"p5\": \"p5_vae\", \"p4ep\": \"p3_3_vae_patchgan\", \"label\": \"VAE+Perc+PatchGAN\",\"color\": \"#E8705A\"},\n",
" \"DDPM\": {\"p5\": \"p5_ddpm\", \"p4ep\": \"p4_4_ddpm_wider\", \"label\": \"DDPM wider 192ch\", \"color\": \"#6ABF69\"},\n",
"}"
],
"execution_count": null,
"outputs": [],
"id": "d0000002"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Load logs"
],
"id": "d0000003"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"def load_log(run_name):\n",
" p = LOGS / f\"{run_name}.json\"\n",
" if p.exists():\n",
" with open(p) as f:\n",
" return json.load(f)\n",
" return None\n",
"\n",
"def get_fid(log, epoch):\n",
" if log is None:\n",
" return None\n",
" fid = log[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"logs_p5 = {fam: load_log(info[\"p5\"]) for fam, info in FAMILIES.items()}\n",
"logs_p4 = {fam: load_log(info[\"p4ep\"]) for fam, info in FAMILIES.items()}\n",
"\n",
"for fam in FAMILIES:\n",
" p5_ok = \"✓\" if logs_p5[fam] else \"✗\"\n",
" p4_ok = \"✓\" if logs_p4[fam] else \"✗\"\n",
" print(f\" {fam}: 200ep={p5_ok} 100ep={p4_ok}\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000004"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Quantitative Summary Table"
],
"id": "d0000005"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"rows = []\n",
"for fam, info in FAMILIES.items():\n",
" log = logs_p5[fam]\n",
" train_time = log.get(\"history\", {}).get(\"train_time_s\") if log else None\n",
" train_min = f\"{train_time / 60:.1f}\" if train_time else \"?\"\n",
" rows.append({\n",
" \"Family\": fam,\n",
" \"Model\": info[\"label\"],\n",
" \"Res\": log.get(\"config\", log).get(\"image_size\", \"?\") if log else \"?\",\n",
" \"Params\": log.get(\"n_params\") if log else None,\n",
" \"Train (min)\": train_min,\n",
" \"FID@100\": get_fid(log, 100),\n",
" \"FID@150\": get_fid(log, 150),\n",
" \"FID@200\": get_fid(log, 200),\n",
" # IS and LPIPS filled in by Section 6\n",
" \"IS ↑\": None,\n",
" \"LPIPS ↑\": None,\n",
" })\n",
"\n",
"df = pd.DataFrame(rows).set_index(\"Family\")\n",
"\n",
"\n",
"def fmt_params(v):\n",
" if v is None:\n",
" return \"?\"\n",
" if v >= 1_000_000:\n",
" return f\"{v / 1_000_000:.1f}M\"\n",
" if v >= 1_000:\n",
" return f\"{v / 1_000:.0f}K\"\n",
" return str(v)\n",
"\n",
"\n",
"df_display = df.copy()\n",
"df_display[\"Params\"] = df_display[\"Params\"].apply(fmt_params)\n",
"df_display.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df_display})"
],
"execution_count": null,
"outputs": [],
"id": "d0000006"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. FID Curves — All Three Families"
],
"id": "d0000007"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" c = info[\"color\"]\n",
" # Phase 5 (200ep) — solid\n",
" log = logs_p5[fam]\n",
" if log:\n",
" fid_dict = log[\"history\"][\"fid\"]\n",
" eps = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in eps]\n",
" axes[0].plot(eps, fids, \"-o\", color=c, linewidth=2, markersize=6,\n",
" label=f\"{fam} 200ep (FID@200={fid_dict.get('200','?'):.1f})\" if isinstance(fid_dict.get('200'), float) else f\"{fam} 200ep\")\n",
"\n",
" # Phase 4/3/2 best (100ep) — dashed, same colour\n",
" log4 = logs_p4[fam]\n",
" if log4:\n",
" fid4 = log4[\"history\"][\"fid\"]\n",
" eps4 = sorted(int(k) for k in fid4)\n",
" fids4 = [fid4[str(e)] for e in eps4]\n",
" axes[0].plot(eps4, fids4, \"--\", color=c, linewidth=1.2, alpha=0.55,\n",
" label=f\"{fam} 100ep\")\n",
"\n",
"axes[0].set_xlabel(\"Epoch\")\n",
"axes[0].set_ylabel(\"FID (lower is better)\")\n",
"axes[0].set_title(\"FID Curves — Phase 5 (solid) vs best 100-ep (dashed)\")\n",
"axes[0].legend(fontsize=8)\n",
"\n",
"# Bar chart: FID at 100 vs 200 epochs per family\n",
"fams = list(FAMILIES.keys())\n",
"fid100 = [get_fid(logs_p5[f], 100) for f in fams]\n",
"fid200 = [get_fid(logs_p5[f], 200) for f in fams]\n",
"x = np.arange(len(fams))\n",
"w = 0.35\n",
"bars1 = axes[1].bar(x - w/2, [v or 0 for v in fid100], w, label=\"FID@100\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=0.6)\n",
"bars2 = axes[1].bar(x + w/2, [v or 0 for v in fid200], w, label=\"FID@200\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=1.0)\n",
"axes[1].set_xticks(x); axes[1].set_xticklabels(fams)\n",
"axes[1].set_ylabel(\"FID\"); axes[1].set_title(\"FID: 100ep vs 200ep per family\")\n",
"axes[1].legend()\n",
"for bar in list(bars1) + list(bars2):\n",
" h = bar.get_height()\n",
" if h > 0:\n",
" axes[1].text(bar.get_x() + bar.get_width()/2, h + 1, f\"{h:.0f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000008"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Sample Grids — Epoch 200"
],
"id": "d0000009"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(16, 6))\n",
"\n",
"for idx, (fam, info) in enumerate(FAMILIES.items()):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / info[\"p5\"] / \"epoch_0200.png\"\n",
" log = logs_p5[fam]\n",
" fid200 = get_fid(log, 200)\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" title = f\"{fam}: {info['label']}\\n\"\n",
" if fid200:\n",
" res = log[\"config\"].get(\"image_size\", \"?\")\n",
" title += f\"FID@200={fid200:.1f} ({res}×{res})\"\n",
" ax.set_title(title, fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes, fontsize=12)\n",
" ax.set_title(f\"{fam}: {info['label']}\", fontsize=9)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 5 — Epoch 200 Sample Grids (4×4, prior samples)\", fontsize=13, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000010"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Training Progression — Epoch 10 → 50 → 100 → 200"
],
"id": "d0000011"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"check_epochs = [10, 50, 100, 200]\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" run = info[\"p5\"]\n",
" log = logs_p5[fam]\n",
" if log is None:\n",
" print(f\"{fam}: not yet run\")\n",
" continue\n",
"\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(16, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / run / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(log, ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
" transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{fam} ({info['label']}) — 200-epoch progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000012"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Extended Metrics — IS and LPIPS Diversity\n",
"\n",
"Requires loading trained model weights. Run after all phase 5 models have finished.\n",
"Generates 5 000 samples per model and computes IS and LPIPS over 200 random pairs."
],
"id": "d0000013"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import sys\n",
"sys.path.insert(0, \"..\")\n",
"\n",
"import torch\n",
"from src.utils import load_config\n",
"from src.models import get_model\n",
"from src.training.metrics import compute_is, compute_lpips_diversity\n",
"from src.training.diffusion import cosine_betas, make_alpha_bars, ddim_sample\n",
"from src.training.ema import EMA\n",
"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"N_SAMPLE = 5_000\n",
"\n",
"def load_ema_model(run_name, config_path):\n",
" \"\"\"Load the best EMA weights for a given phase-5 run.\"\"\"\n",
" cfg = load_config(str(Path(\"../configs/phase5\") / config_path))\n",
" model_obj, kind = get_model(cfg)\n",
" ema_path = Path(\"../outputs/models\") / f\"{run_name}_best_ema.pt\"\n",
" if not ema_path.exists():\n",
" ema_path = Path(\"../outputs/models\") / f\"{run_name}_final_ema.pt\"\n",
" if isinstance(model_obj, tuple):\n",
" model = model_obj[0] # generator for GAN\n",
" else:\n",
" model = model_obj\n",
" model.load_state_dict(torch.load(ema_path, map_location=DEVICE))\n",
" return model.to(DEVICE).eval(), cfg, kind\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def generate_samples(run_name, config_path, n=N_SAMPLE):\n",
" model, cfg, kind = load_ema_model(run_name, config_path)\n",
" image_size = cfg.get(\"image_size\", 64)\n",
"\n",
" if kind == \"wgan\":\n",
" latent_dim = cfg.get(\"latent_dim\", 128)\n",
" imgs = torch.cat([\n",
" model(torch.randn(min(64, n - i), latent_dim, 1, 1, device=DEVICE))\n",
" for i in range(0, n, 64)\n",
" ])[:n].cpu()\n",
"\n",
" elif kind == \"vae\":\n",
" imgs = torch.cat([\n",
" model.sample(min(64, n - i), DEVICE)\n",
" for i in range(0, n, 64)\n",
" ])[:n].cpu()\n",
"\n",
" elif kind == \"ddpm\":\n",
" schedule = cfg.get(\"noise_schedule\", \"cosine\")\n",
" pred_type = cfg.get(\"pred_type\", \"v\")\n",
" T = cfg.get(\"T\", 1000)\n",
" from src.training.diffusion import cosine_betas, linear_betas, make_alpha_bars, ddim_sample\n",
" betas = (cosine_betas(T) if schedule == \"cosine\" else linear_betas(T)).to(DEVICE)\n",
" ab = make_alpha_bars(betas)\n",
" imgs = ddim_sample(model, n, image_size, ab, n_steps=100,\n",
" pred_type=pred_type, device=DEVICE, batch_size=32)\n",
" return imgs\n",
"\n",
"\n",
"print(\"Run this cell once all phase-5 models have completed.\")\n",
"print(\"Expected to take ~510 minutes per family on an RTX 3090.\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000014"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# ── Compute IS and LPIPS per family ──────────────────────────────────────────\n",
"# Uncomment when models are ready.\n",
"\n",
"extended_metrics = {}\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" run = info[\"p5\"]\n",
" config = f\"{run}.json\"\n",
" try:\n",
" print(f\"\\n{fam}: generating {N_SAMPLE} samples...\")\n",
" imgs = generate_samples(run, config)\n",
"\n",
" print(f\" Computing IS...\")\n",
" is_mean, is_std = compute_is(imgs, device=DEVICE)\n",
"\n",
" print(f\" Computing LPIPS diversity...\")\n",
" lpips = compute_lpips_diversity(imgs, n_pairs=200, device=DEVICE)\n",
"\n",
" extended_metrics[fam] = {\"IS_mean\": is_mean, \"IS_std\": is_std, \"LPIPS\": lpips}\n",
" print(f\" IS = {is_mean:.2f} ± {is_std:.2f} LPIPS = {lpips:.4f}\")\n",
"\n",
" except Exception as e:\n",
" print(f\" Skipped ({e})\")\n",
" extended_metrics[fam] = {}\n",
"\n",
"# Merge with FID table\n",
"for fam in FAMILIES:\n",
" em = extended_metrics.get(fam, {})\n",
" idx = list(FAMILIES.keys()).index(fam)\n",
" df.loc[fam, \"IS ↑\"] = f\"{em['IS_mean']:.2f}±{em['IS_std']:.2f}\" if \"IS_mean\" in em else \"—\"\n",
" df.loc[fam, \"LPIPS ↑\"] = f\"{em['LPIPS']:.4f}\" if \"LPIPS\" in em else \"—\"\n",
"\n",
"df.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df})"
],
"execution_count": null,
"outputs": [],
"id": "d0000015"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Latent Interpolation — GAN and VAE\n",
"\n",
"Smooth interpolation between two latent codes reveals whether the generator has learned a\n",
"continuous manifold. DDPM has no encoder, so interpolation is done by different noise seeds."
],
"id": "d0000016"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Spherical linear interpolation (slerp)\n",
"def slerp(z1, z2, t):\n",
" z1_n = z1 / z1.norm()\n",
" z2_n = z2 / z2.norm()\n",
" omega = torch.acos((z1_n * z2_n).sum().clamp(-1, 1))\n",
" if omega.abs() < 1e-6:\n",
" return (1 - t) * z1 + t * z2\n",
" return (torch.sin((1-t)*omega)/torch.sin(omega)) * z1 + \\\n",
" (torch.sin(t*omega)/torch.sin(omega)) * z2\n",
"\n",
"\n",
"def gan_interpolation(model, latent_dim, n_steps=10, device=DEVICE):\n",
" z1 = torch.randn(1, latent_dim, 1, 1, device=device)\n",
" z2 = torch.randn(1, latent_dim, 1, 1, device=device)\n",
" alphas = torch.linspace(0, 1, n_steps)\n",
" imgs = []\n",
" with torch.no_grad():\n",
" for a in alphas:\n",
" z = slerp(z1.flatten(), z2.flatten(), a.item()).view_as(z1)\n",
" imgs.append(model(z).cpu())\n",
" return torch.cat(imgs)\n",
"\n",
"\n",
"def vae_interpolation(model, real_imgs, n_steps=10, device=DEVICE):\n",
" \"\"\"Encode two real images, interpolate in latent space, decode.\"\"\"\n",
" img1, img2 = real_imgs[:1].to(device), real_imgs[1:2].to(device)\n",
" with torch.no_grad():\n",
" mu1, _ = model.encode(img1)\n",
" mu2, _ = model.encode(img2)\n",
" alphas = torch.linspace(0, 1, n_steps, device=device)\n",
" imgs = [model.decode((1-a)*mu1 + a*mu2).cpu() for a in alphas]\n",
" return torch.cat(imgs)\n",
"\n",
"\n",
"print(\"Interpolation helpers defined. Run cells below after loading models.\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000017"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# ── GAN interpolation ─────────────────────────────────────────────────────────\n",
"try:\n",
" gan_model, gan_cfg, _ = load_ema_model(\"p5_gan\", \"p5_gan.json\")\n",
" latent_dim = gan_cfg.get(\"latent_dim\", 128)\n",
" interp_imgs = gan_interpolation(gan_model, latent_dim, n_steps=10)\n",
" interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n",
"\n",
" fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n",
" for ax, img in zip(axes, interp_imgs):\n",
" ax.imshow(img.permute(1, 2, 0).numpy())\n",
" ax.axis(\"off\")\n",
" fig.suptitle(\"GAN — Slerp latent interpolation (z₁ → z₂)\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
"except Exception as e:\n",
" print(f\"GAN interpolation: {e}\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000018"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# ── VAE interpolation ─────────────────────────────────────────────────────────\n",
"try:\n",
" from src.data import GeneratorDataset, get_transform\n",
"\n",
" vae_model, vae_cfg, _ = load_ema_model(\"p5_vae\", \"p5_vae.json\")\n",
" ds = GeneratorDataset(\"../../\" + vae_cfg[\"data_dir\"],\n",
" sources=vae_cfg.get(\"sources\", [\"wiki\"]),\n",
" transform=get_transform(vae_cfg[\"image_size\"], augment=False))\n",
" sample_real = torch.stack([ds[i] for i in range(2)])\n",
"\n",
" interp_imgs = vae_interpolation(vae_model, sample_real, n_steps=10)\n",
" interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n",
"\n",
" fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n",
" for ax, img in zip(axes, interp_imgs):\n",
" ax.imshow(img.permute(1, 2, 0).numpy())\n",
" ax.axis(\"off\")\n",
" fig.suptitle(\"VAE — μ-space linear interpolation (image₁ → image₂)\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
"except Exception as e:\n",
" print(f\"VAE interpolation: {e}\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000019"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Failure Mode Analysis\n",
"\n",
"Identify the worst-generated images per model: those farthest from their nearest real neighbour in pixel space, or simply those with highest reconstruction error (VAE) or highest DDPM loss."
],
"id": "d0000020"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# For GAN and DDPM: generate 256 images, pick the 8 with lowest mean activation\n",
"# (a proxy for less-coherent images — very rough)\n",
"\n",
"def worst_samples(imgs, n=8):\n",
" \"\"\"Heuristic: pick samples with lowest per-pixel mean (often darker / less structured).\"\"\"\n",
" scores = imgs.mean(dim=[1, 2, 3]) # mean brightness per image\n",
" worst_idx = scores.argsort()[:n]\n",
" return imgs[worst_idx]\n",
"\n",
"print(\"Failure mode analysis requires generated samples — run after model loading (Section 6).\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000021"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Optionally run with already-generated `imgs` from Section 6:\n",
"# worst = worst_samples(imgs.cpu())\n",
"# worst = (worst.clamp(-1,1) + 1) / 2\n",
"# fig, axes = plt.subplots(1, 8, figsize=(18, 2.5))\n",
"# for ax, img in zip(axes, worst):\n",
"# ax.imshow(img.permute(1,2,0)); ax.axis('off')\n",
"# plt.suptitle('Failure modes (lowest-brightness samples)', fontsize=11)\n",
"# plt.tight_layout(); plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000022"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Training Loss Overview (all families)"
],
"id": "d0000023"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(18, 4))\n",
"\n",
"for ax, (fam, info) in zip(axes, FAMILIES.items()):\n",
" log = logs_p5[fam]\n",
" if log is None:\n",
" ax.set_title(f\"{fam} (not yet run)\"); ax.axis(\"off\"); continue\n",
" h = log[\"history\"]\n",
" c = info[\"color\"]\n",
"\n",
" if fam == \"GAN\":\n",
" ax.plot(h[\"g_loss\"], label=\"G loss\", color=c, linewidth=1.2)\n",
" ax.plot(h[\"w_dist\"], label=\"W-dist\", color=c, linewidth=1.2, linestyle=\"--\")\n",
" ax.set_ylabel(\"Loss / W-distance\")\n",
" elif fam == \"VAE\":\n",
" ax.plot(h[\"recon_loss\"], label=\"MSE\", color=c)\n",
" ax2 = ax.twinx()\n",
" ax2.plot(h[\"kl_loss\"], label=\"KL\", color=\"grey\", linestyle=\"--\")\n",
" ax2.set_ylabel(\"KL\", color=\"grey\")\n",
" ax.set_ylabel(\"MSE\")\n",
" elif fam == \"DDPM\":\n",
" ax.plot(h[\"loss\"], label=\"MSE (v-pred)\", color=c)\n",
" ax.set_ylabel(\"MSE loss\")\n",
"\n",
" ax.set_xlabel(\"Epoch\")\n",
" ax.set_title(f\"{fam}: {info['label']}\")\n",
" ax.legend(fontsize=8)\n",
"\n",
"fig.suptitle(\"Phase 5 — Training Dynamics (200 epochs)\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000024"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10. Conclusions"
],
"id": "d0000025"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"print(\"=\" * 72)\n",
"print(\"PHASE 5 — CROSS-FAMILY COMPARISON (200 epochs)\")\n",
"print(\"=\" * 72)\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" log = logs_p5[fam]\n",
" em = extended_metrics.get(fam, {})\n",
" print(f\"\\n ── {fam}: {info['label']} ──\")\n",
" if log:\n",
" res = log.get(\"config\", log).get(\"image_size\", \"?\")\n",
" print(f\" Resolution : {res}×{res}\")\n",
" n_p = log.get(\"n_params\")\n",
" print(f\" Params : {n_p:,}\" if n_p else \" Params : ?\")\n",
" tt = log.get(\"history\", {}).get(\"train_time_s\")\n",
" print(f\" Train time : {tt / 60:.1f} min\" if tt else \" Train time : ?\")\n",
" for ep in (100, 150, 200):\n",
" fid = get_fid(log, ep)\n",
" print(f\" FID@{ep:<3} : {fid:.1f}\" if fid else f\" FID@{ep:<3} : ?\")\n",
" if em:\n",
" print(f\" IS : {em.get('IS_mean','?'):.2f} ± {em.get('IS_std','?'):.2f}\" if 'IS_mean' in em else \" IS : ?\")\n",
" print(f\" LPIPS div : {em.get('LPIPS','?'):.4f}\" if 'LPIPS' in em else \" LPIPS div : ?\")\n",
" else:\n",
" print(\" IS / LPIPS : (run Section 6 to compute)\")\n",
"\n",
"print(\"\\n\" + \"=\" * 72)\n",
"print(\"Narrative to fill in after results:\")\n",
"print(\" - Which family achieves best FID?\")\n",
"print(\" - GAN: fast convergence but mode collapse risk?\")\n",
"print(\" - VAE: blurry priors improved by perceptual+adversarial loss?\")\n",
"print(\" - DDPM: highest quality but slowest inference (100 DDIM steps)?\")\n",
"print(\"=\" * 72)"
],
"execution_count": null,
"outputs": [],
"id": "d0000026"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}