From 741726711716733f0b2da2acf32a1f361ba1de7d Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Thu, 30 Apr 2026 13:10:33 +0100 Subject: [PATCH] Preview of phase 2-5 implementation; needs a full check --- generator/configs/phase2/_base_phase2.json | 10 + generator/configs/phase2/p2_1_dcgan.json | 12 + generator/configs/phase2/p2_2_wgan.json | 14 + .../configs/phase2/p2_3_wgan_sn_attn.json | 15 + .../configs/phase2/p2_4_wgan_sn_attn_128.json | 15 + generator/configs/phase3/_base_phase3.json | 13 + generator/configs/phase3/p3_1_vae.json | 8 + .../configs/phase3/p3_2_vae_perceptual.json | 8 + .../configs/phase3/p3_3_vae_patchgan.json | 10 + generator/configs/phase4/_base_phase4.json | 19 + .../configs/phase4/p4_1_ddpm_linear.json | 6 + .../configs/phase4/p4_2_ddpm_cosine.json | 6 + generator/configs/phase4/p4_3_ddpm_vpred.json | 6 + generator/configs/phase4/p4_4_ddpm_wider.json | 8 + generator/configs/phase5/p5_ddpm.json | 22 + generator/configs/phase5/p5_gan.json | 21 + generator/configs/phase5/p5_vae.json | 20 + generator/notebooks/phase2_analysis.ipynb | 366 ++++++++++ generator/notebooks/phase3_analysis.ipynb | 396 +++++++++++ generator/notebooks/phase4_analysis.ipynb | 403 +++++++++++ generator/notebooks/phase5_analysis.ipynb | 669 ++++++++++++++++++ .../pipeline/20260430T012157.632405+0000.json | 19 - .../pipeline/20260430T023357.318888+0000.json | 18 - generator/run.py | 29 +- generator/src/data/dataset.py | 11 +- generator/src/models/__init__.py | 3 + generator/src/models/patchgan.py | 69 ++ generator/src/models/unet.py | 279 ++++++++ generator/src/models/vae.py | 132 ++++ generator/src/models/wgan.py | 254 +++++-- generator/src/training/__init__.py | 4 +- generator/src/training/diffusion.py | 136 ++++ generator/src/training/metrics.py | 56 ++ generator/src/training/perceptual.py | 57 ++ generator/src/training/trainer.py | 606 +++++++++++++++- 35 files changed, 3605 insertions(+), 115 deletions(-) create mode 100644 generator/configs/phase2/_base_phase2.json create mode 100644 generator/configs/phase2/p2_1_dcgan.json create mode 100644 generator/configs/phase2/p2_2_wgan.json create mode 100644 generator/configs/phase2/p2_3_wgan_sn_attn.json create mode 100644 generator/configs/phase2/p2_4_wgan_sn_attn_128.json create mode 100644 generator/configs/phase3/_base_phase3.json create mode 100644 generator/configs/phase3/p3_1_vae.json create mode 100644 generator/configs/phase3/p3_2_vae_perceptual.json create mode 100644 generator/configs/phase3/p3_3_vae_patchgan.json create mode 100644 generator/configs/phase4/_base_phase4.json create mode 100644 generator/configs/phase4/p4_1_ddpm_linear.json create mode 100644 generator/configs/phase4/p4_2_ddpm_cosine.json create mode 100644 generator/configs/phase4/p4_3_ddpm_vpred.json create mode 100644 generator/configs/phase4/p4_4_ddpm_wider.json create mode 100644 generator/configs/phase5/p5_ddpm.json create mode 100644 generator/configs/phase5/p5_gan.json create mode 100644 generator/configs/phase5/p5_vae.json create mode 100644 generator/notebooks/phase2_analysis.ipynb create mode 100644 generator/notebooks/phase3_analysis.ipynb create mode 100644 generator/notebooks/phase4_analysis.ipynb create mode 100644 generator/notebooks/phase5_analysis.ipynb delete mode 100644 generator/outputs/pipeline/20260430T012157.632405+0000.json delete mode 100644 generator/outputs/pipeline/20260430T023357.318888+0000.json create mode 100644 generator/src/models/patchgan.py create mode 100644 generator/src/models/unet.py create mode 100644 generator/src/models/vae.py create mode 100644 generator/src/training/diffusion.py create mode 100644 generator/src/training/metrics.py create mode 100644 generator/src/training/perceptual.py diff --git a/generator/configs/phase2/_base_phase2.json b/generator/configs/phase2/_base_phase2.json new file mode 100644 index 0000000..a294941 --- /dev/null +++ b/generator/configs/phase2/_base_phase2.json @@ -0,0 +1,10 @@ +{ + "epochs": 100, + "data_dir": "cropped/generator", + "sources": ["wiki"], + "augment": true, + "image_size": 64, + "sample_interval": 10, + "fid_interval": 25, + "fid_n_real": 5000 +} diff --git a/generator/configs/phase2/p2_1_dcgan.json b/generator/configs/phase2/p2_1_dcgan.json new file mode 100644 index 0000000..31ed19b --- /dev/null +++ b/generator/configs/phase2/p2_1_dcgan.json @@ -0,0 +1,12 @@ +{ + "extends": "_base_phase2.json", + "run_name": "p2_1_dcgan", + "model": "dcgan", + "latent_dim": 100, + "ngf": 64, + "ndf": 64, + "lr_g": 2e-4, + "lr_d": 2e-4, + "beta1": 0.5, + "beta2": 0.999 +} diff --git a/generator/configs/phase2/p2_2_wgan.json b/generator/configs/phase2/p2_2_wgan.json new file mode 100644 index 0000000..50322e3 --- /dev/null +++ b/generator/configs/phase2/p2_2_wgan.json @@ -0,0 +1,14 @@ +{ + "extends": "_base_phase2.json", + "run_name": "p2_2_wgan", + "model": "wgan_basic", + "latent_dim": 128, + "ngf": 64, + "ndf": 64, + "lr_g": 1e-4, + "lr_d": 1e-4, + "beta1": 0.0, + "beta2": 0.9, + "n_critic": 2, + "gp_lambda": 10 +} diff --git a/generator/configs/phase2/p2_3_wgan_sn_attn.json b/generator/configs/phase2/p2_3_wgan_sn_attn.json new file mode 100644 index 0000000..3f43d18 --- /dev/null +++ b/generator/configs/phase2/p2_3_wgan_sn_attn.json @@ -0,0 +1,15 @@ +{ + "extends": "_base_phase2.json", + "run_name": "p2_3_wgan_sn_attn", + "model": "wgan", + "image_size": 64, + "latent_dim": 128, + "ngf": 128, + "ndf": 128, + "lr_g": 1e-4, + "lr_d": 1e-4, + "beta1": 0.0, + "beta2": 0.9, + "n_critic": 2, + "gp_lambda": 10 +} diff --git a/generator/configs/phase2/p2_4_wgan_sn_attn_128.json b/generator/configs/phase2/p2_4_wgan_sn_attn_128.json new file mode 100644 index 0000000..f012313 --- /dev/null +++ b/generator/configs/phase2/p2_4_wgan_sn_attn_128.json @@ -0,0 +1,15 @@ +{ + "extends": "_base_phase2.json", + "run_name": "p2_4_wgan_sn_attn_128", + "model": "wgan", + "image_size": 128, + "latent_dim": 128, + "ngf": 128, + "ndf": 128, + "lr_g": 1e-4, + "lr_d": 1e-4, + "beta1": 0.0, + "beta2": 0.9, + "n_critic": 2, + "gp_lambda": 10 +} diff --git a/generator/configs/phase3/_base_phase3.json b/generator/configs/phase3/_base_phase3.json new file mode 100644 index 0000000..3c3e438 --- /dev/null +++ b/generator/configs/phase3/_base_phase3.json @@ -0,0 +1,13 @@ +{ + "epochs": 100, + "data_dir": "cropped/generator", + "sources": ["wiki"], + "augment": "hflip", + "image_size": 64, + "model": "vae", + "latent_dim": 256, + "ngf": 64, + "sample_interval": 10, + "fid_interval": 25, + "fid_n_real": 5000 +} diff --git a/generator/configs/phase3/p3_1_vae.json b/generator/configs/phase3/p3_1_vae.json new file mode 100644 index 0000000..75b6af9 --- /dev/null +++ b/generator/configs/phase3/p3_1_vae.json @@ -0,0 +1,8 @@ +{ + "extends": "_base_phase3.json", + "run_name": "p3_1_vae", + "lr": 1e-3, + "beta_kl": 1.0, + "lambda_perceptual": 0.0, + "lambda_adversarial": 0.0 +} diff --git a/generator/configs/phase3/p3_2_vae_perceptual.json b/generator/configs/phase3/p3_2_vae_perceptual.json new file mode 100644 index 0000000..8116ea2 --- /dev/null +++ b/generator/configs/phase3/p3_2_vae_perceptual.json @@ -0,0 +1,8 @@ +{ + "extends": "_base_phase3.json", + "run_name": "p3_2_vae_perceptual", + "lr": 1e-3, + "beta_kl": 0.0001, + "lambda_perceptual": 0.1, + "lambda_adversarial": 0.0 +} diff --git a/generator/configs/phase3/p3_3_vae_patchgan.json b/generator/configs/phase3/p3_3_vae_patchgan.json new file mode 100644 index 0000000..1c6e443 --- /dev/null +++ b/generator/configs/phase3/p3_3_vae_patchgan.json @@ -0,0 +1,10 @@ +{ + "extends": "_base_phase3.json", + "run_name": "p3_3_vae_patchgan", + "lr": 1e-3, + "lr_d": 1e-4, + "beta_kl": 0.0001, + "lambda_perceptual": 0.1, + "lambda_adversarial": 0.1, + "ndf_patch": 64 +} diff --git a/generator/configs/phase4/_base_phase4.json b/generator/configs/phase4/_base_phase4.json new file mode 100644 index 0000000..7b11ceb --- /dev/null +++ b/generator/configs/phase4/_base_phase4.json @@ -0,0 +1,19 @@ +{ + "epochs": 100, + "data_dir": "cropped/generator", + "sources": ["wiki"], + "augment": "hflip", + "image_size": 64, + "model": "ddpm", + "T": 1000, + "ddim_steps": 100, + "lr": 2e-4, + "base_ch": 128, + "ch_mult": [1, 2, 2, 2], + "attn_resolutions": [16, 8], + "num_res_blocks": 2, + "dropout": 0.1, + "sample_interval": 10, + "fid_interval": 25, + "fid_n_real": 5000 +} diff --git a/generator/configs/phase4/p4_1_ddpm_linear.json b/generator/configs/phase4/p4_1_ddpm_linear.json new file mode 100644 index 0000000..1b87412 --- /dev/null +++ b/generator/configs/phase4/p4_1_ddpm_linear.json @@ -0,0 +1,6 @@ +{ + "extends": "_base_phase4.json", + "run_name": "p4_1_ddpm_linear", + "noise_schedule": "linear", + "pred_type": "eps" +} diff --git a/generator/configs/phase4/p4_2_ddpm_cosine.json b/generator/configs/phase4/p4_2_ddpm_cosine.json new file mode 100644 index 0000000..d82a794 --- /dev/null +++ b/generator/configs/phase4/p4_2_ddpm_cosine.json @@ -0,0 +1,6 @@ +{ + "extends": "_base_phase4.json", + "run_name": "p4_2_ddpm_cosine", + "noise_schedule": "cosine", + "pred_type": "eps" +} diff --git a/generator/configs/phase4/p4_3_ddpm_vpred.json b/generator/configs/phase4/p4_3_ddpm_vpred.json new file mode 100644 index 0000000..81b97e2 --- /dev/null +++ b/generator/configs/phase4/p4_3_ddpm_vpred.json @@ -0,0 +1,6 @@ +{ + "extends": "_base_phase4.json", + "run_name": "p4_3_ddpm_vpred", + "noise_schedule": "cosine", + "pred_type": "v" +} diff --git a/generator/configs/phase4/p4_4_ddpm_wider.json b/generator/configs/phase4/p4_4_ddpm_wider.json new file mode 100644 index 0000000..1b97106 --- /dev/null +++ b/generator/configs/phase4/p4_4_ddpm_wider.json @@ -0,0 +1,8 @@ +{ + "extends": "_base_phase4.json", + "run_name": "p4_4_ddpm_wider", + "noise_schedule": "cosine", + "pred_type": "v", + "base_ch": 192, + "attn_resolutions": [32, 16, 8] +} diff --git a/generator/configs/phase5/p5_ddpm.json b/generator/configs/phase5/p5_ddpm.json new file mode 100644 index 0000000..600dbe9 --- /dev/null +++ b/generator/configs/phase5/p5_ddpm.json @@ -0,0 +1,22 @@ +{ + "run_name": "p5_ddpm", + "model": "ddpm", + "epochs": 200, + "data_dir": "cropped/generator", + "sources": ["wiki"], + "augment": "hflip", + "image_size": 64, + "T": 1000, + "noise_schedule": "cosine", + "pred_type": "v", + "base_ch": 192, + "ch_mult": [1, 2, 2, 2], + "attn_resolutions": [32, 16, 8], + "num_res_blocks": 2, + "dropout": 0.1, + "lr": 2e-4, + "ddim_steps": 100, + "sample_interval": 10, + "fid_interval": 25, + "fid_n_real": 5000 +} diff --git a/generator/configs/phase5/p5_gan.json b/generator/configs/phase5/p5_gan.json new file mode 100644 index 0000000..68b35ea --- /dev/null +++ b/generator/configs/phase5/p5_gan.json @@ -0,0 +1,21 @@ +{ + "run_name": "p5_gan", + "model": "wgan", + "epochs": 200, + "data_dir": "cropped/generator", + "sources": ["wiki"], + "augment": true, + "image_size": 128, + "latent_dim": 128, + "ngf": 128, + "ndf": 128, + "lr_g": 1e-4, + "lr_d": 1e-4, + "beta1": 0.0, + "beta2": 0.9, + "n_critic": 2, + "gp_lambda": 10, + "sample_interval": 10, + "fid_interval": 25, + "fid_n_real": 5000 +} diff --git a/generator/configs/phase5/p5_vae.json b/generator/configs/phase5/p5_vae.json new file mode 100644 index 0000000..4eaf91d --- /dev/null +++ b/generator/configs/phase5/p5_vae.json @@ -0,0 +1,20 @@ +{ + "run_name": "p5_vae", + "model": "vae", + "epochs": 200, + "data_dir": "cropped/generator", + "sources": ["wiki"], + "augment": "hflip", + "image_size": 64, + "latent_dim": 256, + "ngf": 64, + "lr": 1e-3, + "lr_d": 1e-4, + "beta_kl": 0.0001, + "lambda_perceptual": 0.1, + "lambda_adversarial": 0.1, + "ndf_patch": 64, + "sample_interval": 10, + "fid_interval": 25, + "fid_n_real": 5000 +} diff --git a/generator/notebooks/phase2_analysis.ipynb b/generator/notebooks/phase2_analysis.ipynb new file mode 100644 index 0000000..a52425d --- /dev/null +++ b/generator/notebooks/phase2_analysis.ipynb @@ -0,0 +1,366 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a0000001", + "metadata": {}, + "source": [ + "# Phase 2 — GAN Evolution Analysis\n", + "\n", + "Traces the GAN improvement story, each step motivated by the failure of the previous:\n", + "\n", + "| Step | Model | Key change | Expected failure |\n", + "|------|-------|------------|------------------|\n", + "| 2.1 | DCGAN 64×64 | Baseline on best pipeline | Mode collapse, training instability |\n", + "| 2.2 | WGAN-GP | Wasserstein loss + GP | Texture artifacts, limited coherence |\n", + "| 2.3 | WGAN-GP + SN + GroupNorm + Attn | Principled Lipschitz + long-range deps | Possible underfitting at 64×64 |\n", + "| 2.4 | 2.3 @ 128×128 | Scale resolution | ? |\n", + "\n", + "All runs use the best pipeline from Phase 1: MTCNN-aligned crops, H-flip + rotation + colour jitter, aligned-only dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0000002", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.image as mpimg\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", + "\n", + "OUTPUTS = Path(\"../outputs\")\n", + "LOGS = OUTPUTS / \"logs\"\n", + "SAMPLES = OUTPUTS / \"samples\"" + ] + }, + { + "cell_type": "markdown", + "id": "a0000003", + "metadata": {}, + "source": [ + "## 1. Load experiment logs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0000004", + "metadata": {}, + "outputs": [], + "source": [ + "run_names = [\"p2_1_dcgan\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"]\n", + "run_labels = {\n", + " \"p2_1_dcgan\": \"2.1 DCGAN\",\n", + " \"p2_2_wgan\": \"2.2 WGAN-GP\",\n", + " \"p2_3_wgan_sn_attn\": \"2.3 WGAN-GP+SN+Attn\",\n", + " \"p2_4_wgan_sn_attn_128\": \"2.4 +128×128\",\n", + "}\n", + "\n", + "runs = {}\n", + "for name in run_names:\n", + " log_path = LOGS / f\"{name}.json\"\n", + " if log_path.exists():\n", + " with open(log_path) as f:\n", + " runs[name] = json.load(f)\n", + " else:\n", + " print(f\" Missing: {log_path}\")\n", + "\n", + "print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n", + "for name in run_names:\n", + " status = \"✓\" if name in runs else \"✗\"\n", + " print(f\" {status} {name}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a0000005", + "metadata": {}, + "source": [ + "## 2. FID Comparison Table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0000006", + "metadata": {}, + "outputs": [], + "source": [ + "def get_fid(run, epoch):\n", + " fid = run[\"history\"][\"fid\"]\n", + " return fid.get(str(epoch), fid.get(epoch, None))\n", + "\n", + "rows = []\n", + "for name in run_names:\n", + " if name not in runs:\n", + " continue\n", + " r = runs[name]\n", + " cfg = r[\"config\"]\n", + " h = r[\"history\"]\n", + " rows.append({\n", + " \"Step\": run_labels[name],\n", + " \"Model\": cfg.get(\"model\"),\n", + " \"Size\": f\"{cfg.get('image_size', 64)}×{cfg.get('image_size', 64)}\",\n", + " \"FID@25\": get_fid(r, 25),\n", + " \"FID@50\": get_fid(r, 50),\n", + " \"FID@75\": get_fid(r, 75),\n", + " \"FID@100\": get_fid(r, 100),\n", + " })\n", + "\n", + "df = pd.DataFrame(rows)\n", + "df.style.format({c: \"{:.1f}\" for c in [\"FID@25\", \"FID@50\", \"FID@75\", \"FID@100\"] if c in df})" + ] + }, + { + "cell_type": "markdown", + "id": "a0000007", + "metadata": {}, + "source": [ + "## 3. FID Curves — Evolution Story" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0000008", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(11, 5))\n", + "colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n", + "\n", + "for i, name in enumerate(run_names):\n", + " if name not in runs:\n", + " continue\n", + " fid_dict = runs[name][\"history\"][\"fid\"]\n", + " epochs = sorted(int(k) for k in fid_dict)\n", + " fids = [fid_dict[str(e)] for e in epochs]\n", + " ax.plot(epochs, fids, \"o-\", label=f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\",\n", + " color=colors[i], linewidth=2, markersize=8)\n", + "\n", + "ax.set_xlabel(\"Epoch\")\n", + "ax.set_ylabel(\"FID (lower is better)\")\n", + "ax.set_title(\"Phase 2 — FID Curves: DCGAN → WGAN-GP → +SN+Attn → 128×128\")\n", + "ax.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a0000009", + "metadata": {}, + "source": [ + "## 4. Training Dynamics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0000010", + "metadata": {}, + "outputs": [], + "source": [ + "# Separate DCGAN (has g_loss/d_loss/d_real/d_fake) from WGAN (has g_loss/w_dist/gp)\n", + "dcgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") == \"dcgan\"]\n", + "wgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") != \"dcgan\"]\n", + "\n", + "if dcgan_names:\n", + " fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n", + " for name in dcgan_names:\n", + " h = runs[name][\"history\"]\n", + " epochs = range(1, len(h[\"g_loss\"]) + 1)\n", + " axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], linewidth=1.2)\n", + " axes[1].plot(epochs, h[\"d_loss\"], label=run_labels[name], linewidth=1.2)\n", + " axes[0].set_title(\"DCGAN — Generator Loss (BCE)\")\n", + " axes[1].set_title(\"DCGAN — Discriminator Loss\")\n", + " for ax in axes:\n", + " ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"Loss\"); ax.legend(fontsize=8)\n", + " plt.suptitle(\"Phase 2.1 — DCGAN Training Dynamics\", fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "if wgan_names:\n", + " fig, axes = plt.subplots(1, 3, figsize=(16, 4))\n", + " cmap = plt.cm.Set1\n", + " for i, name in enumerate(wgan_names):\n", + " h = runs[name][\"history\"]\n", + " epochs = range(1, len(h[\"g_loss\"]) + 1)\n", + " c = cmap(i / max(len(wgan_names), 1))\n", + " axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], color=c, linewidth=1.2)\n", + " axes[1].plot(epochs, h[\"w_dist\"], label=run_labels[name], color=c, linewidth=1.2)\n", + " axes[2].plot(epochs, h[\"gp\"], label=run_labels[name], color=c, linewidth=1.2)\n", + " axes[0].set_title(\"Generator Loss (−E[D(G(z))])\")\n", + " axes[1].set_title(\"Wasserstein Distance Est. (↑ better)\")\n", + " axes[2].set_title(\"Gradient Penalty\")\n", + " for ax in axes:\n", + " ax.set_xlabel(\"Epoch\"); ax.legend(fontsize=8)\n", + " plt.suptitle(\"Phase 2.2–2.4 — WGAN-GP Training Dynamics\", fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a0000011", + "metadata": {}, + "source": [ + "## 5. Sample Image Grids — Epoch 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0000012", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 4, figsize=(18, 5))\n", + "\n", + "for idx, name in enumerate(run_names):\n", + " ax = axes[idx]\n", + " img_path = SAMPLES / name / \"epoch_0100.png\"\n", + " if img_path.exists():\n", + " fid = get_fid(runs[name], 100) if name in runs else None\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n", + " else:\n", + " ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", + " ax.set_title(run_labels[name], fontsize=8)\n", + " ax.axis(\"off\")\n", + "\n", + "fig.suptitle(\"Phase 2 — Epoch 100 Sample Grids (4×4)\", fontsize=13, fontweight=\"bold\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a0000013", + "metadata": {}, + "source": [ + "## 6. Progression: Epoch 10 → 50 → 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0000014", + "metadata": {}, + "outputs": [], + "source": [ + "check_epochs = [10, 50, 100]\n", + "\n", + "for name in run_names:\n", + " if name not in runs:\n", + " continue\n", + " fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n", + " for ax, ep in zip(axes, check_epochs):\n", + " img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n", + " if img_path.exists():\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " fid = get_fid(runs[name], ep)\n", + " ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n", + " else:\n", + " ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", + " ax.axis(\"off\")\n", + " fig.suptitle(f\"{run_labels[name]} — Training Progression\", fontsize=11, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a0000015", + "metadata": {}, + "source": [ + "## 7. Step-by-step Pairwise Comparisons" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0000016", + "metadata": {}, + "outputs": [], + "source": [ + "transitions = [\n", + " (\"2.1→2.2: BCE→Wasserstein\", \"p2_1_dcgan\", \"p2_2_wgan\"),\n", + " (\"2.2→2.3: +SN+GroupNorm+Attn\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\"),\n", + " (\"2.3→2.4: 64→128 resolution\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"),\n", + "]\n", + "\n", + "for title, name_a, name_b in transitions:\n", + " fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", + " for ax, name in zip(axes, [name_a, name_b]):\n", + " img_path = SAMPLES / name / \"epoch_0100.png\"\n", + " if img_path.exists():\n", + " fid = get_fid(runs[name], 100) if name in runs else None\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " ax.set_title(f\"{run_labels[name]}\\nFID@100 = {fid:.1f}\" if fid else run_labels[name], fontsize=10)\n", + " else:\n", + " ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", + " ax.set_title(run_labels[name], fontsize=10)\n", + " ax.axis(\"off\")\n", + " fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a0000017", + "metadata": {}, + "source": [ + "## 8. Conclusions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0000018", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 70)\n", + "print(\"PHASE 2 — GAN EVOLUTION SUMMARY\")\n", + "print(\"=\" * 70)\n", + "\n", + "for name in run_names:\n", + " if name not in runs:\n", + " print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n", + " continue\n", + " fid100 = get_fid(runs[name], 100)\n", + " fid50 = get_fid(runs[name], 50)\n", + " print(f\"\\n {run_labels[name]}:\")\n", + " print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n", + " print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n", + "\n", + "print(\"\\n\" + \"=\" * 70)\n", + "print(\"Best model for Phase 3/4 comparison: fill in after runs complete.\")\n", + "print(\"=\" * 70)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/generator/notebooks/phase3_analysis.ipynb b/generator/notebooks/phase3_analysis.ipynb new file mode 100644 index 0000000..331bed2 --- /dev/null +++ b/generator/notebooks/phase3_analysis.ipynb @@ -0,0 +1,396 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b0000001", + "metadata": {}, + "source": [ + "# Phase 3 — VAE Evolution Analysis\n", + "\n", + "Traces the VAE improvement story — each step motivated by the failure of the previous:\n", + "\n", + "| Step | Model | Key change | Expected failure |\n", + "|------|-------|------------|------------------|\n", + "| 3.1 | Vanilla VAE (MSE+KL) | Baseline | Blurry samples — MSE minimises pixel average |\n", + "| 3.2 | + Perceptual loss (VGG) | Feature-space reconstruction | Residual texture blur |\n", + "| 3.3 | + PatchGAN (VQGAN-lite) | Local texture adversarial | — |\n", + "\n", + "All runs use H-flip-only augmentation and MTCNN-aligned 64×64 crops.\n", + "FID is computed from prior samples (`z ~ N(0, I) → decode`), same metric as GAN.\n", + "Reconstructions are shown separately to diagnose encoder quality." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0000002", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.image as mpimg\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", + "\n", + "OUTPUTS = Path(\"../outputs\")\n", + "LOGS = OUTPUTS / \"logs\"\n", + "SAMPLES = OUTPUTS / \"samples\"" + ] + }, + { + "cell_type": "markdown", + "id": "b0000003", + "metadata": {}, + "source": [ + "## 1. Load experiment logs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0000004", + "metadata": {}, + "outputs": [], + "source": [ + "run_names = [\"p3_1_vae\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"]\n", + "run_labels = {\n", + " \"p3_1_vae\": \"3.1 VAE (MSE+KL)\",\n", + " \"p3_2_vae_perceptual\": \"3.2 +Perceptual\",\n", + " \"p3_3_vae_patchgan\": \"3.3 +PatchGAN\",\n", + "}\n", + "\n", + "runs = {}\n", + "for name in run_names:\n", + " log_path = LOGS / f\"{name}.json\"\n", + " if log_path.exists():\n", + " with open(log_path) as f:\n", + " runs[name] = json.load(f)\n", + " else:\n", + " print(f\" Missing: {log_path}\")\n", + "\n", + "print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n", + "for name in run_names:\n", + " status = \"✓\" if name in runs else \"✗\"\n", + " print(f\" {status} {name}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b0000005", + "metadata": {}, + "source": [ + "## 2. FID Comparison Table (prior samples)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0000006", + "metadata": {}, + "outputs": [], + "source": [ + "def get_fid(run, epoch):\n", + " fid = run[\"history\"][\"fid\"]\n", + " return fid.get(str(epoch), fid.get(epoch, None))\n", + "\n", + "rows = []\n", + "for name in run_names:\n", + " if name not in runs:\n", + " continue\n", + " r = runs[name]\n", + " cfg = r[\"config\"]\n", + " h = r[\"history\"]\n", + " rows.append({\n", + " \"Step\": run_labels[name],\n", + " \"λ_perc\": cfg.get(\"lambda_perceptual\", 0),\n", + " \"λ_adv\": cfg.get(\"lambda_adversarial\", 0),\n", + " \"β_kl\": cfg.get(\"beta_kl\", 1.0),\n", + " \"FID@25 (prior)\": get_fid(r, 25),\n", + " \"FID@50 (prior)\": get_fid(r, 50),\n", + " \"FID@75 (prior)\": get_fid(r, 75),\n", + " \"FID@100 (prior)\": get_fid(r, 100),\n", + " })\n", + "\n", + "df = pd.DataFrame(rows)\n", + "df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})" + ] + }, + { + "cell_type": "markdown", + "id": "b0000007", + "metadata": {}, + "source": [ + "## 3. FID Curves — Evolution Story" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0000008", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(10, 5))\n", + "colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\"]\n", + "\n", + "for i, name in enumerate(run_names):\n", + " if name not in runs:\n", + " continue\n", + " fid_dict = runs[name][\"history\"][\"fid\"]\n", + " epochs = sorted(int(k) for k in fid_dict)\n", + " fids = [fid_dict[str(e)] for e in epochs]\n", + " label = f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\"\n", + " ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n", + "\n", + "ax.set_xlabel(\"Epoch\")\n", + "ax.set_ylabel(\"FID (lower is better) — prior samples\")\n", + "ax.set_title(\"Phase 3 — FID Curves: VAE → +Perceptual → +PatchGAN\")\n", + "ax.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b0000009", + "metadata": {}, + "source": [ + "## 4. Training Loss Curves" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0000010", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(2, 3, figsize=(16, 9))\n", + "axes = axes.flatten()\n", + "keys = [\"recon_loss\", \"kl_loss\", \"perc_loss\", \"adv_g_loss\", \"adv_d_loss\"]\n", + "titles = [\"MSE Reconstruction\", \"KL Divergence\", \"Perceptual (VGG)\", \"Adv G Loss\", \"Adv D Loss (PatchGAN)\"]\n", + "\n", + "for ax, key, title in zip(axes, keys, titles):\n", + " for i, name in enumerate(run_names):\n", + " if name not in runs:\n", + " continue\n", + " h = runs[name][\"history\"].get(key, [])\n", + " if any(v != 0.0 for v in h):\n", + " ax.plot(range(1, len(h)+1), h, label=run_labels[name],\n", + " color=colors[i], linewidth=1.2, alpha=0.9)\n", + " ax.set_title(title)\n", + " ax.set_xlabel(\"Epoch\")\n", + " ax.legend(fontsize=8)\n", + "\n", + "axes[-1].axis(\"off\") # empty sixth panel\n", + "fig.suptitle(\"Phase 3 — Training Dynamics\", fontsize=13, fontweight=\"bold\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b0000011", + "metadata": {}, + "source": [ + "## 5. Prior Samples — Epoch 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0000012", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", + "\n", + "for idx, name in enumerate(run_names):\n", + " ax = axes[idx]\n", + " img_path = SAMPLES / name / \"epoch_0100.png\"\n", + " if img_path.exists():\n", + " fid = get_fid(runs[name], 100) if name in runs else None\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n", + " else:\n", + " ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", + " ax.set_title(run_labels[name], fontsize=9)\n", + " ax.axis(\"off\")\n", + "\n", + "fig.suptitle(\"Phase 3 — Epoch 100 Prior Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b0000013", + "metadata": {}, + "source": [ + "## 6. Reconstructions — Epoch 100\n", + "\n", + "Left half = real images, right half = reconstructions (interleaved pairs)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0000014", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", + "\n", + "for idx, name in enumerate(run_names):\n", + " ax = axes[idx]\n", + " img_path = SAMPLES / name / \"epoch_0100_recon.png\"\n", + " if img_path.exists():\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " ax.set_title(f\"{run_labels[name]}\\nreal | recon\", fontsize=9)\n", + " else:\n", + " ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", + " ax.set_title(run_labels[name], fontsize=9)\n", + " ax.axis(\"off\")\n", + "\n", + "fig.suptitle(\"Phase 3 — Epoch 100 Reconstructions\", fontsize=12, fontweight=\"bold\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b0000015", + "metadata": {}, + "source": [ + "## 7. Step-by-step Pairwise Comparisons" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0000016", + "metadata": {}, + "outputs": [], + "source": [ + "transitions = [\n", + " (\"3.1→3.2: MSE→+Perceptual\", \"p3_1_vae\", \"p3_2_vae_perceptual\"),\n", + " (\"3.2→3.3: +PatchGAN adversarial\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"),\n", + "]\n", + "\n", + "for title, name_a, name_b in transitions:\n", + " fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n", + " for col, name in enumerate([name_a, name_b]):\n", + " for row, suffix in enumerate([\"\", \"_recon\"]):\n", + " ax = axes[row][col]\n", + " ep = 100\n", + " img_path = SAMPLES / name / f\"epoch_{ep:04d}{suffix}.png\"\n", + " label = run_labels[name]\n", + " kind = \"prior\" if suffix == \"\" else \"recon\"\n", + " if img_path.exists():\n", + " fid = get_fid(runs[name], ep) if (suffix == \"\" and name in runs) else None\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " ax.set_title(f\"{label}\\n{kind}\" + (f\" FID={fid:.1f}\" if fid else \"\"), fontsize=9)\n", + " else:\n", + " ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", + " ax.set_title(f\"{label} ({kind})\", fontsize=9)\n", + " ax.axis(\"off\")\n", + " fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b0000017", + "metadata": {}, + "source": [ + "## 8. Progression: Epoch 10 → 50 → 100 (prior samples)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0000018", + "metadata": {}, + "outputs": [], + "source": [ + "check_epochs = [10, 50, 100]\n", + "\n", + "for name in run_names:\n", + " if name not in runs:\n", + " continue\n", + " fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n", + " for ax, ep in zip(axes, check_epochs):\n", + " img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n", + " if img_path.exists():\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " fid = get_fid(runs[name], ep)\n", + " ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n", + " else:\n", + " ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n", + " transform=ax.transAxes)\n", + " ax.axis(\"off\")\n", + " fig.suptitle(f\"{run_labels[name]} — Prior Sample Progression\", fontsize=11, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b0000019", + "metadata": {}, + "source": [ + "## 9. Conclusions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0000020", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 70)\n", + "print(\"PHASE 3 — VAE EVOLUTION SUMMARY\")\n", + "print(\"=\" * 70)\n", + "\n", + "for name in run_names:\n", + " if name not in runs:\n", + " print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n", + " continue\n", + " h = runs[name][\"history\"]\n", + " fid100 = get_fid(runs[name], 100)\n", + " fid50 = get_fid(runs[name], 50)\n", + " mse50 = h[\"recon_loss\"][49] if len(h[\"recon_loss\"]) > 49 else None\n", + " kl50 = h[\"kl_loss\"][49] if len(h[\"kl_loss\"]) > 49 else None\n", + " print(f\"\\n {run_labels[name]}:\")\n", + " print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n", + " print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n", + " print(f\" MSE@50 = {mse50:.4f}\" if mse50 else \" MSE@50 = ?\")\n", + " print(f\" KL@50 = {kl50:.2f}\" if kl50 else \" KL@50 = ?\")\n", + "\n", + "print(\"\\n\" + \"=\" * 70)\n", + "print(\"Best VAE model for Phase 5 cross-family comparison: fill in after runs.\")\n", + "print(\"=\" * 70)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/generator/notebooks/phase4_analysis.ipynb b/generator/notebooks/phase4_analysis.ipynb new file mode 100644 index 0000000..5d681d2 --- /dev/null +++ b/generator/notebooks/phase4_analysis.ipynb @@ -0,0 +1,403 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c0000001", + "metadata": {}, + "source": [ + "# Phase 4 — DDPM Evolution Analysis\n", + "\n", + "Traces the DDPM improvement story:\n", + "\n", + "| Step | Model | Key change | Expected failure |\n", + "|------|-------|------------|------------------|\n", + "| 4.1 | DDPM linear + ε-pred | Baseline | Noise prediction unstable at very low t (linear schedule over-denoises) |\n", + "| 4.2 | + cosine schedule | Less noise wasted at low timesteps | Residual instability from ε parameterisation |\n", + "| 4.3 | + v-prediction | Numerically stable across full trajectory | Possible underfitting at 64×64 |\n", + "| 4.4 | + wider U-Net (192ch) + 32×32 attention | More capacity and longer-range context | — |\n", + "\n", + "FID is computed via DDIM (100 steps, deterministic) from the EMA model.\n", + "H-flip-only augmentation, MTCNN-aligned 64×64 crops, T=1000." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0000002", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.image as mpimg\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", + "\n", + "OUTPUTS = Path(\"../outputs\")\n", + "LOGS = OUTPUTS / \"logs\"\n", + "SAMPLES = OUTPUTS / \"samples\"" + ] + }, + { + "cell_type": "markdown", + "id": "c0000003", + "metadata": {}, + "source": [ + "## 1. Load experiment logs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0000004", + "metadata": {}, + "outputs": [], + "source": [ + "run_names = [\"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"]\n", + "run_labels = {\n", + " \"p4_1_ddpm_linear\": \"4.1 linear + ε\",\n", + " \"p4_2_ddpm_cosine\": \"4.2 cosine + ε\",\n", + " \"p4_3_ddpm_vpred\": \"4.3 cosine + v\",\n", + " \"p4_4_ddpm_wider\": \"4.4 wider + 32×32 attn\",\n", + "}\n", + "\n", + "runs = {}\n", + "for name in run_names:\n", + " log_path = LOGS / f\"{name}.json\"\n", + " if log_path.exists():\n", + " with open(log_path) as f:\n", + " runs[name] = json.load(f)\n", + " else:\n", + " print(f\" Missing: {log_path}\")\n", + "\n", + "print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n", + "for name in run_names:\n", + " print(f\" {'✓' if name in runs else '✗'} {name}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c0000005", + "metadata": {}, + "source": [ + "## 2. FID Comparison Table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0000006", + "metadata": {}, + "outputs": [], + "source": [ + "def get_fid(run, epoch):\n", + " fid = run[\"history\"][\"fid\"]\n", + " return fid.get(str(epoch), fid.get(epoch, None))\n", + "\n", + "rows = []\n", + "for name in run_names:\n", + " if name not in runs:\n", + " continue\n", + " r = runs[name]\n", + " cfg = r[\"config\"]\n", + " rows.append({\n", + " \"Step\": run_labels[name],\n", + " \"Schedule\": cfg.get(\"noise_schedule\"),\n", + " \"Pred\": cfg.get(\"pred_type\"),\n", + " \"base_ch\": cfg.get(\"base_ch\"),\n", + " \"FID@25\": get_fid(r, 25),\n", + " \"FID@50\": get_fid(r, 50),\n", + " \"FID@75\": get_fid(r, 75),\n", + " \"FID@100\": get_fid(r, 100),\n", + " })\n", + "\n", + "df = pd.DataFrame(rows)\n", + "df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})" + ] + }, + { + "cell_type": "markdown", + "id": "c0000007", + "metadata": {}, + "source": [ + "## 3. FID Curves — Evolution Story" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0000008", + "metadata": {}, + "outputs": [], + "source": [ + "colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n", + "\n", + "fig, ax = plt.subplots(figsize=(11, 5))\n", + "for i, name in enumerate(run_names):\n", + " if name not in runs:\n", + " continue\n", + " fid_dict = runs[name][\"history\"][\"fid\"]\n", + " epochs = sorted(int(k) for k in fid_dict)\n", + " fids = [fid_dict[str(e)] for e in epochs]\n", + " fid100 = fid_dict.get(\"100\", \"?\")\n", + " label = f\"{run_labels[name]} (FID@100={fid100:.1f})\" if isinstance(fid100, float) else run_labels[name]\n", + " ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n", + "\n", + "ax.set_xlabel(\"Epoch\")\n", + "ax.set_ylabel(\"FID (DDIM 100 steps, lower is better)\")\n", + "ax.set_title(\"Phase 4 — FID Curves: linear·ε → cosine·ε → cosine·v → wider\")\n", + "ax.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c0000009", + "metadata": {}, + "source": [ + "## 4. Training Loss Curves (MSE on predicted target)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0000010", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(11, 4))\n", + "for i, name in enumerate(run_names):\n", + " if name not in runs:\n", + " continue\n", + " losses = runs[name][\"history\"][\"loss\"]\n", + " ax.plot(range(1, len(losses)+1), losses, label=run_labels[name],\n", + " color=colors[i], linewidth=1.2, alpha=0.9)\n", + "\n", + "ax.set_xlabel(\"Epoch\")\n", + "ax.set_ylabel(\"MSE (noise / v prediction loss)\")\n", + "ax.set_title(\"Phase 4 — Training Loss\")\n", + "ax.legend(fontsize=9)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c0000011", + "metadata": {}, + "source": [ + "## 5. Sample Grids — Epoch 100 (DDIM 50 steps)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0000012", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n", + "\n", + "for idx, name in enumerate(run_names):\n", + " ax = axes[idx]\n", + " img_path = SAMPLES / name / \"epoch_0100.png\"\n", + " if img_path.exists():\n", + " fid = get_fid(runs[name], 100) if name in runs else None\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n", + " else:\n", + " ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", + " ax.set_title(run_labels[name], fontsize=8)\n", + " ax.axis(\"off\")\n", + "\n", + "fig.suptitle(\"Phase 4 — Epoch 100 DDIM Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c0000013", + "metadata": {}, + "source": [ + "## 6. Step-by-step Pairwise Comparisons" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0000014", + "metadata": {}, + "outputs": [], + "source": [ + "transitions = [\n", + " (\"4.1→4.2: linear→cosine schedule\", \"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\"),\n", + " (\"4.2→4.3: ε-pred→v-pred\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\"),\n", + " (\"4.3→4.4: 128ch→192ch + 32×32 attn\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"),\n", + "]\n", + "\n", + "for title, name_a, name_b in transitions:\n", + " fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", + " for ax, name in zip(axes, [name_a, name_b]):\n", + " img_path = SAMPLES / name / \"epoch_0100.png\"\n", + " if img_path.exists():\n", + " fid = get_fid(runs[name], 100) if name in runs else None\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n", + " else:\n", + " ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n", + " ax.set_title(run_labels[name], fontsize=9)\n", + " ax.axis(\"off\")\n", + " fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c0000015", + "metadata": {}, + "source": [ + "## 7. Progression: Epoch 10 → 50 → 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0000016", + "metadata": {}, + "outputs": [], + "source": [ + "check_epochs = [10, 50, 100]\n", + "\n", + "for name in run_names:\n", + " if name not in runs:\n", + " continue\n", + " fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n", + " for ax, ep in zip(axes, check_epochs):\n", + " img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n", + " if img_path.exists():\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " fid = get_fid(runs[name], ep)\n", + " ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n", + " else:\n", + " ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n", + " transform=ax.transAxes)\n", + " ax.axis(\"off\")\n", + " fig.suptitle(f\"{run_labels[name]} — Progression\", fontsize=11, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c0000017", + "metadata": {}, + "source": [ + "## 8. Noise Schedule Visualisation\n", + "\n", + "Illustrates why cosine outperforms linear: the linear schedule allocates many timesteps near t=T where the image is already near-pure noise, wasting model capacity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0000018", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "\n", + "T = 1000\n", + "t = np.arange(T)\n", + "\n", + "# Linear betas\n", + "betas_lin = np.linspace(1e-4, 0.02, T)\n", + "ab_lin = np.cumprod(1 - betas_lin)\n", + "\n", + "# Cosine betas\n", + "s = 0.008\n", + "f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2\n", + "f = f / f[0]\n", + "betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)\n", + "ab_cos = np.cumprod(1 - betas_cos)\n", + "\n", + "fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n", + "\n", + "axes[0].plot(t, ab_lin, label=\"linear\", color=\"#5B8DB8\", linewidth=2)\n", + "axes[0].plot(t, ab_cos, label=\"cosine\", color=\"#E8705A\", linewidth=2)\n", + "axes[0].set_xlabel(\"Timestep t\")\n", + "axes[0].set_ylabel(\"ᾱ_t (signal fraction)\")\n", + "axes[0].set_title(\"ᾱ_t vs t — cosine stays informative longer\")\n", + "axes[0].legend()\n", + "\n", + "axes[1].plot(t, np.sqrt(ab_lin), label=\"linear √ᾱ\", color=\"#5B8DB8\", linewidth=2)\n", + "axes[1].plot(t, np.sqrt(ab_cos), label=\"cosine √ᾱ\", color=\"#E8705A\", linewidth=2)\n", + "axes[1].set_xlabel(\"Timestep t\")\n", + "axes[1].set_ylabel(\"√ᾱ_t (signal amplitude)\")\n", + "axes[1].set_title(\"Signal amplitude — cosine is more uniform\")\n", + "axes[1].legend()\n", + "\n", + "plt.suptitle(\"Noise Schedule Comparison\", fontsize=12, fontweight=\"bold\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c0000019", + "metadata": {}, + "source": [ + "## 9. Conclusions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0000020", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 70)\n", + "print(\"PHASE 4 — DDPM EVOLUTION SUMMARY\")\n", + "print(\"=\" * 70)\n", + "\n", + "for name in run_names:\n", + " if name not in runs:\n", + " print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n", + " continue\n", + " fid100 = get_fid(runs[name], 100)\n", + " fid50 = get_fid(runs[name], 50)\n", + " h = runs[name][\"history\"]\n", + " loss50 = h[\"loss\"][49] if len(h[\"loss\"]) > 49 else None\n", + " print(f\"\\n {run_labels[name]}:\")\n", + " print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n", + " print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n", + " print(f\" Loss@50 = {loss50:.5f}\" if loss50 else \" Loss@50 = ?\")\n", + "\n", + "print(\"\\n\" + \"=\" * 70)\n", + "print(\"Best DDPM model for Phase 5 comparison: fill in after runs complete.\")\n", + "print(\"=\" * 70)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/generator/notebooks/phase5_analysis.ipynb b/generator/notebooks/phase5_analysis.ipynb new file mode 100644 index 0000000..b526d56 --- /dev/null +++ b/generator/notebooks/phase5_analysis.ipynb @@ -0,0 +1,669 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Phase 5 — Cross-Family Comparison\n", + "\n", + "Best-of-each finalist retrained for **200 epochs** under identical data conditions.\n", + "\n", + "| Family | Config | Resolution | Key design |\n", + "|--------|--------|-----------|------------|\n", + "| GAN | `p5_gan` | 128×128 | WGAN-GP + SpectralNorm + GroupNorm + Self-Attention |\n", + "| VAE | `p5_vae` | 64×64 | Convolutional VAE + VGG perceptual + PatchGAN |\n", + "| DDPM | `p5_ddpm` | 64×64 | Wider U-Net (192ch) + cosine schedule + v-prediction |\n", + "\n", + "> **Resolution note**: GAN runs at 128×128 (best architecture from Phase 2.4), while VAE and DDPM run at 64×64. FID is measured at each model's native resolution against real images at that resolution, so scores are not directly numerically comparable across families — they are indicators of within-family improvement." + ], + "id": "d0000001" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import json\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.image as mpimg\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n", + "\n", + "OUTPUTS = Path(\"../outputs\")\n", + "LOGS = OUTPUTS / \"logs\"\n", + "SAMPLES = OUTPUTS / \"samples\"\n", + "\n", + "# Phase 5 finalists and their best phase-2/3/4 counterparts (100-epoch comparison)\n", + "FAMILIES = {\n", + " \"GAN\": {\"p5\": \"p5_gan\", \"p4ep\": \"p2_4_wgan_sn_attn_128\", \"label\": \"WGAN-GP+SN+Attn\", \"color\": \"#5B8DB8\"},\n", + " \"VAE\": {\"p5\": \"p5_vae\", \"p4ep\": \"p3_3_vae_patchgan\", \"label\": \"VAE+Perc+PatchGAN\",\"color\": \"#E8705A\"},\n", + " \"DDPM\": {\"p5\": \"p5_ddpm\", \"p4ep\": \"p4_4_ddpm_wider\", \"label\": \"DDPM wider 192ch\", \"color\": \"#6ABF69\"},\n", + "}" + ], + "execution_count": null, + "outputs": [], + "id": "d0000002" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load logs" + ], + "id": "d0000003" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "def load_log(run_name):\n", + " p = LOGS / f\"{run_name}.json\"\n", + " if p.exists():\n", + " with open(p) as f:\n", + " return json.load(f)\n", + " return None\n", + "\n", + "def get_fid(log, epoch):\n", + " if log is None:\n", + " return None\n", + " fid = log[\"history\"][\"fid\"]\n", + " return fid.get(str(epoch), fid.get(epoch, None))\n", + "\n", + "logs_p5 = {fam: load_log(info[\"p5\"]) for fam, info in FAMILIES.items()}\n", + "logs_p4 = {fam: load_log(info[\"p4ep\"]) for fam, info in FAMILIES.items()}\n", + "\n", + "for fam in FAMILIES:\n", + " p5_ok = \"✓\" if logs_p5[fam] else \"✗\"\n", + " p4_ok = \"✓\" if logs_p4[fam] else \"✗\"\n", + " print(f\" {fam}: 200ep={p5_ok} 100ep={p4_ok}\")" + ], + "execution_count": null, + "outputs": [], + "id": "d0000004" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Quantitative Summary Table" + ], + "id": "d0000005" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "rows = []\n", + "for fam, info in FAMILIES.items():\n", + " log = logs_p5[fam]\n", + " train_time = log.get(\"history\", {}).get(\"train_time_s\") if log else None\n", + " train_min = f\"{train_time / 60:.1f}\" if train_time else \"?\"\n", + " rows.append({\n", + " \"Family\": fam,\n", + " \"Model\": info[\"label\"],\n", + " \"Res\": log.get(\"config\", log).get(\"image_size\", \"?\") if log else \"?\",\n", + " \"Params\": log.get(\"n_params\") if log else None,\n", + " \"Train (min)\": train_min,\n", + " \"FID@100\": get_fid(log, 100),\n", + " \"FID@150\": get_fid(log, 150),\n", + " \"FID@200\": get_fid(log, 200),\n", + " # IS and LPIPS filled in by Section 6\n", + " \"IS ↑\": None,\n", + " \"LPIPS ↑\": None,\n", + " })\n", + "\n", + "df = pd.DataFrame(rows).set_index(\"Family\")\n", + "\n", + "\n", + "def fmt_params(v):\n", + " if v is None:\n", + " return \"?\"\n", + " if v >= 1_000_000:\n", + " return f\"{v / 1_000_000:.1f}M\"\n", + " if v >= 1_000:\n", + " return f\"{v / 1_000:.0f}K\"\n", + " return str(v)\n", + "\n", + "\n", + "df_display = df.copy()\n", + "df_display[\"Params\"] = df_display[\"Params\"].apply(fmt_params)\n", + "df_display.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df_display})" + ], + "execution_count": null, + "outputs": [], + "id": "d0000006" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. FID Curves — All Three Families" + ], + "id": "d0000007" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n", + "\n", + "for fam, info in FAMILIES.items():\n", + " c = info[\"color\"]\n", + " # Phase 5 (200ep) — solid\n", + " log = logs_p5[fam]\n", + " if log:\n", + " fid_dict = log[\"history\"][\"fid\"]\n", + " eps = sorted(int(k) for k in fid_dict)\n", + " fids = [fid_dict[str(e)] for e in eps]\n", + " axes[0].plot(eps, fids, \"-o\", color=c, linewidth=2, markersize=6,\n", + " label=f\"{fam} 200ep (FID@200={fid_dict.get('200','?'):.1f})\" if isinstance(fid_dict.get('200'), float) else f\"{fam} 200ep\")\n", + "\n", + " # Phase 4/3/2 best (100ep) — dashed, same colour\n", + " log4 = logs_p4[fam]\n", + " if log4:\n", + " fid4 = log4[\"history\"][\"fid\"]\n", + " eps4 = sorted(int(k) for k in fid4)\n", + " fids4 = [fid4[str(e)] for e in eps4]\n", + " axes[0].plot(eps4, fids4, \"--\", color=c, linewidth=1.2, alpha=0.55,\n", + " label=f\"{fam} 100ep\")\n", + "\n", + "axes[0].set_xlabel(\"Epoch\")\n", + "axes[0].set_ylabel(\"FID (lower is better)\")\n", + "axes[0].set_title(\"FID Curves — Phase 5 (solid) vs best 100-ep (dashed)\")\n", + "axes[0].legend(fontsize=8)\n", + "\n", + "# Bar chart: FID at 100 vs 200 epochs per family\n", + "fams = list(FAMILIES.keys())\n", + "fid100 = [get_fid(logs_p5[f], 100) for f in fams]\n", + "fid200 = [get_fid(logs_p5[f], 200) for f in fams]\n", + "x = np.arange(len(fams))\n", + "w = 0.35\n", + "bars1 = axes[1].bar(x - w/2, [v or 0 for v in fid100], w, label=\"FID@100\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=0.6)\n", + "bars2 = axes[1].bar(x + w/2, [v or 0 for v in fid200], w, label=\"FID@200\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=1.0)\n", + "axes[1].set_xticks(x); axes[1].set_xticklabels(fams)\n", + "axes[1].set_ylabel(\"FID\"); axes[1].set_title(\"FID: 100ep vs 200ep per family\")\n", + "axes[1].legend()\n", + "for bar in list(bars1) + list(bars2):\n", + " h = bar.get_height()\n", + " if h > 0:\n", + " axes[1].text(bar.get_x() + bar.get_width()/2, h + 1, f\"{h:.0f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ], + "execution_count": null, + "outputs": [], + "id": "d0000008" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Sample Grids — Epoch 200" + ], + "id": "d0000009" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(16, 6))\n", + "\n", + "for idx, (fam, info) in enumerate(FAMILIES.items()):\n", + " ax = axes[idx]\n", + " img_path = SAMPLES / info[\"p5\"] / \"epoch_0200.png\"\n", + " log = logs_p5[fam]\n", + " fid200 = get_fid(log, 200)\n", + " if img_path.exists():\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " title = f\"{fam}: {info['label']}\\n\"\n", + " if fid200:\n", + " res = log[\"config\"].get(\"image_size\", \"?\")\n", + " title += f\"FID@200={fid200:.1f} ({res}×{res})\"\n", + " ax.set_title(title, fontsize=9)\n", + " else:\n", + " ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes, fontsize=12)\n", + " ax.set_title(f\"{fam}: {info['label']}\", fontsize=9)\n", + " ax.axis(\"off\")\n", + "\n", + "fig.suptitle(\"Phase 5 — Epoch 200 Sample Grids (4×4, prior samples)\", fontsize=13, fontweight=\"bold\")\n", + "plt.tight_layout()\n", + "plt.show()" + ], + "execution_count": null, + "outputs": [], + "id": "d0000010" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Training Progression — Epoch 10 → 50 → 100 → 200" + ], + "id": "d0000011" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "check_epochs = [10, 50, 100, 200]\n", + "\n", + "for fam, info in FAMILIES.items():\n", + " run = info[\"p5\"]\n", + " log = logs_p5[fam]\n", + " if log is None:\n", + " print(f\"{fam}: not yet run\")\n", + " continue\n", + "\n", + " fig, axes = plt.subplots(1, len(check_epochs), figsize=(16, 4))\n", + " for ax, ep in zip(axes, check_epochs):\n", + " img_path = SAMPLES / run / f\"epoch_{ep:04d}.png\"\n", + " if img_path.exists():\n", + " ax.imshow(mpimg.imread(str(img_path)))\n", + " fid = get_fid(log, ep)\n", + " ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n", + " else:\n", + " ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n", + " transform=ax.transAxes)\n", + " ax.axis(\"off\")\n", + " fig.suptitle(f\"{fam} ({info['label']}) — 200-epoch progression\", fontsize=11, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()" + ], + "execution_count": null, + "outputs": [], + "id": "d0000012" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Extended Metrics — IS and LPIPS Diversity\n", + "\n", + "Requires loading trained model weights. Run after all phase 5 models have finished.\n", + "Generates 5 000 samples per model and computes IS and LPIPS over 200 random pairs." + ], + "id": "d0000013" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import sys\n", + "sys.path.insert(0, \"..\")\n", + "\n", + "import torch\n", + "from src.utils import load_config\n", + "from src.models import get_model\n", + "from src.training.metrics import compute_is, compute_lpips_diversity\n", + "from src.training.diffusion import cosine_betas, make_alpha_bars, ddim_sample\n", + "from src.training.ema import EMA\n", + "\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "N_SAMPLE = 5_000\n", + "\n", + "def load_ema_model(run_name, config_path):\n", + " \"\"\"Load the best EMA weights for a given phase-5 run.\"\"\"\n", + " cfg = load_config(str(Path(\"../configs/phase5\") / config_path))\n", + " model_obj, kind = get_model(cfg)\n", + " ema_path = Path(\"../outputs/models\") / f\"{run_name}_best_ema.pt\"\n", + " if not ema_path.exists():\n", + " ema_path = Path(\"../outputs/models\") / f\"{run_name}_final_ema.pt\"\n", + " if isinstance(model_obj, tuple):\n", + " model = model_obj[0] # generator for GAN\n", + " else:\n", + " model = model_obj\n", + " model.load_state_dict(torch.load(ema_path, map_location=DEVICE))\n", + " return model.to(DEVICE).eval(), cfg, kind\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def generate_samples(run_name, config_path, n=N_SAMPLE):\n", + " model, cfg, kind = load_ema_model(run_name, config_path)\n", + " image_size = cfg.get(\"image_size\", 64)\n", + "\n", + " if kind == \"wgan\":\n", + " latent_dim = cfg.get(\"latent_dim\", 128)\n", + " imgs = torch.cat([\n", + " model(torch.randn(min(64, n - i), latent_dim, 1, 1, device=DEVICE))\n", + " for i in range(0, n, 64)\n", + " ])[:n].cpu()\n", + "\n", + " elif kind == \"vae\":\n", + " imgs = torch.cat([\n", + " model.sample(min(64, n - i), DEVICE)\n", + " for i in range(0, n, 64)\n", + " ])[:n].cpu()\n", + "\n", + " elif kind == \"ddpm\":\n", + " schedule = cfg.get(\"noise_schedule\", \"cosine\")\n", + " pred_type = cfg.get(\"pred_type\", \"v\")\n", + " T = cfg.get(\"T\", 1000)\n", + " from src.training.diffusion import cosine_betas, linear_betas, make_alpha_bars, ddim_sample\n", + " betas = (cosine_betas(T) if schedule == \"cosine\" else linear_betas(T)).to(DEVICE)\n", + " ab = make_alpha_bars(betas)\n", + " imgs = ddim_sample(model, n, image_size, ab, n_steps=100,\n", + " pred_type=pred_type, device=DEVICE, batch_size=32)\n", + " return imgs\n", + "\n", + "\n", + "print(\"Run this cell once all phase-5 models have completed.\")\n", + "print(\"Expected to take ~5–10 minutes per family on an RTX 3090.\")" + ], + "execution_count": null, + "outputs": [], + "id": "d0000014" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# ── Compute IS and LPIPS per family ──────────────────────────────────────────\n", + "# Uncomment when models are ready.\n", + "\n", + "extended_metrics = {}\n", + "\n", + "for fam, info in FAMILIES.items():\n", + " run = info[\"p5\"]\n", + " config = f\"{run}.json\"\n", + " try:\n", + " print(f\"\\n{fam}: generating {N_SAMPLE} samples...\")\n", + " imgs = generate_samples(run, config)\n", + "\n", + " print(f\" Computing IS...\")\n", + " is_mean, is_std = compute_is(imgs, device=DEVICE)\n", + "\n", + " print(f\" Computing LPIPS diversity...\")\n", + " lpips = compute_lpips_diversity(imgs, n_pairs=200, device=DEVICE)\n", + "\n", + " extended_metrics[fam] = {\"IS_mean\": is_mean, \"IS_std\": is_std, \"LPIPS\": lpips}\n", + " print(f\" IS = {is_mean:.2f} ± {is_std:.2f} LPIPS = {lpips:.4f}\")\n", + "\n", + " except Exception as e:\n", + " print(f\" Skipped ({e})\")\n", + " extended_metrics[fam] = {}\n", + "\n", + "# Merge with FID table\n", + "for fam in FAMILIES:\n", + " em = extended_metrics.get(fam, {})\n", + " idx = list(FAMILIES.keys()).index(fam)\n", + " df.loc[fam, \"IS ↑\"] = f\"{em['IS_mean']:.2f}±{em['IS_std']:.2f}\" if \"IS_mean\" in em else \"—\"\n", + " df.loc[fam, \"LPIPS ↑\"] = f\"{em['LPIPS']:.4f}\" if \"LPIPS\" in em else \"—\"\n", + "\n", + "df.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df})" + ], + "execution_count": null, + "outputs": [], + "id": "d0000015" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Latent Interpolation — GAN and VAE\n", + "\n", + "Smooth interpolation between two latent codes reveals whether the generator has learned a\n", + "continuous manifold. DDPM has no encoder, so interpolation is done by different noise seeds." + ], + "id": "d0000016" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Spherical linear interpolation (slerp)\n", + "def slerp(z1, z2, t):\n", + " z1_n = z1 / z1.norm()\n", + " z2_n = z2 / z2.norm()\n", + " omega = torch.acos((z1_n * z2_n).sum().clamp(-1, 1))\n", + " if omega.abs() < 1e-6:\n", + " return (1 - t) * z1 + t * z2\n", + " return (torch.sin((1-t)*omega)/torch.sin(omega)) * z1 + \\\n", + " (torch.sin(t*omega)/torch.sin(omega)) * z2\n", + "\n", + "\n", + "def gan_interpolation(model, latent_dim, n_steps=10, device=DEVICE):\n", + " z1 = torch.randn(1, latent_dim, 1, 1, device=device)\n", + " z2 = torch.randn(1, latent_dim, 1, 1, device=device)\n", + " alphas = torch.linspace(0, 1, n_steps)\n", + " imgs = []\n", + " with torch.no_grad():\n", + " for a in alphas:\n", + " z = slerp(z1.flatten(), z2.flatten(), a.item()).view_as(z1)\n", + " imgs.append(model(z).cpu())\n", + " return torch.cat(imgs)\n", + "\n", + "\n", + "def vae_interpolation(model, real_imgs, n_steps=10, device=DEVICE):\n", + " \"\"\"Encode two real images, interpolate in latent space, decode.\"\"\"\n", + " img1, img2 = real_imgs[:1].to(device), real_imgs[1:2].to(device)\n", + " with torch.no_grad():\n", + " mu1, _ = model.encode(img1)\n", + " mu2, _ = model.encode(img2)\n", + " alphas = torch.linspace(0, 1, n_steps, device=device)\n", + " imgs = [model.decode((1-a)*mu1 + a*mu2).cpu() for a in alphas]\n", + " return torch.cat(imgs)\n", + "\n", + "\n", + "print(\"Interpolation helpers defined. Run cells below after loading models.\")" + ], + "execution_count": null, + "outputs": [], + "id": "d0000017" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# ── GAN interpolation ─────────────────────────────────────────────────────────\n", + "try:\n", + " gan_model, gan_cfg, _ = load_ema_model(\"p5_gan\", \"p5_gan.json\")\n", + " latent_dim = gan_cfg.get(\"latent_dim\", 128)\n", + " interp_imgs = gan_interpolation(gan_model, latent_dim, n_steps=10)\n", + " interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n", + "\n", + " fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n", + " for ax, img in zip(axes, interp_imgs):\n", + " ax.imshow(img.permute(1, 2, 0).numpy())\n", + " ax.axis(\"off\")\n", + " fig.suptitle(\"GAN — Slerp latent interpolation (z₁ → z₂)\", fontsize=11, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()\n", + "except Exception as e:\n", + " print(f\"GAN interpolation: {e}\")" + ], + "execution_count": null, + "outputs": [], + "id": "d0000018" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# ── VAE interpolation ─────────────────────────────────────────────────────────\n", + "try:\n", + " from src.data import GeneratorDataset, get_transform\n", + "\n", + " vae_model, vae_cfg, _ = load_ema_model(\"p5_vae\", \"p5_vae.json\")\n", + " ds = GeneratorDataset(\"../../\" + vae_cfg[\"data_dir\"],\n", + " sources=vae_cfg.get(\"sources\", [\"wiki\"]),\n", + " transform=get_transform(vae_cfg[\"image_size\"], augment=False))\n", + " sample_real = torch.stack([ds[i] for i in range(2)])\n", + "\n", + " interp_imgs = vae_interpolation(vae_model, sample_real, n_steps=10)\n", + " interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n", + "\n", + " fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n", + " for ax, img in zip(axes, interp_imgs):\n", + " ax.imshow(img.permute(1, 2, 0).numpy())\n", + " ax.axis(\"off\")\n", + " fig.suptitle(\"VAE — μ-space linear interpolation (image₁ → image₂)\", fontsize=11, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.show()\n", + "except Exception as e:\n", + " print(f\"VAE interpolation: {e}\")" + ], + "execution_count": null, + "outputs": [], + "id": "d0000019" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Failure Mode Analysis\n", + "\n", + "Identify the worst-generated images per model: those farthest from their nearest real neighbour in pixel space, or simply those with highest reconstruction error (VAE) or highest DDPM loss." + ], + "id": "d0000020" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# For GAN and DDPM: generate 256 images, pick the 8 with lowest mean activation\n", + "# (a proxy for less-coherent images — very rough)\n", + "\n", + "def worst_samples(imgs, n=8):\n", + " \"\"\"Heuristic: pick samples with lowest per-pixel mean (often darker / less structured).\"\"\"\n", + " scores = imgs.mean(dim=[1, 2, 3]) # mean brightness per image\n", + " worst_idx = scores.argsort()[:n]\n", + " return imgs[worst_idx]\n", + "\n", + "print(\"Failure mode analysis requires generated samples — run after model loading (Section 6).\")" + ], + "execution_count": null, + "outputs": [], + "id": "d0000021" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Optionally run with already-generated `imgs` from Section 6:\n", + "# worst = worst_samples(imgs.cpu())\n", + "# worst = (worst.clamp(-1,1) + 1) / 2\n", + "# fig, axes = plt.subplots(1, 8, figsize=(18, 2.5))\n", + "# for ax, img in zip(axes, worst):\n", + "# ax.imshow(img.permute(1,2,0)); ax.axis('off')\n", + "# plt.suptitle('Failure modes (lowest-brightness samples)', fontsize=11)\n", + "# plt.tight_layout(); plt.show()" + ], + "execution_count": null, + "outputs": [], + "id": "d0000022" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. Training Loss Overview (all families)" + ], + "id": "d0000023" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(18, 4))\n", + "\n", + "for ax, (fam, info) in zip(axes, FAMILIES.items()):\n", + " log = logs_p5[fam]\n", + " if log is None:\n", + " ax.set_title(f\"{fam} (not yet run)\"); ax.axis(\"off\"); continue\n", + " h = log[\"history\"]\n", + " c = info[\"color\"]\n", + "\n", + " if fam == \"GAN\":\n", + " ax.plot(h[\"g_loss\"], label=\"G loss\", color=c, linewidth=1.2)\n", + " ax.plot(h[\"w_dist\"], label=\"W-dist\", color=c, linewidth=1.2, linestyle=\"--\")\n", + " ax.set_ylabel(\"Loss / W-distance\")\n", + " elif fam == \"VAE\":\n", + " ax.plot(h[\"recon_loss\"], label=\"MSE\", color=c)\n", + " ax2 = ax.twinx()\n", + " ax2.plot(h[\"kl_loss\"], label=\"KL\", color=\"grey\", linestyle=\"--\")\n", + " ax2.set_ylabel(\"KL\", color=\"grey\")\n", + " ax.set_ylabel(\"MSE\")\n", + " elif fam == \"DDPM\":\n", + " ax.plot(h[\"loss\"], label=\"MSE (v-pred)\", color=c)\n", + " ax.set_ylabel(\"MSE loss\")\n", + "\n", + " ax.set_xlabel(\"Epoch\")\n", + " ax.set_title(f\"{fam}: {info['label']}\")\n", + " ax.legend(fontsize=8)\n", + "\n", + "fig.suptitle(\"Phase 5 — Training Dynamics (200 epochs)\", fontsize=12, fontweight=\"bold\")\n", + "plt.tight_layout()\n", + "plt.show()" + ], + "execution_count": null, + "outputs": [], + "id": "d0000024" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 10. Conclusions" + ], + "id": "d0000025" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "print(\"=\" * 72)\n", + "print(\"PHASE 5 — CROSS-FAMILY COMPARISON (200 epochs)\")\n", + "print(\"=\" * 72)\n", + "\n", + "for fam, info in FAMILIES.items():\n", + " log = logs_p5[fam]\n", + " em = extended_metrics.get(fam, {})\n", + " print(f\"\\n ── {fam}: {info['label']} ──\")\n", + " if log:\n", + " res = log.get(\"config\", log).get(\"image_size\", \"?\")\n", + " print(f\" Resolution : {res}×{res}\")\n", + " n_p = log.get(\"n_params\")\n", + " print(f\" Params : {n_p:,}\" if n_p else \" Params : ?\")\n", + " tt = log.get(\"history\", {}).get(\"train_time_s\")\n", + " print(f\" Train time : {tt / 60:.1f} min\" if tt else \" Train time : ?\")\n", + " for ep in (100, 150, 200):\n", + " fid = get_fid(log, ep)\n", + " print(f\" FID@{ep:<3} : {fid:.1f}\" if fid else f\" FID@{ep:<3} : ?\")\n", + " if em:\n", + " print(f\" IS : {em.get('IS_mean','?'):.2f} ± {em.get('IS_std','?'):.2f}\" if 'IS_mean' in em else \" IS : ?\")\n", + " print(f\" LPIPS div : {em.get('LPIPS','?'):.4f}\" if 'LPIPS' in em else \" LPIPS div : ?\")\n", + " else:\n", + " print(\" IS / LPIPS : (run Section 6 to compute)\")\n", + "\n", + "print(\"\\n\" + \"=\" * 72)\n", + "print(\"Narrative to fill in after results:\")\n", + "print(\" - Which family achieves best FID?\")\n", + "print(\" - GAN: fast convergence but mode collapse risk?\")\n", + "print(\" - VAE: blurry priors improved by perceptual+adversarial loss?\")\n", + "print(\" - DDPM: highest quality but slowest inference (100 DDIM steps)?\")\n", + "print(\"=\" * 72)" + ], + "execution_count": null, + "outputs": [], + "id": "d0000026" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/generator/outputs/pipeline/20260430T012157.632405+0000.json b/generator/outputs/pipeline/20260430T012157.632405+0000.json deleted file mode 100644 index c1deb81..0000000 --- a/generator/outputs/pipeline/20260430T012157.632405+0000.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "created_at": "2026-04-30T01:21:57.632405+00:00", - "config_paths": [ - "generator/configs/phase1/_base_dcgan.json", - "generator/configs/phase1/p1a_dcgan_128.json", - "generator/configs/phase1/p1a_dcgan_64.json", - "generator/configs/phase1/p1b_dcgan_aligned.json", - "generator/configs/phase1/p1b_dcgan_full.json", - "generator/configs/phase1/p1c_dcgan_full_aug.json", - "generator/configs/phase1/p1c_dcgan_hflip.json", - "generator/configs/phase1/p1d_dcgan_combined.json" - ], - "instance_id": 35870989, - "offer_id": 26314481, - "ssh_host": "ssh4.vast.ai", - "ssh_port": 30988, - "status": "failed", - "remote_workspace": "/workspace/DRL_PROJ" -} \ No newline at end of file diff --git a/generator/outputs/pipeline/20260430T023357.318888+0000.json b/generator/outputs/pipeline/20260430T023357.318888+0000.json deleted file mode 100644 index 84ca1b8..0000000 --- a/generator/outputs/pipeline/20260430T023357.318888+0000.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "created_at": "2026-04-30T02:33:57.318888+00:00", - "config_paths": [ - "generator/configs/phase1/p1a_dcgan_128.json", - "generator/configs/phase1/p1a_dcgan_64.json", - "generator/configs/phase1/p1b_dcgan_aligned.json", - "generator/configs/phase1/p1b_dcgan_full.json", - "generator/configs/phase1/p1c_dcgan_full_aug.json", - "generator/configs/phase1/p1c_dcgan_hflip.json", - "generator/configs/phase1/p1d_dcgan_combined.json" - ], - "instance_id": 35873942, - "offer_id": 35472561, - "ssh_host": "ssh4.vast.ai", - "ssh_port": 33942, - "status": "cancelled", - "remote_workspace": "/workspace/DRL_PROJ" -} \ No newline at end of file diff --git a/generator/run.py b/generator/run.py index 2586f88..f11727e 100644 --- a/generator/run.py +++ b/generator/run.py @@ -32,7 +32,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs" import torch from src.data import GeneratorDataset, get_transform from src.models import get_model - from src.training import train_dcgan + from src.training import train_dcgan, train_wgan, train_vae, train_ddpm from src.utils import load_config cfg = load_config(config_path) @@ -50,6 +50,13 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs" model, kind = get_model(cfg) + # Count total trainable parameters + if isinstance(model, tuple): + n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad) + else: + n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Trainable params: {n_params:,}") + augment = cfg.get("augment", True) transform = get_transform(cfg.get("image_size", 128), augment=augment) dataset = GeneratorDataset( @@ -66,13 +73,31 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs" generator, discriminator, dataset, cfg, save_dir=models_dir, run_name=run_name, device=device, ) + elif kind == "wgan": + generator, critic = model + history = train_wgan( + generator, critic, dataset, cfg, + save_dir=models_dir, run_name=run_name, device=device, + ) + elif kind == "vae": + history = train_vae( + model, dataset, cfg, + save_dir=models_dir, run_name=run_name, device=device, + ) + elif kind == "ddpm": + history = train_ddpm( + model, dataset, cfg, + save_dir=models_dir, run_name=run_name, device=device, + ) else: raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase") logs_dir.mkdir(parents=True, exist_ok=True) out = logs_dir / f"{run_name}.json" + log_data = {"run_name": run_name, "config": cfg, "history": history} + log_data["n_params"] = n_params with open(out, "w") as f: - json.dump({"run_name": run_name, "config": cfg, "history": history}, f, indent=2) + json.dump(log_data, f, indent=2) print(f"\nSaved log to {out}") diff --git a/generator/src/data/dataset.py b/generator/src/data/dataset.py index 5281296..ed88f62 100644 --- a/generator/src/data/dataset.py +++ b/generator/src/data/dataset.py @@ -51,17 +51,20 @@ class GeneratorDataset(Dataset): return img -def get_transform(image_size: int, augment: bool = False) -> T.Compose: +def get_transform(image_size: int, augment=False) -> T.Compose: """Build transform for generator training. Output is in [-1, 1]. - augment=True adds horizontal flip + mild rotation + mild color jitter. - Use augment=False for validation / FID real-image sets. + augment=False — no augmentation (for FID real-image sets) + augment="hflip" — horizontal flip only (recommended for VAE/DDPM) + augment=True — H-flip + rotation ±5° + mild color jitter (for GAN) """ ops = [ T.Resize(image_size), T.CenterCrop(image_size), ] - if augment: + if augment == "hflip": + ops.append(T.RandomHorizontalFlip(p=0.5)) + elif augment: ops += [ T.RandomHorizontalFlip(p=0.5), T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR), diff --git a/generator/src/models/__init__.py b/generator/src/models/__init__.py index 6f25924..cf8f908 100644 --- a/generator/src/models/__init__.py +++ b/generator/src/models/__init__.py @@ -23,4 +23,7 @@ def get_model(cfg: dict) -> tuple: from src.models import dcgan # noqa: E402, F401 +from src.models import wgan # noqa: E402, F401 +from src.models import vae # noqa: E402, F401 +from src.models import unet # noqa: E402, F401 diff --git a/generator/src/models/patchgan.py b/generator/src/models/patchgan.py new file mode 100644 index 0000000..9245416 --- /dev/null +++ b/generator/src/models/patchgan.py @@ -0,0 +1,69 @@ +"""PatchGAN discriminator for Phase 3.3 (VQGAN-lite adversarial training). + +Outputs a spatial patch map instead of a single scalar — each patch +predicts real/fake independently. Loss is the mean over all patches. + +Not registered in the model registry; instantiated inside train_vae +when lambda_adversarial > 0. +""" +import torch +import torch.nn as nn + + +def _init_weights(m): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0.0, 0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + +class PatchGANDiscriminator(nn.Module): + """Stride-2 + stride-1 convolution chain → spatial patch logit map. + + Supports image_size ∈ {64, 128}. For 64×64 input the final map is 6×6 + (70×70 receptive field). For 128×128 an extra stride-2 layer is added. + InstanceNorm everywhere except the first layer. + """ + + def __init__(self, ndf: int = 64, image_size: int = 64): + super().__init__() + layers: list[nn.Module] = [ + # First layer: no norm + nn.Conv2d(3, ndf, 4, stride=2, padding=1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + ] + if image_size >= 128: + layers += [ + nn.Conv2d(ndf, ndf, 4, stride=2, padding=1, bias=False), + nn.InstanceNorm2d(ndf, affine=True), + nn.LeakyReLU(0.2, inplace=True), + ] + layers += [ + nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1, bias=False), + nn.InstanceNorm2d(ndf * 2, affine=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1, bias=False), + nn.InstanceNorm2d(ndf * 4, affine=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ndf * 4, ndf * 8, 4, stride=1, padding=1, bias=False), + nn.InstanceNorm2d(ndf * 8, affine=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ndf * 8, 1, 4, stride=1, padding=1, bias=True), + ] + self.net = nn.Sequential(*layers) + self.apply(_init_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) # (B, 1, H', W') — patch logit map + + +def hinge_d_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor: + """Hinge loss for the discriminator (Lim & Ye, 2017).""" + loss_real = torch.mean(torch.relu(1.0 - real_logits)) + loss_fake = torch.mean(torch.relu(1.0 + fake_logits)) + return 0.5 * (loss_real + loss_fake) + + +def hinge_g_loss(fake_logits: torch.Tensor) -> torch.Tensor: + """Generator hinge loss — maximise D(fake).""" + return -torch.mean(fake_logits) diff --git a/generator/src/models/unet.py b/generator/src/models/unet.py new file mode 100644 index 0000000..8250bd8 --- /dev/null +++ b/generator/src/models/unet.py @@ -0,0 +1,279 @@ +"""Time-conditioned U-Net for DDPM (Phase 4). + +Architecture follows Ho et al. (2020) with options from Nichol & Dhariwal (2021): + - Sinusoidal time embedding → MLP → added to every ResBlock + - GroupNorm (32 groups) + SiLU activations throughout + - Self-attention at configurable spatial resolutions + - Upsample(nearest) + Conv in the decoder — no checkerboard artefacts + +Registered as kind="ddpm". +""" +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.models import register + +_GN = 32 # GroupNorm groups — all channel counts used here are multiples of 32 + + +# ── Time embedding ──────────────────────────────────────────────────────────── + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, t: torch.Tensor) -> torch.Tensor: + half = self.dim // 2 + freqs = torch.exp( + -math.log(10000) * torch.arange(half, device=t.device, dtype=torch.float) / half + ) + angles = t[:, None].float() * freqs[None] # (B, half) + return torch.cat([angles.sin(), angles.cos()], dim=-1) # (B, dim) + + +# ── Core building blocks ────────────────────────────────────────────────────── + +class ResBlock(nn.Module): + """ResNet block with time-embedding injection (additive, after first conv).""" + + def __init__(self, in_ch: int, out_ch: int, t_emb_dim: int, dropout: float = 0.1): + super().__init__() + self.norm1 = nn.GroupNorm(_GN, in_ch) + self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) + self.t_proj = nn.Linear(t_emb_dim, out_ch) + self.norm2 = nn.GroupNorm(_GN, out_ch) + self.drop = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) + self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() + + def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: + h = self.conv1(F.silu(self.norm1(x))) + h = h + self.t_proj(F.silu(t_emb))[:, :, None, None] + h = self.conv2(self.drop(F.silu(self.norm2(h)))) + return h + self.skip(x) + + +class AttentionBlock(nn.Module): + """Single-head self-attention with GroupNorm pre-norm and residual.""" + + def __init__(self, ch: int): + super().__init__() + self.norm = nn.GroupNorm(_GN, ch) + self.qkv = nn.Conv2d(ch, ch * 3, 1, bias=False) + self.proj = nn.Conv2d(ch, ch, 1) + self._scale = ch ** -0.5 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + n = h * w + qkv = self.norm(x) + q, k, v = self.qkv(qkv).reshape(b, 3, c, n).unbind(1) # each (b, c, n) + attn = torch.softmax(q.transpose(-2, -1) @ k * self._scale, dim=-1) # (b, n, n) + out = (v @ attn.transpose(-2, -1)).reshape(b, c, h, w) + return x + self.proj(out) + + +class Downsample(nn.Module): + def __init__(self, ch: int): + super().__init__() + self.conv = nn.Conv2d(ch, ch, 4, stride=2, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class Upsample(nn.Module): + def __init__(self, ch: int): + super().__init__() + self.conv = nn.Conv2d(ch, ch, 3, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(F.interpolate(x, scale_factor=2, mode="nearest")) + + +# ── Down / Up blocks ────────────────────────────────────────────────────────── + +class DownBlock(nn.Module): + def __init__( + self, + in_ch: int, + out_ch: int, + t_emb_dim: int, + num_res_blocks: int, + with_attn: bool, + dropout: float, + ): + super().__init__() + self.resnets = nn.ModuleList( + ResBlock(in_ch if j == 0 else out_ch, out_ch, t_emb_dim, dropout) + for j in range(num_res_blocks) + ) + self.attn = AttentionBlock(out_ch) if with_attn else nn.Identity() + + def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: + for res in self.resnets: + x = res(x, t_emb) + return self.attn(x) + + +class UpBlock(nn.Module): + def __init__( + self, + in_ch: int, + skip_ch: int, + out_ch: int, + t_emb_dim: int, + num_res_blocks: int, + with_attn: bool, + dropout: float, + ): + super().__init__() + # First ResBlock absorbs the skip-connection channels via concat + self.resnets = nn.ModuleList( + ResBlock( + (in_ch + skip_ch) if j == 0 else out_ch, + out_ch, + t_emb_dim, + dropout, + ) + for j in range(num_res_blocks + 1) # +1 to consume the concat + ) + self.attn = AttentionBlock(out_ch) if with_attn else nn.Identity() + + def forward( + self, x: torch.Tensor, skip: torch.Tensor, t_emb: torch.Tensor + ) -> torch.Tensor: + x = torch.cat([x, skip], dim=1) + for res in self.resnets: + x = res(x, t_emb) + return self.attn(x) + + +# ── U-Net ───────────────────────────────────────────────────────────────────── + +class UNet(nn.Module): + """Time-conditioned U-Net. + + image_size — must be a power-of-two; 64 or 128 recommended. + base_ch — base channel count (128 for phases 4.1–4.3, 192 for 4.4). + ch_mult — channel multipliers per resolution level. + attn_resolutions — spatial resolutions at which attention is inserted. + num_res_blocks — ResBlocks per level (in both down and up paths). + dropout — applied inside every ResBlock. + """ + + def __init__( + self, + image_size: int = 64, + base_ch: int = 128, + ch_mult: tuple = (1, 2, 2, 2), + attn_resolutions: tuple = (16, 8), + num_res_blocks: int = 2, + dropout: float = 0.1, + ): + super().__init__() + n_levels = len(ch_mult) + chs = [base_ch * m for m in ch_mult] + t_emb_dim = base_ch * 4 + + # ── Time embedding ──────────────────────────────────────────────── + self.time_embed = nn.Sequential( + SinusoidalPosEmb(base_ch), + nn.Linear(base_ch, t_emb_dim), + nn.SiLU(), + nn.Linear(t_emb_dim, t_emb_dim), + ) + + # ── Input projection ────────────────────────────────────────────── + self.in_conv = nn.Conv2d(3, chs[0], 3, padding=1) + + # ── Down path ───────────────────────────────────────────────────── + self.down_blocks = nn.ModuleList() + self.downsamples = nn.ModuleList() + cur_res = image_size + prev_ch = chs[0] + for i, ch in enumerate(chs): + with_attn = (cur_res in attn_resolutions) + self.down_blocks.append( + DownBlock(prev_ch, ch, t_emb_dim, num_res_blocks, with_attn, dropout) + ) + if i < n_levels - 1: + self.downsamples.append(Downsample(ch)) + cur_res //= 2 + prev_ch = ch + + # ── Middle ──────────────────────────────────────────────────────── + mid_ch = chs[-1] + self.mid_res1 = ResBlock(mid_ch, mid_ch, t_emb_dim, dropout) + self.mid_attn = AttentionBlock(mid_ch) + self.mid_res2 = ResBlock(mid_ch, mid_ch, t_emb_dim, dropout) + + # ── Up path ─────────────────────────────────────────────────────── + # Mirrors down path: iterate chs in reverse; skip_ch = chs[n_levels-1-i]. + self.up_blocks = nn.ModuleList() + self.upsamples = nn.ModuleList() + in_ch = mid_ch + for i in range(n_levels): + level = n_levels - 1 - i # index from deep (n-1) to shallow (0) + skip_ch = chs[level] + out_ch = chs[level - 1] if level > 0 else chs[0] + with_attn = (cur_res in attn_resolutions) + self.up_blocks.append( + UpBlock(in_ch, skip_ch, out_ch, t_emb_dim, num_res_blocks, with_attn, dropout) + ) + if level > 0: + self.upsamples.append(Upsample(out_ch)) + cur_res *= 2 + in_ch = out_ch + + # ── Output ──────────────────────────────────────────────────────── + self.out_norm = nn.GroupNorm(_GN, chs[0]) + self.out_conv = nn.Conv2d(chs[0], 3, 3, padding=1) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + t_emb = self.time_embed(t) + + x = self.in_conv(x) + + # Down + skips = [] + ds_idx = 0 + for i, block in enumerate(self.down_blocks): + x = block(x, t_emb) + skips.append(x) + if ds_idx < len(self.downsamples): + x = self.downsamples[ds_idx](x) + ds_idx += 1 + + # Middle + x = self.mid_res1(x, t_emb) + x = self.mid_attn(x) + x = self.mid_res2(x, t_emb) + + # Up + us_idx = 0 + for i, block in enumerate(self.up_blocks): + x = block(x, skips[-(i + 1)], t_emb) + if us_idx < len(self.upsamples): + x = self.upsamples[us_idx](x) + us_idx += 1 + + return self.out_conv(F.silu(self.out_norm(x))) + + +def _build(cfg: dict): + return UNet( + image_size = cfg.get("image_size", 64), + base_ch = cfg.get("base_ch", 128), + ch_mult = tuple(cfg.get("ch_mult", [1, 2, 2, 2])), + attn_resolutions = tuple(cfg.get("attn_resolutions", [16, 8])), + num_res_blocks = cfg.get("num_res_blocks", 2), + dropout = cfg.get("dropout", 0.1), + ) + + +register("ddpm", _build, kind="ddpm") diff --git a/generator/src/models/vae.py b/generator/src/models/vae.py new file mode 100644 index 0000000..285ad56 --- /dev/null +++ b/generator/src/models/vae.py @@ -0,0 +1,132 @@ +"""Convolutional VAE for Phase 3. + +Encoder uses stride-2 Conv → flatten → linear (μ, log σ²). +Decoder uses Linear → Upsample(nearest) + Conv to avoid ConvTranspose2d +checkerboard artefacts. + +Registered as kind="vae". The run.py dispatcher passes the model to +train_vae(), which internally builds perceptual loss and PatchGAN when +the corresponding lambdas are non-zero. +""" +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.models import register + + +def _init_weights(m): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): + nn.init.normal_(m.weight, 0.0, 0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d) and m.weight is not None: + nn.init.normal_(m.weight, 1.0, 0.02) + nn.init.zeros_(m.bias) + + +def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential: + """Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard.""" + return nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + ) + + +class VAE(nn.Module): + """Convolutional VAE. image_size must be a power-of-two ≥ 32. + + Spatial bottleneck is always at 4×4 regardless of image_size — + the encoder and decoder scale the number of stride-2 steps accordingly. + """ + + def __init__(self, latent_dim: int = 256, ngf: int = 64, image_size: int = 64): + super().__init__() + if image_size < 32 or (image_size & (image_size - 1)): + raise ValueError(f"image_size must be a power-of-two ≥ 32, got {image_size}") + + self.latent_dim = latent_dim + self.image_size = image_size + + n_down = int(math.log2(image_size)) - 2 # steps from image_size to 4×4 + # 64 → n_down=4: 64→32→16→8→4 + # 128 → n_down=5: 128→64→32→16→8→4 + + # ── Encoder ────────────────────────────────────────────────────────── + enc_layers: list[nn.Module] = [ + nn.Conv2d(3, ngf, 4, stride=2, padding=1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + ] + ch = ngf + for _ in range(n_down - 1): + enc_layers += [ + nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ch * 2), + nn.LeakyReLU(0.2, inplace=True), + ] + ch *= 2 + # ch = ngf * 2^(n_down-1); spatial = 4×4 + self.encoder = nn.Sequential(*enc_layers) + + flat = ch * 4 * 4 + self.fc_mu = nn.Linear(flat, latent_dim) + self.fc_lv = nn.Linear(flat, latent_dim) + + # ── Decoder ────────────────────────────────────────────────────────── + self.fc_dec = nn.Linear(latent_dim, flat) + self._dec_ch = ch # channels at the 4×4 bottleneck + + dec_layers: list[nn.Module] = [] + for _ in range(n_down - 1): + dec_layers.append(_upsample_block(ch, ch // 2)) + ch //= 2 + # Final upsample to image_size, output 3 channels, no BN, Tanh + dec_layers += [ + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(ch, 3, 3, padding=1, bias=True), + nn.Tanh(), + ] + self.decoder = nn.Sequential(*dec_layers) + + self.apply(_init_weights) + + # ── Interface ──────────────────────────────────────────────────────────── + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + h = self.encoder(x).flatten(1) + return self.fc_mu(h), self.fc_lv(h) + + def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: + std = torch.exp(0.5 * log_var) + return mu + std * torch.randn_like(std) + + def decode(self, z: torch.Tensor) -> torch.Tensor: + h = self.fc_dec(z).view(z.size(0), self._dec_ch, 4, 4) + return self.decoder(h) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Returns (reconstruction, mu, log_var).""" + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + return self.decode(z), mu, log_var + + @torch.no_grad() + def sample(self, n: int, device) -> torch.Tensor: + """Sample n images by drawing z ~ N(0, I).""" + z = torch.randn(n, self.latent_dim, device=device) + return self.decode(z) + + +def _build(cfg: dict): + return VAE( + latent_dim=cfg.get("latent_dim", 256), + ngf=cfg.get("ngf", 64), + image_size=cfg.get("image_size", 64), + ) + + +register("vae", _build, kind="vae") diff --git a/generator/src/models/wgan.py b/generator/src/models/wgan.py index 1084697..2c61e22 100644 --- a/generator/src/models/wgan.py +++ b/generator/src/models/wgan.py @@ -1,26 +1,31 @@ -"""WGAN-GP with spectral normalization, self-attention, and GroupNorm. +"""WGAN-GP variants. -Improvements over the original: -- Generator: BatchNorm -> GroupNorm (no batch-size coupling, stable with varied content) -- Critic: InstanceNorm -> spectral normalization (principled Lipschitz constraint) -- Both: one SAGAN-style self-attention block at the 32x32 feature map -- Larger capacity: ngf=128, ndf=128 +wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only. +wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic. """ +import math + import torch import torch.nn as nn -import torch.nn.functional as F + from src.models import register def _init_weights(m): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.normal_(m.weight, 0.0, 0.02) - elif isinstance(m, nn.GroupNorm) and m.weight is not None: + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and m.weight is not None: nn.init.normal_(m.weight, 1.0, 0.02) nn.init.zeros_(m.bias) +def _sn(module): + return nn.utils.spectral_norm(module) + + class SelfAttention(nn.Module): + """SAGAN-style self-attention.""" + def __init__(self, in_ch: int): super().__init__() mid = max(in_ch // 8, 1) @@ -36,98 +41,221 @@ class SelfAttention(nn.Module): k = self.k(x).view(b, self._mid, -1) # (b, mid, hw) v = self.v(x).view(b, c, -1) # (b, c, hw) attn = torch.softmax(q @ k * self._mid ** -0.5, dim=-1) # (b, hw, hw) - out = (v @ attn.transpose(-2, -1)).view(b, c, h, w) - return x + self.gamma * out + return x + self.gamma * (v @ attn.transpose(-2, -1)).view(b, c, h, w) -def _sn(module): - """Apply spectral normalization to a conv layer.""" - return nn.utils.spectral_norm(module) +# --------------------------------------------------------------------------- +# Phase 2.2 — basic WGAN-GP (BatchNorm in G, InstanceNorm in D, 64×64 only) +# --------------------------------------------------------------------------- +class WGANBasicGenerator(nn.Module): + """Maps (latent_dim, 1, 1) -> (3, 64, 64) in [-1, 1]. -class WGANGenerator(nn.Module): - """Maps (latent_dim x 1 x 1) -> (3 x 128 x 128) in [-1, 1]. - - Upsampling path: 1 -> 4 -> 8 -> 16 (+attn) -> 32 -> 64 -> 128 - Self-attention sits at 16x16 (attention matrix 256x256 vs 1024x1024 at 32x32). + Same channel structure as DCGAN. BatchNorm in generator is fine because + WGAN-GP's constraint targets the critic, not the generator. """ def __init__(self, latent_dim: int = 128, ngf: int = 64): super().__init__() self.net = nn.Sequential( - # 1x1 -> 4x4 + # 1×1 → 4×4 nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False), - nn.GroupNorm(8, ngf * 8), nn.ReLU(True), - # 4x4 -> 8x8 + nn.BatchNorm2d(ngf * 8), nn.ReLU(True), + # 4×4 → 8×8 nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), - nn.GroupNorm(8, ngf * 4), nn.ReLU(True), - # 8x8 -> 16x16 + nn.BatchNorm2d(ngf * 4), nn.ReLU(True), + # 8×8 → 16×16 nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), - nn.GroupNorm(8, ngf * 2), nn.ReLU(True), - ) - self.attn = SelfAttention(ngf * 2) # applied at 16x16 - self.out = nn.Sequential( - # 16x16 -> 32x32 + nn.BatchNorm2d(ngf * 2), nn.ReLU(True), + # 16×16 → 32×32 nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), - nn.GroupNorm(8, ngf), nn.ReLU(True), - # 32x32 -> 64x64 - nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1, bias=False), - nn.GroupNorm(8, ngf // 2), nn.ReLU(True), - # 64x64 -> 128x128 - nn.ConvTranspose2d(ngf // 2, 3, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf), nn.ReLU(True), + # 32×32 → 64×64 + nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False), nn.Tanh(), ) self.apply(_init_weights) def forward(self, z: torch.Tensor) -> torch.Tensor: - h = self.net(z) - h = self.attn(h) - return self.out(h) + return self.net(z) -class WGANCritic(nn.Module): - """Critic (no sigmoid) for WGAN-GP. All conv layers are spectrally normalized. - - Downsampling path: 128 -> 64 -> 32 -> 16 (+attn) -> 8 -> 4 -> score +class WGANBasicCritic(nn.Module): + """WGAN-GP critic (64×64). InstanceNorm instead of BatchNorm — BatchNorm + breaks the per-sample Lipschitz constraint the gradient penalty enforces. """ def __init__(self, ndf: int = 64): super().__init__() - self.down = nn.Sequential( - # 128x128 -> 64x64 (no norm on first layer) - _sn(nn.Conv2d(3, ndf // 2, 4, 2, 1, bias=False)), + self.net = nn.Sequential( + # 64×64 → 32×32 (no norm on first layer) + nn.Conv2d(3, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, True), - # 64x64 -> 32x32 - _sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)), + # 32×32 → 16×16 + nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), + nn.InstanceNorm2d(ndf * 2, affine=True), nn.LeakyReLU(0.2, True), - # 32x32 -> 16x16 - _sn(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)), + # 16×16 → 8×8 + nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), + nn.InstanceNorm2d(ndf * 4, affine=True), nn.LeakyReLU(0.2, True), - ) - self.attn = SelfAttention(ndf * 2) # applied at 16x16 - self.tail = nn.Sequential( - # 16x16 -> 8x8 - _sn(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)), + # 8×8 → 4×4 + nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), + nn.InstanceNorm2d(ndf * 8, affine=True), nn.LeakyReLU(0.2, True), - # 8x8 -> 4x4 - _sn(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)), - nn.LeakyReLU(0.2, True), - # 4x4 -> 1x1 - _sn(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)), + # 4×4 → 1×1 (score, no sigmoid) + nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), ) self.apply(_init_weights) def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self.down(x) - h = self.attn(h) + return self.net(x).view(x.size(0)) + + +# --------------------------------------------------------------------------- +# Phase 2.3 / 2.4 — advanced WGAN-GP (GroupNorm, SpectralNorm, attention) +# --------------------------------------------------------------------------- + +class WGANGenerator(nn.Module): + """GroupNorm generator with SAGAN self-attention. + + Supports image_size ∈ {64, 128}. + Stem is always 1×1 → 4×4 → 8×8 → 16×16 (ngf×8 → ngf×4 → ngf×2 channels). + Attention at 16×16 always; additional attention at 32×32 for 128×128. + """ + + def __init__(self, latent_dim: int = 128, ngf: int = 128, image_size: int = 64): + super().__init__() + if image_size not in (64, 128): + raise ValueError(f"WGANGenerator supports image_size 64 or 128, got {image_size}") + + self._image_size = image_size + + self.stem = nn.Sequential( + # 1×1 → 4×4 + nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False), + nn.GroupNorm(8, ngf * 8), nn.ReLU(True), + # 4×4 → 8×8 + nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), + nn.GroupNorm(8, ngf * 4), nn.ReLU(True), + # 8×8 → 16×16 + nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), + nn.GroupNorm(8, ngf * 2), nn.ReLU(True), + ) # output: (ngf×2, 16, 16) + self.attn16 = SelfAttention(ngf * 2) + + if image_size == 64: + self.mid = None + self.attn32 = None + self.tail = nn.Sequential( + # 16×16 → 32×32 + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.GroupNorm(8, ngf), nn.ReLU(True), + # 32×32 → 64×64 + nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False), + nn.Tanh(), + ) + else: # 128 + self.mid = nn.Sequential( + # 16×16 → 32×32 + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.GroupNorm(8, ngf), nn.ReLU(True), + ) + self.attn32 = SelfAttention(ngf) + self.tail = nn.Sequential( + # 32×32 → 64×64 + nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1, bias=False), + nn.GroupNorm(8, ngf // 2), nn.ReLU(True), + # 64×64 → 128×128 + nn.ConvTranspose2d(ngf // 2, 3, 4, 2, 1, bias=False), + nn.Tanh(), + ) + self.apply(_init_weights) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + h = self.attn16(self.stem(z)) + if self.mid is not None: + h = self.attn32(self.mid(h)) + return self.tail(h) + + +class WGANCritic(nn.Module): + """SpectralNorm critic with SAGAN self-attention. + + Supports image_size ∈ {64, 128}. + Attention at 16×16 always; additional attention at 32×32 for 128×128. + """ + + def __init__(self, ndf: int = 128, image_size: int = 64): + super().__init__() + if image_size not in (64, 128): + raise ValueError(f"WGANCritic supports image_size 64 or 128, got {image_size}") + + self._image_size = image_size + + if image_size == 64: + # Head: 64→32 (ndf//2) + self.head = nn.Sequential( + _sn(nn.Conv2d(3, ndf // 2, 4, 2, 1, bias=False)), + nn.LeakyReLU(0.2, True), + ) + self.attn32 = None + # 32→16 (ndf) + self.mid = nn.Sequential( + _sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)), + nn.LeakyReLU(0.2, True), + ) + attn_ch = ndf + else: # 128 + # Head: 128→64 (ndf//4), 64→32 (ndf//2) + self.head = nn.Sequential( + _sn(nn.Conv2d(3, ndf // 4, 4, 2, 1, bias=False)), + nn.LeakyReLU(0.2, True), + _sn(nn.Conv2d(ndf // 4, ndf // 2, 4, 2, 1, bias=False)), + nn.LeakyReLU(0.2, True), + ) + self.attn32 = SelfAttention(ndf // 2) + # 32→16 (ndf) + self.mid = nn.Sequential( + _sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)), + nn.LeakyReLU(0.2, True), + ) + attn_ch = ndf + + self.attn16 = SelfAttention(attn_ch) + + # Tail: 16×16 → 8×8 → 4×4 → score + self.tail = nn.Sequential( + _sn(nn.Conv2d(attn_ch, attn_ch * 2, 4, 2, 1, bias=False)), + nn.LeakyReLU(0.2, True), + _sn(nn.Conv2d(attn_ch * 2, attn_ch * 4, 4, 2, 1, bias=False)), + nn.LeakyReLU(0.2, True), + _sn(nn.Conv2d(attn_ch * 4, 1, 4, 1, 0, bias=False)), + ) + self.apply(_init_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.head(x) + if self.attn32 is not None: + h = self.attn32(h) + h = self.attn16(self.mid(h)) return self.tail(h).view(x.size(0)) -def _build(cfg: dict): +def _build_basic(cfg: dict): return ( - WGANGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 128)), - WGANCritic(ndf=cfg.get("ndf", 128)), + WGANBasicGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 64)), + WGANBasicCritic(ndf=cfg.get("ndf", 64)), ) +def _build(cfg: dict): + image_size = cfg.get("image_size", 64) + return ( + WGANGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 128), image_size=image_size), + WGANCritic(ndf=cfg.get("ndf", 128), image_size=image_size), + ) + + +register("wgan_basic", _build_basic, kind="wgan") register("wgan", _build, kind="wgan") diff --git a/generator/src/training/__init__.py b/generator/src/training/__init__.py index dc356a6..9dee71b 100644 --- a/generator/src/training/__init__.py +++ b/generator/src/training/__init__.py @@ -1,3 +1,3 @@ -from src.training.trainer import train_dcgan +from src.training.trainer import train_dcgan, train_wgan, train_vae, train_ddpm -__all__ = ["train_dcgan"] +__all__ = ["train_dcgan", "train_wgan", "train_vae", "train_ddpm"] diff --git a/generator/src/training/diffusion.py b/generator/src/training/diffusion.py new file mode 100644 index 0000000..b737d6e --- /dev/null +++ b/generator/src/training/diffusion.py @@ -0,0 +1,136 @@ +"""Gaussian diffusion utilities for Phase 4 (DDPM). + +Provides noise schedules, the forward (noising) process, training loss, +and DDIM deterministic sampling (Song et al., 2020). + +Convention: alpha_bars is a 1-D tensor of length T, where alpha_bars[t] += ᾱ_{t+1} in 1-indexed notation. Timestep t used in the training loop +is a 0-indexed integer in [0, T). At t=0 the image is almost clean +(ᾱ ≈ 1 − β_1); at t=T−1 the image is almost pure noise (ᾱ ≈ 0). +""" +import math + +import torch +import torch.nn.functional as F + + +# ── Noise schedules ────────────────────────────────────────────────────────── + +def linear_betas(T: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor: + """Ho et al. (2020) linear schedule.""" + return torch.linspace(beta_start, beta_end, T) + + +def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor: + """Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t.""" + t = torch.linspace(0, T, T + 1) + f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2 + alpha_bar = f / f[0] + betas = 1 - alpha_bar[1:] / alpha_bar[:-1] + return betas.clamp(max=0.999) + + +def make_alpha_bars(betas: torch.Tensor) -> torch.Tensor: + """Cumulative product of (1 − β), shape (T,).""" + return (1.0 - betas).cumprod(0) + + +# ── Forward process ────────────────────────────────────────────────────────── + +def q_sample( + x0: torch.Tensor, + t: torch.Tensor, + alpha_bars: torch.Tensor, + noise: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Add noise to x0 at timestep t. Returns (x_t, noise).""" + if noise is None: + noise = torch.randn_like(x0) + ab = alpha_bars[t].to(x0.device)[:, None, None, None] + x_t = ab.sqrt() * x0 + (1 - ab).sqrt() * noise + return x_t, noise + + +# ── Training loss ──────────────────────────────────────────────────────────── + +def diffusion_loss( + model, + x0: torch.Tensor, + t: torch.Tensor, + alpha_bars: torch.Tensor, + pred_type: str = "eps", +) -> torch.Tensor: + """MSE on the model's prediction vs the true target. + + pred_type="eps" → target is the added noise ε (Ho et al.) + pred_type="v" → target is v = √ᾱ·ε − √(1−ᾱ)·x0 (Salimans & Ho) + """ + x_t, noise = q_sample(x0, t, alpha_bars) + pred = model(x_t, t) + + if pred_type == "eps": + target = noise + else: # v + ab = alpha_bars[t].to(x0.device)[:, None, None, None] + target = ab.sqrt() * noise - (1 - ab).sqrt() * x0 + + return F.mse_loss(pred, target) + + +# ── DDIM deterministic sampling ─────────────────────────────────────────────── + +@torch.no_grad() +def ddim_sample( + model, + n: int, + image_size: int, + alpha_bars: torch.Tensor, + n_steps: int = 100, + pred_type: str = "eps", + device: str = "cuda", + batch_size: int = 32, +) -> torch.Tensor: + """Generate n images via DDIM (eta=0, deterministic). + + Batches internally to avoid OOM when n is large. + Returns tensor shape (n, 3, image_size, image_size) in [-1, 1]. + """ + model.eval() + T = len(alpha_bars) + + # Build reversed subsequence: [T-1, T-1-step, ..., 0] + step = max(T // n_steps, 1) + ts = list(range(T - 1, -1, -step))[:n_steps] + if ts[-1] != 0: + ts.append(0) + + results = [] + remaining = n + while remaining > 0: + bsz = min(batch_size, remaining) + x = torch.randn(bsz, 3, image_size, image_size, device=device) + + for i, t_cur in enumerate(ts): + t_prev = ts[i + 1] if i + 1 < len(ts) else -1 + + t_batch = torch.full((bsz,), t_cur, device=device, dtype=torch.long) + ab_t = alpha_bars[t_cur].to(device) + ab_prev = alpha_bars[t_prev].to(device) if t_prev >= 0 else torch.ones(1, device=device) + + pred = model(x, t_batch) + + # Reconstruct x0 from prediction + if pred_type == "eps": + x0_hat = (x - (1 - ab_t).sqrt() * pred) / ab_t.sqrt() + else: # v + x0_hat = ab_t.sqrt() * x - (1 - ab_t).sqrt() * pred + x0_hat = x0_hat.clamp(-1, 1) + + # DDIM step + eps_hat = (x - ab_t.sqrt() * x0_hat) / (1 - ab_t).sqrt() + x = ab_prev.sqrt() * x0_hat + (1 - ab_prev).sqrt() * eps_hat + + results.append(x.cpu()) + remaining -= bsz + + return torch.cat(results)[:n] diff --git a/generator/src/training/metrics.py b/generator/src/training/metrics.py new file mode 100644 index 0000000..1bca7df --- /dev/null +++ b/generator/src/training/metrics.py @@ -0,0 +1,56 @@ +"""Extended generation quality metrics for Phase 5 cross-family comparison. + +IS — Inception Score (Salimans et al., 2016): measures sample quality × diversity. +LPIPS — average pairwise learned perceptual distance: measures sample diversity alone. + +Both functions accept float tensors in [-1, 1]. +""" +import torch +from torchmetrics.image.inception import InceptionScore +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + + +def compute_is( + imgs: torch.Tensor, + device: str = "cuda", + batch_size: int = 64, +) -> tuple[float, float]: + """Inception Score (mean ± std) over 10 splits. + + imgs: (N, 3, H, W) in [-1, 1]. N ≥ 2 048 for a reliable estimate. + Returns (is_mean, is_std). + """ + metric = InceptionScore(normalize=True).to(device) + imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5) + for i in range(0, len(imgs_01), batch_size): + metric.update(imgs_01[i : i + batch_size].to(device)) + mean, std = metric.compute() + return float(mean), float(std) + + +def compute_lpips_diversity( + imgs: torch.Tensor, + n_pairs: int = 200, + device: str = "cuda", + batch_size: int = 16, +) -> float: + """Average pairwise LPIPS distance — higher means more diverse samples. + + imgs: (N, 3, H, W) in [-1, 1]. Samples n_pairs random (i, j) pairs with i ≠ j. + """ + metric = LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True).to(device) + imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5) + N = len(imgs_01) + + # Sample random pairs (ensure i ≠ j by rejection) + idx = torch.randperm(N * 2)[:n_pairs * 2].view(n_pairs, 2) % N + same = idx[:, 0] == idx[:, 1] + idx[same, 1] = (idx[same, 1] + 1) % N # shift duplicate indices + + for start in range(0, n_pairs, batch_size): + end = min(start + batch_size, n_pairs) + i_batch = idx[start:end, 0] + j_batch = idx[start:end, 1] + metric.update(imgs_01[i_batch].to(device), imgs_01[j_batch].to(device)) + + return float(metric.compute()) diff --git a/generator/src/training/perceptual.py b/generator/src/training/perceptual.py new file mode 100644 index 0000000..51ab7a0 --- /dev/null +++ b/generator/src/training/perceptual.py @@ -0,0 +1,57 @@ +"""VGG-16 perceptual loss for Phase 3.2 and 3.3. + +Extracts features at relu1_2, relu2_2, relu3_3 and returns the +L1 distance in feature space. VGG weights are frozen. + +Input convention: images in [-1, 1] — the loss converts internally to +[0, 1] and then applies ImageNet normalisation before passing to VGG. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as tv_models + + +class PerceptualLoss(nn.Module): + """L1 feature-matching loss at three VGG-16 layers. + + VGG-16 feature indices: + relu1_2: features[:4] (before first maxpool) + relu2_2: features[4:9] (before second maxpool) + relu3_3: features[9:16] (before third maxpool) + """ + + def __init__(self): + super().__init__() + vgg = tv_models.vgg16(weights=tv_models.VGG16_Weights.IMAGENET1K_V1) + feats = vgg.features + self.slice1 = nn.Sequential(*list(feats[:4])) # relu1_2 + self.slice2 = nn.Sequential(*list(feats[4:9])) # relu2_2 + self.slice3 = nn.Sequential(*list(feats[9:16])) # relu3_3 + + for p in self.parameters(): + p.requires_grad_(False) + + self.register_buffer( + "mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + ) + self.register_buffer( + "std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + ) + + def _normalise(self, x: torch.Tensor) -> torch.Tensor: + """Convert [-1, 1] → ImageNet-normalised [0, 1].""" + x = x * 0.5 + 0.5 # → [0, 1] + return (x - self.mean) / self.std + + def forward(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor: + """L1 feature distance. real gradients are stopped — only fake trains.""" + f = self._normalise(fake) + r = self._normalise(real) + + loss = torch.tensor(0.0, device=fake.device) + for layer in (self.slice1, self.slice2, self.slice3): + f = layer(f) + r = layer(r) + loss = loss + F.l1_loss(f, r.detach()) + return loss diff --git a/generator/src/training/trainer.py b/generator/src/training/trainer.py index 388d7c6..4061525 100644 --- a/generator/src/training/trainer.py +++ b/generator/src/training/trainer.py @@ -1,8 +1,10 @@ import os +import time from pathlib import Path import torch import torch.nn as nn +import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision.utils import save_image from tqdm import tqdm @@ -19,12 +21,11 @@ else: _autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw) -def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, latent_dim: int, device) -> None: +def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None: samples_dir.mkdir(parents=True, exist_ok=True) with torch.no_grad(): - noise = torch.randn(16, latent_dim, 1, 1, device=device) - imgs = generator_ema.model(noise) # EMA model, [-1, 1] - imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1] + imgs = generator_ema.model(fixed_noise.to(device)) # EMA model, [-1, 1] + imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1] save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4) @@ -78,6 +79,9 @@ def train_dcgan( ema = EMA(generator, decay=ema_decay) + # Fixed noise for consistent sample tracking across epochs + fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device) + save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) samples_dir = save_dir.parent / "samples" / run_name @@ -88,6 +92,15 @@ def train_dcgan( best_fid = float("inf") print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}") + # Linear LR decay from epoch epochs//2 to epochs + decay_start = epochs // 2 + sched_g = torch.optim.lr_scheduler.LambdaLR( + opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) + sched_d = torch.optim.lr_scheduler.LambdaLR( + opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) + + t_start = time.time() + for epoch in range(1, epochs + 1): generator.train() discriminator.train() @@ -142,13 +155,13 @@ def train_dcgan( ) if epoch % sample_interval == 0: - _save_samples(ema, samples_dir, epoch, latent_dim=latent_dim, device=device) + _save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device) if epoch % fid_interval == 0: - generator.eval() + ema.model.eval() with torch.no_grad(): fake_imgs = torch.cat([ - generator(torch.randn(64, latent_dim, 1, 1, device=device)) + ema.model(torch.randn(64, latent_dim, 1, 1, device=device)) for _ in range(fid_n_real // 64 + 1) ])[:fid_n_real] fid_score = fid_eval.compute(fake_imgs) @@ -160,7 +173,586 @@ def train_dcgan( torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") + sched_g.step() + sched_d.step() + torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt") torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt") + history["train_time_s"] = time.time() - t_start + return history + + +def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor: + """Two-sided gradient penalty (Gulrajani et al., 2017).""" + bsz = real.size(0) + eps = torch.rand(bsz, 1, 1, 1, device=device) + interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True) + d_interp = critic(interp) + grad = torch.autograd.grad( + outputs=d_interp, + inputs=interp, + grad_outputs=torch.ones_like(d_interp), + create_graph=True, + retain_graph=True, + )[0] + return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean() + + +def train_wgan( + generator, + critic, + train_dataset, + cfg: dict, + *, + save_dir, + run_name: str, + device: str = "cuda", +) -> dict: + """WGAN-GP training loop (Gulrajani et al., 2017). + + Used for Phase 2.2–2.4. Gradient penalty replaces weight clipping. + The critic runs in float32 to keep GP gradient computation numerically + stable; AMP is used only for the generator forward/backward. + """ + device = torch.device(device if torch.cuda.is_available() else "cpu") + generator = generator.to(device) + critic = critic.to(device) + + n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad) + n_c = sum(p.numel() for p in critic.parameters() if p.requires_grad) + print(f"Generator: {n_g:,} params Critic: {n_c:,} params") + + epochs = cfg["epochs"] + batch_size = cfg["batch_size"] + lr_g = cfg.get("lr_g", 1e-4) + lr_d = cfg.get("lr_d", 1e-4) + beta1 = cfg.get("beta1", 0.0) + beta2 = cfg.get("beta2", 0.9) + latent_dim = cfg.get("latent_dim", 128) + n_critic = cfg.get("n_critic", 5) + gp_lambda = cfg.get("gp_lambda", 10) + ema_decay = cfg.get("ema_decay", 0.9999) + sample_interval = cfg.get("sample_interval", 10) + fid_interval = cfg.get("fid_interval", 25) + fid_n_real = cfg.get("fid_n_real", 5000) + + loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, + num_workers=min(4, os.cpu_count() or 1), + pin_memory=(device.type == "cuda"), drop_last=True, + ) + opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2)) + opt_c = torch.optim.Adam(critic.parameters(), lr=lr_d, betas=(beta1, beta2)) + + use_amp = device.type == "cuda" + scaler_g = _GradScaler("cuda", enabled=use_amp) + + ema = EMA(generator, decay=ema_decay) + + # Fixed noise for consistent sample tracking across epochs + fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device) + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + samples_dir = save_dir.parent / "samples" / run_name + + fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) + + history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}} + best_fid = float("inf") + print(f"Device: {device} AMP (G only): {use_amp} Batches/epoch: {len(loader)} n_critic: {n_critic}") + + # Linear LR decay from epoch epochs//2 to epochs + decay_start = epochs // 2 + sched_g = torch.optim.lr_scheduler.LambdaLR( + opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) + sched_c = torch.optim.lr_scheduler.LambdaLR( + opt_c, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) + + t_start = time.time() + + for epoch in range(1, epochs + 1): + generator.train() + critic.train() + g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0 + n_c_steps = n_g_steps = 0 + + for batch_idx, real in enumerate(tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False)): + real = real.to(device) + bsz = real.size(0) + + # ── Critic step (every batch) ───────────────────────────────── + # Run critic in float32 — GP requires double-precision gradients + # and AMP can degrade stability here. + opt_c.zero_grad() + with torch.no_grad(): + fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device)) + + real_f32 = real.float() + fake_f32 = fake.float().detach() + + d_real = critic(real_f32) + d_fake = critic(fake_f32) + gp = _gradient_penalty(critic, real_f32, fake_f32.detach(), device) + c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp + c_loss.backward() + opt_c.step() + + w_dist = (d_real.mean() - d_fake.mean()).item() + w_sum += w_dist + gp_sum += gp.item() + real_sum += d_real.mean().item() + fake_sum += d_fake.mean().item() + n_c_steps += 1 + + # ── Generator step (every n_critic batches) ─────────────────── + if (batch_idx + 1) % n_critic == 0: + opt_g.zero_grad() + with _autocast("cuda", enabled=use_amp): + fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device)) + g_loss = -critic(fake.float()).mean() + scaler_g.scale(g_loss).backward() + scaler_g.step(opt_g) + scaler_g.update() + ema.update(generator) + g_sum += g_loss.item() + n_g_steps += 1 + + avg_w = w_sum / max(n_c_steps, 1) + avg_gp = gp_sum / max(n_c_steps, 1) + avg_g = g_sum / max(n_g_steps, 1) + avg_r = real_sum / max(n_c_steps, 1) + avg_f = fake_sum / max(n_c_steps, 1) + history["g_loss"].append(avg_g) + history["w_dist"].append(avg_w) + history["gp"].append(avg_gp) + history["d_real"].append(avg_r) + history["d_fake"].append(avg_f) + print( + f"[{epoch:03d}/{epochs}] " + f"G: {avg_g:.4f} W-dist: {avg_w:.4f} GP: {avg_gp:.4f} " + f"C(real): {avg_r:.4f} C(fake): {avg_f:.4f}" + ) + + if epoch % sample_interval == 0: + _save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device) + + if epoch % fid_interval == 0: + ema.model.eval() + with torch.no_grad(): + fake_imgs = torch.cat([ + ema.model(torch.randn(64, latent_dim, 1, 1, device=device)) + for _ in range(fid_n_real // 64 + 1) + ])[:fid_n_real] + fid_score = fid_eval.compute(fake_imgs) + history["fid"][epoch] = fid_score + print(f" FID @ epoch {epoch}: {fid_score:.2f}") + + if fid_score < best_fid: + best_fid = fid_score + torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt") + torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") + + sched_g.step() + sched_c.step() + + torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt") + torch.save(critic.state_dict(), save_dir / f"{run_name}_final_d.pt") + torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt") + history["train_time_s"] = time.time() - t_start + return history + + +# ──────────────────────────────────────────────────────────────────────────── +# Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN) +# ──────────────────────────────────────────────────────────────────────────── + +def _save_vae_samples( + vae, + samples_dir: Path, + epoch: int, + *, + fixed_z: torch.Tensor, + fixed_real: torch.Tensor, + device, +) -> None: + """Save prior samples and a real-vs-reconstruction grid side by side.""" + samples_dir.mkdir(parents=True, exist_ok=True) + vae.eval() + with torch.no_grad(): + prior = vae.decode(fixed_z.to(device)) + prior = (prior.clamp(-1, 1) + 1.0) / 2.0 + save_image(prior, samples_dir / f"epoch_{epoch:04d}.png", nrow=4) + + recon, _, _ = vae(fixed_real.to(device)) + recon = (recon.clamp(-1, 1) + 1.0) / 2.0 + real = (fixed_real.to(device) + 1.0) / 2.0 + # Interleave real / reconstruction pairs + pairs = torch.stack([real, recon], dim=1).flatten(0, 1) + save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4) + vae.train() + + +def train_vae( + vae, + train_dataset, + cfg: dict, + *, + save_dir, + run_name: str, + device: str = "cuda", +) -> dict: + """VAE training loop covering Phase 3.1 – 3.3. + + Config toggles: + lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+) + lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3) + + Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl + """ + device = torch.device(device if torch.cuda.is_available() else "cpu") + vae = vae.to(device) + + n_vae = sum(p.numel() for p in vae.parameters() if p.requires_grad) + print(f"VAE: {n_vae:,} params") + + epochs = cfg["epochs"] + batch_size = cfg["batch_size"] + lr = cfg.get("lr", 1e-3) + latent_dim = cfg.get("latent_dim", 256) + beta_kl = cfg.get("beta_kl", 1.0) + lambda_perceptual = cfg.get("lambda_perceptual", 0.0) + lambda_adversarial = cfg.get("lambda_adversarial", 0.0) + lr_d = cfg.get("lr_d", 1e-4) + ema_decay = cfg.get("ema_decay", 0.9999) + sample_interval = cfg.get("sample_interval", 10) + fid_interval = cfg.get("fid_interval", 25) + fid_n_real = cfg.get("fid_n_real", 5000) + + use_perceptual = lambda_perceptual > 0 + use_adversarial = lambda_adversarial > 0 + + loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, + num_workers=min(4, os.cpu_count() or 1), + pin_memory=(device.type == "cuda"), drop_last=True, + ) + + opt_vae = torch.optim.Adam(vae.parameters(), lr=lr) + use_amp = device.type == "cuda" + scaler = _GradScaler("cuda", enabled=use_amp) + + # KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training + kl_warmup_epochs = max(1, epochs // 5) + + # Linear LR decay from epoch epochs//2 to epochs + decay_start = epochs // 2 + sched_vae = torch.optim.lr_scheduler.LambdaLR( + opt_vae, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) + sched_d = None # set below if adversarial + + # ── Optional components ─────────────────────────────────────────────── + perc_fn = None + patchgan = None + opt_d = None + scaler_d = None + + if use_perceptual: + from src.training.perceptual import PerceptualLoss + perc_fn = PerceptualLoss().to(device) + print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3") + + if use_adversarial: + from src.models.patchgan import PatchGANDiscriminator, hinge_d_loss, hinge_g_loss + patchgan = PatchGANDiscriminator( + ndf=cfg.get("ndf_patch", 64), + image_size=cfg.get("image_size", 64), + ).to(device) + opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999)) + scaler_d = _GradScaler("cuda", enabled=use_amp) + sched_d = torch.optim.lr_scheduler.LambdaLR( + opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) + n_d = sum(p.numel() for p in patchgan.parameters()) + print(f"PatchGAN: {n_d:,} params") + else: + hinge_d_loss = hinge_g_loss = None # satisfy linter, never called + + # ── Fixed seeds for consistent visualisation ────────────────────────── + fixed_z = torch.randn(16, latent_dim, device=device) + # Grab first 16 real images from the loader for reconstruction tracking + _it = iter(loader) + fixed_real = next(_it)[:16].cpu() + + ema = EMA(vae, decay=ema_decay) + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + samples_dir = save_dir.parent / "samples" / run_name + + fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) + + history = { + "recon_loss": [], "kl_loss": [], "perc_loss": [], + "adv_g_loss": [], "adv_d_loss": [], "fid": {}, + } + best_fid = float("inf") + print( + f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}" + f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}" + ) + + t_start = time.time() + + for epoch in range(1, epochs + 1): + vae.train() + if patchgan is not None: + patchgan.train() + + recon_sum = kl_sum = perc_sum = adv_g_sum = adv_d_sum = 0.0 + n_batches = 0 + + for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False): + real = real.to(device) + + # KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs + current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs) + + # ── VAE forward ─────────────────────────────────────────────── + with _autocast("cuda", enabled=use_amp): + recon, mu, log_var = vae(real) + mse = F.mse_loss(recon, real) + kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean() + perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze() + vae_loss = mse + current_beta * kl + lambda_perceptual * perc + + # ── PatchGAN discriminator step ─────────────────────────────── + adv_d = real.new_zeros(1).squeeze() + if use_adversarial: + opt_d.zero_grad() + with _autocast("cuda", enabled=use_amp): + d_real = patchgan(real) + d_fake = patchgan(recon.detach()) + adv_d = hinge_d_loss(d_real, d_fake) + scaler_d.scale(adv_d).backward() + scaler_d.step(opt_d) + scaler_d.update() + + # ── PatchGAN generator adversarial loss ─────────────────────── + adv_g = real.new_zeros(1).squeeze() + if use_adversarial: + with _autocast("cuda", enabled=use_amp): + adv_g = hinge_g_loss(patchgan(recon)) + vae_loss = vae_loss + lambda_adversarial * adv_g + + # ── VAE backward ────────────────────────────────────────────── + opt_vae.zero_grad() + scaler.scale(vae_loss).backward() + scaler.step(opt_vae) + scaler.update() + ema.update(vae) + + recon_sum += mse.item() + kl_sum += kl.item() + perc_sum += perc.item() + adv_g_sum += adv_g.item() + adv_d_sum += adv_d.item() + n_batches += 1 + + avg_r = recon_sum / n_batches + avg_k = kl_sum / n_batches + avg_p = perc_sum / n_batches + avg_g = adv_g_sum / n_batches + avg_d = adv_d_sum / n_batches + history["recon_loss"].append(avg_r) + history["kl_loss"].append(avg_k) + history["perc_loss"].append(avg_p) + history["adv_g_loss"].append(avg_g) + history["adv_d_loss"].append(avg_d) + + print( + f"[{epoch:03d}/{epochs}] " + f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} " + f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}" + ) + + if epoch % sample_interval == 0: + _save_vae_samples( + ema.model, samples_dir, epoch, + fixed_z=fixed_z, fixed_real=fixed_real, device=device, + ) + + if epoch % fid_interval == 0: + ema.model.eval() + with torch.no_grad(): + fake_imgs = torch.cat([ + ema.model.sample(64, device) + for _ in range(fid_n_real // 64 + 1) + ])[:fid_n_real] + fid_score = fid_eval.compute(fake_imgs) + history["fid"][epoch] = fid_score + print(f" FID @ epoch {epoch}: {fid_score:.2f}") + + if fid_score < best_fid: + best_fid = fid_score + torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt") + torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") + + sched_vae.step() + if sched_d is not None: + sched_d.step() + + torch.save(vae.state_dict(), save_dir / f"{run_name}_final_vae.pt") + torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt") + if patchgan is not None: + torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt") + history["train_time_s"] = time.time() - t_start + return history + + +# ──────────────────────────────────────────────────────────────────────────── +# Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider) +# ──────────────────────────────────────────────────────────────────────────── + +def train_ddpm( + model, + train_dataset, + cfg: dict, + *, + save_dir, + run_name: str, + device: str = "cuda", +) -> dict: + """DDPM training loop (Ho et al., 2020) covering Phase 4.1 – 4.4. + + Config keys: + noise_schedule — "linear" (4.1) or "cosine" (4.2+) + pred_type — "eps" (4.1–4.2) or "v" (4.3+) + T — diffusion timesteps (default 1000) + base_ch / ch_mult / attn_resolutions — U-Net capacity (see unet.py) + ddim_steps — DDIM steps for FID evaluation (default 100) + """ + from src.training.diffusion import ( + linear_betas, cosine_betas, make_alpha_bars, + diffusion_loss, ddim_sample, + ) + + device = torch.device(device if torch.cuda.is_available() else "cpu") + model = model.to(device) + + n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"U-Net: {n_params:,} params") + + epochs = cfg["epochs"] + batch_size = cfg["batch_size"] + lr = cfg.get("lr", 2e-4) + T = cfg.get("T", 1000) + noise_schedule = cfg.get("noise_schedule", "linear") + pred_type = cfg.get("pred_type", "eps") + ddim_steps = cfg.get("ddim_steps", 100) + image_size = cfg.get("image_size", 64) + ema_decay = cfg.get("ema_decay", 0.9999) + sample_interval = cfg.get("sample_interval", 10) + fid_interval = cfg.get("fid_interval", 25) + fid_n_real = cfg.get("fid_n_real", 5000) + + # Build noise schedule and register on device + betas = (cosine_betas(T) if noise_schedule == "cosine" else linear_betas(T)).to(device) + alpha_bars = make_alpha_bars(betas) # on device + + loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, + num_workers=min(4, os.cpu_count() or 1), + pin_memory=(device.type == "cuda"), drop_last=True, + ) + opt = torch.optim.AdamW(model.parameters(), lr=lr) + + use_amp = device.type == "cuda" + scaler = _GradScaler("cuda", enabled=use_amp) + + ema = EMA(model, decay=ema_decay) + + # Fixed noise for sample visualisation (same latents across epochs) + fixed_noise = torch.randn(16, 3, image_size, image_size, device=device) + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + samples_dir = save_dir.parent / "samples" / run_name + + fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) + + history = {"loss": [], "fid": {}} + best_fid = float("inf") + print( + f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}" + f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}" + ) + + # Linear LR decay from epoch epochs//2 to epochs + decay_start = epochs // 2 + sched = torch.optim.lr_scheduler.LambdaLR( + opt, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) + + t_start = time.time() + + for epoch in range(1, epochs + 1): + model.train() + loss_sum = 0.0 + n_batches = 0 + + for x0 in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False): + x0 = x0.to(device) + t = torch.randint(0, T, (x0.size(0),), device=device) + + with _autocast("cuda", enabled=use_amp): + loss = diffusion_loss(model, x0, t, alpha_bars, pred_type) + + opt.zero_grad() + scaler.scale(loss).backward() + scaler.unscale_(opt) + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + scaler.step(opt) + scaler.update() + ema.update(model) + + loss_sum += loss.item() + n_batches += 1 + + avg_loss = loss_sum / n_batches + history["loss"].append(avg_loss) + print(f"[{epoch:03d}/{epochs}] Loss: {avg_loss:.5f}") + + if epoch % sample_interval == 0: + samples_dir.mkdir(parents=True, exist_ok=True) + ema.model.eval() + with torch.no_grad(): + # Quick visualisation: denoise fixed_noise via DDIM + imgs = ddim_sample( + ema.model, 16, image_size, alpha_bars, + n_steps=50, pred_type=pred_type, device=str(device), batch_size=16, + ) + imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 + save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4) + + if epoch % fid_interval == 0: + ema.model.eval() + fake_imgs = ddim_sample( + ema.model, fid_n_real, image_size, alpha_bars, + n_steps=ddim_steps, pred_type=pred_type, + device=str(device), batch_size=32, + ) + fid_score = fid_eval.compute(fake_imgs) + history["fid"][epoch] = fid_score + print(f" FID @ epoch {epoch}: {fid_score:.2f}") + + if fid_score < best_fid: + best_fid = fid_score + torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt") + torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") + + sched.step() + + torch.save(model.state_dict(), save_dir / f"{run_name}_final_unet.pt") + torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt") + history["train_time_s"] = time.time() - t_start return history