{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Phase 3 - VAE Progression\n", "\n", "Phase 3 studies the VAE family after the pipeline has been locked. The baseline\n", "VAE is fast and stable, but its MSE + KL objective tends to average away facial\n", "detail.\n", "\n", "This phase asks whether extra loss terms can recover sharper, more realistic\n", "samples. The saved runs stack the losses one step at a time.\n", "\n", "## What this phase changes\n", "\n", "| Run | Recipe change |\n", "|---|---|\n", "| `p3_1_vae` | Baseline MSE + KL |\n", "| `p3_2_vae_perceptual` | Adds VGG16 perceptual loss (`lambda_perceptual=0.1`) |\n", "| `p3_3_vae_patchgan` | Adds PatchGAN adversarial loss (`lambda_adversarial=0.01`) |\n", "\n", "**Headline result:** `p3_3_vae_patchgan` reaches **best FID = 50.1** on\n", "prior samples. The important result is the sequence: the added losses are\n", "complementary, not interchangeable.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reference: Phase 0 baseline from the same family\n", "\n", "`p0_vae` was MSE+KL on raw un-aligned data. Prior samples were heavily blurred\n", "mean-faces, the textbook VAE failure mode. Phase 3 keeps the encoder/decoder\n", "architecture and pipeline fixed and shows that adding perceptual + adversarial\n", "terms is what actually moves the needle.\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "060fd1c7", "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import numpy as np\n", "import pandas as pd\n", "\n", "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", "\n", "def find_generator_root():\n", " for base in [Path.cwd(), *Path.cwd().parents]:\n", " for candidate in [base, base / \"generator\"]:\n", " if (candidate / \"outputs\" / \"logs\").exists() and (candidate / \"outputs\" / \"samples\").exists():\n", " return candidate.resolve()\n", " raise FileNotFoundError(\"Could not locate generator/outputs from the current working directory\")\n", "\n", "GENERATOR_ROOT = find_generator_root()\n", "PROJECT_ROOT = GENERATOR_ROOT.parent\n", "OUTPUTS = GENERATOR_ROOT / \"outputs\"\n", "LOGS = OUTPUTS / \"logs\"\n", "SAMPLES = OUTPUTS / \"samples\"\n", "\n", "\n", "def load_log(name):\n", " p = LOGS / f\"{name}.json\"\n", " return json.load(open(p)) if p.exists() else None\n", "\n", "def get_fid(log, epoch):\n", " fid = log.get(\"history\", {}).get(\"fid\", {})\n", " return fid.get(str(epoch))\n", "\n", "def fid_series(log):\n", " fid = log.get(\"history\", {}).get(\"fid\", {})\n", " items = sorted((int(k), v) for k, v in fid.items())\n", " return [e for e, _ in items], [v for _, v in items]\n", "\n", "def show_image_or_missing(ax, path, title=None):\n", " if path.exists():\n", " ax.imshow(mpimg.imread(str(path)))\n", " else:\n", " ax.text(0.5, 0.5, f\"missing artifact\\n{path.name}\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", " if title:\n", " ax.set_title(title, fontsize=9)\n", " ax.axis(\"off\")\n", "\n" ] }, { "cell_type": "markdown", "id": "04c1769a", "metadata": {}, "source": [ "## 1. Load experiment logs\n", "\n", "The notebook loads the existing VAE logs and saved previews only." ] }, { "cell_type": "code", "execution_count": 18, "id": "d60f5f69", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " p3_1_vae: OK\n", " p3_2_vae_perceptual: OK\n", " p3_3_vae_patchgan: OK\n" ] } ], "source": [ "run_names = [\"p3_1_vae\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"]\n", "run_labels = {\n", " \"p3_1_vae\": \"3.1 MSE + KL\",\n", " \"p3_2_vae_perceptual\": \"3.2 + Perceptual\",\n", " \"p3_3_vae_patchgan\": \"3.3 + PatchGAN\",\n", "}\n", "runs = {name: load_log(name) for name in run_names}\n", "runs = {k: v for k, v in runs.items() if v}\n", "for n in run_names: print(f\" {n}: {'OK' if n in runs else 'MISSING'}\")\n" ] }, { "cell_type": "markdown", "id": "94333e76", "metadata": {}, "source": [ "## 2. FID comparison table (prior samples)\n", "\n", "FID is computed on samples decoded from the prior, so it evaluates generation quality rather than only reconstruction quality." ] }, { "cell_type": "code", "execution_count": 19, "id": "2b220dfe", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
| \n", " | Run | \n", "FID@50 | \n", "FID@100 | \n", "Best FID | \n", "Recon@100 | \n", "KL@100 | \n", "Train (min) | \n", "
|---|---|---|---|---|---|---|---|
| 2 | \n", "3.3 + PatchGAN | \n", "51.9 | \n", "52.5 | \n", "50.1 | \n", "0.0439 | \n", "0.22 | \n", "35.8 | \n", "
| 1 | \n", "3.2 + Perceptual | \n", "70.3 | \n", "68.2 | \n", "68.2 | \n", "0.0409 | \n", "0.19 | \n", "25.4 | \n", "
| 0 | \n", "3.1 MSE + KL | \n", "93.7 | \n", "88.4 | \n", "88.4 | \n", "0.0510 | \n", "0.11 | \n", "25.4 | \n", "