397 lines
13 KiB
Plaintext
397 lines
13 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b0000001",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Phase 3 — VAE Evolution Analysis\n",
|
||
"\n",
|
||
"Traces the VAE improvement story — each step motivated by the failure of the previous:\n",
|
||
"\n",
|
||
"| Step | Model | Key change | Expected failure |\n",
|
||
"|------|-------|------------|------------------|\n",
|
||
"| 3.1 | Vanilla VAE (MSE+KL) | Baseline | Blurry samples — MSE minimises pixel average |\n",
|
||
"| 3.2 | + Perceptual loss (VGG) | Feature-space reconstruction | Residual texture blur |\n",
|
||
"| 3.3 | + PatchGAN (VQGAN-lite) | Local texture adversarial | — |\n",
|
||
"\n",
|
||
"All runs use H-flip-only augmentation and MTCNN-aligned 64×64 crops.\n",
|
||
"FID is computed from prior samples (`z ~ N(0, I) → decode`), same metric as GAN.\n",
|
||
"Reconstructions are shown separately to diagnose encoder quality."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b0000002",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import json\n",
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import matplotlib.image as mpimg\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
|
||
"\n",
|
||
"OUTPUTS = Path(\"../outputs\")\n",
|
||
"LOGS = OUTPUTS / \"logs\"\n",
|
||
"SAMPLES = OUTPUTS / \"samples\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b0000003",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1. Load experiment logs"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b0000004",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"run_names = [\"p3_1_vae\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"]\n",
|
||
"run_labels = {\n",
|
||
" \"p3_1_vae\": \"3.1 VAE (MSE+KL)\",\n",
|
||
" \"p3_2_vae_perceptual\": \"3.2 +Perceptual\",\n",
|
||
" \"p3_3_vae_patchgan\": \"3.3 +PatchGAN\",\n",
|
||
"}\n",
|
||
"\n",
|
||
"runs = {}\n",
|
||
"for name in run_names:\n",
|
||
" log_path = LOGS / f\"{name}.json\"\n",
|
||
" if log_path.exists():\n",
|
||
" with open(log_path) as f:\n",
|
||
" runs[name] = json.load(f)\n",
|
||
" else:\n",
|
||
" print(f\" Missing: {log_path}\")\n",
|
||
"\n",
|
||
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
|
||
"for name in run_names:\n",
|
||
" status = \"✓\" if name in runs else \"✗\"\n",
|
||
" print(f\" {status} {name}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b0000005",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2. FID Comparison Table (prior samples)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b0000006",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get_fid(run, epoch):\n",
|
||
" fid = run[\"history\"][\"fid\"]\n",
|
||
" return fid.get(str(epoch), fid.get(epoch, None))\n",
|
||
"\n",
|
||
"rows = []\n",
|
||
"for name in run_names:\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" r = runs[name]\n",
|
||
" cfg = r[\"config\"]\n",
|
||
" h = r[\"history\"]\n",
|
||
" rows.append({\n",
|
||
" \"Step\": run_labels[name],\n",
|
||
" \"λ_perc\": cfg.get(\"lambda_perceptual\", 0),\n",
|
||
" \"λ_adv\": cfg.get(\"lambda_adversarial\", 0),\n",
|
||
" \"β_kl\": cfg.get(\"beta_kl\", 1.0),\n",
|
||
" \"FID@25 (prior)\": get_fid(r, 25),\n",
|
||
" \"FID@50 (prior)\": get_fid(r, 50),\n",
|
||
" \"FID@75 (prior)\": get_fid(r, 75),\n",
|
||
" \"FID@100 (prior)\": get_fid(r, 100),\n",
|
||
" })\n",
|
||
"\n",
|
||
"df = pd.DataFrame(rows)\n",
|
||
"df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b0000007",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3. FID Curves — Evolution Story"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b0000008",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fig, ax = plt.subplots(figsize=(10, 5))\n",
|
||
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\"]\n",
|
||
"\n",
|
||
"for i, name in enumerate(run_names):\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
|
||
" epochs = sorted(int(k) for k in fid_dict)\n",
|
||
" fids = [fid_dict[str(e)] for e in epochs]\n",
|
||
" label = f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\"\n",
|
||
" ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n",
|
||
"\n",
|
||
"ax.set_xlabel(\"Epoch\")\n",
|
||
"ax.set_ylabel(\"FID (lower is better) — prior samples\")\n",
|
||
"ax.set_title(\"Phase 3 — FID Curves: VAE → +Perceptual → +PatchGAN\")\n",
|
||
"ax.legend()\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b0000009",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4. Training Loss Curves"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b0000010",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fig, axes = plt.subplots(2, 3, figsize=(16, 9))\n",
|
||
"axes = axes.flatten()\n",
|
||
"keys = [\"recon_loss\", \"kl_loss\", \"perc_loss\", \"adv_g_loss\", \"adv_d_loss\"]\n",
|
||
"titles = [\"MSE Reconstruction\", \"KL Divergence\", \"Perceptual (VGG)\", \"Adv G Loss\", \"Adv D Loss (PatchGAN)\"]\n",
|
||
"\n",
|
||
"for ax, key, title in zip(axes, keys, titles):\n",
|
||
" for i, name in enumerate(run_names):\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" h = runs[name][\"history\"].get(key, [])\n",
|
||
" if any(v != 0.0 for v in h):\n",
|
||
" ax.plot(range(1, len(h)+1), h, label=run_labels[name],\n",
|
||
" color=colors[i], linewidth=1.2, alpha=0.9)\n",
|
||
" ax.set_title(title)\n",
|
||
" ax.set_xlabel(\"Epoch\")\n",
|
||
" ax.legend(fontsize=8)\n",
|
||
"\n",
|
||
"axes[-1].axis(\"off\") # empty sixth panel\n",
|
||
"fig.suptitle(\"Phase 3 — Training Dynamics\", fontsize=13, fontweight=\"bold\")\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b0000011",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 5. Prior Samples — Epoch 100"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b0000012",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
|
||
"\n",
|
||
"for idx, name in enumerate(run_names):\n",
|
||
" ax = axes[idx]\n",
|
||
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
|
||
" if img_path.exists():\n",
|
||
" fid = get_fid(runs[name], 100) if name in runs else None\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||
" ax.set_title(run_labels[name], fontsize=9)\n",
|
||
" ax.axis(\"off\")\n",
|
||
"\n",
|
||
"fig.suptitle(\"Phase 3 — Epoch 100 Prior Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b0000013",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 6. Reconstructions — Epoch 100\n",
|
||
"\n",
|
||
"Left half = real images, right half = reconstructions (interleaved pairs)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b0000014",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
|
||
"\n",
|
||
"for idx, name in enumerate(run_names):\n",
|
||
" ax = axes[idx]\n",
|
||
" img_path = SAMPLES / name / \"epoch_0100_recon.png\"\n",
|
||
" if img_path.exists():\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" ax.set_title(f\"{run_labels[name]}\\nreal | recon\", fontsize=9)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||
" ax.set_title(run_labels[name], fontsize=9)\n",
|
||
" ax.axis(\"off\")\n",
|
||
"\n",
|
||
"fig.suptitle(\"Phase 3 — Epoch 100 Reconstructions\", fontsize=12, fontweight=\"bold\")\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b0000015",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 7. Step-by-step Pairwise Comparisons"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b0000016",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"transitions = [\n",
|
||
" (\"3.1→3.2: MSE→+Perceptual\", \"p3_1_vae\", \"p3_2_vae_perceptual\"),\n",
|
||
" (\"3.2→3.3: +PatchGAN adversarial\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"),\n",
|
||
"]\n",
|
||
"\n",
|
||
"for title, name_a, name_b in transitions:\n",
|
||
" fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
|
||
" for col, name in enumerate([name_a, name_b]):\n",
|
||
" for row, suffix in enumerate([\"\", \"_recon\"]):\n",
|
||
" ax = axes[row][col]\n",
|
||
" ep = 100\n",
|
||
" img_path = SAMPLES / name / f\"epoch_{ep:04d}{suffix}.png\"\n",
|
||
" label = run_labels[name]\n",
|
||
" kind = \"prior\" if suffix == \"\" else \"recon\"\n",
|
||
" if img_path.exists():\n",
|
||
" fid = get_fid(runs[name], ep) if (suffix == \"\" and name in runs) else None\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" ax.set_title(f\"{label}\\n{kind}\" + (f\" FID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
|
||
" ax.set_title(f\"{label} ({kind})\", fontsize=9)\n",
|
||
" ax.axis(\"off\")\n",
|
||
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b0000017",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 8. Progression: Epoch 10 → 50 → 100 (prior samples)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b0000018",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"check_epochs = [10, 50, 100]\n",
|
||
"\n",
|
||
"for name in run_names:\n",
|
||
" if name not in runs:\n",
|
||
" continue\n",
|
||
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
|
||
" for ax, ep in zip(axes, check_epochs):\n",
|
||
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
|
||
" if img_path.exists():\n",
|
||
" ax.imshow(mpimg.imread(str(img_path)))\n",
|
||
" fid = get_fid(runs[name], ep)\n",
|
||
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
|
||
" else:\n",
|
||
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
|
||
" transform=ax.transAxes)\n",
|
||
" ax.axis(\"off\")\n",
|
||
" fig.suptitle(f\"{run_labels[name]} — Prior Sample Progression\", fontsize=11, fontweight=\"bold\")\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b0000019",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 9. Conclusions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b0000020",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print(\"=\" * 70)\n",
|
||
"print(\"PHASE 3 — VAE EVOLUTION SUMMARY\")\n",
|
||
"print(\"=\" * 70)\n",
|
||
"\n",
|
||
"for name in run_names:\n",
|
||
" if name not in runs:\n",
|
||
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
|
||
" continue\n",
|
||
" h = runs[name][\"history\"]\n",
|
||
" fid100 = get_fid(runs[name], 100)\n",
|
||
" fid50 = get_fid(runs[name], 50)\n",
|
||
" mse50 = h[\"recon_loss\"][49] if len(h[\"recon_loss\"]) > 49 else None\n",
|
||
" kl50 = h[\"kl_loss\"][49] if len(h[\"kl_loss\"]) > 49 else None\n",
|
||
" print(f\"\\n {run_labels[name]}:\")\n",
|
||
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
|
||
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
|
||
" print(f\" MSE@50 = {mse50:.4f}\" if mse50 else \" MSE@50 = ?\")\n",
|
||
" print(f\" KL@50 = {kl50:.2f}\" if kl50 else \" KL@50 = ?\")\n",
|
||
"\n",
|
||
"print(\"\\n\" + \"=\" * 70)\n",
|
||
"print(\"Best VAE model for Phase 5 cross-family comparison: fill in after runs.\")\n",
|
||
"print(\"=\" * 70)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"name": "python",
|
||
"version": "3.10.0"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|