Files
DRL_PROJ/generator/notebooks/phase3_analysis.ipynb
T
2026-04-30 13:10:33 +01:00

397 lines
13 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "b0000001",
"metadata": {},
"source": [
"# Phase 3 — VAE Evolution Analysis\n",
"\n",
"Traces the VAE improvement story — each step motivated by the failure of the previous:\n",
"\n",
"| Step | Model | Key change | Expected failure |\n",
"|------|-------|------------|------------------|\n",
"| 3.1 | Vanilla VAE (MSE+KL) | Baseline | Blurry samples — MSE minimises pixel average |\n",
"| 3.2 | + Perceptual loss (VGG) | Feature-space reconstruction | Residual texture blur |\n",
"| 3.3 | + PatchGAN (VQGAN-lite) | Local texture adversarial | — |\n",
"\n",
"All runs use H-flip-only augmentation and MTCNN-aligned 64×64 crops.\n",
"FID is computed from prior samples (`z ~ N(0, I) → decode`), same metric as GAN.\n",
"Reconstructions are shown separately to diagnose encoder quality."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000002",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\""
]
},
{
"cell_type": "markdown",
"id": "b0000003",
"metadata": {},
"source": [
"## 1. Load experiment logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000004",
"metadata": {},
"outputs": [],
"source": [
"run_names = [\"p3_1_vae\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"]\n",
"run_labels = {\n",
" \"p3_1_vae\": \"3.1 VAE (MSE+KL)\",\n",
" \"p3_2_vae_perceptual\": \"3.2 +Perceptual\",\n",
" \"p3_3_vae_patchgan\": \"3.3 +PatchGAN\",\n",
"}\n",
"\n",
"runs = {}\n",
"for name in run_names:\n",
" log_path = LOGS / f\"{name}.json\"\n",
" if log_path.exists():\n",
" with open(log_path) as f:\n",
" runs[name] = json.load(f)\n",
" else:\n",
" print(f\" Missing: {log_path}\")\n",
"\n",
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
"for name in run_names:\n",
" status = \"✓\" if name in runs else \"✗\"\n",
" print(f\" {status} {name}\")"
]
},
{
"cell_type": "markdown",
"id": "b0000005",
"metadata": {},
"source": [
"## 2. FID Comparison Table (prior samples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000006",
"metadata": {},
"outputs": [],
"source": [
"def get_fid(run, epoch):\n",
" fid = run[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"rows = []\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" r = runs[name]\n",
" cfg = r[\"config\"]\n",
" h = r[\"history\"]\n",
" rows.append({\n",
" \"Step\": run_labels[name],\n",
" \"λ_perc\": cfg.get(\"lambda_perceptual\", 0),\n",
" \"λ_adv\": cfg.get(\"lambda_adversarial\", 0),\n",
" \"β_kl\": cfg.get(\"beta_kl\", 1.0),\n",
" \"FID@25 (prior)\": get_fid(r, 25),\n",
" \"FID@50 (prior)\": get_fid(r, 50),\n",
" \"FID@75 (prior)\": get_fid(r, 75),\n",
" \"FID@100 (prior)\": get_fid(r, 100),\n",
" })\n",
"\n",
"df = pd.DataFrame(rows)\n",
"df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})"
]
},
{
"cell_type": "markdown",
"id": "b0000007",
"metadata": {},
"source": [
"## 3. FID Curves — Evolution Story"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000008",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(10, 5))\n",
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\"]\n",
"\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
" epochs = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in epochs]\n",
" label = f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\"\n",
" ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"FID (lower is better) — prior samples\")\n",
"ax.set_title(\"Phase 3 — FID Curves: VAE → +Perceptual → +PatchGAN\")\n",
"ax.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000009",
"metadata": {},
"source": [
"## 4. Training Loss Curves"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000010",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(2, 3, figsize=(16, 9))\n",
"axes = axes.flatten()\n",
"keys = [\"recon_loss\", \"kl_loss\", \"perc_loss\", \"adv_g_loss\", \"adv_d_loss\"]\n",
"titles = [\"MSE Reconstruction\", \"KL Divergence\", \"Perceptual (VGG)\", \"Adv G Loss\", \"Adv D Loss (PatchGAN)\"]\n",
"\n",
"for ax, key, title in zip(axes, keys, titles):\n",
" for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" h = runs[name][\"history\"].get(key, [])\n",
" if any(v != 0.0 for v in h):\n",
" ax.plot(range(1, len(h)+1), h, label=run_labels[name],\n",
" color=colors[i], linewidth=1.2, alpha=0.9)\n",
" ax.set_title(title)\n",
" ax.set_xlabel(\"Epoch\")\n",
" ax.legend(fontsize=8)\n",
"\n",
"axes[-1].axis(\"off\") # empty sixth panel\n",
"fig.suptitle(\"Phase 3 — Training Dynamics\", fontsize=13, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000011",
"metadata": {},
"source": [
"## 5. Prior Samples — Epoch 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000012",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=9)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 3 — Epoch 100 Prior Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000013",
"metadata": {},
"source": [
"## 6. Reconstructions — Epoch 100\n",
"\n",
"Left half = real images, right half = reconstructions (interleaved pairs)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000014",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100_recon.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nreal | recon\", fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=9)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 3 — Epoch 100 Reconstructions\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000015",
"metadata": {},
"source": [
"## 7. Step-by-step Pairwise Comparisons"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000016",
"metadata": {},
"outputs": [],
"source": [
"transitions = [\n",
" (\"3.1→3.2: MSE→+Perceptual\", \"p3_1_vae\", \"p3_2_vae_perceptual\"),\n",
" (\"3.2→3.3: +PatchGAN adversarial\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"),\n",
"]\n",
"\n",
"for title, name_a, name_b in transitions:\n",
" fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
" for col, name in enumerate([name_a, name_b]):\n",
" for row, suffix in enumerate([\"\", \"_recon\"]):\n",
" ax = axes[row][col]\n",
" ep = 100\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}{suffix}.png\"\n",
" label = run_labels[name]\n",
" kind = \"prior\" if suffix == \"\" else \"recon\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], ep) if (suffix == \"\" and name in runs) else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{label}\\n{kind}\" + (f\" FID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(f\"{label} ({kind})\", fontsize=9)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000017",
"metadata": {},
"source": [
"## 8. Progression: Epoch 10 → 50 → 100 (prior samples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000018",
"metadata": {},
"outputs": [],
"source": [
"check_epochs = [10, 50, 100]\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(runs[name], ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
" transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{run_labels[name]} — Prior Sample Progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000019",
"metadata": {},
"source": [
"## 9. Conclusions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000020",
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 70)\n",
"print(\"PHASE 3 — VAE EVOLUTION SUMMARY\")\n",
"print(\"=\" * 70)\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
" continue\n",
" h = runs[name][\"history\"]\n",
" fid100 = get_fid(runs[name], 100)\n",
" fid50 = get_fid(runs[name], 50)\n",
" mse50 = h[\"recon_loss\"][49] if len(h[\"recon_loss\"]) > 49 else None\n",
" kl50 = h[\"kl_loss\"][49] if len(h[\"kl_loss\"]) > 49 else None\n",
" print(f\"\\n {run_labels[name]}:\")\n",
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
" print(f\" MSE@50 = {mse50:.4f}\" if mse50 else \" MSE@50 = ?\")\n",
" print(f\" KL@50 = {kl50:.2f}\" if kl50 else \" KL@50 = ?\")\n",
"\n",
"print(\"\\n\" + \"=\" * 70)\n",
"print(\"Best VAE model for Phase 5 cross-family comparison: fill in after runs.\")\n",
"print(\"=\" * 70)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}