Phase 5 preparation

This commit is contained in:
Johnny Fernandes
2026-05-03 15:47:13 +01:00
parent bac52bc15e
commit 8c0e845b5c
16 changed files with 197 additions and 298 deletions
+4 -2
View File
@@ -1,7 +1,7 @@
{ {
"run_name": "p5_ddpm", "run_name": "p5_ddpm",
"model": "ddpm", "model": "ddpm",
"epochs": 200, "epochs": 400,
"data_dir": "cropped/generator", "data_dir": "cropped/generator",
"sources": ["wiki"], "sources": ["wiki"],
"augment": "hflip", "augment": "hflip",
@@ -18,5 +18,7 @@
"ddim_steps": 100, "ddim_steps": 100,
"sample_interval": 10, "sample_interval": 10,
"fid_interval": 25, "fid_interval": 25,
"fid_n_real": 5000 "fid_n_real": 5000,
"max_train_hours": 24,
"fid_patience": 2
} }
+5 -3
View File
@@ -1,11 +1,11 @@
{ {
"run_name": "p5_gan", "run_name": "p5_gan",
"model": "wgan", "model": "wgan",
"epochs": 200, "epochs": 400,
"data_dir": "cropped/generator", "data_dir": "cropped/generator",
"sources": ["wiki"], "sources": ["wiki"],
"augment": true, "augment": true,
"image_size": 128, "image_size": 64,
"latent_dim": 128, "latent_dim": 128,
"ngf": 128, "ngf": 128,
"ndf": 128, "ndf": 128,
@@ -17,5 +17,7 @@
"gp_lambda": 10, "gp_lambda": 10,
"sample_interval": 10, "sample_interval": 10,
"fid_interval": 25, "fid_interval": 25,
"fid_n_real": 5000 "fid_n_real": 5000,
"max_train_hours": 24,
"fid_patience": 2
} }
+5 -3
View File
@@ -1,7 +1,7 @@
{ {
"run_name": "p5_vae", "run_name": "p5_vae",
"model": "vae", "model": "vae",
"epochs": 200, "epochs": 400,
"data_dir": "cropped/generator", "data_dir": "cropped/generator",
"sources": ["wiki"], "sources": ["wiki"],
"augment": "hflip", "augment": "hflip",
@@ -10,11 +10,13 @@
"ngf": 64, "ngf": 64,
"lr": 1e-3, "lr": 1e-3,
"lr_d": 1e-4, "lr_d": 1e-4,
"beta_kl": 0.0001, "beta_kl": 0.25,
"lambda_perceptual": 0.1, "lambda_perceptual": 0.1,
"lambda_adversarial": 0.1, "lambda_adversarial": 0.1,
"ndf_patch": 64, "ndf_patch": 64,
"sample_interval": 10, "sample_interval": 10,
"fid_interval": 25, "fid_interval": 25,
"fid_n_real": 5000 "fid_n_real": 5000,
"max_train_hours": 24,
"fid_patience": 2
} }
+2 -12
View File
@@ -6,13 +6,8 @@ import torchvision.transforms as T
from torch.utils.data import Dataset from torch.utils.data import Dataset
# Unlabeled image dataset for generative model training; returns tensors only, no labels
class GeneratorDataset(Dataset): class GeneratorDataset(Dataset):
"""Unlabeled image dataset for generative model training.
Loads images from source subdirectories and returns tensors only —
no labels, since generation is unsupervised.
"""
def __init__(self, data_dir, sources=None, subsample=1.0, transform=None, seed=42): def __init__(self, data_dir, sources=None, subsample=1.0, transform=None, seed=42):
self.transform = transform self.transform = transform
self.samples = [] self.samples = []
@@ -51,13 +46,8 @@ class GeneratorDataset(Dataset):
return img return img
# Builds transform pipeline outputting [-1, 1]; augment=False (none), "hflip" (flip only), True (flip+rot+jitter)
def get_transform(image_size: int, augment=False) -> T.Compose: def get_transform(image_size: int, augment=False) -> T.Compose:
"""Build transform for generator training. Output is in [-1, 1].
augment=False — no augmentation (for FID real-image sets)
augment="hflip" — horizontal flip only (recommended for VAE/DDPM)
augment=True — H-flip + rotation ±5° + mild color jitter (for GAN)
"""
ops = [ ops = [
T.Resize(image_size), T.Resize(image_size),
T.CenterCrop(image_size), T.CenterCrop(image_size),
+5 -4
View File
@@ -1,18 +1,18 @@
from typing import Callable from typing import Callable
import torch.nn as nn import torch.nn as nn
# Maps model name → (builder, kind); populated by each model module at import time
_REGISTRY: dict[str, tuple[Callable, str]] = {} _REGISTRY: dict[str, tuple[Callable, str]] = {}
# Called by each model module to advertise its name and kind to get_model
def register(name: str, builder: Callable, *, kind: str) -> None: def register(name: str, builder: Callable, *, kind: str) -> None:
_REGISTRY[name] = (builder, kind) _REGISTRY[name] = (builder, kind)
# Returns (model_or_pair, kind) for the model named in cfg["model"]
# kind in {"dcgan", "wgan"} → (generator, critic) | kind in {"vae", "ddpm"} → model
def get_model(cfg: dict) -> tuple: def get_model(cfg: dict) -> tuple:
"""Return (model_or_pair, kind).
kind="dcgan" -> (generator, discriminator)
"""
name = cfg.get("model") name = cfg.get("model")
entry = _REGISTRY.get(name) entry = _REGISTRY.get(name)
if entry is None: if entry is None:
@@ -22,6 +22,7 @@ def get_model(cfg: dict) -> tuple:
return builder(cfg), kind return builder(cfg), kind
# Importing each module triggers its register() calls
from src.models import dcgan # noqa: E402, F401 from src.models import dcgan # noqa: E402, F401
from src.models import wgan # noqa: E402, F401 from src.models import wgan # noqa: E402, F401
from src.models import vae # noqa: E402, F401 from src.models import vae # noqa: E402, F401
+7 -17
View File
@@ -1,15 +1,6 @@
"""Vanilla DCGAN (Radford et al., 2015). # Vanilla DCGAN (Radford et al., 2015) — Phase 1 baseline for cheap pipeline ablations.
# Depth scales with image_size; each step doubles spatial size from a 4×4 stem.
Used as the Phase 1 baseline for cheap pipeline ablations. Architecture is # 64 → 5 upsample steps | 128 → 6 upsample steps
intentionally minimal — BatchNorm in both networks, no spectral norm, no
attention, no gradient penalty. The whole point is to be the cheapest GAN
we can run, so 1A1D pipeline deltas show up in FID quickly.
Depth scales with image_size: each step doubles the spatial dimension,
starting from 4×4 after the first transposed conv.
64 -> 5 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64)
128 -> 6 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64 -> 128)
"""
import math import math
import torch import torch
@@ -27,16 +18,15 @@ def _init_weights(m):
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
# Number of 2× upsampling steps from the 4×4 stem to image_size
def _n_upsamples(image_size: int) -> int: def _n_upsamples(image_size: int) -> int:
"""Number of 2x upsampling steps from 4x4 to image_size."""
if image_size < 8 or image_size & (image_size - 1): if image_size < 8 or image_size & (image_size - 1):
raise ValueError(f"image_size must be a power of two ≥ 8, got {image_size}") raise ValueError(f"image_size must be a power of two ≥ 8, got {image_size}")
return int(math.log2(image_size)) - 2 # 64 -> 4, 128 -> 5 return int(math.log2(image_size)) - 2 # 64 -> 4, 128 -> 5
# Maps (latent_dim × 1 × 1) → (3 × image_size × image_size) in [-1, 1]
class DCGANGenerator(nn.Module): class DCGANGenerator(nn.Module):
"""Maps (latent_dim x 1 x 1) -> (3 x image_size x image_size) in [-1, 1]."""
def __init__(self, latent_dim: int = 100, ngf: int = 64, image_size: int = 64): def __init__(self, latent_dim: int = 100, ngf: int = 64, image_size: int = 64):
super().__init__() super().__init__()
n_up = _n_upsamples(image_size) # 64 -> 4 upsamples after the 1->4 init n_up = _n_upsamples(image_size) # 64 -> 4 upsamples after the 1->4 init
@@ -69,9 +59,8 @@ class DCGANGenerator(nn.Module):
return self.net(z) return self.net(z)
# Maps (3 × image_size × image_size) → scalar logit (no sigmoid)
class DCGANDiscriminator(nn.Module): class DCGANDiscriminator(nn.Module):
"""Maps (3 x image_size x image_size) -> scalar logit (no sigmoid)."""
def __init__(self, ndf: int = 64, image_size: int = 64): def __init__(self, ndf: int = 64, image_size: int = 64):
super().__init__() super().__init__()
n_down = _n_upsamples(image_size) n_down = _n_upsamples(image_size)
@@ -97,6 +86,7 @@ class DCGANDiscriminator(nn.Module):
return self.net(x).view(x.size(0)) return self.net(x).view(x.size(0))
# Builds a (DCGANGenerator, DCGANDiscriminator) pair from cfg
def _build(cfg: dict): def _build(cfg: dict):
image_size = cfg.get("image_size", 64) image_size = cfg.get("image_size", 64)
return ( return (
+6 -17
View File
@@ -1,11 +1,5 @@
"""PatchGAN discriminator for Phase 3.3 (VQGAN-lite adversarial training). # PatchGAN discriminator for Phase 3.3 adversarial training: outputs a spatial patch logit map, not a scalar.
# Not in the model registry — instantiated directly inside train_vae when lambda_adversarial > 0.
Outputs a spatial patch map instead of a single scalar — each patch
predicts real/fake independently. Loss is the mean over all patches.
Not registered in the model registry; instantiated inside train_vae
when lambda_adversarial > 0.
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -17,14 +11,9 @@ def _init_weights(m):
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
# Stride-2 + stride-1 conv chain → spatial patch logit map; supports image_size ∈ {64, 128}
# 64×64 → 6×6 map (70×70 receptive field); 128×128 adds an extra stride-2 layer. InstanceNorm except first layer.
class PatchGANDiscriminator(nn.Module): class PatchGANDiscriminator(nn.Module):
"""Stride-2 + stride-1 convolution chain → spatial patch logit map.
Supports image_size ∈ {64, 128}. For 64×64 input the final map is 6×6
(70×70 receptive field). For 128×128 an extra stride-2 layer is added.
InstanceNorm everywhere except the first layer.
"""
def __init__(self, ndf: int = 64, image_size: int = 64): def __init__(self, ndf: int = 64, image_size: int = 64):
super().__init__() super().__init__()
layers: list[nn.Module] = [ layers: list[nn.Module] = [
@@ -57,13 +46,13 @@ class PatchGANDiscriminator(nn.Module):
return self.net(x) # (B, 1, H', W') — patch logit map return self.net(x) # (B, 1, H', W') — patch logit map
# Hinge loss for the discriminator (Lim & Ye, 2017)
def hinge_d_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor: def hinge_d_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor:
"""Hinge loss for the discriminator (Lim & Ye, 2017)."""
loss_real = torch.mean(torch.relu(1.0 - real_logits)) loss_real = torch.mean(torch.relu(1.0 - real_logits))
loss_fake = torch.mean(torch.relu(1.0 + fake_logits)) loss_fake = torch.mean(torch.relu(1.0 + fake_logits))
return 0.5 * (loss_real + loss_fake) return 0.5 * (loss_real + loss_fake)
# Generator hinge loss — maximises D(fake)
def hinge_g_loss(fake_logits: torch.Tensor) -> torch.Tensor: def hinge_g_loss(fake_logits: torch.Tensor) -> torch.Tensor:
"""Generator hinge loss — maximise D(fake)."""
return -torch.mean(fake_logits) return -torch.mean(fake_logits)
+7 -24
View File
@@ -1,13 +1,5 @@
"""Time-conditioned U-Net for DDPM (Phase 4). # Time-conditioned U-Net for DDPM (Phase 4), following Ho et al. (2020) with Nichol & Dhariwal (2021) options.
# Sinusoidal time embedding → MLP → injected additively into every ResBlock; GroupNorm(32) + SiLU throughout.
Architecture follows Ho et al. (2020) with options from Nichol & Dhariwal (2021):
- Sinusoidal time embedding → MLP → added to every ResBlock
- GroupNorm (32 groups) + SiLU activations throughout
- Self-attention at configurable spatial resolutions
- Upsample(nearest) + Conv in the decoder — no checkerboard artefacts
Registered as kind="ddpm".
"""
import math import math
import torch import torch
@@ -37,9 +29,8 @@ class SinusoidalPosEmb(nn.Module):
# ── Core building blocks ────────────────────────────────────────────────────── # ── Core building blocks ──────────────────────────────────────────────────────
# ResNet block with additive time-embedding injection after the first conv
class ResBlock(nn.Module): class ResBlock(nn.Module):
"""ResNet block with time-embedding injection (additive, after first conv)."""
def __init__(self, in_ch: int, out_ch: int, t_emb_dim: int, dropout: float = 0.1): def __init__(self, in_ch: int, out_ch: int, t_emb_dim: int, dropout: float = 0.1):
super().__init__() super().__init__()
self.norm1 = nn.GroupNorm(_GN, in_ch) self.norm1 = nn.GroupNorm(_GN, in_ch)
@@ -57,9 +48,8 @@ class ResBlock(nn.Module):
return h + self.skip(x) return h + self.skip(x)
# Single-head self-attention with GroupNorm pre-norm and residual
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
"""Single-head self-attention with GroupNorm pre-norm and residual."""
def __init__(self, ch: int): def __init__(self, ch: int):
super().__init__() super().__init__()
self.norm = nn.GroupNorm(_GN, ch) self.norm = nn.GroupNorm(_GN, ch)
@@ -155,17 +145,9 @@ class UpBlock(nn.Module):
# ── U-Net ───────────────────────────────────────────────────────────────────── # ── U-Net ─────────────────────────────────────────────────────────────────────
# Time-conditioned U-Net; image_size must be a power-of-two (64 or 128 recommended)
# ch_mult sets channel multipliers per level; attn_resolutions controls where attention is inserted
class UNet(nn.Module): class UNet(nn.Module):
"""Time-conditioned U-Net.
image_size — must be a power-of-two; 64 or 128 recommended.
base_ch — base channel count (128 for phases 4.14.3, 192 for 4.4).
ch_mult — channel multipliers per resolution level.
attn_resolutions — spatial resolutions at which attention is inserted.
num_res_blocks — ResBlocks per level (in both down and up paths).
dropout — applied inside every ResBlock.
"""
def __init__( def __init__(
self, self,
image_size: int = 64, image_size: int = 64,
@@ -265,6 +247,7 @@ class UNet(nn.Module):
return self.out_conv(F.silu(self.out_norm(x))) return self.out_conv(F.silu(self.out_norm(x)))
# Builds a UNet from cfg for DDPM training
def _build(cfg: dict): def _build(cfg: dict):
return UNet( return UNet(
image_size = cfg.get("image_size", 64), image_size = cfg.get("image_size", 64),
+8 -20
View File
@@ -1,13 +1,5 @@
"""Convolutional VAE for Phase 3. # Convolutional VAE for Phase 3. Encoder: stride-2 Conv → flatten → linear (μ, log σ²).
# Decoder: Linear → Upsample(nearest) + Conv — avoids ConvTranspose2d checkerboard artefacts.
Encoder uses stride-2 Conv → flatten → linear (μ, log σ²).
Decoder uses Linear → Upsample(nearest) + Conv to avoid ConvTranspose2d
checkerboard artefacts.
Registered as kind="vae". The run.py dispatcher passes the model to
train_vae(), which internally builds perceptual loss and PatchGAN when
the corresponding lambdas are non-zero.
"""
import math import math
import torch import torch
@@ -31,8 +23,8 @@ def _norm(channels: int) -> nn.GroupNorm:
return nn.GroupNorm(8, channels) return nn.GroupNorm(8, channels)
# Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard artefacts
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential: def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
return nn.Sequential( return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"), nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
@@ -41,20 +33,15 @@ def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
) )
# Convolutional VAE; image_size must be a power-of-two ≥ 32
# Spatial bottleneck is always at 4×4 — encoder and decoder scale stride-2 steps accordingly
class VAE(nn.Module): class VAE(nn.Module):
"""Convolutional VAE. image_size must be a power-of-two ≥ 32.
Spatial bottleneck is always at 4×4 regardless of image_size —
the encoder and decoder scale the number of stride-2 steps accordingly.
"""
def __init__(self, latent_dim: int = 256, ngf: int = 64, image_size: int = 64): def __init__(self, latent_dim: int = 256, ngf: int = 64, image_size: int = 64):
super().__init__() super().__init__()
if image_size < 32 or (image_size & (image_size - 1)): if image_size < 32 or (image_size & (image_size - 1)):
raise ValueError(f"image_size must be a power-of-two ≥ 32, got {image_size}") raise ValueError(f"image_size must be a power-of-two ≥ 32, got {image_size}")
self.latent_dim = latent_dim self.latent_dim = latent_dim
self.image_size = image_size
n_down = int(math.log2(image_size)) - 2 # steps from image_size to 4×4 n_down = int(math.log2(image_size)) - 2 # steps from image_size to 4×4
# 64 → n_down=4: 64→32→16→8→4 # 64 → n_down=4: 64→32→16→8→4
@@ -115,19 +102,20 @@ class VAE(nn.Module):
h = self.fc_dec(z).view(z.size(0), self._dec_ch, 4, 4) h = self.fc_dec(z).view(z.size(0), self._dec_ch, 4, 4)
return self.decoder(h) return self.decoder(h)
# Returns (reconstruction, mu, log_var)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (reconstruction, mu, log_var)."""
mu, log_var = self.encode(x) mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var) z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var return self.decode(z), mu, log_var
# Samples n images by drawing z ~ N(0, I) and decoding
@torch.no_grad() @torch.no_grad()
def sample(self, n: int, device) -> torch.Tensor: def sample(self, n: int, device) -> torch.Tensor:
"""Sample n images by drawing z ~ N(0, I)."""
z = torch.randn(n, self.latent_dim, device=device) z = torch.randn(n, self.latent_dim, device=device)
return self.decode(z) return self.decode(z)
# Builds a VAE from cfg
def _build(cfg: dict): def _build(cfg: dict):
return VAE( return VAE(
latent_dim=cfg.get("latent_dim", 256), latent_dim=cfg.get("latent_dim", 256),
+12 -34
View File
@@ -1,8 +1,6 @@
"""WGAN-GP variants. # WGAN-GP variants.
# wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only.
wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only. # wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic.
wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic.
"""
import math import math
import torch import torch
@@ -23,9 +21,8 @@ def _sn(module):
return nn.utils.spectral_norm(module) return nn.utils.spectral_norm(module)
# SAGAN-style self-attention: gamma-gated residual, dot-products scaled by mid^-0.5
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
"""SAGAN-style self-attention."""
def __init__(self, in_ch: int): def __init__(self, in_ch: int):
super().__init__() super().__init__()
mid = max(in_ch // 8, 1) mid = max(in_ch // 8, 1)
@@ -48,13 +45,8 @@ class SelfAttention(nn.Module):
# Phase 2.2 — basic WGAN-GP (BatchNorm in G, InstanceNorm in D, 64×64 only) # Phase 2.2 — basic WGAN-GP (BatchNorm in G, InstanceNorm in D, 64×64 only)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Maps (latent_dim, 1, 1) → (3, 64, 64) in [-1, 1]; BatchNorm in G is safe under WGAN-GP (constraint targets the critic)
class WGANBasicGenerator(nn.Module): class WGANBasicGenerator(nn.Module):
"""Maps (latent_dim, 1, 1) -> (3, 64, 64) in [-1, 1].
Same channel structure as DCGAN. BatchNorm in generator is fine because
WGAN-GP's constraint targets the critic, not the generator.
"""
def __init__(self, latent_dim: int = 128, ngf: int = 64): def __init__(self, latent_dim: int = 128, ngf: int = 64):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
@@ -80,11 +72,8 @@ class WGANBasicGenerator(nn.Module):
return self.net(z) return self.net(z)
# WGAN-GP critic (64×64); InstanceNorm instead of BatchNorm — BatchNorm breaks the per-sample Lipschitz constraint
class WGANBasicCritic(nn.Module): class WGANBasicCritic(nn.Module):
"""WGAN-GP critic (64×64). InstanceNorm instead of BatchNorm — BatchNorm
breaks the per-sample Lipschitz constraint the gradient penalty enforces.
"""
def __init__(self, ndf: int = 64): def __init__(self, ndf: int = 64):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
@@ -116,21 +105,14 @@ class WGANBasicCritic(nn.Module):
# Phase 2.3 / 2.4 — advanced WGAN-GP (GroupNorm, SpectralNorm, attention) # Phase 2.3 / 2.4 — advanced WGAN-GP (GroupNorm, SpectralNorm, attention)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GroupNorm generator with SAGAN self-attention; supports image_size ∈ {64, 128}
# Stem always 1×1 → 4×4 → 8×8 → 16×16; attention at 16×16, and at 32×32 for 128×128
class WGANGenerator(nn.Module): class WGANGenerator(nn.Module):
"""GroupNorm generator with SAGAN self-attention.
Supports image_size ∈ {64, 128}.
Stem is always 1×1 → 4×4 → 8×8 → 16×16 (ngf×8 → ngf×4 → ngf×2 channels).
Attention at 16×16 always; additional attention at 32×32 for 128×128.
"""
def __init__(self, latent_dim: int = 128, ngf: int = 128, image_size: int = 64): def __init__(self, latent_dim: int = 128, ngf: int = 128, image_size: int = 64):
super().__init__() super().__init__()
if image_size not in (64, 128): if image_size not in (64, 128):
raise ValueError(f"WGANGenerator supports image_size 64 or 128, got {image_size}") raise ValueError(f"WGANGenerator supports image_size 64 or 128, got {image_size}")
self._image_size = image_size
self.stem = nn.Sequential( self.stem = nn.Sequential(
# 1×1 → 4×4 # 1×1 → 4×4
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False), nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
@@ -179,20 +161,14 @@ class WGANGenerator(nn.Module):
return self.tail(h) return self.tail(h)
# SpectralNorm critic with SAGAN self-attention; supports image_size ∈ {64, 128}
# Attention at 16×16 always; additional attention at 32×32 for 128×128
class WGANCritic(nn.Module): class WGANCritic(nn.Module):
"""SpectralNorm critic with SAGAN self-attention.
Supports image_size ∈ {64, 128}.
Attention at 16×16 always; additional attention at 32×32 for 128×128.
"""
def __init__(self, ndf: int = 128, image_size: int = 64): def __init__(self, ndf: int = 128, image_size: int = 64):
super().__init__() super().__init__()
if image_size not in (64, 128): if image_size not in (64, 128):
raise ValueError(f"WGANCritic supports image_size 64 or 128, got {image_size}") raise ValueError(f"WGANCritic supports image_size 64 or 128, got {image_size}")
self._image_size = image_size
if image_size == 64: if image_size == 64:
# Head: 64→32 (ndf//2) # Head: 64→32 (ndf//2)
self.head = nn.Sequential( self.head = nn.Sequential(
@@ -242,6 +218,7 @@ class WGANCritic(nn.Module):
return self.tail(h).view(x.size(0)) return self.tail(h).view(x.size(0))
# Builds a (WGANBasicGenerator, WGANBasicCritic) pair from cfg
def _build_basic(cfg: dict): def _build_basic(cfg: dict):
return ( return (
WGANBasicGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 64)), WGANBasicGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 64)),
@@ -249,6 +226,7 @@ def _build_basic(cfg: dict):
) )
# Builds a (WGANGenerator, WGANCritic) pair from cfg
def _build(cfg: dict): def _build(cfg: dict):
image_size = cfg.get("image_size", 64) image_size = cfg.get("image_size", 64)
return ( return (
+9 -24
View File
@@ -1,13 +1,5 @@
"""Gaussian diffusion utilities for Phase 4 (DDPM). # Gaussian diffusion utilities for Phase 4 (DDPM): noise schedules, forward process, training loss, DDIM sampling.
# alpha_bars[t] = ᾱ_{t+1} (0-indexed); t=0 is near-clean (ᾱ ≈ 1−β₁), t=T1 is near-pure-noise (ᾱ ≈ 0).
Provides noise schedules, the forward (noising) process, training loss,
and DDIM deterministic sampling (Song et al., 2020).
Convention: alpha_bars is a 1-D tensor of length T, where alpha_bars[t]
= ᾱ_{t+1} in 1-indexed notation. Timestep t used in the training loop
is a 0-indexed integer in [0, T). At t=0 the image is almost clean
(ᾱ ≈ 1 β_1); at t=T1 the image is almost pure noise (ᾱ ≈ 0).
"""
import math import math
import torch import torch
@@ -16,13 +8,13 @@ import torch.nn.functional as F
# ── Noise schedules ────────────────────────────────────────────────────────── # ── Noise schedules ──────────────────────────────────────────────────────────
# Ho et al. (2020) linear schedule
def linear_betas(T: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor: def linear_betas(T: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor:
"""Ho et al. (2020) linear schedule."""
return torch.linspace(beta_start, beta_end, T) return torch.linspace(beta_start, beta_end, T)
# Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t
def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor: def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor:
"""Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t."""
t = torch.linspace(0, T, T + 1) t = torch.linspace(0, T, T + 1)
f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2 f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
alpha_bar = f / f[0] alpha_bar = f / f[0]
@@ -30,20 +22,20 @@ def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor:
return betas.clamp(max=0.999) return betas.clamp(max=0.999)
# Cumulative product of (1 β), shape (T,)
def make_alpha_bars(betas: torch.Tensor) -> torch.Tensor: def make_alpha_bars(betas: torch.Tensor) -> torch.Tensor:
"""Cumulative product of (1 β), shape (T,)."""
return (1.0 - betas).cumprod(0) return (1.0 - betas).cumprod(0)
# ── Forward process ────────────────────────────────────────────────────────── # ── Forward process ──────────────────────────────────────────────────────────
# Adds noise to x0 at timestep t; returns (x_t, noise)
def q_sample( def q_sample(
x0: torch.Tensor, x0: torch.Tensor,
t: torch.Tensor, t: torch.Tensor,
alpha_bars: torch.Tensor, alpha_bars: torch.Tensor,
noise: torch.Tensor | None = None, noise: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Add noise to x0 at timestep t. Returns (x_t, noise)."""
if noise is None: if noise is None:
noise = torch.randn_like(x0) noise = torch.randn_like(x0)
ab = alpha_bars[t].to(x0.device)[:, None, None, None] ab = alpha_bars[t].to(x0.device)[:, None, None, None]
@@ -53,6 +45,7 @@ def q_sample(
# ── Training loss ──────────────────────────────────────────────────────────── # ── Training loss ────────────────────────────────────────────────────────────
# MSE between model prediction and target; pred_type="eps" targets noise ε, pred_type="v" targets v = √ᾱ·ε √(1−ᾱ)·x0
def diffusion_loss( def diffusion_loss(
model, model,
x0: torch.Tensor, x0: torch.Tensor,
@@ -60,11 +53,6 @@ def diffusion_loss(
alpha_bars: torch.Tensor, alpha_bars: torch.Tensor,
pred_type: str = "eps", pred_type: str = "eps",
) -> torch.Tensor: ) -> torch.Tensor:
"""MSE on the model's prediction vs the true target.
pred_type="eps" → target is the added noise ε (Ho et al.)
pred_type="v" → target is v = √ᾱ·ε √(1−ᾱ)·x0 (Salimans & Ho)
"""
x_t, noise = q_sample(x0, t, alpha_bars) x_t, noise = q_sample(x0, t, alpha_bars)
pred = model(x_t, t) pred = model(x_t, t)
@@ -80,6 +68,8 @@ def diffusion_loss(
# ── DDIM deterministic sampling ─────────────────────────────────────────────── # ── DDIM deterministic sampling ───────────────────────────────────────────────
@torch.no_grad() @torch.no_grad()
# Generates n images via DDIM (eta=0, deterministic); batches internally to avoid OOM
# Returns tensor shape (n, 3, image_size, image_size) in [-1, 1]
def ddim_sample( def ddim_sample(
model, model,
n: int, n: int,
@@ -90,11 +80,6 @@ def ddim_sample(
device: str = "cuda", device: str = "cuda",
batch_size: int = 32, batch_size: int = 32,
) -> torch.Tensor: ) -> torch.Tensor:
"""Generate n images via DDIM (eta=0, deterministic).
Batches internally to avoid OOM when n is large.
Returns tensor shape (n, 3, image_size, image_size) in [-1, 1].
"""
model.eval() model.eval()
T = len(alpha_bars) T = len(alpha_bars)
+2 -6
View File
@@ -3,13 +3,9 @@ import torch
import torch.nn as nn import torch.nn as nn
# Exponential moving average of model weights; shadow copy updated after each optimizer step
# Sample from ema.model at eval time, never from the training model directly
class EMA: class EMA:
"""Exponential moving average of model weights.
Maintains a shadow copy of the model. Call update() after each
optimizer step. Sample from ema.model, never from the training model.
"""
def __init__(self, model: nn.Module, decay: float = 0.9999): def __init__(self, model: nn.Module, decay: float = 0.9999):
self.decay = decay self.decay = decay
self.model = copy.deepcopy(model).eval() self.model = copy.deepcopy(model).eval()
+9 -16
View File
@@ -1,20 +1,17 @@
"""FID evaluation helper. # FID evaluation helper: caches real images as a CPU tensor to avoid re-reading disk on every eval call.
Computes Fréchet Inception Distance between a fixed set of real images
and a batch of generated images. Real images are stored as a tensor on CPU
and moved to device only during evaluation — this avoids re-reading disk
every call while keeping GPU memory free between evaluations.
"""
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.fid import FrechetInceptionDistance
# Computes FID between cached real images and a batch of generated images
class FIDEvaluator: class FIDEvaluator:
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda", def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
num_workers: int = 2): num_workers: int = 2):
self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.n_real = n_real self.n_real = n_real
# Inception network loaded once here; compute() calls reset() between evaluations
self._fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device)
# Cache real images as a CPU tensor ([-1, 1] range) # Cache real images as a CPU tensor ([-1, 1] range)
imgs_list = [] imgs_list = []
@@ -27,24 +24,20 @@ class FIDEvaluator:
real = torch.cat(imgs_list)[:n_real] real = torch.cat(imgs_list)[:n_real]
self._real = real # stored on CPU, shape (N, 3, H, W) in [-1, 1] self._real = real # stored on CPU, shape (N, 3, H, W) in [-1, 1]
# Computes FID score; fake_imgs should be float in [-1, 1], shape (N, 3, H, W), N ≥ 2048 for reliability
@torch.no_grad() @torch.no_grad()
def compute(self, fake_imgs: torch.Tensor) -> float: def compute(self, fake_imgs: torch.Tensor) -> float:
"""Compute FID score. self._fid.reset()
fake_imgs: float tensor in [-1, 1], shape (N, 3, H, W).
N should be at least 2048 for a reliable score.
"""
fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device)
# Feed real images in batches # Feed real images in batches
for i in range(0, self._real.size(0), 256): for i in range(0, self._real.size(0), 256):
batch = (self._real[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device) batch = (self._real[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
fid.update(batch, real=True) self._fid.update(batch, real=True)
# Feed fake images in batches # Feed fake images in batches
fake = fake_imgs.cpu() fake = fake_imgs.cpu()
for i in range(0, fake.size(0), 256): for i in range(0, fake.size(0), 256):
batch = (fake[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device) batch = (fake[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
fid.update(batch, real=False) self._fid.update(batch, real=False)
return float(fid.compute()) return float(self._fid.compute())
+5 -16
View File
@@ -1,25 +1,17 @@
"""Extended generation quality metrics for Phase 5 cross-family comparison. # Extended generation quality metrics for Phase 5 cross-family comparison.
# IS (Salimans et al., 2016): quality × diversity. LPIPS diversity: pairwise perceptual distance.
IS — Inception Score (Salimans et al., 2016): measures sample quality × diversity. # Both functions accept float tensors in [-1, 1].
LPIPS — average pairwise learned perceptual distance: measures sample diversity alone.
Both functions accept float tensors in [-1, 1].
"""
import torch import torch
from torchmetrics.image.inception import InceptionScore from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
# Inception Score (mean ± std) over 10 splits; imgs in [-1, 1], N ≥ 2048 for reliability; returns (mean, std)
def compute_is( def compute_is(
imgs: torch.Tensor, imgs: torch.Tensor,
device: str = "cuda", device: str = "cuda",
batch_size: int = 64, batch_size: int = 64,
) -> tuple[float, float]: ) -> tuple[float, float]:
"""Inception Score (mean ± std) over 10 splits.
imgs: (N, 3, H, W) in [-1, 1]. N ≥ 2 048 for a reliable estimate.
Returns (is_mean, is_std).
"""
metric = InceptionScore(normalize=True).to(device) metric = InceptionScore(normalize=True).to(device)
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5) imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
for i in range(0, len(imgs_01), batch_size): for i in range(0, len(imgs_01), batch_size):
@@ -28,16 +20,13 @@ def compute_is(
return float(mean), float(std) return float(mean), float(std)
# Average pairwise LPIPS distance over n_pairs random (i ≠ j) pairs; higher = more diverse
def compute_lpips_diversity( def compute_lpips_diversity(
imgs: torch.Tensor, imgs: torch.Tensor,
n_pairs: int = 200, n_pairs: int = 200,
device: str = "cuda", device: str = "cuda",
batch_size: int = 16, batch_size: int = 16,
) -> float: ) -> float:
"""Average pairwise LPIPS distance — higher means more diverse samples.
imgs: (N, 3, H, W) in [-1, 1]. Samples n_pairs random (i, j) pairs with i ≠ j.
"""
metric = LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True).to(device) metric = LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True).to(device)
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5) imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
N = len(imgs_01) N = len(imgs_01)
+5 -18
View File
@@ -1,26 +1,13 @@
"""VGG-16 perceptual loss for Phase 3.2 and 3.3. # VGG-16 perceptual loss (Phase 3.2/3.3): L1 feature distance at relu1_2, relu2_2, relu3_3.
# Expects images in [-1, 1]; converts internally to [0, 1] then ImageNet-normalises before VGG. Weights frozen.
Extracts features at relu1_2, relu2_2, relu3_3 and returns the
L1 distance in feature space. VGG weights are frozen.
Input convention: images in [-1, 1] — the loss converts internally to
[0, 1] and then applies ImageNet normalisation before passing to VGG.
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.models as tv_models import torchvision.models as tv_models
# L1 feature-matching loss at relu1_2 ([:4]), relu2_2 ([4:9]), relu3_3 ([9:16]) of VGG-16
class PerceptualLoss(nn.Module): class PerceptualLoss(nn.Module):
"""L1 feature-matching loss at three VGG-16 layers.
VGG-16 feature indices:
relu1_2: features[:4] (before first maxpool)
relu2_2: features[4:9] (before second maxpool)
relu3_3: features[9:16] (before third maxpool)
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
vgg = tv_models.vgg16(weights=tv_models.VGG16_Weights.IMAGENET1K_V1) vgg = tv_models.vgg16(weights=tv_models.VGG16_Weights.IMAGENET1K_V1)
@@ -39,13 +26,13 @@ class PerceptualLoss(nn.Module):
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) "std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
) )
# Converts [-1, 1] → [0, 1] then applies ImageNet normalisation
def _normalise(self, x: torch.Tensor) -> torch.Tensor: def _normalise(self, x: torch.Tensor) -> torch.Tensor:
"""Convert [-1, 1] → ImageNet-normalised [0, 1]."""
x = x * 0.5 + 0.5 # → [0, 1] x = x * 0.5 + 0.5 # → [0, 1]
return (x - self.mean) / self.std return (x - self.mean) / self.std
# L1 feature distance; real gradients are stopped — only fake trains
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor: def forward(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
"""L1 feature distance. real gradients are stopped — only fake trains."""
f = self._normalise(fake) f = self._normalise(fake)
r = self._normalise(real) r = self._normalise(real)
+106 -82
View File
@@ -21,6 +21,31 @@ else:
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw) _autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
# Linear decay from decay_start to epochs; returns a multiplier clamped to [0, 1]
def _linear_decay(decay_start: int, epochs: int):
return lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))
# Returns (new_best_fid, new_no_improve, should_stop); improved = new_best_fid < best_fid
def _check_fid_plateau(fid_score: float, best_fid: float, no_improve: int, patience) -> tuple[float, int, bool]:
if fid_score < best_fid:
return fid_score, 0, False
no_improve += 1
if patience and no_improve >= patience:
print(f"FID plateau — {no_improve} evals without improvement, stopping")
return best_fid, no_improve, True
return best_fid, no_improve, False
# Generates n GAN fake images in batches of 64 using the EMA model
@torch.no_grad()
def _generate_gan_fakes(ema_model, n: int, latent_dim: int, device) -> torch.Tensor:
return torch.cat([
ema_model(torch.randn(64, latent_dim, 1, 1, device=device))
for _ in range(n // 64 + 1)
])[:n]
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None: def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None:
samples_dir.mkdir(parents=True, exist_ok=True) samples_dir.mkdir(parents=True, exist_ok=True)
with torch.no_grad(): with torch.no_grad():
@@ -29,6 +54,7 @@ def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise:
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4) save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
# Vanilla DCGAN training loop (Phase 1): BCE loss, single G/D step per batch, no gradient penalty
def train_dcgan( def train_dcgan(
generator, generator,
discriminator, discriminator,
@@ -39,11 +65,6 @@ def train_dcgan(
run_name: str, run_name: str,
device: str = "cuda", device: str = "cuda",
) -> dict: ) -> dict:
"""Vanilla DCGAN training loop with BCE loss (Radford et al., 2015).
Used as the Phase 1 baseline for cheap pipeline ablations. No gradient
penalty, no n_critic, single G/D step per batch.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu") device = torch.device(device if torch.cuda.is_available() else "cpu")
generator = generator.to(device) generator = generator.to(device)
discriminator = discriminator.to(device) discriminator = discriminator.to(device)
@@ -79,6 +100,10 @@ def train_dcgan(
ema = EMA(generator, decay=ema_decay) ema = EMA(generator, decay=ema_decay)
if hasattr(torch, "compile"):
generator = torch.compile(generator)
discriminator = torch.compile(discriminator)
# Fixed noise for consistent sample tracking across epochs # Fixed noise for consistent sample tracking across epochs
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device) fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
@@ -91,18 +116,22 @@ def train_dcgan(
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}} history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf") best_fid = float("inf")
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}") print(f"Device: {device} AMP: {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)}")
# Linear LR decay from epoch epochs//2 to epochs # Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2 decay_start = epochs // 2
sched_g = torch.optim.lr_scheduler.LambdaLR( sched_g = torch.optim.lr_scheduler.LambdaLR(opt_g, _linear_decay(decay_start, epochs))
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) sched_d = torch.optim.lr_scheduler.LambdaLR(opt_d, _linear_decay(decay_start, epochs))
sched_d = torch.optim.lr_scheduler.LambdaLR(
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
t_start = time.time() t_start = time.time()
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
fid_patience = cfg.get("fid_patience", None)
no_improve = 0
for epoch in range(1, epochs + 1): for epoch in range(1, epochs + 1):
if max_train_secs and time.time() - t_start >= max_train_secs:
print(f"Time limit reached — stopping after epoch {epoch - 1}")
break
generator.train() generator.train()
discriminator.train() discriminator.train()
g_sum = d_sum = real_sum = fake_sum = 0.0 g_sum = d_sum = real_sum = fake_sum = 0.0
@@ -160,19 +189,18 @@ def train_dcgan(
if epoch % fid_interval == 0: if epoch % fid_interval == 0:
ema.model.eval() ema.model.eval()
with torch.no_grad(): fake_imgs = _generate_gan_fakes(ema.model, fid_n_real, latent_dim, device)
fake_imgs = torch.cat([
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
for _ in range(fid_n_real // 64 + 1)
])[:fid_n_real]
fid_score = fid_eval.compute(fake_imgs) fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}") print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid: new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
best_fid = fid_score if new_best < best_fid:
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt") torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
best_fid = new_best
if plateau:
break
sched_g.step() sched_g.step()
sched_d.step() sched_d.step()
@@ -184,8 +212,8 @@ def train_dcgan(
return history return history
# Two-sided gradient penalty (Gulrajani et al., 2017)
def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor: def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor:
"""Two-sided gradient penalty (Gulrajani et al., 2017)."""
bsz = real.size(0) bsz = real.size(0)
eps = torch.rand(bsz, 1, 1, 1, device=device) eps = torch.rand(bsz, 1, 1, 1, device=device)
interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True) interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True)
@@ -200,6 +228,8 @@ def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) ->
return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean() return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean()
# WGAN-GP training loop (Phase 2.22.4): gradient penalty replaces weight clipping
# Critic runs in float32 for GP stability; AMP used only for the generator forward/backward
def train_wgan( def train_wgan(
generator, generator,
critic, critic,
@@ -210,12 +240,6 @@ def train_wgan(
run_name: str, run_name: str,
device: str = "cuda", device: str = "cuda",
) -> dict: ) -> dict:
"""WGAN-GP training loop (Gulrajani et al., 2017).
Used for Phase 2.22.4. Gradient penalty replaces weight clipping.
The critic runs in float32 to keep GP gradient computation numerically
stable; AMP is used only for the generator forward/backward.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu") device = torch.device(device if torch.cuda.is_available() else "cpu")
generator = generator.to(device) generator = generator.to(device)
critic = critic.to(device) critic = critic.to(device)
@@ -251,6 +275,10 @@ def train_wgan(
ema = EMA(generator, decay=ema_decay) ema = EMA(generator, decay=ema_decay)
if hasattr(torch, "compile"):
generator = torch.compile(generator)
critic = torch.compile(critic)
# Fixed noise for consistent sample tracking across epochs # Fixed noise for consistent sample tracking across epochs
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device) fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
@@ -263,18 +291,22 @@ def train_wgan(
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}} history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf") best_fid = float("inf")
print(f"Device: {device} AMP (G only): {use_amp} Batches/epoch: {len(loader)} n_critic: {n_critic}") print(f"Device: {device} AMP (G only): {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)} n_critic: {n_critic}")
# Linear LR decay from epoch epochs//2 to epochs # Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2 decay_start = epochs // 2
sched_g = torch.optim.lr_scheduler.LambdaLR( sched_g = torch.optim.lr_scheduler.LambdaLR(opt_g, _linear_decay(decay_start, epochs))
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) sched_c = torch.optim.lr_scheduler.LambdaLR(opt_c, _linear_decay(decay_start, epochs))
sched_c = torch.optim.lr_scheduler.LambdaLR(
opt_c, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
t_start = time.time() t_start = time.time()
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
fid_patience = cfg.get("fid_patience", None)
no_improve = 0
for epoch in range(1, epochs + 1): for epoch in range(1, epochs + 1):
if max_train_secs and time.time() - t_start >= max_train_secs:
print(f"Time limit reached — stopping after epoch {epoch - 1}")
break
generator.train() generator.train()
critic.train() critic.train()
g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0 g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0
@@ -292,11 +324,11 @@ def train_wgan(
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device)) fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
real_f32 = real.float() real_f32 = real.float()
fake_f32 = fake.float().detach() fake_f32 = fake.float() # already no_grad from torch.no_grad() context above
d_real = critic(real_f32) d_real = critic(real_f32)
d_fake = critic(fake_f32) d_fake = critic(fake_f32)
gp = _gradient_penalty(critic, real_f32, fake_f32.detach(), device) gp = _gradient_penalty(critic, real_f32, fake_f32, device)
c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp
c_loss.backward() c_loss.backward()
opt_c.step() opt_c.step()
@@ -342,19 +374,18 @@ def train_wgan(
if epoch % fid_interval == 0: if epoch % fid_interval == 0:
ema.model.eval() ema.model.eval()
with torch.no_grad(): fake_imgs = _generate_gan_fakes(ema.model, fid_n_real, latent_dim, device)
fake_imgs = torch.cat([
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
for _ in range(fid_n_real // 64 + 1)
])[:fid_n_real]
fid_score = fid_eval.compute(fake_imgs) fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}") print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid: new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
best_fid = fid_score if new_best < best_fid:
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt") torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
best_fid = new_best
if plateau:
break
sched_g.step() sched_g.step()
sched_c.step() sched_c.step()
@@ -370,6 +401,7 @@ def train_wgan(
# Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN) # Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN)
# ──────────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────────
# Saves prior samples and a real-vs-reconstruction interleaved grid for the epoch
def _save_vae_samples( def _save_vae_samples(
vae, vae,
samples_dir: Path, samples_dir: Path,
@@ -379,7 +411,6 @@ def _save_vae_samples(
fixed_real: torch.Tensor, fixed_real: torch.Tensor,
device, device,
) -> None: ) -> None:
"""Save prior samples and a real-vs-reconstruction grid side by side."""
samples_dir.mkdir(parents=True, exist_ok=True) samples_dir.mkdir(parents=True, exist_ok=True)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
@@ -395,6 +426,8 @@ def _save_vae_samples(
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4) save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
# VAE training loop (Phase 3.13.3): lambda_perceptual > 0 adds VGG loss, lambda_adversarial > 0 adds PatchGAN
# AMP disabled — float16 overflows on KL spikes causing unrecoverable NaN cascades; everything runs in float32
def train_vae( def train_vae(
vae, vae,
train_dataset, train_dataset,
@@ -404,22 +437,6 @@ def train_vae(
run_name: str, run_name: str,
device: str = "cuda", device: str = "cuda",
) -> dict: ) -> dict:
"""VAE training loop covering Phase 3.1 3.3 and Phase 5.
Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
KL is computed as mean over latent dimensions (scale-invariant), so
beta_kl is comparable across different latent_dim values.
AMP is intentionally disabled for VAE training — mixed-precision float16
overflows when the KL divergence spikes, producing NaN cascades that
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
computation runs in float32.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu") device = torch.device(device if torch.cuda.is_available() else "cpu")
vae = vae.to(device) vae = vae.to(device)
@@ -450,16 +467,13 @@ def train_vae(
) )
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr) opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
# AMP disabled — float16 overflows on KL spikes, causing NaN cascades
use_amp = False
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training # KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
kl_warmup_epochs = max(1, epochs // 5) kl_warmup_epochs = max(1, epochs // 5)
# Linear LR decay from epoch epochs//2 to epochs # Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2 decay_start = epochs // 2
sched_vae = torch.optim.lr_scheduler.LambdaLR( sched_vae = torch.optim.lr_scheduler.LambdaLR(opt_vae, _linear_decay(decay_start, epochs))
opt_vae, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
sched_d = None # set below if adversarial sched_d = None # set below if adversarial
# ── Optional components ─────────────────────────────────────────────── # ── Optional components ───────────────────────────────────────────────
@@ -479,12 +493,11 @@ def train_vae(
image_size=cfg.get("image_size", 64), image_size=cfg.get("image_size", 64),
).to(device).float() ).to(device).float()
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999)) opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
sched_d = torch.optim.lr_scheduler.LambdaLR( sched_d = torch.optim.lr_scheduler.LambdaLR(opt_d, _linear_decay(decay_start, epochs))
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
n_d = sum(p.numel() for p in patchgan.parameters()) n_d = sum(p.numel() for p in patchgan.parameters())
print(f"PatchGAN: {n_d:,} params") print(f"PatchGAN: {n_d:,} params")
else: if hasattr(torch, "compile"):
hinge_d_loss = hinge_g_loss = None # never called patchgan = torch.compile(patchgan)
# ── Fixed seeds for consistent visualisation ────────────────────────── # ── Fixed seeds for consistent visualisation ──────────────────────────
fixed_z = torch.randn(16, latent_dim, device=device) fixed_z = torch.randn(16, latent_dim, device=device)
@@ -494,6 +507,9 @@ def train_vae(
ema = EMA(vae, decay=ema_decay) ema = EMA(vae, decay=ema_decay)
if hasattr(torch, "compile"):
vae = torch.compile(vae)
save_dir = Path(save_dir) save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name samples_dir = save_dir.parent / "samples" / run_name
@@ -514,8 +530,14 @@ def train_vae(
) )
t_start = time.time() t_start = time.time()
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
fid_patience = cfg.get("fid_patience", None)
no_improve = 0
for epoch in range(1, epochs + 1): for epoch in range(1, epochs + 1):
if max_train_secs and time.time() - t_start >= max_train_secs:
print(f"Time limit reached — stopping after epoch {epoch - 1}")
break
vae.train() vae.train()
if patchgan is not None: if patchgan is not None:
patchgan.train() patchgan.train()
@@ -614,10 +636,13 @@ def train_vae(
history["fid"][epoch] = fid_score history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}") print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid: new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
best_fid = fid_score if new_best < best_fid:
torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt") torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
best_fid = new_best
if plateau:
break
sched_vae.step() sched_vae.step()
if sched_d is not None: if sched_d is not None:
@@ -636,6 +661,7 @@ def train_vae(
# Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider) # Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider)
# ──────────────────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────────────────────
# DDPM training loop (Phase 4.14.4): noise_schedule="linear"/"cosine", pred_type="eps"/"v"
def train_ddpm( def train_ddpm(
model, model,
train_dataset, train_dataset,
@@ -645,15 +671,6 @@ def train_ddpm(
run_name: str, run_name: str,
device: str = "cuda", device: str = "cuda",
) -> dict: ) -> dict:
"""DDPM training loop (Ho et al., 2020) covering Phase 4.1 4.4.
Config keys:
noise_schedule — "linear" (4.1) or "cosine" (4.2+)
pred_type — "eps" (4.14.2) or "v" (4.3+)
T — diffusion timesteps (default 1000)
base_ch / ch_mult / attn_resolutions — U-Net capacity (see unet.py)
ddim_steps — DDIM steps for FID evaluation (default 100)
"""
from src.training.diffusion import ( from src.training.diffusion import (
linear_betas, cosine_betas, make_alpha_bars, linear_betas, cosine_betas, make_alpha_bars,
diffusion_loss, ddim_sample, diffusion_loss, ddim_sample,
@@ -694,8 +711,8 @@ def train_ddpm(
ema = EMA(model, decay=ema_decay) ema = EMA(model, decay=ema_decay)
# Fixed noise for sample visualisation (same latents across epochs) if hasattr(torch, "compile"):
fixed_noise = torch.randn(16, 3, image_size, image_size, device=device) model = torch.compile(model)
save_dir = Path(save_dir) save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
@@ -707,18 +724,23 @@ def train_ddpm(
history = {"loss": [], "fid": {}} history = {"loss": [], "fid": {}}
best_fid = float("inf") best_fid = float("inf")
print( print(
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}" f"Device: {device} AMP: {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)}"
f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}" f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}"
) )
# Linear LR decay from epoch epochs//2 to epochs # Linear LR decay from epoch epochs//2 to epochs
decay_start = epochs // 2 decay_start = epochs // 2
sched = torch.optim.lr_scheduler.LambdaLR( sched = torch.optim.lr_scheduler.LambdaLR(opt, _linear_decay(decay_start, epochs))
opt, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
t_start = time.time() t_start = time.time()
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
fid_patience = cfg.get("fid_patience", None)
no_improve = 0
for epoch in range(1, epochs + 1): for epoch in range(1, epochs + 1):
if max_train_secs and time.time() - t_start >= max_train_secs:
print(f"Time limit reached — stopping after epoch {epoch - 1}")
break
model.train() model.train()
loss_sum = 0.0 loss_sum = 0.0
n_batches = 0 n_batches = 0
@@ -749,7 +771,6 @@ def train_ddpm(
samples_dir.mkdir(parents=True, exist_ok=True) samples_dir.mkdir(parents=True, exist_ok=True)
ema.model.eval() ema.model.eval()
with torch.no_grad(): with torch.no_grad():
# Quick visualisation: denoise fixed_noise via DDIM
imgs = ddim_sample( imgs = ddim_sample(
ema.model, 16, image_size, alpha_bars, ema.model, 16, image_size, alpha_bars,
n_steps=50, pred_type=pred_type, device=str(device), batch_size=16, n_steps=50, pred_type=pred_type, device=str(device), batch_size=16,
@@ -768,10 +789,13 @@ def train_ddpm(
history["fid"][epoch] = fid_score history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}") print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid: new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
best_fid = fid_score if new_best < best_fid:
torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt") torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
best_fid = new_best
if plateau:
break
sched.step() sched.step()