{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Phase 3 — VAE Evolution\n", "\n", "Standard VAEs collapse to mean-blur. Phase 3 stacks losses on top of the basic\n", "MSE+KL objective to recover detail.\n", "\n", "| Run | Step |\n", "|----------------------|------------------------------------------------------|\n", "| `p3_1_vae` | MSE + KL only |\n", "| `p3_2_vae_perceptual`| + VGG16 perceptual loss (`lambda_perceptual=0.1`) |\n", "| `p3_3_vae_patchgan` | + PatchGAN adversarial loss (`lambda_adversarial=0.01`) |\n", "\n", "**Headline result:** `p3_3_vae_patchgan` — **best FID = 50.1** (prior samples) at 100 epochs.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reference: phase 0 baseline (same family)\n", "\n", "`p0_vae` was MSE+KL on raw un-aligned data. Prior samples were heavily blurred\n", "mean-faces — the textbook VAE failure mode. Phase 3 keeps the encoder/decoder\n", "architecture and pipeline fixed and shows that adding perceptual + adversarial\n", "terms is what actually moves the needle.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import numpy as np\n", "import pandas as pd\n", "\n", "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", "\n", "OUTPUTS = Path(\"../outputs\")\n", "LOGS = OUTPUTS / \"logs\"\n", "SAMPLES = OUTPUTS / \"samples\"\n", "\n", "\n", "def load_log(name):\n", " p = LOGS / f\"{name}.json\"\n", " return json.load(open(p)) if p.exists() else None\n", "\n", "def get_fid(log, epoch):\n", " fid = log.get(\"history\", {}).get(\"fid\", {})\n", " return fid.get(str(epoch))\n", "\n", "def fid_series(log):\n", " fid = log.get(\"history\", {}).get(\"fid\", {})\n", " items = sorted((int(k), v) for k, v in fid.items())\n", " return [e for e, _ in items], [v for _, v in items]\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load experiment logs" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " p3_1_vae: OK\n", " p3_2_vae_perceptual: OK\n", " p3_3_vae_patchgan: OK\n" ] } ], "source": [ "run_names = [\"p3_1_vae\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"]\n", "run_labels = {\n", " \"p3_1_vae\": \"3.1 MSE + KL\",\n", " \"p3_2_vae_perceptual\": \"3.2 + Perceptual\",\n", " \"p3_3_vae_patchgan\": \"3.3 + PatchGAN\",\n", "}\n", "runs = {name: load_log(name) for name in run_names}\n", "runs = {k: v for k, v in runs.items() if v}\n", "for n in run_names: print(f\" {n}: {'OK' if n in runs else 'MISSING'}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. FID comparison table (prior samples)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
| \n", " | Run | \n", "FID@50 | \n", "FID@100 | \n", "Best FID | \n", "Recon@100 | \n", "KL@100 | \n", "Train (min) | \n", "
|---|---|---|---|---|---|---|---|
| 2 | \n", "3.3 + PatchGAN | \n", "51.9 | \n", "52.5 | \n", "50.1 | \n", "0.0439 | \n", "0.22 | \n", "35.8 | \n", "
| 1 | \n", "3.2 + Perceptual | \n", "70.3 | \n", "68.2 | \n", "68.2 | \n", "0.0409 | \n", "0.19 | \n", "25.4 | \n", "
| 0 | \n", "3.1 MSE + KL | \n", "93.7 | \n", "88.4 | \n", "88.4 | \n", "0.0510 | \n", "0.11 | \n", "25.4 | \n", "