Preview of phase 2-5 implementation; needs a full check

This commit is contained in:
Johnny Fernandes
2026-04-30 13:10:33 +01:00
parent 6e32001ebc
commit 7417267117
35 changed files with 3605 additions and 115 deletions
@@ -0,0 +1,10 @@
{
"epochs": 100,
"data_dir": "cropped/generator",
"sources": ["wiki"],
"augment": true,
"image_size": 64,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000
}
+12
View File
@@ -0,0 +1,12 @@
{
"extends": "_base_phase2.json",
"run_name": "p2_1_dcgan",
"model": "dcgan",
"latent_dim": 100,
"ngf": 64,
"ndf": 64,
"lr_g": 2e-4,
"lr_d": 2e-4,
"beta1": 0.5,
"beta2": 0.999
}
+14
View File
@@ -0,0 +1,14 @@
{
"extends": "_base_phase2.json",
"run_name": "p2_2_wgan",
"model": "wgan_basic",
"latent_dim": 128,
"ngf": 64,
"ndf": 64,
"lr_g": 1e-4,
"lr_d": 1e-4,
"beta1": 0.0,
"beta2": 0.9,
"n_critic": 2,
"gp_lambda": 10
}
@@ -0,0 +1,15 @@
{
"extends": "_base_phase2.json",
"run_name": "p2_3_wgan_sn_attn",
"model": "wgan",
"image_size": 64,
"latent_dim": 128,
"ngf": 128,
"ndf": 128,
"lr_g": 1e-4,
"lr_d": 1e-4,
"beta1": 0.0,
"beta2": 0.9,
"n_critic": 2,
"gp_lambda": 10
}
@@ -0,0 +1,15 @@
{
"extends": "_base_phase2.json",
"run_name": "p2_4_wgan_sn_attn_128",
"model": "wgan",
"image_size": 128,
"latent_dim": 128,
"ngf": 128,
"ndf": 128,
"lr_g": 1e-4,
"lr_d": 1e-4,
"beta1": 0.0,
"beta2": 0.9,
"n_critic": 2,
"gp_lambda": 10
}
@@ -0,0 +1,13 @@
{
"epochs": 100,
"data_dir": "cropped/generator",
"sources": ["wiki"],
"augment": "hflip",
"image_size": 64,
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000
}
+8
View File
@@ -0,0 +1,8 @@
{
"extends": "_base_phase3.json",
"run_name": "p3_1_vae",
"lr": 1e-3,
"beta_kl": 1.0,
"lambda_perceptual": 0.0,
"lambda_adversarial": 0.0
}
@@ -0,0 +1,8 @@
{
"extends": "_base_phase3.json",
"run_name": "p3_2_vae_perceptual",
"lr": 1e-3,
"beta_kl": 0.0001,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.0
}
@@ -0,0 +1,10 @@
{
"extends": "_base_phase3.json",
"run_name": "p3_3_vae_patchgan",
"lr": 1e-3,
"lr_d": 1e-4,
"beta_kl": 0.0001,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.1,
"ndf_patch": 64
}
@@ -0,0 +1,19 @@
{
"epochs": 100,
"data_dir": "cropped/generator",
"sources": ["wiki"],
"augment": "hflip",
"image_size": 64,
"model": "ddpm",
"T": 1000,
"ddim_steps": 100,
"lr": 2e-4,
"base_ch": 128,
"ch_mult": [1, 2, 2, 2],
"attn_resolutions": [16, 8],
"num_res_blocks": 2,
"dropout": 0.1,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000
}
@@ -0,0 +1,6 @@
{
"extends": "_base_phase4.json",
"run_name": "p4_1_ddpm_linear",
"noise_schedule": "linear",
"pred_type": "eps"
}
@@ -0,0 +1,6 @@
{
"extends": "_base_phase4.json",
"run_name": "p4_2_ddpm_cosine",
"noise_schedule": "cosine",
"pred_type": "eps"
}
@@ -0,0 +1,6 @@
{
"extends": "_base_phase4.json",
"run_name": "p4_3_ddpm_vpred",
"noise_schedule": "cosine",
"pred_type": "v"
}
@@ -0,0 +1,8 @@
{
"extends": "_base_phase4.json",
"run_name": "p4_4_ddpm_wider",
"noise_schedule": "cosine",
"pred_type": "v",
"base_ch": 192,
"attn_resolutions": [32, 16, 8]
}
+22
View File
@@ -0,0 +1,22 @@
{
"run_name": "p5_ddpm",
"model": "ddpm",
"epochs": 200,
"data_dir": "cropped/generator",
"sources": ["wiki"],
"augment": "hflip",
"image_size": 64,
"T": 1000,
"noise_schedule": "cosine",
"pred_type": "v",
"base_ch": 192,
"ch_mult": [1, 2, 2, 2],
"attn_resolutions": [32, 16, 8],
"num_res_blocks": 2,
"dropout": 0.1,
"lr": 2e-4,
"ddim_steps": 100,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000
}
+21
View File
@@ -0,0 +1,21 @@
{
"run_name": "p5_gan",
"model": "wgan",
"epochs": 200,
"data_dir": "cropped/generator",
"sources": ["wiki"],
"augment": true,
"image_size": 128,
"latent_dim": 128,
"ngf": 128,
"ndf": 128,
"lr_g": 1e-4,
"lr_d": 1e-4,
"beta1": 0.0,
"beta2": 0.9,
"n_critic": 2,
"gp_lambda": 10,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000
}
+20
View File
@@ -0,0 +1,20 @@
{
"run_name": "p5_vae",
"model": "vae",
"epochs": 200,
"data_dir": "cropped/generator",
"sources": ["wiki"],
"augment": "hflip",
"image_size": 64,
"latent_dim": 256,
"ngf": 64,
"lr": 1e-3,
"lr_d": 1e-4,
"beta_kl": 0.0001,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.1,
"ndf_patch": 64,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000
}
+366
View File
@@ -0,0 +1,366 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a0000001",
"metadata": {},
"source": [
"# Phase 2 — GAN Evolution Analysis\n",
"\n",
"Traces the GAN improvement story, each step motivated by the failure of the previous:\n",
"\n",
"| Step | Model | Key change | Expected failure |\n",
"|------|-------|------------|------------------|\n",
"| 2.1 | DCGAN 64×64 | Baseline on best pipeline | Mode collapse, training instability |\n",
"| 2.2 | WGAN-GP | Wasserstein loss + GP | Texture artifacts, limited coherence |\n",
"| 2.3 | WGAN-GP + SN + GroupNorm + Attn | Principled Lipschitz + long-range deps | Possible underfitting at 64×64 |\n",
"| 2.4 | 2.3 @ 128×128 | Scale resolution | ? |\n",
"\n",
"All runs use the best pipeline from Phase 1: MTCNN-aligned crops, H-flip + rotation + colour jitter, aligned-only dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000002",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\""
]
},
{
"cell_type": "markdown",
"id": "a0000003",
"metadata": {},
"source": [
"## 1. Load experiment logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000004",
"metadata": {},
"outputs": [],
"source": [
"run_names = [\"p2_1_dcgan\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"]\n",
"run_labels = {\n",
" \"p2_1_dcgan\": \"2.1 DCGAN\",\n",
" \"p2_2_wgan\": \"2.2 WGAN-GP\",\n",
" \"p2_3_wgan_sn_attn\": \"2.3 WGAN-GP+SN+Attn\",\n",
" \"p2_4_wgan_sn_attn_128\": \"2.4 +128×128\",\n",
"}\n",
"\n",
"runs = {}\n",
"for name in run_names:\n",
" log_path = LOGS / f\"{name}.json\"\n",
" if log_path.exists():\n",
" with open(log_path) as f:\n",
" runs[name] = json.load(f)\n",
" else:\n",
" print(f\" Missing: {log_path}\")\n",
"\n",
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
"for name in run_names:\n",
" status = \"✓\" if name in runs else \"✗\"\n",
" print(f\" {status} {name}\")"
]
},
{
"cell_type": "markdown",
"id": "a0000005",
"metadata": {},
"source": [
"## 2. FID Comparison Table"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000006",
"metadata": {},
"outputs": [],
"source": [
"def get_fid(run, epoch):\n",
" fid = run[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"rows = []\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" r = runs[name]\n",
" cfg = r[\"config\"]\n",
" h = r[\"history\"]\n",
" rows.append({\n",
" \"Step\": run_labels[name],\n",
" \"Model\": cfg.get(\"model\"),\n",
" \"Size\": f\"{cfg.get('image_size', 64)}×{cfg.get('image_size', 64)}\",\n",
" \"FID@25\": get_fid(r, 25),\n",
" \"FID@50\": get_fid(r, 50),\n",
" \"FID@75\": get_fid(r, 75),\n",
" \"FID@100\": get_fid(r, 100),\n",
" })\n",
"\n",
"df = pd.DataFrame(rows)\n",
"df.style.format({c: \"{:.1f}\" for c in [\"FID@25\", \"FID@50\", \"FID@75\", \"FID@100\"] if c in df})"
]
},
{
"cell_type": "markdown",
"id": "a0000007",
"metadata": {},
"source": [
"## 3. FID Curves — Evolution Story"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000008",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(11, 5))\n",
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n",
"\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
" epochs = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in epochs]\n",
" ax.plot(epochs, fids, \"o-\", label=f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\",\n",
" color=colors[i], linewidth=2, markersize=8)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"FID (lower is better)\")\n",
"ax.set_title(\"Phase 2 — FID Curves: DCGAN → WGAN-GP → +SN+Attn → 128×128\")\n",
"ax.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000009",
"metadata": {},
"source": [
"## 4. Training Dynamics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000010",
"metadata": {},
"outputs": [],
"source": [
"# Separate DCGAN (has g_loss/d_loss/d_real/d_fake) from WGAN (has g_loss/w_dist/gp)\n",
"dcgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") == \"dcgan\"]\n",
"wgan_names = [n for n in run_names if n in runs and runs[n][\"config\"].get(\"model\") != \"dcgan\"]\n",
"\n",
"if dcgan_names:\n",
" fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
" for name in dcgan_names:\n",
" h = runs[name][\"history\"]\n",
" epochs = range(1, len(h[\"g_loss\"]) + 1)\n",
" axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], linewidth=1.2)\n",
" axes[1].plot(epochs, h[\"d_loss\"], label=run_labels[name], linewidth=1.2)\n",
" axes[0].set_title(\"DCGAN — Generator Loss (BCE)\")\n",
" axes[1].set_title(\"DCGAN — Discriminator Loss\")\n",
" for ax in axes:\n",
" ax.set_xlabel(\"Epoch\"); ax.set_ylabel(\"Loss\"); ax.legend(fontsize=8)\n",
" plt.suptitle(\"Phase 2.1 — DCGAN Training Dynamics\", fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"if wgan_names:\n",
" fig, axes = plt.subplots(1, 3, figsize=(16, 4))\n",
" cmap = plt.cm.Set1\n",
" for i, name in enumerate(wgan_names):\n",
" h = runs[name][\"history\"]\n",
" epochs = range(1, len(h[\"g_loss\"]) + 1)\n",
" c = cmap(i / max(len(wgan_names), 1))\n",
" axes[0].plot(epochs, h[\"g_loss\"], label=run_labels[name], color=c, linewidth=1.2)\n",
" axes[1].plot(epochs, h[\"w_dist\"], label=run_labels[name], color=c, linewidth=1.2)\n",
" axes[2].plot(epochs, h[\"gp\"], label=run_labels[name], color=c, linewidth=1.2)\n",
" axes[0].set_title(\"Generator Loss (E[D(G(z))])\")\n",
" axes[1].set_title(\"Wasserstein Distance Est. (↑ better)\")\n",
" axes[2].set_title(\"Gradient Penalty\")\n",
" for ax in axes:\n",
" ax.set_xlabel(\"Epoch\"); ax.legend(fontsize=8)\n",
" plt.suptitle(\"Phase 2.22.4 — WGAN-GP Training Dynamics\", fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000011",
"metadata": {},
"source": [
"## 5. Sample Image Grids — Epoch 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000012",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 4, figsize=(18, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=8)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 2 — Epoch 100 Sample Grids (4×4)\", fontsize=13, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000013",
"metadata": {},
"source": [
"## 6. Progression: Epoch 10 → 50 → 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000014",
"metadata": {},
"outputs": [],
"source": [
"check_epochs = [10, 50, 100]\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(runs[name], ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{run_labels[name]} — Training Progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000015",
"metadata": {},
"source": [
"## 7. Step-by-step Pairwise Comparisons"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000016",
"metadata": {},
"outputs": [],
"source": [
"transitions = [\n",
" (\"2.1→2.2: BCE→Wasserstein\", \"p2_1_dcgan\", \"p2_2_wgan\"),\n",
" (\"2.2→2.3: +SN+GroupNorm+Attn\", \"p2_2_wgan\", \"p2_3_wgan_sn_attn\"),\n",
" (\"2.3→2.4: 64→128 resolution\", \"p2_3_wgan_sn_attn\", \"p2_4_wgan_sn_attn_128\"),\n",
"]\n",
"\n",
"for title, name_a, name_b in transitions:\n",
" fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
" for ax, name in zip(axes, [name_a, name_b]):\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100 = {fid:.1f}\" if fid else run_labels[name], fontsize=10)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=10)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a0000017",
"metadata": {},
"source": [
"## 8. Conclusions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0000018",
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 70)\n",
"print(\"PHASE 2 — GAN EVOLUTION SUMMARY\")\n",
"print(\"=\" * 70)\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
" continue\n",
" fid100 = get_fid(runs[name], 100)\n",
" fid50 = get_fid(runs[name], 50)\n",
" print(f\"\\n {run_labels[name]}:\")\n",
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
"\n",
"print(\"\\n\" + \"=\" * 70)\n",
"print(\"Best model for Phase 3/4 comparison: fill in after runs complete.\")\n",
"print(\"=\" * 70)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+396
View File
@@ -0,0 +1,396 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "b0000001",
"metadata": {},
"source": [
"# Phase 3 — VAE Evolution Analysis\n",
"\n",
"Traces the VAE improvement story — each step motivated by the failure of the previous:\n",
"\n",
"| Step | Model | Key change | Expected failure |\n",
"|------|-------|------------|------------------|\n",
"| 3.1 | Vanilla VAE (MSE+KL) | Baseline | Blurry samples — MSE minimises pixel average |\n",
"| 3.2 | + Perceptual loss (VGG) | Feature-space reconstruction | Residual texture blur |\n",
"| 3.3 | + PatchGAN (VQGAN-lite) | Local texture adversarial | — |\n",
"\n",
"All runs use H-flip-only augmentation and MTCNN-aligned 64×64 crops.\n",
"FID is computed from prior samples (`z ~ N(0, I) → decode`), same metric as GAN.\n",
"Reconstructions are shown separately to diagnose encoder quality."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000002",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\""
]
},
{
"cell_type": "markdown",
"id": "b0000003",
"metadata": {},
"source": [
"## 1. Load experiment logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000004",
"metadata": {},
"outputs": [],
"source": [
"run_names = [\"p3_1_vae\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"]\n",
"run_labels = {\n",
" \"p3_1_vae\": \"3.1 VAE (MSE+KL)\",\n",
" \"p3_2_vae_perceptual\": \"3.2 +Perceptual\",\n",
" \"p3_3_vae_patchgan\": \"3.3 +PatchGAN\",\n",
"}\n",
"\n",
"runs = {}\n",
"for name in run_names:\n",
" log_path = LOGS / f\"{name}.json\"\n",
" if log_path.exists():\n",
" with open(log_path) as f:\n",
" runs[name] = json.load(f)\n",
" else:\n",
" print(f\" Missing: {log_path}\")\n",
"\n",
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
"for name in run_names:\n",
" status = \"✓\" if name in runs else \"✗\"\n",
" print(f\" {status} {name}\")"
]
},
{
"cell_type": "markdown",
"id": "b0000005",
"metadata": {},
"source": [
"## 2. FID Comparison Table (prior samples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000006",
"metadata": {},
"outputs": [],
"source": [
"def get_fid(run, epoch):\n",
" fid = run[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"rows = []\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" r = runs[name]\n",
" cfg = r[\"config\"]\n",
" h = r[\"history\"]\n",
" rows.append({\n",
" \"Step\": run_labels[name],\n",
" \"λ_perc\": cfg.get(\"lambda_perceptual\", 0),\n",
" \"λ_adv\": cfg.get(\"lambda_adversarial\", 0),\n",
" \"β_kl\": cfg.get(\"beta_kl\", 1.0),\n",
" \"FID@25 (prior)\": get_fid(r, 25),\n",
" \"FID@50 (prior)\": get_fid(r, 50),\n",
" \"FID@75 (prior)\": get_fid(r, 75),\n",
" \"FID@100 (prior)\": get_fid(r, 100),\n",
" })\n",
"\n",
"df = pd.DataFrame(rows)\n",
"df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})"
]
},
{
"cell_type": "markdown",
"id": "b0000007",
"metadata": {},
"source": [
"## 3. FID Curves — Evolution Story"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000008",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(10, 5))\n",
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\"]\n",
"\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
" epochs = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in epochs]\n",
" label = f\"{run_labels[name]} (FID@100={fid_dict.get('100', '?'):.1f})\"\n",
" ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"FID (lower is better) — prior samples\")\n",
"ax.set_title(\"Phase 3 — FID Curves: VAE → +Perceptual → +PatchGAN\")\n",
"ax.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000009",
"metadata": {},
"source": [
"## 4. Training Loss Curves"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000010",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(2, 3, figsize=(16, 9))\n",
"axes = axes.flatten()\n",
"keys = [\"recon_loss\", \"kl_loss\", \"perc_loss\", \"adv_g_loss\", \"adv_d_loss\"]\n",
"titles = [\"MSE Reconstruction\", \"KL Divergence\", \"Perceptual (VGG)\", \"Adv G Loss\", \"Adv D Loss (PatchGAN)\"]\n",
"\n",
"for ax, key, title in zip(axes, keys, titles):\n",
" for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" h = runs[name][\"history\"].get(key, [])\n",
" if any(v != 0.0 for v in h):\n",
" ax.plot(range(1, len(h)+1), h, label=run_labels[name],\n",
" color=colors[i], linewidth=1.2, alpha=0.9)\n",
" ax.set_title(title)\n",
" ax.set_xlabel(\"Epoch\")\n",
" ax.legend(fontsize=8)\n",
"\n",
"axes[-1].axis(\"off\") # empty sixth panel\n",
"fig.suptitle(\"Phase 3 — Training Dynamics\", fontsize=13, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000011",
"metadata": {},
"source": [
"## 5. Prior Samples — Epoch 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000012",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=9)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 3 — Epoch 100 Prior Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000013",
"metadata": {},
"source": [
"## 6. Reconstructions — Epoch 100\n",
"\n",
"Left half = real images, right half = reconstructions (interleaved pairs)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000014",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100_recon.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nreal | recon\", fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=9)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 3 — Epoch 100 Reconstructions\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000015",
"metadata": {},
"source": [
"## 7. Step-by-step Pairwise Comparisons"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000016",
"metadata": {},
"outputs": [],
"source": [
"transitions = [\n",
" (\"3.1→3.2: MSE→+Perceptual\", \"p3_1_vae\", \"p3_2_vae_perceptual\"),\n",
" (\"3.2→3.3: +PatchGAN adversarial\", \"p3_2_vae_perceptual\", \"p3_3_vae_patchgan\"),\n",
"]\n",
"\n",
"for title, name_a, name_b in transitions:\n",
" fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
" for col, name in enumerate([name_a, name_b]):\n",
" for row, suffix in enumerate([\"\", \"_recon\"]):\n",
" ax = axes[row][col]\n",
" ep = 100\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}{suffix}.png\"\n",
" label = run_labels[name]\n",
" kind = \"prior\" if suffix == \"\" else \"recon\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], ep) if (suffix == \"\" and name in runs) else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{label}\\n{kind}\" + (f\" FID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(f\"{label} ({kind})\", fontsize=9)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000017",
"metadata": {},
"source": [
"## 8. Progression: Epoch 10 → 50 → 100 (prior samples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000018",
"metadata": {},
"outputs": [],
"source": [
"check_epochs = [10, 50, 100]\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(runs[name], ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
" transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{run_labels[name]} — Prior Sample Progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b0000019",
"metadata": {},
"source": [
"## 9. Conclusions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0000020",
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 70)\n",
"print(\"PHASE 3 — VAE EVOLUTION SUMMARY\")\n",
"print(\"=\" * 70)\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
" continue\n",
" h = runs[name][\"history\"]\n",
" fid100 = get_fid(runs[name], 100)\n",
" fid50 = get_fid(runs[name], 50)\n",
" mse50 = h[\"recon_loss\"][49] if len(h[\"recon_loss\"]) > 49 else None\n",
" kl50 = h[\"kl_loss\"][49] if len(h[\"kl_loss\"]) > 49 else None\n",
" print(f\"\\n {run_labels[name]}:\")\n",
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
" print(f\" MSE@50 = {mse50:.4f}\" if mse50 else \" MSE@50 = ?\")\n",
" print(f\" KL@50 = {kl50:.2f}\" if kl50 else \" KL@50 = ?\")\n",
"\n",
"print(\"\\n\" + \"=\" * 70)\n",
"print(\"Best VAE model for Phase 5 cross-family comparison: fill in after runs.\")\n",
"print(\"=\" * 70)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+403
View File
@@ -0,0 +1,403 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c0000001",
"metadata": {},
"source": [
"# Phase 4 — DDPM Evolution Analysis\n",
"\n",
"Traces the DDPM improvement story:\n",
"\n",
"| Step | Model | Key change | Expected failure |\n",
"|------|-------|------------|------------------|\n",
"| 4.1 | DDPM linear + ε-pred | Baseline | Noise prediction unstable at very low t (linear schedule over-denoises) |\n",
"| 4.2 | + cosine schedule | Less noise wasted at low timesteps | Residual instability from ε parameterisation |\n",
"| 4.3 | + v-prediction | Numerically stable across full trajectory | Possible underfitting at 64×64 |\n",
"| 4.4 | + wider U-Net (192ch) + 32×32 attention | More capacity and longer-range context | — |\n",
"\n",
"FID is computed via DDIM (100 steps, deterministic) from the EMA model.\n",
"H-flip-only augmentation, MTCNN-aligned 64×64 crops, T=1000."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000002",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\""
]
},
{
"cell_type": "markdown",
"id": "c0000003",
"metadata": {},
"source": [
"## 1. Load experiment logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000004",
"metadata": {},
"outputs": [],
"source": [
"run_names = [\"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"]\n",
"run_labels = {\n",
" \"p4_1_ddpm_linear\": \"4.1 linear + ε\",\n",
" \"p4_2_ddpm_cosine\": \"4.2 cosine + ε\",\n",
" \"p4_3_ddpm_vpred\": \"4.3 cosine + v\",\n",
" \"p4_4_ddpm_wider\": \"4.4 wider + 32×32 attn\",\n",
"}\n",
"\n",
"runs = {}\n",
"for name in run_names:\n",
" log_path = LOGS / f\"{name}.json\"\n",
" if log_path.exists():\n",
" with open(log_path) as f:\n",
" runs[name] = json.load(f)\n",
" else:\n",
" print(f\" Missing: {log_path}\")\n",
"\n",
"print(f\"Loaded {len(runs)}/{len(run_names)} experiments:\")\n",
"for name in run_names:\n",
" print(f\" {'✓' if name in runs else '✗'} {name}\")"
]
},
{
"cell_type": "markdown",
"id": "c0000005",
"metadata": {},
"source": [
"## 2. FID Comparison Table"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000006",
"metadata": {},
"outputs": [],
"source": [
"def get_fid(run, epoch):\n",
" fid = run[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"rows = []\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" r = runs[name]\n",
" cfg = r[\"config\"]\n",
" rows.append({\n",
" \"Step\": run_labels[name],\n",
" \"Schedule\": cfg.get(\"noise_schedule\"),\n",
" \"Pred\": cfg.get(\"pred_type\"),\n",
" \"base_ch\": cfg.get(\"base_ch\"),\n",
" \"FID@25\": get_fid(r, 25),\n",
" \"FID@50\": get_fid(r, 50),\n",
" \"FID@75\": get_fid(r, 75),\n",
" \"FID@100\": get_fid(r, 100),\n",
" })\n",
"\n",
"df = pd.DataFrame(rows)\n",
"df.style.format({c: \"{:.1f}\" for c in df.columns if \"FID\" in c})"
]
},
{
"cell_type": "markdown",
"id": "c0000007",
"metadata": {},
"source": [
"## 3. FID Curves — Evolution Story"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000008",
"metadata": {},
"outputs": [],
"source": [
"colors = [\"#5B8DB8\", \"#E8705A\", \"#6ABF69\", \"#B86FB8\"]\n",
"\n",
"fig, ax = plt.subplots(figsize=(11, 5))\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" fid_dict = runs[name][\"history\"][\"fid\"]\n",
" epochs = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in epochs]\n",
" fid100 = fid_dict.get(\"100\", \"?\")\n",
" label = f\"{run_labels[name]} (FID@100={fid100:.1f})\" if isinstance(fid100, float) else run_labels[name]\n",
" ax.plot(epochs, fids, \"o-\", label=label, color=colors[i], linewidth=2, markersize=8)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"FID (DDIM 100 steps, lower is better)\")\n",
"ax.set_title(\"Phase 4 — FID Curves: linear·ε → cosine·ε → cosine·v → wider\")\n",
"ax.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000009",
"metadata": {},
"source": [
"## 4. Training Loss Curves (MSE on predicted target)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000010",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(11, 4))\n",
"for i, name in enumerate(run_names):\n",
" if name not in runs:\n",
" continue\n",
" losses = runs[name][\"history\"][\"loss\"]\n",
" ax.plot(range(1, len(losses)+1), losses, label=run_labels[name],\n",
" color=colors[i], linewidth=1.2, alpha=0.9)\n",
"\n",
"ax.set_xlabel(\"Epoch\")\n",
"ax.set_ylabel(\"MSE (noise / v prediction loss)\")\n",
"ax.set_title(\"Phase 4 — Training Loss\")\n",
"ax.legend(fontsize=9)\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000011",
"metadata": {},
"source": [
"## 5. Sample Grids — Epoch 100 (DDIM 50 steps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000012",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n",
"\n",
"for idx, name in enumerate(run_names):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=8)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=8)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 4 — Epoch 100 DDIM Samples (4×4 grids)\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000013",
"metadata": {},
"source": [
"## 6. Step-by-step Pairwise Comparisons"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000014",
"metadata": {},
"outputs": [],
"source": [
"transitions = [\n",
" (\"4.1→4.2: linear→cosine schedule\", \"p4_1_ddpm_linear\", \"p4_2_ddpm_cosine\"),\n",
" (\"4.2→4.3: ε-pred→v-pred\", \"p4_2_ddpm_cosine\", \"p4_3_ddpm_vpred\"),\n",
" (\"4.3→4.4: 128ch→192ch + 32×32 attn\", \"p4_3_ddpm_vpred\", \"p4_4_ddpm_wider\"),\n",
"]\n",
"\n",
"for title, name_a, name_b in transitions:\n",
" fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
" for ax, name in zip(axes, [name_a, name_b]):\n",
" img_path = SAMPLES / name / \"epoch_0100.png\"\n",
" if img_path.exists():\n",
" fid = get_fid(runs[name], 100) if name in runs else None\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" ax.set_title(f\"{run_labels[name]}\\nFID@100={fid:.1f}\" if fid else run_labels[name], fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Pending\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
" ax.set_title(run_labels[name], fontsize=9)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(title, fontsize=12, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000015",
"metadata": {},
"source": [
"## 7. Progression: Epoch 10 → 50 → 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000016",
"metadata": {},
"outputs": [],
"source": [
"check_epochs = [10, 50, 100]\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" continue\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(12, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / name / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(runs[name], ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
" transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{run_labels[name]} — Progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000017",
"metadata": {},
"source": [
"## 8. Noise Schedule Visualisation\n",
"\n",
"Illustrates why cosine outperforms linear: the linear schedule allocates many timesteps near t=T where the image is already near-pure noise, wasting model capacity."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000018",
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import numpy as np\n",
"\n",
"T = 1000\n",
"t = np.arange(T)\n",
"\n",
"# Linear betas\n",
"betas_lin = np.linspace(1e-4, 0.02, T)\n",
"ab_lin = np.cumprod(1 - betas_lin)\n",
"\n",
"# Cosine betas\n",
"s = 0.008\n",
"f = np.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2\n",
"f = f / f[0]\n",
"betas_cos = np.clip(1 - f[1:] / f[:-1], 0, 0.999)\n",
"ab_cos = np.cumprod(1 - betas_cos)\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(13, 4))\n",
"\n",
"axes[0].plot(t, ab_lin, label=\"linear\", color=\"#5B8DB8\", linewidth=2)\n",
"axes[0].plot(t, ab_cos, label=\"cosine\", color=\"#E8705A\", linewidth=2)\n",
"axes[0].set_xlabel(\"Timestep t\")\n",
"axes[0].set_ylabel(\"ᾱ_t (signal fraction)\")\n",
"axes[0].set_title(\"ᾱ_t vs t — cosine stays informative longer\")\n",
"axes[0].legend()\n",
"\n",
"axes[1].plot(t, np.sqrt(ab_lin), label=\"linear √ᾱ\", color=\"#5B8DB8\", linewidth=2)\n",
"axes[1].plot(t, np.sqrt(ab_cos), label=\"cosine √ᾱ\", color=\"#E8705A\", linewidth=2)\n",
"axes[1].set_xlabel(\"Timestep t\")\n",
"axes[1].set_ylabel(\"√ᾱ_t (signal amplitude)\")\n",
"axes[1].set_title(\"Signal amplitude — cosine is more uniform\")\n",
"axes[1].legend()\n",
"\n",
"plt.suptitle(\"Noise Schedule Comparison\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c0000019",
"metadata": {},
"source": [
"## 9. Conclusions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0000020",
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 70)\n",
"print(\"PHASE 4 — DDPM EVOLUTION SUMMARY\")\n",
"print(\"=\" * 70)\n",
"\n",
"for name in run_names:\n",
" if name not in runs:\n",
" print(f\"\\n {run_labels[name]}: NOT YET RUN\")\n",
" continue\n",
" fid100 = get_fid(runs[name], 100)\n",
" fid50 = get_fid(runs[name], 50)\n",
" h = runs[name][\"history\"]\n",
" loss50 = h[\"loss\"][49] if len(h[\"loss\"]) > 49 else None\n",
" print(f\"\\n {run_labels[name]}:\")\n",
" print(f\" FID@50 = {fid50:.1f}\" if fid50 else \" FID@50 = ?\")\n",
" print(f\" FID@100 = {fid100:.1f}\" if fid100 else \" FID@100 = ?\")\n",
" print(f\" Loss@50 = {loss50:.5f}\" if loss50 else \" Loss@50 = ?\")\n",
"\n",
"print(\"\\n\" + \"=\" * 70)\n",
"print(\"Best DDPM model for Phase 5 comparison: fill in after runs complete.\")\n",
"print(\"=\" * 70)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+669
View File
@@ -0,0 +1,669 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Phase 5 — Cross-Family Comparison\n",
"\n",
"Best-of-each finalist retrained for **200 epochs** under identical data conditions.\n",
"\n",
"| Family | Config | Resolution | Key design |\n",
"|--------|--------|-----------|------------|\n",
"| GAN | `p5_gan` | 128×128 | WGAN-GP + SpectralNorm + GroupNorm + Self-Attention |\n",
"| VAE | `p5_vae` | 64×64 | Convolutional VAE + VGG perceptual + PatchGAN |\n",
"| DDPM | `p5_ddpm` | 64×64 | Wider U-Net (192ch) + cosine schedule + v-prediction |\n",
"\n",
"> **Resolution note**: GAN runs at 128×128 (best architecture from Phase 2.4), while VAE and DDPM run at 64×64. FID is measured at each model's native resolution against real images at that resolution, so scores are not directly numerically comparable across families — they are indicators of within-family improvement."
],
"id": "d0000001"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import json\n",
"import sys\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"plt.rcParams.update({\"figure.dpi\": 120, \"font.size\": 10})\n",
"\n",
"OUTPUTS = Path(\"../outputs\")\n",
"LOGS = OUTPUTS / \"logs\"\n",
"SAMPLES = OUTPUTS / \"samples\"\n",
"\n",
"# Phase 5 finalists and their best phase-2/3/4 counterparts (100-epoch comparison)\n",
"FAMILIES = {\n",
" \"GAN\": {\"p5\": \"p5_gan\", \"p4ep\": \"p2_4_wgan_sn_attn_128\", \"label\": \"WGAN-GP+SN+Attn\", \"color\": \"#5B8DB8\"},\n",
" \"VAE\": {\"p5\": \"p5_vae\", \"p4ep\": \"p3_3_vae_patchgan\", \"label\": \"VAE+Perc+PatchGAN\",\"color\": \"#E8705A\"},\n",
" \"DDPM\": {\"p5\": \"p5_ddpm\", \"p4ep\": \"p4_4_ddpm_wider\", \"label\": \"DDPM wider 192ch\", \"color\": \"#6ABF69\"},\n",
"}"
],
"execution_count": null,
"outputs": [],
"id": "d0000002"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Load logs"
],
"id": "d0000003"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"def load_log(run_name):\n",
" p = LOGS / f\"{run_name}.json\"\n",
" if p.exists():\n",
" with open(p) as f:\n",
" return json.load(f)\n",
" return None\n",
"\n",
"def get_fid(log, epoch):\n",
" if log is None:\n",
" return None\n",
" fid = log[\"history\"][\"fid\"]\n",
" return fid.get(str(epoch), fid.get(epoch, None))\n",
"\n",
"logs_p5 = {fam: load_log(info[\"p5\"]) for fam, info in FAMILIES.items()}\n",
"logs_p4 = {fam: load_log(info[\"p4ep\"]) for fam, info in FAMILIES.items()}\n",
"\n",
"for fam in FAMILIES:\n",
" p5_ok = \"✓\" if logs_p5[fam] else \"✗\"\n",
" p4_ok = \"✓\" if logs_p4[fam] else \"✗\"\n",
" print(f\" {fam}: 200ep={p5_ok} 100ep={p4_ok}\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000004"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Quantitative Summary Table"
],
"id": "d0000005"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"rows = []\n",
"for fam, info in FAMILIES.items():\n",
" log = logs_p5[fam]\n",
" train_time = log.get(\"history\", {}).get(\"train_time_s\") if log else None\n",
" train_min = f\"{train_time / 60:.1f}\" if train_time else \"?\"\n",
" rows.append({\n",
" \"Family\": fam,\n",
" \"Model\": info[\"label\"],\n",
" \"Res\": log.get(\"config\", log).get(\"image_size\", \"?\") if log else \"?\",\n",
" \"Params\": log.get(\"n_params\") if log else None,\n",
" \"Train (min)\": train_min,\n",
" \"FID@100\": get_fid(log, 100),\n",
" \"FID@150\": get_fid(log, 150),\n",
" \"FID@200\": get_fid(log, 200),\n",
" # IS and LPIPS filled in by Section 6\n",
" \"IS ↑\": None,\n",
" \"LPIPS ↑\": None,\n",
" })\n",
"\n",
"df = pd.DataFrame(rows).set_index(\"Family\")\n",
"\n",
"\n",
"def fmt_params(v):\n",
" if v is None:\n",
" return \"?\"\n",
" if v >= 1_000_000:\n",
" return f\"{v / 1_000_000:.1f}M\"\n",
" if v >= 1_000:\n",
" return f\"{v / 1_000:.0f}K\"\n",
" return str(v)\n",
"\n",
"\n",
"df_display = df.copy()\n",
"df_display[\"Params\"] = df_display[\"Params\"].apply(fmt_params)\n",
"df_display.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df_display})"
],
"execution_count": null,
"outputs": [],
"id": "d0000006"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. FID Curves — All Three Families"
],
"id": "d0000007"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" c = info[\"color\"]\n",
" # Phase 5 (200ep) — solid\n",
" log = logs_p5[fam]\n",
" if log:\n",
" fid_dict = log[\"history\"][\"fid\"]\n",
" eps = sorted(int(k) for k in fid_dict)\n",
" fids = [fid_dict[str(e)] for e in eps]\n",
" axes[0].plot(eps, fids, \"-o\", color=c, linewidth=2, markersize=6,\n",
" label=f\"{fam} 200ep (FID@200={fid_dict.get('200','?'):.1f})\" if isinstance(fid_dict.get('200'), float) else f\"{fam} 200ep\")\n",
"\n",
" # Phase 4/3/2 best (100ep) — dashed, same colour\n",
" log4 = logs_p4[fam]\n",
" if log4:\n",
" fid4 = log4[\"history\"][\"fid\"]\n",
" eps4 = sorted(int(k) for k in fid4)\n",
" fids4 = [fid4[str(e)] for e in eps4]\n",
" axes[0].plot(eps4, fids4, \"--\", color=c, linewidth=1.2, alpha=0.55,\n",
" label=f\"{fam} 100ep\")\n",
"\n",
"axes[0].set_xlabel(\"Epoch\")\n",
"axes[0].set_ylabel(\"FID (lower is better)\")\n",
"axes[0].set_title(\"FID Curves — Phase 5 (solid) vs best 100-ep (dashed)\")\n",
"axes[0].legend(fontsize=8)\n",
"\n",
"# Bar chart: FID at 100 vs 200 epochs per family\n",
"fams = list(FAMILIES.keys())\n",
"fid100 = [get_fid(logs_p5[f], 100) for f in fams]\n",
"fid200 = [get_fid(logs_p5[f], 200) for f in fams]\n",
"x = np.arange(len(fams))\n",
"w = 0.35\n",
"bars1 = axes[1].bar(x - w/2, [v or 0 for v in fid100], w, label=\"FID@100\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=0.6)\n",
"bars2 = axes[1].bar(x + w/2, [v or 0 for v in fid200], w, label=\"FID@200\", color=[FAMILIES[f][\"color\"] for f in fams], alpha=1.0)\n",
"axes[1].set_xticks(x); axes[1].set_xticklabels(fams)\n",
"axes[1].set_ylabel(\"FID\"); axes[1].set_title(\"FID: 100ep vs 200ep per family\")\n",
"axes[1].legend()\n",
"for bar in list(bars1) + list(bars2):\n",
" h = bar.get_height()\n",
" if h > 0:\n",
" axes[1].text(bar.get_x() + bar.get_width()/2, h + 1, f\"{h:.0f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000008"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Sample Grids — Epoch 200"
],
"id": "d0000009"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(16, 6))\n",
"\n",
"for idx, (fam, info) in enumerate(FAMILIES.items()):\n",
" ax = axes[idx]\n",
" img_path = SAMPLES / info[\"p5\"] / \"epoch_0200.png\"\n",
" log = logs_p5[fam]\n",
" fid200 = get_fid(log, 200)\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" title = f\"{fam}: {info['label']}\\n\"\n",
" if fid200:\n",
" res = log[\"config\"].get(\"image_size\", \"?\")\n",
" title += f\"FID@200={fid200:.1f} ({res}×{res})\"\n",
" ax.set_title(title, fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, \"Not yet run\", ha=\"center\", va=\"center\", transform=ax.transAxes, fontsize=12)\n",
" ax.set_title(f\"{fam}: {info['label']}\", fontsize=9)\n",
" ax.axis(\"off\")\n",
"\n",
"fig.suptitle(\"Phase 5 — Epoch 200 Sample Grids (4×4, prior samples)\", fontsize=13, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000010"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Training Progression — Epoch 10 → 50 → 100 → 200"
],
"id": "d0000011"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"check_epochs = [10, 50, 100, 200]\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" run = info[\"p5\"]\n",
" log = logs_p5[fam]\n",
" if log is None:\n",
" print(f\"{fam}: not yet run\")\n",
" continue\n",
"\n",
" fig, axes = plt.subplots(1, len(check_epochs), figsize=(16, 4))\n",
" for ax, ep in zip(axes, check_epochs):\n",
" img_path = SAMPLES / run / f\"epoch_{ep:04d}.png\"\n",
" if img_path.exists():\n",
" ax.imshow(mpimg.imread(str(img_path)))\n",
" fid = get_fid(log, ep)\n",
" ax.set_title(f\"Ep {ep}\" + (f\"\\nFID={fid:.1f}\" if fid else \"\"), fontsize=9)\n",
" else:\n",
" ax.text(0.5, 0.5, f\"Ep {ep}\\n(pending)\", ha=\"center\", va=\"center\",\n",
" transform=ax.transAxes)\n",
" ax.axis(\"off\")\n",
" fig.suptitle(f\"{fam} ({info['label']}) — 200-epoch progression\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000012"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Extended Metrics — IS and LPIPS Diversity\n",
"\n",
"Requires loading trained model weights. Run after all phase 5 models have finished.\n",
"Generates 5 000 samples per model and computes IS and LPIPS over 200 random pairs."
],
"id": "d0000013"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import sys\n",
"sys.path.insert(0, \"..\")\n",
"\n",
"import torch\n",
"from src.utils import load_config\n",
"from src.models import get_model\n",
"from src.training.metrics import compute_is, compute_lpips_diversity\n",
"from src.training.diffusion import cosine_betas, make_alpha_bars, ddim_sample\n",
"from src.training.ema import EMA\n",
"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"N_SAMPLE = 5_000\n",
"\n",
"def load_ema_model(run_name, config_path):\n",
" \"\"\"Load the best EMA weights for a given phase-5 run.\"\"\"\n",
" cfg = load_config(str(Path(\"../configs/phase5\") / config_path))\n",
" model_obj, kind = get_model(cfg)\n",
" ema_path = Path(\"../outputs/models\") / f\"{run_name}_best_ema.pt\"\n",
" if not ema_path.exists():\n",
" ema_path = Path(\"../outputs/models\") / f\"{run_name}_final_ema.pt\"\n",
" if isinstance(model_obj, tuple):\n",
" model = model_obj[0] # generator for GAN\n",
" else:\n",
" model = model_obj\n",
" model.load_state_dict(torch.load(ema_path, map_location=DEVICE))\n",
" return model.to(DEVICE).eval(), cfg, kind\n",
"\n",
"\n",
"@torch.no_grad()\n",
"def generate_samples(run_name, config_path, n=N_SAMPLE):\n",
" model, cfg, kind = load_ema_model(run_name, config_path)\n",
" image_size = cfg.get(\"image_size\", 64)\n",
"\n",
" if kind == \"wgan\":\n",
" latent_dim = cfg.get(\"latent_dim\", 128)\n",
" imgs = torch.cat([\n",
" model(torch.randn(min(64, n - i), latent_dim, 1, 1, device=DEVICE))\n",
" for i in range(0, n, 64)\n",
" ])[:n].cpu()\n",
"\n",
" elif kind == \"vae\":\n",
" imgs = torch.cat([\n",
" model.sample(min(64, n - i), DEVICE)\n",
" for i in range(0, n, 64)\n",
" ])[:n].cpu()\n",
"\n",
" elif kind == \"ddpm\":\n",
" schedule = cfg.get(\"noise_schedule\", \"cosine\")\n",
" pred_type = cfg.get(\"pred_type\", \"v\")\n",
" T = cfg.get(\"T\", 1000)\n",
" from src.training.diffusion import cosine_betas, linear_betas, make_alpha_bars, ddim_sample\n",
" betas = (cosine_betas(T) if schedule == \"cosine\" else linear_betas(T)).to(DEVICE)\n",
" ab = make_alpha_bars(betas)\n",
" imgs = ddim_sample(model, n, image_size, ab, n_steps=100,\n",
" pred_type=pred_type, device=DEVICE, batch_size=32)\n",
" return imgs\n",
"\n",
"\n",
"print(\"Run this cell once all phase-5 models have completed.\")\n",
"print(\"Expected to take ~510 minutes per family on an RTX 3090.\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000014"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# ── Compute IS and LPIPS per family ──────────────────────────────────────────\n",
"# Uncomment when models are ready.\n",
"\n",
"extended_metrics = {}\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" run = info[\"p5\"]\n",
" config = f\"{run}.json\"\n",
" try:\n",
" print(f\"\\n{fam}: generating {N_SAMPLE} samples...\")\n",
" imgs = generate_samples(run, config)\n",
"\n",
" print(f\" Computing IS...\")\n",
" is_mean, is_std = compute_is(imgs, device=DEVICE)\n",
"\n",
" print(f\" Computing LPIPS diversity...\")\n",
" lpips = compute_lpips_diversity(imgs, n_pairs=200, device=DEVICE)\n",
"\n",
" extended_metrics[fam] = {\"IS_mean\": is_mean, \"IS_std\": is_std, \"LPIPS\": lpips}\n",
" print(f\" IS = {is_mean:.2f} ± {is_std:.2f} LPIPS = {lpips:.4f}\")\n",
"\n",
" except Exception as e:\n",
" print(f\" Skipped ({e})\")\n",
" extended_metrics[fam] = {}\n",
"\n",
"# Merge with FID table\n",
"for fam in FAMILIES:\n",
" em = extended_metrics.get(fam, {})\n",
" idx = list(FAMILIES.keys()).index(fam)\n",
" df.loc[fam, \"IS ↑\"] = f\"{em['IS_mean']:.2f}±{em['IS_std']:.2f}\" if \"IS_mean\" in em else \"—\"\n",
" df.loc[fam, \"LPIPS ↑\"] = f\"{em['LPIPS']:.4f}\" if \"LPIPS\" in em else \"—\"\n",
"\n",
"df.style.format({c: \"{:.1f}\" for c in [\"FID@100\", \"FID@150\", \"FID@200\"] if c in df})"
],
"execution_count": null,
"outputs": [],
"id": "d0000015"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Latent Interpolation — GAN and VAE\n",
"\n",
"Smooth interpolation between two latent codes reveals whether the generator has learned a\n",
"continuous manifold. DDPM has no encoder, so interpolation is done by different noise seeds."
],
"id": "d0000016"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Spherical linear interpolation (slerp)\n",
"def slerp(z1, z2, t):\n",
" z1_n = z1 / z1.norm()\n",
" z2_n = z2 / z2.norm()\n",
" omega = torch.acos((z1_n * z2_n).sum().clamp(-1, 1))\n",
" if omega.abs() < 1e-6:\n",
" return (1 - t) * z1 + t * z2\n",
" return (torch.sin((1-t)*omega)/torch.sin(omega)) * z1 + \\\n",
" (torch.sin(t*omega)/torch.sin(omega)) * z2\n",
"\n",
"\n",
"def gan_interpolation(model, latent_dim, n_steps=10, device=DEVICE):\n",
" z1 = torch.randn(1, latent_dim, 1, 1, device=device)\n",
" z2 = torch.randn(1, latent_dim, 1, 1, device=device)\n",
" alphas = torch.linspace(0, 1, n_steps)\n",
" imgs = []\n",
" with torch.no_grad():\n",
" for a in alphas:\n",
" z = slerp(z1.flatten(), z2.flatten(), a.item()).view_as(z1)\n",
" imgs.append(model(z).cpu())\n",
" return torch.cat(imgs)\n",
"\n",
"\n",
"def vae_interpolation(model, real_imgs, n_steps=10, device=DEVICE):\n",
" \"\"\"Encode two real images, interpolate in latent space, decode.\"\"\"\n",
" img1, img2 = real_imgs[:1].to(device), real_imgs[1:2].to(device)\n",
" with torch.no_grad():\n",
" mu1, _ = model.encode(img1)\n",
" mu2, _ = model.encode(img2)\n",
" alphas = torch.linspace(0, 1, n_steps, device=device)\n",
" imgs = [model.decode((1-a)*mu1 + a*mu2).cpu() for a in alphas]\n",
" return torch.cat(imgs)\n",
"\n",
"\n",
"print(\"Interpolation helpers defined. Run cells below after loading models.\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000017"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# ── GAN interpolation ─────────────────────────────────────────────────────────\n",
"try:\n",
" gan_model, gan_cfg, _ = load_ema_model(\"p5_gan\", \"p5_gan.json\")\n",
" latent_dim = gan_cfg.get(\"latent_dim\", 128)\n",
" interp_imgs = gan_interpolation(gan_model, latent_dim, n_steps=10)\n",
" interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n",
"\n",
" fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n",
" for ax, img in zip(axes, interp_imgs):\n",
" ax.imshow(img.permute(1, 2, 0).numpy())\n",
" ax.axis(\"off\")\n",
" fig.suptitle(\"GAN — Slerp latent interpolation (z₁ → z₂)\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
"except Exception as e:\n",
" print(f\"GAN interpolation: {e}\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000018"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# ── VAE interpolation ─────────────────────────────────────────────────────────\n",
"try:\n",
" from src.data import GeneratorDataset, get_transform\n",
"\n",
" vae_model, vae_cfg, _ = load_ema_model(\"p5_vae\", \"p5_vae.json\")\n",
" ds = GeneratorDataset(\"../../\" + vae_cfg[\"data_dir\"],\n",
" sources=vae_cfg.get(\"sources\", [\"wiki\"]),\n",
" transform=get_transform(vae_cfg[\"image_size\"], augment=False))\n",
" sample_real = torch.stack([ds[i] for i in range(2)])\n",
"\n",
" interp_imgs = vae_interpolation(vae_model, sample_real, n_steps=10)\n",
" interp_imgs = (interp_imgs.clamp(-1, 1) + 1) / 2\n",
"\n",
" fig, axes = plt.subplots(1, 10, figsize=(20, 2.5))\n",
" for ax, img in zip(axes, interp_imgs):\n",
" ax.imshow(img.permute(1, 2, 0).numpy())\n",
" ax.axis(\"off\")\n",
" fig.suptitle(\"VAE — μ-space linear interpolation (image₁ → image₂)\", fontsize=11, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
"except Exception as e:\n",
" print(f\"VAE interpolation: {e}\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000019"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Failure Mode Analysis\n",
"\n",
"Identify the worst-generated images per model: those farthest from their nearest real neighbour in pixel space, or simply those with highest reconstruction error (VAE) or highest DDPM loss."
],
"id": "d0000020"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# For GAN and DDPM: generate 256 images, pick the 8 with lowest mean activation\n",
"# (a proxy for less-coherent images — very rough)\n",
"\n",
"def worst_samples(imgs, n=8):\n",
" \"\"\"Heuristic: pick samples with lowest per-pixel mean (often darker / less structured).\"\"\"\n",
" scores = imgs.mean(dim=[1, 2, 3]) # mean brightness per image\n",
" worst_idx = scores.argsort()[:n]\n",
" return imgs[worst_idx]\n",
"\n",
"print(\"Failure mode analysis requires generated samples — run after model loading (Section 6).\")"
],
"execution_count": null,
"outputs": [],
"id": "d0000021"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Optionally run with already-generated `imgs` from Section 6:\n",
"# worst = worst_samples(imgs.cpu())\n",
"# worst = (worst.clamp(-1,1) + 1) / 2\n",
"# fig, axes = plt.subplots(1, 8, figsize=(18, 2.5))\n",
"# for ax, img in zip(axes, worst):\n",
"# ax.imshow(img.permute(1,2,0)); ax.axis('off')\n",
"# plt.suptitle('Failure modes (lowest-brightness samples)', fontsize=11)\n",
"# plt.tight_layout(); plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000022"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Training Loss Overview (all families)"
],
"id": "d0000023"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(18, 4))\n",
"\n",
"for ax, (fam, info) in zip(axes, FAMILIES.items()):\n",
" log = logs_p5[fam]\n",
" if log is None:\n",
" ax.set_title(f\"{fam} (not yet run)\"); ax.axis(\"off\"); continue\n",
" h = log[\"history\"]\n",
" c = info[\"color\"]\n",
"\n",
" if fam == \"GAN\":\n",
" ax.plot(h[\"g_loss\"], label=\"G loss\", color=c, linewidth=1.2)\n",
" ax.plot(h[\"w_dist\"], label=\"W-dist\", color=c, linewidth=1.2, linestyle=\"--\")\n",
" ax.set_ylabel(\"Loss / W-distance\")\n",
" elif fam == \"VAE\":\n",
" ax.plot(h[\"recon_loss\"], label=\"MSE\", color=c)\n",
" ax2 = ax.twinx()\n",
" ax2.plot(h[\"kl_loss\"], label=\"KL\", color=\"grey\", linestyle=\"--\")\n",
" ax2.set_ylabel(\"KL\", color=\"grey\")\n",
" ax.set_ylabel(\"MSE\")\n",
" elif fam == \"DDPM\":\n",
" ax.plot(h[\"loss\"], label=\"MSE (v-pred)\", color=c)\n",
" ax.set_ylabel(\"MSE loss\")\n",
"\n",
" ax.set_xlabel(\"Epoch\")\n",
" ax.set_title(f\"{fam}: {info['label']}\")\n",
" ax.legend(fontsize=8)\n",
"\n",
"fig.suptitle(\"Phase 5 — Training Dynamics (200 epochs)\", fontsize=12, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.show()"
],
"execution_count": null,
"outputs": [],
"id": "d0000024"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10. Conclusions"
],
"id": "d0000025"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"print(\"=\" * 72)\n",
"print(\"PHASE 5 — CROSS-FAMILY COMPARISON (200 epochs)\")\n",
"print(\"=\" * 72)\n",
"\n",
"for fam, info in FAMILIES.items():\n",
" log = logs_p5[fam]\n",
" em = extended_metrics.get(fam, {})\n",
" print(f\"\\n ── {fam}: {info['label']} ──\")\n",
" if log:\n",
" res = log.get(\"config\", log).get(\"image_size\", \"?\")\n",
" print(f\" Resolution : {res}×{res}\")\n",
" n_p = log.get(\"n_params\")\n",
" print(f\" Params : {n_p:,}\" if n_p else \" Params : ?\")\n",
" tt = log.get(\"history\", {}).get(\"train_time_s\")\n",
" print(f\" Train time : {tt / 60:.1f} min\" if tt else \" Train time : ?\")\n",
" for ep in (100, 150, 200):\n",
" fid = get_fid(log, ep)\n",
" print(f\" FID@{ep:<3} : {fid:.1f}\" if fid else f\" FID@{ep:<3} : ?\")\n",
" if em:\n",
" print(f\" IS : {em.get('IS_mean','?'):.2f} ± {em.get('IS_std','?'):.2f}\" if 'IS_mean' in em else \" IS : ?\")\n",
" print(f\" LPIPS div : {em.get('LPIPS','?'):.4f}\" if 'LPIPS' in em else \" LPIPS div : ?\")\n",
" else:\n",
" print(\" IS / LPIPS : (run Section 6 to compute)\")\n",
"\n",
"print(\"\\n\" + \"=\" * 72)\n",
"print(\"Narrative to fill in after results:\")\n",
"print(\" - Which family achieves best FID?\")\n",
"print(\" - GAN: fast convergence but mode collapse risk?\")\n",
"print(\" - VAE: blurry priors improved by perceptual+adversarial loss?\")\n",
"print(\" - DDPM: highest quality but slowest inference (100 DDIM steps)?\")\n",
"print(\"=\" * 72)"
],
"execution_count": null,
"outputs": [],
"id": "d0000026"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@@ -1,19 +0,0 @@
{
"created_at": "2026-04-30T01:21:57.632405+00:00",
"config_paths": [
"generator/configs/phase1/_base_dcgan.json",
"generator/configs/phase1/p1a_dcgan_128.json",
"generator/configs/phase1/p1a_dcgan_64.json",
"generator/configs/phase1/p1b_dcgan_aligned.json",
"generator/configs/phase1/p1b_dcgan_full.json",
"generator/configs/phase1/p1c_dcgan_full_aug.json",
"generator/configs/phase1/p1c_dcgan_hflip.json",
"generator/configs/phase1/p1d_dcgan_combined.json"
],
"instance_id": 35870989,
"offer_id": 26314481,
"ssh_host": "ssh4.vast.ai",
"ssh_port": 30988,
"status": "failed",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -1,18 +0,0 @@
{
"created_at": "2026-04-30T02:33:57.318888+00:00",
"config_paths": [
"generator/configs/phase1/p1a_dcgan_128.json",
"generator/configs/phase1/p1a_dcgan_64.json",
"generator/configs/phase1/p1b_dcgan_aligned.json",
"generator/configs/phase1/p1b_dcgan_full.json",
"generator/configs/phase1/p1c_dcgan_full_aug.json",
"generator/configs/phase1/p1c_dcgan_hflip.json",
"generator/configs/phase1/p1d_dcgan_combined.json"
],
"instance_id": 35873942,
"offer_id": 35472561,
"ssh_host": "ssh4.vast.ai",
"ssh_port": 33942,
"status": "cancelled",
"remote_workspace": "/workspace/DRL_PROJ"
}
+27 -2
View File
@@ -32,7 +32,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
import torch import torch
from src.data import GeneratorDataset, get_transform from src.data import GeneratorDataset, get_transform
from src.models import get_model from src.models import get_model
from src.training import train_dcgan from src.training import train_dcgan, train_wgan, train_vae, train_ddpm
from src.utils import load_config from src.utils import load_config
cfg = load_config(config_path) cfg = load_config(config_path)
@@ -50,6 +50,13 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
model, kind = get_model(cfg) model, kind = get_model(cfg)
# Count total trainable parameters
if isinstance(model, tuple):
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad)
else:
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {n_params:,}")
augment = cfg.get("augment", True) augment = cfg.get("augment", True)
transform = get_transform(cfg.get("image_size", 128), augment=augment) transform = get_transform(cfg.get("image_size", 128), augment=augment)
dataset = GeneratorDataset( dataset = GeneratorDataset(
@@ -66,13 +73,31 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
generator, discriminator, dataset, cfg, generator, discriminator, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device, save_dir=models_dir, run_name=run_name, device=device,
) )
elif kind == "wgan":
generator, critic = model
history = train_wgan(
generator, critic, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
elif kind == "vae":
history = train_vae(
model, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
elif kind == "ddpm":
history = train_ddpm(
model, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
else: else:
raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase") raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase")
logs_dir.mkdir(parents=True, exist_ok=True) logs_dir.mkdir(parents=True, exist_ok=True)
out = logs_dir / f"{run_name}.json" out = logs_dir / f"{run_name}.json"
log_data = {"run_name": run_name, "config": cfg, "history": history}
log_data["n_params"] = n_params
with open(out, "w") as f: with open(out, "w") as f:
json.dump({"run_name": run_name, "config": cfg, "history": history}, f, indent=2) json.dump(log_data, f, indent=2)
print(f"\nSaved log to {out}") print(f"\nSaved log to {out}")
+7 -4
View File
@@ -51,17 +51,20 @@ class GeneratorDataset(Dataset):
return img return img
def get_transform(image_size: int, augment: bool = False) -> T.Compose: def get_transform(image_size: int, augment=False) -> T.Compose:
"""Build transform for generator training. Output is in [-1, 1]. """Build transform for generator training. Output is in [-1, 1].
augment=True adds horizontal flip + mild rotation + mild color jitter. augment=False — no augmentation (for FID real-image sets)
Use augment=False for validation / FID real-image sets. augment="hflip" — horizontal flip only (recommended for VAE/DDPM)
augment=True — H-flip + rotation ±5° + mild color jitter (for GAN)
""" """
ops = [ ops = [
T.Resize(image_size), T.Resize(image_size),
T.CenterCrop(image_size), T.CenterCrop(image_size),
] ]
if augment: if augment == "hflip":
ops.append(T.RandomHorizontalFlip(p=0.5))
elif augment:
ops += [ ops += [
T.RandomHorizontalFlip(p=0.5), T.RandomHorizontalFlip(p=0.5),
T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR), T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR),
+3
View File
@@ -23,4 +23,7 @@ def get_model(cfg: dict) -> tuple:
from src.models import dcgan # noqa: E402, F401 from src.models import dcgan # noqa: E402, F401
from src.models import wgan # noqa: E402, F401
from src.models import vae # noqa: E402, F401
from src.models import unet # noqa: E402, F401
+69
View File
@@ -0,0 +1,69 @@
"""PatchGAN discriminator for Phase 3.3 (VQGAN-lite adversarial training).
Outputs a spatial patch map instead of a single scalar — each patch
predicts real/fake independently. Loss is the mean over all patches.
Not registered in the model registry; instantiated inside train_vae
when lambda_adversarial > 0.
"""
import torch
import torch.nn as nn
def _init_weights(m):
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
class PatchGANDiscriminator(nn.Module):
"""Stride-2 + stride-1 convolution chain → spatial patch logit map.
Supports image_size ∈ {64, 128}. For 64×64 input the final map is 6×6
(70×70 receptive field). For 128×128 an extra stride-2 layer is added.
InstanceNorm everywhere except the first layer.
"""
def __init__(self, ndf: int = 64, image_size: int = 64):
super().__init__()
layers: list[nn.Module] = [
# First layer: no norm
nn.Conv2d(3, ndf, 4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
]
if image_size >= 128:
layers += [
nn.Conv2d(ndf, ndf, 4, stride=2, padding=1, bias=False),
nn.InstanceNorm2d(ndf, affine=True),
nn.LeakyReLU(0.2, inplace=True),
]
layers += [
nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1, bias=False),
nn.InstanceNorm2d(ndf * 2, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1, bias=False),
nn.InstanceNorm2d(ndf * 4, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, ndf * 8, 4, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(ndf * 8, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 8, 1, 4, stride=1, padding=1, bias=True),
]
self.net = nn.Sequential(*layers)
self.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x) # (B, 1, H', W') — patch logit map
def hinge_d_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor:
"""Hinge loss for the discriminator (Lim & Ye, 2017)."""
loss_real = torch.mean(torch.relu(1.0 - real_logits))
loss_fake = torch.mean(torch.relu(1.0 + fake_logits))
return 0.5 * (loss_real + loss_fake)
def hinge_g_loss(fake_logits: torch.Tensor) -> torch.Tensor:
"""Generator hinge loss — maximise D(fake)."""
return -torch.mean(fake_logits)
+279
View File
@@ -0,0 +1,279 @@
"""Time-conditioned U-Net for DDPM (Phase 4).
Architecture follows Ho et al. (2020) with options from Nichol & Dhariwal (2021):
- Sinusoidal time embedding → MLP → added to every ResBlock
- GroupNorm (32 groups) + SiLU activations throughout
- Self-attention at configurable spatial resolutions
- Upsample(nearest) + Conv in the decoder — no checkerboard artefacts
Registered as kind="ddpm".
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models import register
_GN = 32 # GroupNorm groups — all channel counts used here are multiples of 32
# ── Time embedding ────────────────────────────────────────────────────────────
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.dim // 2
freqs = torch.exp(
-math.log(10000) * torch.arange(half, device=t.device, dtype=torch.float) / half
)
angles = t[:, None].float() * freqs[None] # (B, half)
return torch.cat([angles.sin(), angles.cos()], dim=-1) # (B, dim)
# ── Core building blocks ──────────────────────────────────────────────────────
class ResBlock(nn.Module):
"""ResNet block with time-embedding injection (additive, after first conv)."""
def __init__(self, in_ch: int, out_ch: int, t_emb_dim: int, dropout: float = 0.1):
super().__init__()
self.norm1 = nn.GroupNorm(_GN, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.t_proj = nn.Linear(t_emb_dim, out_ch)
self.norm2 = nn.GroupNorm(_GN, out_ch)
self.drop = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
h = self.conv1(F.silu(self.norm1(x)))
h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
h = self.conv2(self.drop(F.silu(self.norm2(h))))
return h + self.skip(x)
class AttentionBlock(nn.Module):
"""Single-head self-attention with GroupNorm pre-norm and residual."""
def __init__(self, ch: int):
super().__init__()
self.norm = nn.GroupNorm(_GN, ch)
self.qkv = nn.Conv2d(ch, ch * 3, 1, bias=False)
self.proj = nn.Conv2d(ch, ch, 1)
self._scale = ch ** -0.5
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
n = h * w
qkv = self.norm(x)
q, k, v = self.qkv(qkv).reshape(b, 3, c, n).unbind(1) # each (b, c, n)
attn = torch.softmax(q.transpose(-2, -1) @ k * self._scale, dim=-1) # (b, n, n)
out = (v @ attn.transpose(-2, -1)).reshape(b, c, h, w)
return x + self.proj(out)
class Downsample(nn.Module):
def __init__(self, ch: int):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 4, stride=2, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Upsample(nn.Module):
def __init__(self, ch: int):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(F.interpolate(x, scale_factor=2, mode="nearest"))
# ── Down / Up blocks ──────────────────────────────────────────────────────────
class DownBlock(nn.Module):
def __init__(
self,
in_ch: int,
out_ch: int,
t_emb_dim: int,
num_res_blocks: int,
with_attn: bool,
dropout: float,
):
super().__init__()
self.resnets = nn.ModuleList(
ResBlock(in_ch if j == 0 else out_ch, out_ch, t_emb_dim, dropout)
for j in range(num_res_blocks)
)
self.attn = AttentionBlock(out_ch) if with_attn else nn.Identity()
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
for res in self.resnets:
x = res(x, t_emb)
return self.attn(x)
class UpBlock(nn.Module):
def __init__(
self,
in_ch: int,
skip_ch: int,
out_ch: int,
t_emb_dim: int,
num_res_blocks: int,
with_attn: bool,
dropout: float,
):
super().__init__()
# First ResBlock absorbs the skip-connection channels via concat
self.resnets = nn.ModuleList(
ResBlock(
(in_ch + skip_ch) if j == 0 else out_ch,
out_ch,
t_emb_dim,
dropout,
)
for j in range(num_res_blocks + 1) # +1 to consume the concat
)
self.attn = AttentionBlock(out_ch) if with_attn else nn.Identity()
def forward(
self, x: torch.Tensor, skip: torch.Tensor, t_emb: torch.Tensor
) -> torch.Tensor:
x = torch.cat([x, skip], dim=1)
for res in self.resnets:
x = res(x, t_emb)
return self.attn(x)
# ── U-Net ─────────────────────────────────────────────────────────────────────
class UNet(nn.Module):
"""Time-conditioned U-Net.
image_size — must be a power-of-two; 64 or 128 recommended.
base_ch — base channel count (128 for phases 4.14.3, 192 for 4.4).
ch_mult — channel multipliers per resolution level.
attn_resolutions — spatial resolutions at which attention is inserted.
num_res_blocks — ResBlocks per level (in both down and up paths).
dropout — applied inside every ResBlock.
"""
def __init__(
self,
image_size: int = 64,
base_ch: int = 128,
ch_mult: tuple = (1, 2, 2, 2),
attn_resolutions: tuple = (16, 8),
num_res_blocks: int = 2,
dropout: float = 0.1,
):
super().__init__()
n_levels = len(ch_mult)
chs = [base_ch * m for m in ch_mult]
t_emb_dim = base_ch * 4
# ── Time embedding ────────────────────────────────────────────────
self.time_embed = nn.Sequential(
SinusoidalPosEmb(base_ch),
nn.Linear(base_ch, t_emb_dim),
nn.SiLU(),
nn.Linear(t_emb_dim, t_emb_dim),
)
# ── Input projection ──────────────────────────────────────────────
self.in_conv = nn.Conv2d(3, chs[0], 3, padding=1)
# ── Down path ─────────────────────────────────────────────────────
self.down_blocks = nn.ModuleList()
self.downsamples = nn.ModuleList()
cur_res = image_size
prev_ch = chs[0]
for i, ch in enumerate(chs):
with_attn = (cur_res in attn_resolutions)
self.down_blocks.append(
DownBlock(prev_ch, ch, t_emb_dim, num_res_blocks, with_attn, dropout)
)
if i < n_levels - 1:
self.downsamples.append(Downsample(ch))
cur_res //= 2
prev_ch = ch
# ── Middle ────────────────────────────────────────────────────────
mid_ch = chs[-1]
self.mid_res1 = ResBlock(mid_ch, mid_ch, t_emb_dim, dropout)
self.mid_attn = AttentionBlock(mid_ch)
self.mid_res2 = ResBlock(mid_ch, mid_ch, t_emb_dim, dropout)
# ── Up path ───────────────────────────────────────────────────────
# Mirrors down path: iterate chs in reverse; skip_ch = chs[n_levels-1-i].
self.up_blocks = nn.ModuleList()
self.upsamples = nn.ModuleList()
in_ch = mid_ch
for i in range(n_levels):
level = n_levels - 1 - i # index from deep (n-1) to shallow (0)
skip_ch = chs[level]
out_ch = chs[level - 1] if level > 0 else chs[0]
with_attn = (cur_res in attn_resolutions)
self.up_blocks.append(
UpBlock(in_ch, skip_ch, out_ch, t_emb_dim, num_res_blocks, with_attn, dropout)
)
if level > 0:
self.upsamples.append(Upsample(out_ch))
cur_res *= 2
in_ch = out_ch
# ── Output ────────────────────────────────────────────────────────
self.out_norm = nn.GroupNorm(_GN, chs[0])
self.out_conv = nn.Conv2d(chs[0], 3, 3, padding=1)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
t_emb = self.time_embed(t)
x = self.in_conv(x)
# Down
skips = []
ds_idx = 0
for i, block in enumerate(self.down_blocks):
x = block(x, t_emb)
skips.append(x)
if ds_idx < len(self.downsamples):
x = self.downsamples[ds_idx](x)
ds_idx += 1
# Middle
x = self.mid_res1(x, t_emb)
x = self.mid_attn(x)
x = self.mid_res2(x, t_emb)
# Up
us_idx = 0
for i, block in enumerate(self.up_blocks):
x = block(x, skips[-(i + 1)], t_emb)
if us_idx < len(self.upsamples):
x = self.upsamples[us_idx](x)
us_idx += 1
return self.out_conv(F.silu(self.out_norm(x)))
def _build(cfg: dict):
return UNet(
image_size = cfg.get("image_size", 64),
base_ch = cfg.get("base_ch", 128),
ch_mult = tuple(cfg.get("ch_mult", [1, 2, 2, 2])),
attn_resolutions = tuple(cfg.get("attn_resolutions", [16, 8])),
num_res_blocks = cfg.get("num_res_blocks", 2),
dropout = cfg.get("dropout", 0.1),
)
register("ddpm", _build, kind="ddpm")
+132
View File
@@ -0,0 +1,132 @@
"""Convolutional VAE for Phase 3.
Encoder uses stride-2 Conv → flatten → linear (μ, log σ²).
Decoder uses Linear → Upsample(nearest) + Conv to avoid ConvTranspose2d
checkerboard artefacts.
Registered as kind="vae". The run.py dispatcher passes the model to
train_vae(), which internally builds perceptual loss and PatchGAN when
the corresponding lambdas are non-zero.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models import register
def _init_weights(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) and m.weight is not None:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
class VAE(nn.Module):
"""Convolutional VAE. image_size must be a power-of-two ≥ 32.
Spatial bottleneck is always at 4×4 regardless of image_size —
the encoder and decoder scale the number of stride-2 steps accordingly.
"""
def __init__(self, latent_dim: int = 256, ngf: int = 64, image_size: int = 64):
super().__init__()
if image_size < 32 or (image_size & (image_size - 1)):
raise ValueError(f"image_size must be a power-of-two ≥ 32, got {image_size}")
self.latent_dim = latent_dim
self.image_size = image_size
n_down = int(math.log2(image_size)) - 2 # steps from image_size to 4×4
# 64 → n_down=4: 64→32→16→8→4
# 128 → n_down=5: 128→64→32→16→8→4
# ── Encoder ──────────────────────────────────────────────────────────
enc_layers: list[nn.Module] = [
nn.Conv2d(3, ngf, 4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
]
ch = ngf
for _ in range(n_down - 1):
enc_layers += [
nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ch * 2),
nn.LeakyReLU(0.2, inplace=True),
]
ch *= 2
# ch = ngf * 2^(n_down-1); spatial = 4×4
self.encoder = nn.Sequential(*enc_layers)
flat = ch * 4 * 4
self.fc_mu = nn.Linear(flat, latent_dim)
self.fc_lv = nn.Linear(flat, latent_dim)
# ── Decoder ──────────────────────────────────────────────────────────
self.fc_dec = nn.Linear(latent_dim, flat)
self._dec_ch = ch # channels at the 4×4 bottleneck
dec_layers: list[nn.Module] = []
for _ in range(n_down - 1):
dec_layers.append(_upsample_block(ch, ch // 2))
ch //= 2
# Final upsample to image_size, output 3 channels, no BN, Tanh
dec_layers += [
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(ch, 3, 3, padding=1, bias=True),
nn.Tanh(),
]
self.decoder = nn.Sequential(*dec_layers)
self.apply(_init_weights)
# ── Interface ────────────────────────────────────────────────────────────
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = self.encoder(x).flatten(1)
return self.fc_mu(h), self.fc_lv(h)
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * log_var)
return mu + std * torch.randn_like(std)
def decode(self, z: torch.Tensor) -> torch.Tensor:
h = self.fc_dec(z).view(z.size(0), self._dec_ch, 4, 4)
return self.decoder(h)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (reconstruction, mu, log_var)."""
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
@torch.no_grad()
def sample(self, n: int, device) -> torch.Tensor:
"""Sample n images by drawing z ~ N(0, I)."""
z = torch.randn(n, self.latent_dim, device=device)
return self.decode(z)
def _build(cfg: dict):
return VAE(
latent_dim=cfg.get("latent_dim", 256),
ngf=cfg.get("ngf", 64),
image_size=cfg.get("image_size", 64),
)
register("vae", _build, kind="vae")
+191 -63
View File
@@ -1,26 +1,31 @@
"""WGAN-GP with spectral normalization, self-attention, and GroupNorm. """WGAN-GP variants.
Improvements over the original: wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only.
- Generator: BatchNorm -> GroupNorm (no batch-size coupling, stable with varied content) wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic.
- Critic: InstanceNorm -> spectral normalization (principled Lipschitz constraint)
- Both: one SAGAN-style self-attention block at the 32x32 feature map
- Larger capacity: ngf=128, ndf=128
""" """
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from src.models import register from src.models import register
def _init_weights(m): def _init_weights(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.normal_(m.weight, 0.0, 0.02) nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.GroupNorm) and m.weight is not None: elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and m.weight is not None:
nn.init.normal_(m.weight, 1.0, 0.02) nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
def _sn(module):
return nn.utils.spectral_norm(module)
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
"""SAGAN-style self-attention."""
def __init__(self, in_ch: int): def __init__(self, in_ch: int):
super().__init__() super().__init__()
mid = max(in_ch // 8, 1) mid = max(in_ch // 8, 1)
@@ -36,98 +41,221 @@ class SelfAttention(nn.Module):
k = self.k(x).view(b, self._mid, -1) # (b, mid, hw) k = self.k(x).view(b, self._mid, -1) # (b, mid, hw)
v = self.v(x).view(b, c, -1) # (b, c, hw) v = self.v(x).view(b, c, -1) # (b, c, hw)
attn = torch.softmax(q @ k * self._mid ** -0.5, dim=-1) # (b, hw, hw) attn = torch.softmax(q @ k * self._mid ** -0.5, dim=-1) # (b, hw, hw)
out = (v @ attn.transpose(-2, -1)).view(b, c, h, w) return x + self.gamma * (v @ attn.transpose(-2, -1)).view(b, c, h, w)
return x + self.gamma * out
def _sn(module): # ---------------------------------------------------------------------------
"""Apply spectral normalization to a conv layer.""" # Phase 2.2 — basic WGAN-GP (BatchNorm in G, InstanceNorm in D, 64×64 only)
return nn.utils.spectral_norm(module) # ---------------------------------------------------------------------------
class WGANBasicGenerator(nn.Module):
"""Maps (latent_dim, 1, 1) -> (3, 64, 64) in [-1, 1].
class WGANGenerator(nn.Module): Same channel structure as DCGAN. BatchNorm in generator is fine because
"""Maps (latent_dim x 1 x 1) -> (3 x 128 x 128) in [-1, 1]. WGAN-GP's constraint targets the critic, not the generator.
Upsampling path: 1 -> 4 -> 8 -> 16 (+attn) -> 32 -> 64 -> 128
Self-attention sits at 16x16 (attention matrix 256x256 vs 1024x1024 at 32x32).
""" """
def __init__(self, latent_dim: int = 128, ngf: int = 64): def __init__(self, latent_dim: int = 128, ngf: int = 64):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
# 1x1 -> 4x4 # 1×1 4×4
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False), nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
nn.GroupNorm(8, ngf * 8), nn.ReLU(True), nn.BatchNorm2d(ngf * 8), nn.ReLU(True),
# 4x4 -> 8x8 # 4×4 8×8
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf * 4), nn.ReLU(True), nn.BatchNorm2d(ngf * 4), nn.ReLU(True),
# 8x8 -> 16x16 # 8×8 16×16
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf * 2), nn.ReLU(True), nn.BatchNorm2d(ngf * 2), nn.ReLU(True),
) # 16×16 → 32×32
self.attn = SelfAttention(ngf * 2) # applied at 16x16
self.out = nn.Sequential(
# 16x16 -> 32x32
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf), nn.ReLU(True), nn.BatchNorm2d(ngf), nn.ReLU(True),
# 32x32 -> 64x64 # 32×32 64×64
nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1, bias=False), nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf // 2), nn.ReLU(True),
# 64x64 -> 128x128
nn.ConvTranspose2d(ngf // 2, 3, 4, 2, 1, bias=False),
nn.Tanh(), nn.Tanh(),
) )
self.apply(_init_weights) self.apply(_init_weights)
def forward(self, z: torch.Tensor) -> torch.Tensor: def forward(self, z: torch.Tensor) -> torch.Tensor:
h = self.net(z) return self.net(z)
h = self.attn(h)
return self.out(h)
class WGANCritic(nn.Module): class WGANBasicCritic(nn.Module):
"""Critic (no sigmoid) for WGAN-GP. All conv layers are spectrally normalized. """WGAN-GP critic (64×64). InstanceNorm instead of BatchNorm — BatchNorm
breaks the per-sample Lipschitz constraint the gradient penalty enforces.
Downsampling path: 128 -> 64 -> 32 -> 16 (+attn) -> 8 -> 4 -> score
""" """
def __init__(self, ndf: int = 64): def __init__(self, ndf: int = 64):
super().__init__() super().__init__()
self.down = nn.Sequential( self.net = nn.Sequential(
# 128x128 -> 64x64 (no norm on first layer) # 64×64 → 32×32 (no norm on first layer)
_sn(nn.Conv2d(3, ndf // 2, 4, 2, 1, bias=False)), nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
# 64x64 -> 32x32 # 32×32 → 16×16
_sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.InstanceNorm2d(ndf * 2, affine=True),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
# 32x32 -> 16x16 # 16×16 → 8×8
_sn(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.InstanceNorm2d(ndf * 4, affine=True),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
) # 8×8 → 4×4
self.attn = SelfAttention(ndf * 2) # applied at 16x16 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
self.tail = nn.Sequential( nn.InstanceNorm2d(ndf * 8, affine=True),
# 16x16 -> 8x8
_sn(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
# 8x8 -> 4x4 # 4×4 → 1×1 (score, no sigmoid)
_sn(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)), nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.LeakyReLU(0.2, True),
# 4x4 -> 1x1
_sn(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)),
) )
self.apply(_init_weights) self.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.down(x) return self.net(x).view(x.size(0))
h = self.attn(h)
# ---------------------------------------------------------------------------
# Phase 2.3 / 2.4 — advanced WGAN-GP (GroupNorm, SpectralNorm, attention)
# ---------------------------------------------------------------------------
class WGANGenerator(nn.Module):
"""GroupNorm generator with SAGAN self-attention.
Supports image_size ∈ {64, 128}.
Stem is always 1×1 → 4×4 → 8×8 → 16×16 (ngf×8 → ngf×4 → ngf×2 channels).
Attention at 16×16 always; additional attention at 32×32 for 128×128.
"""
def __init__(self, latent_dim: int = 128, ngf: int = 128, image_size: int = 64):
super().__init__()
if image_size not in (64, 128):
raise ValueError(f"WGANGenerator supports image_size 64 or 128, got {image_size}")
self._image_size = image_size
self.stem = nn.Sequential(
# 1×1 → 4×4
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
nn.GroupNorm(8, ngf * 8), nn.ReLU(True),
# 4×4 → 8×8
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf * 4), nn.ReLU(True),
# 8×8 → 16×16
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf * 2), nn.ReLU(True),
) # output: (ngf×2, 16, 16)
self.attn16 = SelfAttention(ngf * 2)
if image_size == 64:
self.mid = None
self.attn32 = None
self.tail = nn.Sequential(
# 16×16 → 32×32
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf), nn.ReLU(True),
# 32×32 → 64×64
nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
nn.Tanh(),
)
else: # 128
self.mid = nn.Sequential(
# 16×16 → 32×32
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf), nn.ReLU(True),
)
self.attn32 = SelfAttention(ngf)
self.tail = nn.Sequential(
# 32×32 → 64×64
nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf // 2), nn.ReLU(True),
# 64×64 → 128×128
nn.ConvTranspose2d(ngf // 2, 3, 4, 2, 1, bias=False),
nn.Tanh(),
)
self.apply(_init_weights)
def forward(self, z: torch.Tensor) -> torch.Tensor:
h = self.attn16(self.stem(z))
if self.mid is not None:
h = self.attn32(self.mid(h))
return self.tail(h)
class WGANCritic(nn.Module):
"""SpectralNorm critic with SAGAN self-attention.
Supports image_size ∈ {64, 128}.
Attention at 16×16 always; additional attention at 32×32 for 128×128.
"""
def __init__(self, ndf: int = 128, image_size: int = 64):
super().__init__()
if image_size not in (64, 128):
raise ValueError(f"WGANCritic supports image_size 64 or 128, got {image_size}")
self._image_size = image_size
if image_size == 64:
# Head: 64→32 (ndf//2)
self.head = nn.Sequential(
_sn(nn.Conv2d(3, ndf // 2, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
)
self.attn32 = None
# 32→16 (ndf)
self.mid = nn.Sequential(
_sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
)
attn_ch = ndf
else: # 128
# Head: 128→64 (ndf//4), 64→32 (ndf//2)
self.head = nn.Sequential(
_sn(nn.Conv2d(3, ndf // 4, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
_sn(nn.Conv2d(ndf // 4, ndf // 2, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
)
self.attn32 = SelfAttention(ndf // 2)
# 32→16 (ndf)
self.mid = nn.Sequential(
_sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
)
attn_ch = ndf
self.attn16 = SelfAttention(attn_ch)
# Tail: 16×16 → 8×8 → 4×4 → score
self.tail = nn.Sequential(
_sn(nn.Conv2d(attn_ch, attn_ch * 2, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
_sn(nn.Conv2d(attn_ch * 2, attn_ch * 4, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
_sn(nn.Conv2d(attn_ch * 4, 1, 4, 1, 0, bias=False)),
)
self.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.head(x)
if self.attn32 is not None:
h = self.attn32(h)
h = self.attn16(self.mid(h))
return self.tail(h).view(x.size(0)) return self.tail(h).view(x.size(0))
def _build(cfg: dict): def _build_basic(cfg: dict):
return ( return (
WGANGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 128)), WGANBasicGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 64)),
WGANCritic(ndf=cfg.get("ndf", 128)), WGANBasicCritic(ndf=cfg.get("ndf", 64)),
) )
def _build(cfg: dict):
image_size = cfg.get("image_size", 64)
return (
WGANGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 128), image_size=image_size),
WGANCritic(ndf=cfg.get("ndf", 128), image_size=image_size),
)
register("wgan_basic", _build_basic, kind="wgan")
register("wgan", _build, kind="wgan") register("wgan", _build, kind="wgan")
+2 -2
View File
@@ -1,3 +1,3 @@
from src.training.trainer import train_dcgan from src.training.trainer import train_dcgan, train_wgan, train_vae, train_ddpm
__all__ = ["train_dcgan"] __all__ = ["train_dcgan", "train_wgan", "train_vae", "train_ddpm"]
+136
View File
@@ -0,0 +1,136 @@
"""Gaussian diffusion utilities for Phase 4 (DDPM).
Provides noise schedules, the forward (noising) process, training loss,
and DDIM deterministic sampling (Song et al., 2020).
Convention: alpha_bars is a 1-D tensor of length T, where alpha_bars[t]
= ᾱ_{t+1} in 1-indexed notation. Timestep t used in the training loop
is a 0-indexed integer in [0, T). At t=0 the image is almost clean
(ᾱ ≈ 1 β_1); at t=T1 the image is almost pure noise (ᾱ ≈ 0).
"""
import math
import torch
import torch.nn.functional as F
# ── Noise schedules ──────────────────────────────────────────────────────────
def linear_betas(T: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor:
"""Ho et al. (2020) linear schedule."""
return torch.linspace(beta_start, beta_end, T)
def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor:
"""Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t."""
t = torch.linspace(0, T, T + 1)
f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
alpha_bar = f / f[0]
betas = 1 - alpha_bar[1:] / alpha_bar[:-1]
return betas.clamp(max=0.999)
def make_alpha_bars(betas: torch.Tensor) -> torch.Tensor:
"""Cumulative product of (1 β), shape (T,)."""
return (1.0 - betas).cumprod(0)
# ── Forward process ──────────────────────────────────────────────────────────
def q_sample(
x0: torch.Tensor,
t: torch.Tensor,
alpha_bars: torch.Tensor,
noise: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Add noise to x0 at timestep t. Returns (x_t, noise)."""
if noise is None:
noise = torch.randn_like(x0)
ab = alpha_bars[t].to(x0.device)[:, None, None, None]
x_t = ab.sqrt() * x0 + (1 - ab).sqrt() * noise
return x_t, noise
# ── Training loss ────────────────────────────────────────────────────────────
def diffusion_loss(
model,
x0: torch.Tensor,
t: torch.Tensor,
alpha_bars: torch.Tensor,
pred_type: str = "eps",
) -> torch.Tensor:
"""MSE on the model's prediction vs the true target.
pred_type="eps" → target is the added noise ε (Ho et al.)
pred_type="v" → target is v = √ᾱ·ε √(1−ᾱ)·x0 (Salimans & Ho)
"""
x_t, noise = q_sample(x0, t, alpha_bars)
pred = model(x_t, t)
if pred_type == "eps":
target = noise
else: # v
ab = alpha_bars[t].to(x0.device)[:, None, None, None]
target = ab.sqrt() * noise - (1 - ab).sqrt() * x0
return F.mse_loss(pred, target)
# ── DDIM deterministic sampling ───────────────────────────────────────────────
@torch.no_grad()
def ddim_sample(
model,
n: int,
image_size: int,
alpha_bars: torch.Tensor,
n_steps: int = 100,
pred_type: str = "eps",
device: str = "cuda",
batch_size: int = 32,
) -> torch.Tensor:
"""Generate n images via DDIM (eta=0, deterministic).
Batches internally to avoid OOM when n is large.
Returns tensor shape (n, 3, image_size, image_size) in [-1, 1].
"""
model.eval()
T = len(alpha_bars)
# Build reversed subsequence: [T-1, T-1-step, ..., 0]
step = max(T // n_steps, 1)
ts = list(range(T - 1, -1, -step))[:n_steps]
if ts[-1] != 0:
ts.append(0)
results = []
remaining = n
while remaining > 0:
bsz = min(batch_size, remaining)
x = torch.randn(bsz, 3, image_size, image_size, device=device)
for i, t_cur in enumerate(ts):
t_prev = ts[i + 1] if i + 1 < len(ts) else -1
t_batch = torch.full((bsz,), t_cur, device=device, dtype=torch.long)
ab_t = alpha_bars[t_cur].to(device)
ab_prev = alpha_bars[t_prev].to(device) if t_prev >= 0 else torch.ones(1, device=device)
pred = model(x, t_batch)
# Reconstruct x0 from prediction
if pred_type == "eps":
x0_hat = (x - (1 - ab_t).sqrt() * pred) / ab_t.sqrt()
else: # v
x0_hat = ab_t.sqrt() * x - (1 - ab_t).sqrt() * pred
x0_hat = x0_hat.clamp(-1, 1)
# DDIM step
eps_hat = (x - ab_t.sqrt() * x0_hat) / (1 - ab_t).sqrt()
x = ab_prev.sqrt() * x0_hat + (1 - ab_prev).sqrt() * eps_hat
results.append(x.cpu())
remaining -= bsz
return torch.cat(results)[:n]
+56
View File
@@ -0,0 +1,56 @@
"""Extended generation quality metrics for Phase 5 cross-family comparison.
IS — Inception Score (Salimans et al., 2016): measures sample quality × diversity.
LPIPS — average pairwise learned perceptual distance: measures sample diversity alone.
Both functions accept float tensors in [-1, 1].
"""
import torch
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
def compute_is(
imgs: torch.Tensor,
device: str = "cuda",
batch_size: int = 64,
) -> tuple[float, float]:
"""Inception Score (mean ± std) over 10 splits.
imgs: (N, 3, H, W) in [-1, 1]. N ≥ 2 048 for a reliable estimate.
Returns (is_mean, is_std).
"""
metric = InceptionScore(normalize=True).to(device)
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
for i in range(0, len(imgs_01), batch_size):
metric.update(imgs_01[i : i + batch_size].to(device))
mean, std = metric.compute()
return float(mean), float(std)
def compute_lpips_diversity(
imgs: torch.Tensor,
n_pairs: int = 200,
device: str = "cuda",
batch_size: int = 16,
) -> float:
"""Average pairwise LPIPS distance — higher means more diverse samples.
imgs: (N, 3, H, W) in [-1, 1]. Samples n_pairs random (i, j) pairs with i ≠ j.
"""
metric = LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True).to(device)
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
N = len(imgs_01)
# Sample random pairs (ensure i ≠ j by rejection)
idx = torch.randperm(N * 2)[:n_pairs * 2].view(n_pairs, 2) % N
same = idx[:, 0] == idx[:, 1]
idx[same, 1] = (idx[same, 1] + 1) % N # shift duplicate indices
for start in range(0, n_pairs, batch_size):
end = min(start + batch_size, n_pairs)
i_batch = idx[start:end, 0]
j_batch = idx[start:end, 1]
metric.update(imgs_01[i_batch].to(device), imgs_01[j_batch].to(device))
return float(metric.compute())
+57
View File
@@ -0,0 +1,57 @@
"""VGG-16 perceptual loss for Phase 3.2 and 3.3.
Extracts features at relu1_2, relu2_2, relu3_3 and returns the
L1 distance in feature space. VGG weights are frozen.
Input convention: images in [-1, 1] — the loss converts internally to
[0, 1] and then applies ImageNet normalisation before passing to VGG.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tv_models
class PerceptualLoss(nn.Module):
"""L1 feature-matching loss at three VGG-16 layers.
VGG-16 feature indices:
relu1_2: features[:4] (before first maxpool)
relu2_2: features[4:9] (before second maxpool)
relu3_3: features[9:16] (before third maxpool)
"""
def __init__(self):
super().__init__()
vgg = tv_models.vgg16(weights=tv_models.VGG16_Weights.IMAGENET1K_V1)
feats = vgg.features
self.slice1 = nn.Sequential(*list(feats[:4])) # relu1_2
self.slice2 = nn.Sequential(*list(feats[4:9])) # relu2_2
self.slice3 = nn.Sequential(*list(feats[9:16])) # relu3_3
for p in self.parameters():
p.requires_grad_(False)
self.register_buffer(
"mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
)
self.register_buffer(
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
)
def _normalise(self, x: torch.Tensor) -> torch.Tensor:
"""Convert [-1, 1] → ImageNet-normalised [0, 1]."""
x = x * 0.5 + 0.5 # → [0, 1]
return (x - self.mean) / self.std
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
"""L1 feature distance. real gradients are stopped — only fake trains."""
f = self._normalise(fake)
r = self._normalise(real)
loss = torch.tensor(0.0, device=fake.device)
for layer in (self.slice1, self.slice2, self.slice3):
f = layer(f)
r = layer(r)
loss = loss + F.l1_loss(f, r.detach())
return loss
+599 -7
View File
@@ -1,8 +1,10 @@
import os import os
import time
from pathlib import Path from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.utils import save_image from torchvision.utils import save_image
from tqdm import tqdm from tqdm import tqdm
@@ -19,12 +21,11 @@ else:
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw) _autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, latent_dim: int, device) -> None: def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None:
samples_dir.mkdir(parents=True, exist_ok=True) samples_dir.mkdir(parents=True, exist_ok=True)
with torch.no_grad(): with torch.no_grad():
noise = torch.randn(16, latent_dim, 1, 1, device=device) imgs = generator_ema.model(fixed_noise.to(device)) # EMA model, [-1, 1]
imgs = generator_ema.model(noise) # EMA model, [-1, 1] imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4) save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
@@ -78,6 +79,9 @@ def train_dcgan(
ema = EMA(generator, decay=ema_decay) ema = EMA(generator, decay=ema_decay)
# Fixed noise for consistent sample tracking across epochs
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
save_dir = Path(save_dir) save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name samples_dir = save_dir.parent / "samples" / run_name
@@ -88,6 +92,15 @@ def train_dcgan(
best_fid = float("inf") best_fid = float("inf")
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}") print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
# Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2
sched_g = torch.optim.lr_scheduler.LambdaLR(
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
sched_d = torch.optim.lr_scheduler.LambdaLR(
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
t_start = time.time()
for epoch in range(1, epochs + 1): for epoch in range(1, epochs + 1):
generator.train() generator.train()
discriminator.train() discriminator.train()
@@ -142,13 +155,13 @@ def train_dcgan(
) )
if epoch % sample_interval == 0: if epoch % sample_interval == 0:
_save_samples(ema, samples_dir, epoch, latent_dim=latent_dim, device=device) _save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device)
if epoch % fid_interval == 0: if epoch % fid_interval == 0:
generator.eval() ema.model.eval()
with torch.no_grad(): with torch.no_grad():
fake_imgs = torch.cat([ fake_imgs = torch.cat([
generator(torch.randn(64, latent_dim, 1, 1, device=device)) ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
for _ in range(fid_n_real // 64 + 1) for _ in range(fid_n_real // 64 + 1)
])[:fid_n_real] ])[:fid_n_real]
fid_score = fid_eval.compute(fake_imgs) fid_score = fid_eval.compute(fake_imgs)
@@ -160,7 +173,586 @@ def train_dcgan(
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt") torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
sched_g.step()
sched_d.step()
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt") torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt") torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
history["train_time_s"] = time.time() - t_start
return history
def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor:
"""Two-sided gradient penalty (Gulrajani et al., 2017)."""
bsz = real.size(0)
eps = torch.rand(bsz, 1, 1, 1, device=device)
interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True)
d_interp = critic(interp)
grad = torch.autograd.grad(
outputs=d_interp,
inputs=interp,
grad_outputs=torch.ones_like(d_interp),
create_graph=True,
retain_graph=True,
)[0]
return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean()
def train_wgan(
generator,
critic,
train_dataset,
cfg: dict,
*,
save_dir,
run_name: str,
device: str = "cuda",
) -> dict:
"""WGAN-GP training loop (Gulrajani et al., 2017).
Used for Phase 2.22.4. Gradient penalty replaces weight clipping.
The critic runs in float32 to keep GP gradient computation numerically
stable; AMP is used only for the generator forward/backward.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
critic = critic.to(device)
n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
n_c = sum(p.numel() for p in critic.parameters() if p.requires_grad)
print(f"Generator: {n_g:,} params Critic: {n_c:,} params")
epochs = cfg["epochs"]
batch_size = cfg["batch_size"]
lr_g = cfg.get("lr_g", 1e-4)
lr_d = cfg.get("lr_d", 1e-4)
beta1 = cfg.get("beta1", 0.0)
beta2 = cfg.get("beta2", 0.9)
latent_dim = cfg.get("latent_dim", 128)
n_critic = cfg.get("n_critic", 5)
gp_lambda = cfg.get("gp_lambda", 10)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
fid_n_real = cfg.get("fid_n_real", 5000)
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
opt_c = torch.optim.Adam(critic.parameters(), lr=lr_d, betas=(beta1, beta2))
use_amp = device.type == "cuda"
scaler_g = _GradScaler("cuda", enabled=use_amp)
ema = EMA(generator, decay=ema_decay)
# Fixed noise for consistent sample tracking across epochs
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf")
print(f"Device: {device} AMP (G only): {use_amp} Batches/epoch: {len(loader)} n_critic: {n_critic}")
# Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2
sched_g = torch.optim.lr_scheduler.LambdaLR(
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
sched_c = torch.optim.lr_scheduler.LambdaLR(
opt_c, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
t_start = time.time()
for epoch in range(1, epochs + 1):
generator.train()
critic.train()
g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0
n_c_steps = n_g_steps = 0
for batch_idx, real in enumerate(tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False)):
real = real.to(device)
bsz = real.size(0)
# ── Critic step (every batch) ─────────────────────────────────
# Run critic in float32 — GP requires double-precision gradients
# and AMP can degrade stability here.
opt_c.zero_grad()
with torch.no_grad():
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
real_f32 = real.float()
fake_f32 = fake.float().detach()
d_real = critic(real_f32)
d_fake = critic(fake_f32)
gp = _gradient_penalty(critic, real_f32, fake_f32.detach(), device)
c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp
c_loss.backward()
opt_c.step()
w_dist = (d_real.mean() - d_fake.mean()).item()
w_sum += w_dist
gp_sum += gp.item()
real_sum += d_real.mean().item()
fake_sum += d_fake.mean().item()
n_c_steps += 1
# ── Generator step (every n_critic batches) ───────────────────
if (batch_idx + 1) % n_critic == 0:
opt_g.zero_grad()
with _autocast("cuda", enabled=use_amp):
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
g_loss = -critic(fake.float()).mean()
scaler_g.scale(g_loss).backward()
scaler_g.step(opt_g)
scaler_g.update()
ema.update(generator)
g_sum += g_loss.item()
n_g_steps += 1
avg_w = w_sum / max(n_c_steps, 1)
avg_gp = gp_sum / max(n_c_steps, 1)
avg_g = g_sum / max(n_g_steps, 1)
avg_r = real_sum / max(n_c_steps, 1)
avg_f = fake_sum / max(n_c_steps, 1)
history["g_loss"].append(avg_g)
history["w_dist"].append(avg_w)
history["gp"].append(avg_gp)
history["d_real"].append(avg_r)
history["d_fake"].append(avg_f)
print(
f"[{epoch:03d}/{epochs}] "
f"G: {avg_g:.4f} W-dist: {avg_w:.4f} GP: {avg_gp:.4f} "
f"C(real): {avg_r:.4f} C(fake): {avg_f:.4f}"
)
if epoch % sample_interval == 0:
_save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device)
if epoch % fid_interval == 0:
ema.model.eval()
with torch.no_grad():
fake_imgs = torch.cat([
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
for _ in range(fid_n_real // 64 + 1)
])[:fid_n_real]
fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid:
best_fid = fid_score
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
sched_g.step()
sched_c.step()
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
torch.save(critic.state_dict(), save_dir / f"{run_name}_final_d.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
history["train_time_s"] = time.time() - t_start
return history
# ────────────────────────────────────────────────────────────────────────────
# Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN)
# ────────────────────────────────────────────────────────────────────────────
def _save_vae_samples(
vae,
samples_dir: Path,
epoch: int,
*,
fixed_z: torch.Tensor,
fixed_real: torch.Tensor,
device,
) -> None:
"""Save prior samples and a real-vs-reconstruction grid side by side."""
samples_dir.mkdir(parents=True, exist_ok=True)
vae.eval()
with torch.no_grad():
prior = vae.decode(fixed_z.to(device))
prior = (prior.clamp(-1, 1) + 1.0) / 2.0
save_image(prior, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
recon, _, _ = vae(fixed_real.to(device))
recon = (recon.clamp(-1, 1) + 1.0) / 2.0
real = (fixed_real.to(device) + 1.0) / 2.0
# Interleave real / reconstruction pairs
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
vae.train()
def train_vae(
vae,
train_dataset,
cfg: dict,
*,
save_dir,
run_name: str,
device: str = "cuda",
) -> dict:
"""VAE training loop covering Phase 3.1 3.3.
Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
n_vae = sum(p.numel() for p in vae.parameters() if p.requires_grad)
print(f"VAE: {n_vae:,} params")
epochs = cfg["epochs"]
batch_size = cfg["batch_size"]
lr = cfg.get("lr", 1e-3)
latent_dim = cfg.get("latent_dim", 256)
beta_kl = cfg.get("beta_kl", 1.0)
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
lr_d = cfg.get("lr_d", 1e-4)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
fid_n_real = cfg.get("fid_n_real", 5000)
use_perceptual = lambda_perceptual > 0
use_adversarial = lambda_adversarial > 0
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
use_amp = device.type == "cuda"
scaler = _GradScaler("cuda", enabled=use_amp)
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
kl_warmup_epochs = max(1, epochs // 5)
# Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2
sched_vae = torch.optim.lr_scheduler.LambdaLR(
opt_vae, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
sched_d = None # set below if adversarial
# ── Optional components ───────────────────────────────────────────────
perc_fn = None
patchgan = None
opt_d = None
scaler_d = None
if use_perceptual:
from src.training.perceptual import PerceptualLoss
perc_fn = PerceptualLoss().to(device)
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
if use_adversarial:
from src.models.patchgan import PatchGANDiscriminator, hinge_d_loss, hinge_g_loss
patchgan = PatchGANDiscriminator(
ndf=cfg.get("ndf_patch", 64),
image_size=cfg.get("image_size", 64),
).to(device)
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
scaler_d = _GradScaler("cuda", enabled=use_amp)
sched_d = torch.optim.lr_scheduler.LambdaLR(
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
n_d = sum(p.numel() for p in patchgan.parameters())
print(f"PatchGAN: {n_d:,} params")
else:
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called
# ── Fixed seeds for consistent visualisation ──────────────────────────
fixed_z = torch.randn(16, latent_dim, device=device)
# Grab first 16 real images from the loader for reconstruction tracking
_it = iter(loader)
fixed_real = next(_it)[:16].cpu()
ema = EMA(vae, decay=ema_decay)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
history = {
"recon_loss": [], "kl_loss": [], "perc_loss": [],
"adv_g_loss": [], "adv_d_loss": [], "fid": {},
}
best_fid = float("inf")
print(
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}"
)
t_start = time.time()
for epoch in range(1, epochs + 1):
vae.train()
if patchgan is not None:
patchgan.train()
recon_sum = kl_sum = perc_sum = adv_g_sum = adv_d_sum = 0.0
n_batches = 0
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
real = real.to(device)
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
# ── VAE forward ───────────────────────────────────────────────
with _autocast("cuda", enabled=use_amp):
recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real)
kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── PatchGAN discriminator step ───────────────────────────────
adv_d = real.new_zeros(1).squeeze()
if use_adversarial:
opt_d.zero_grad()
with _autocast("cuda", enabled=use_amp):
d_real = patchgan(real)
d_fake = patchgan(recon.detach())
adv_d = hinge_d_loss(d_real, d_fake)
scaler_d.scale(adv_d).backward()
scaler_d.step(opt_d)
scaler_d.update()
# ── PatchGAN generator adversarial loss ───────────────────────
adv_g = real.new_zeros(1).squeeze()
if use_adversarial:
with _autocast("cuda", enabled=use_amp):
adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g
# ── VAE backward ──────────────────────────────────────────────
opt_vae.zero_grad()
scaler.scale(vae_loss).backward()
scaler.step(opt_vae)
scaler.update()
ema.update(vae)
recon_sum += mse.item()
kl_sum += kl.item()
perc_sum += perc.item()
adv_g_sum += adv_g.item()
adv_d_sum += adv_d.item()
n_batches += 1
avg_r = recon_sum / n_batches
avg_k = kl_sum / n_batches
avg_p = perc_sum / n_batches
avg_g = adv_g_sum / n_batches
avg_d = adv_d_sum / n_batches
history["recon_loss"].append(avg_r)
history["kl_loss"].append(avg_k)
history["perc_loss"].append(avg_p)
history["adv_g_loss"].append(avg_g)
history["adv_d_loss"].append(avg_d)
print(
f"[{epoch:03d}/{epochs}] "
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
)
if epoch % sample_interval == 0:
_save_vae_samples(
ema.model, samples_dir, epoch,
fixed_z=fixed_z, fixed_real=fixed_real, device=device,
)
if epoch % fid_interval == 0:
ema.model.eval()
with torch.no_grad():
fake_imgs = torch.cat([
ema.model.sample(64, device)
for _ in range(fid_n_real // 64 + 1)
])[:fid_n_real]
fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid:
best_fid = fid_score
torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
sched_vae.step()
if sched_d is not None:
sched_d.step()
torch.save(vae.state_dict(), save_dir / f"{run_name}_final_vae.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
if patchgan is not None:
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
history["train_time_s"] = time.time() - t_start
return history
# ────────────────────────────────────────────────────────────────────────────
# Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider)
# ────────────────────────────────────────────────────────────────────────────
def train_ddpm(
model,
train_dataset,
cfg: dict,
*,
save_dir,
run_name: str,
device: str = "cuda",
) -> dict:
"""DDPM training loop (Ho et al., 2020) covering Phase 4.1 4.4.
Config keys:
noise_schedule — "linear" (4.1) or "cosine" (4.2+)
pred_type — "eps" (4.14.2) or "v" (4.3+)
T — diffusion timesteps (default 1000)
base_ch / ch_mult / attn_resolutions — U-Net capacity (see unet.py)
ddim_steps — DDIM steps for FID evaluation (default 100)
"""
from src.training.diffusion import (
linear_betas, cosine_betas, make_alpha_bars,
diffusion_loss, ddim_sample,
)
device = torch.device(device if torch.cuda.is_available() else "cpu")
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"U-Net: {n_params:,} params")
epochs = cfg["epochs"]
batch_size = cfg["batch_size"]
lr = cfg.get("lr", 2e-4)
T = cfg.get("T", 1000)
noise_schedule = cfg.get("noise_schedule", "linear")
pred_type = cfg.get("pred_type", "eps")
ddim_steps = cfg.get("ddim_steps", 100)
image_size = cfg.get("image_size", 64)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
fid_n_real = cfg.get("fid_n_real", 5000)
# Build noise schedule and register on device
betas = (cosine_betas(T) if noise_schedule == "cosine" else linear_betas(T)).to(device)
alpha_bars = make_alpha_bars(betas) # on device
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
use_amp = device.type == "cuda"
scaler = _GradScaler("cuda", enabled=use_amp)
ema = EMA(model, decay=ema_decay)
# Fixed noise for sample visualisation (same latents across epochs)
fixed_noise = torch.randn(16, 3, image_size, image_size, device=device)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
history = {"loss": [], "fid": {}}
best_fid = float("inf")
print(
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}"
)
# Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2
sched = torch.optim.lr_scheduler.LambdaLR(
opt, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
t_start = time.time()
for epoch in range(1, epochs + 1):
model.train()
loss_sum = 0.0
n_batches = 0
for x0 in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
x0 = x0.to(device)
t = torch.randint(0, T, (x0.size(0),), device=device)
with _autocast("cuda", enabled=use_amp):
loss = diffusion_loss(model, x0, t, alpha_bars, pred_type)
opt.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(opt)
scaler.update()
ema.update(model)
loss_sum += loss.item()
n_batches += 1
avg_loss = loss_sum / n_batches
history["loss"].append(avg_loss)
print(f"[{epoch:03d}/{epochs}] Loss: {avg_loss:.5f}")
if epoch % sample_interval == 0:
samples_dir.mkdir(parents=True, exist_ok=True)
ema.model.eval()
with torch.no_grad():
# Quick visualisation: denoise fixed_noise via DDIM
imgs = ddim_sample(
ema.model, 16, image_size, alpha_bars,
n_steps=50, pred_type=pred_type, device=str(device), batch_size=16,
)
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
if epoch % fid_interval == 0:
ema.model.eval()
fake_imgs = ddim_sample(
ema.model, fid_n_real, image_size, alpha_bars,
n_steps=ddim_steps, pred_type=pred_type,
device=str(device), batch_size=32,
)
fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid:
best_fid = fid_score
torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
sched.step()
torch.save(model.state_dict(), save_dir / f"{run_name}_final_unet.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
history["train_time_s"] = time.time() - t_start
return history return history