{ "cells": [ { "cell_type": "markdown", "id": "b0000001", "metadata": {}, "source": [ "# Phase 3 — VAE Evolution Analysis\n", "\n", "Traces the VAE improvement story — each step motivated by the failure of the previous:\n", "\n", "| Step | Model | Key change | Expected failure |\n", "|------|-------|------------|------------------|\n", "| 3.1 | Vanilla VAE (MSE+KL) | Baseline | Blurry samples — MSE minimises pixel average |\n", "| 3.2 | + Perceptual loss (VGG) | Feature-space reconstruction | Residual texture blur |\n", "| 3.3 | + PatchGAN (VQGAN-lite) | Local texture adversarial | — |\n", "\n", "All runs use H-flip-only augmentation and MTCNN-aligned 64×64 crops.\n", "FID is computed from prior samples (`z ~ N(0, I) → decode`), same metric as GAN.\n", "Reconstructions are shown separately to diagnose encoder quality." ] }, { "cell_type": "code", "execution_count": null, "id": "b0000002", "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import numpy as np\n", "import pandas as pd\n", "\n", "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", "\n", "OUTPUTS = Path(\"../outputs\")\n", "LOGS = OUTPUTS / \"logs\"\n", "SAMPLES = OUTPUTS / \"samples\"" ] }, { "cell_type": "markdown", "id": "b0000003", "metadata": {}, "source": [ "## 1. Load experiment logs" ] }, { "cell_type": "code", "execution_count": null, "id": "b0000004", "metadata": {}, "outputs": [], "source": [ "run_names = [\"p3_1_vae\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"]\n", "run_labels = {\n", " \"p3_1_vae\": \"3.1 VAE (MSE+KL)\",\n", " \"p3_2_vae_perceptual\": \"3.2 +Perceptual\",\n", " \"p3_3_vae_patchgan\": \"3.3 +PatchGAN\",\n", "}\n", "\n", "runs = {}\n", "for name in run_names:\n", " log_path = LOGS / f\"{name}.json\"\n", " if log_path.exists():\n", " with open(log_path) as f:\n", " runs[name] = json.load(f)\n", " else:\n", " print(f\" Missing: {log_path}\")\n", "\n", "print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n", "for name in run_names:\n", " status = \"✓\" if name in runs else \"✗\"\n", " print(f\" {status} {name}\")" ] }, { "cell_type": "markdown", "id": "b0000005", "metadata": {}, "source": [ "## 2. FID Comparison Table (prior samples)" ] }, { "cell_type": "code", "execution_count": null, "id": "b0000006", "metadata": {}, "outputs": [], "source": [ "def get_fid(run, epoch):\n", " fid = run[\"history\"][\"fid\"]\n", " return fid.get(str(epoch), fid.get(epoch, None))\n", "\n", "rows = []\n", "for name in run_names:\n", " if name not in runs:\n", " continue\n", " r = runs[name]\n", " cfg = r[\"config\"]\n", " h = r[\"history\"]\n", " rows.append({\n", " \"Step\": run_labels[name],\n", " \"λ_perc\": cfg.get(\"lambda_perceptual\", 0),\n", " \"λ_adv\": cfg.get(\"lambda_adversarial\", 0),\n", " \"β_kl\": cfg.get(\"beta_kl\", 1.0),\n", " \"FID@25 (prior)\": get_fid(r, 25),\n", " \"FID@50 (prior)\": get_fid(r, 50),\n", " \"FID@75 (prior)\": get_fid(r, 75),\n", " \"FID@100 (prior)\": get_fid(r, 100),\n", " })\n", "\n", "df = pd.DataFrame(rows)\n", "df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})" ] }, { "cell_type": "markdown", "id": "b0000007", "metadata": {}, "source": [ "## 3. FID Curves — Evolution Story" ] }, { "cell_type": "code", "execution_count": null, "id": "b0000008", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(10, 5))\n", "colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\"]\n", "\n", "for i, name in enumerate(run_names):\n", " if name not in runs:\n", " continue\n", " fid_dict = runs[name][\"history\"][\"fid\"]\n", " epochs = sorted(int(k) for k in fid_dict)\n", " fids = [fid_dict[str(e)] for e in epochs]\n", " label = f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\"\n", " ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n", "\n", "ax.set_xlabel(\"Epoch\")\n", "ax.set_ylabel(\"FID (lower is better) — prior samples\")\n", "ax.set_title(\"Phase 3 — FID Curves: VAE → +Perceptual → +PatchGAN\")\n", "ax.legend()\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "b0000009", "metadata": {}, "source": [ "## 4. Training Loss Curves" ] }, { "cell_type": "code", "execution_count": null, "id": "b0000010", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(2, 3, figsize=(16, 9))\n", "axes = axes.flatten()\n", "keys = [\"recon_loss\", \"kl_loss\", \"perc_loss\", \"adv_g_loss\", \"adv_d_loss\"]\n", "titles = [\"MSE Reconstruction\", \"KL Divergence\", \"Perceptual (VGG)\", \"Adv G Loss\", \"Adv D Loss (PatchGAN)\"]\n", "\n", "for ax, key, title in zip(axes, keys, titles):\n", " for i, name in enumerate(run_names):\n", " if name not in runs:\n", " continue\n", " h = runs[name][\"history\"].get(key, [])\n", " if any(v != 0.0 for v in h):\n", " ax.plot(range(1, len(h)+1), h, label=run_labels[name],\n", " color=colors[i], linewidth=1.2, alpha=0.9)\n", " ax.set_title(title)\n", " ax.set_xlabel(\"Epoch\")\n", " ax.legend(fontsize=8)\n", "\n", "axes[-1].axis(\"off\") # empty sixth panel\n", "fig.suptitle(\"Phase 3 — Training Dynamics\", fontsize=13, fontweight=\"bold\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "b0000011", "metadata": {}, "source": [ "## 5. Prior Samples — Epoch 100" ] }, { "cell_type": "code", "execution_count": null, "id": "b0000012", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", "\n", "for idx, name in enumerate(run_names):\n", " ax = axes[idx]\n", " img_path = SAMPLES / name / \"epoch_0100.png\"\n", " if img_path.exists():\n", " fid = get_fid(runs[name], 100) if name in runs else None\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n", " else:\n", " ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " ax.set_title(run_labels[name], fontsize=9)\n", " ax.axis(\"off\")\n", "\n", "fig.suptitle(\"Phase 3 — Epoch 100 Prior Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "b0000013", "metadata": {}, "source": [ "## 6. Reconstructions — Epoch 100\n", "\n", "Left half = real images, right half = reconstructions (interleaved pairs)." ] }, { "cell_type": "code", "execution_count": null, "id": "b0000014", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", "\n", "for idx, name in enumerate(run_names):\n", " ax = axes[idx]\n", " img_path = SAMPLES / name / \"epoch_0100_recon.png\"\n", " if img_path.exists():\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " ax.set_title(f\"{run_labels[name]}\\nreal | recon\", fontsize=9)\n", " else:\n", " ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " ax.set_title(run_labels[name], fontsize=9)\n", " ax.axis(\"off\")\n", "\n", "fig.suptitle(\"Phase 3 — Epoch 100 Reconstructions\", fontsize=12, fontweight=\"bold\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "b0000015", "metadata": {}, "source": [ "## 7. Step-by-step Pairwise Comparisons" ] }, { "cell_type": "code", "execution_count": null, "id": "b0000016", "metadata": {}, "outputs": [], "source": [ "transitions = [\n", " (\"3.1→3.2: MSE→+Perceptual\", \"p3_1_vae\", \"p3_2_vae_perceptual\"),\n", " (\"3.2→3.3: +PatchGAN adversarial\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"),\n", "]\n", "\n", "for title, name_a, name_b in transitions:\n", " fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n", " for col, name in enumerate([name_a, name_b]):\n", " for row, suffix in enumerate([\"\", \"_recon\"]):\n", " ax = axes[row][col]\n", " ep = 100\n", " img_path = SAMPLES / name / f\"epoch_{ep:04d}{suffix}.png\"\n", " label = run_labels[name]\n", " kind = \"prior\" if suffix == \"\" else \"recon\"\n", " if img_path.exists():\n", " fid = get_fid(runs[name], ep) if (suffix == \"\" and name in runs) else None\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " ax.set_title(f\"{label}\\n{kind}\" + (f\" FID={fid:.1f}\" if fid else \"\"), fontsize=9)\n", " else:\n", " ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " ax.set_title(f\"{label} ({kind})\", fontsize=9)\n", " ax.axis(\"off\")\n", " fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "b0000017", "metadata": {}, "source": [ "## 8. Progression: Epoch 10 → 50 → 100 (prior samples)" ] }, { "cell_type": "code", "execution_count": null, "id": "b0000018", "metadata": {}, "outputs": [], "source": [ "check_epochs = [10, 50, 100]\n", "\n", "for name in run_names:\n", " if name not in runs:\n", " continue\n", " fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n", " for ax, ep in zip(axes, check_epochs):\n", " img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n", " if img_path.exists():\n", " ax.imshow(mpimg.imread(str(img_path)))\n", " fid = get_fid(runs[name], ep)\n", " ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n", " else:\n", " ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n", " transform=ax.transAxes)\n", " ax.axis(\"off\")\n", " fig.suptitle(f\"{run_labels[name]} — Prior Sample Progression\", fontsize=11, fontweight=\"bold\")\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "b0000019", "metadata": {}, "source": [ "## 9. Conclusions" ] }, { "cell_type": "code", "execution_count": null, "id": "b0000020", "metadata": {}, "outputs": [], "source": [ "print(\"=\" * 70)\n", "print(\"PHASE 3 — VAE EVOLUTION SUMMARY\")\n", "print(\"=\" * 70)\n", "\n", "for name in run_names:\n", " if name not in runs:\n", " print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n", " continue\n", " h = runs[name][\"history\"]\n", " fid100 = get_fid(runs[name], 100)\n", " fid50 = get_fid(runs[name], 50)\n", " mse50 = h[\"recon_loss\"][49] if len(h[\"recon_loss\"]) > 49 else None\n", " kl50 = h[\"kl_loss\"][49] if len(h[\"kl_loss\"]) > 49 else None\n", " print(f\"\\n {run_labels[name]}:\")\n", " print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n", " print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n", " print(f\" MSE@50 = {mse50:.4f}\" if mse50 else \" MSE@50 = ?\")\n", " print(f\" KL@50 = {kl50:.2f}\" if kl50 else \" KL@50 = ?\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"Best VAE model for Phase 5 cross-family comparison: fill in after runs.\")\n", "print(\"=\" * 70)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 5 }