Phase 5 preparation
This commit is contained in:
@@ -6,13 +6,8 @@ import torchvision.transforms as T
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
# Unlabeled image dataset for generative model training; returns tensors only, no labels
|
||||
class GeneratorDataset(Dataset):
|
||||
"""Unlabeled image dataset for generative model training.
|
||||
|
||||
Loads images from source subdirectories and returns tensors only —
|
||||
no labels, since generation is unsupervised.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir, sources=None, subsample=1.0, transform=None, seed=42):
|
||||
self.transform = transform
|
||||
self.samples = []
|
||||
@@ -51,13 +46,8 @@ class GeneratorDataset(Dataset):
|
||||
return img
|
||||
|
||||
|
||||
# Builds transform pipeline outputting [-1, 1]; augment=False (none), "hflip" (flip only), True (flip+rot+jitter)
|
||||
def get_transform(image_size: int, augment=False) -> T.Compose:
|
||||
"""Build transform for generator training. Output is in [-1, 1].
|
||||
|
||||
augment=False — no augmentation (for FID real-image sets)
|
||||
augment="hflip" — horizontal flip only (recommended for VAE/DDPM)
|
||||
augment=True — H-flip + rotation ±5° + mild color jitter (for GAN)
|
||||
"""
|
||||
ops = [
|
||||
T.Resize(image_size),
|
||||
T.CenterCrop(image_size),
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
from typing import Callable
|
||||
import torch.nn as nn
|
||||
|
||||
# Maps model name → (builder, kind); populated by each model module at import time
|
||||
_REGISTRY: dict[str, tuple[Callable, str]] = {}
|
||||
|
||||
|
||||
# Called by each model module to advertise its name and kind to get_model
|
||||
def register(name: str, builder: Callable, *, kind: str) -> None:
|
||||
_REGISTRY[name] = (builder, kind)
|
||||
|
||||
|
||||
# Returns (model_or_pair, kind) for the model named in cfg["model"]
|
||||
# kind in {"dcgan", "wgan"} → (generator, critic) | kind in {"vae", "ddpm"} → model
|
||||
def get_model(cfg: dict) -> tuple:
|
||||
"""Return (model_or_pair, kind).
|
||||
|
||||
kind="dcgan" -> (generator, discriminator)
|
||||
"""
|
||||
name = cfg.get("model")
|
||||
entry = _REGISTRY.get(name)
|
||||
if entry is None:
|
||||
@@ -22,6 +22,7 @@ def get_model(cfg: dict) -> tuple:
|
||||
return builder(cfg), kind
|
||||
|
||||
|
||||
# Importing each module triggers its register() calls
|
||||
from src.models import dcgan # noqa: E402, F401
|
||||
from src.models import wgan # noqa: E402, F401
|
||||
from src.models import vae # noqa: E402, F401
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
"""Vanilla DCGAN (Radford et al., 2015).
|
||||
|
||||
Used as the Phase 1 baseline for cheap pipeline ablations. Architecture is
|
||||
intentionally minimal — BatchNorm in both networks, no spectral norm, no
|
||||
attention, no gradient penalty. The whole point is to be the cheapest GAN
|
||||
we can run, so 1A–1D pipeline deltas show up in FID quickly.
|
||||
|
||||
Depth scales with image_size: each step doubles the spatial dimension,
|
||||
starting from 4×4 after the first transposed conv.
|
||||
64 -> 5 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64)
|
||||
128 -> 6 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64 -> 128)
|
||||
"""
|
||||
# Vanilla DCGAN (Radford et al., 2015) — Phase 1 baseline for cheap pipeline ablations.
|
||||
# Depth scales with image_size; each step doubles spatial size from a 4×4 stem.
|
||||
# 64 → 5 upsample steps | 128 → 6 upsample steps
|
||||
import math
|
||||
|
||||
import torch
|
||||
@@ -27,16 +18,15 @@ def _init_weights(m):
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
# Number of 2× upsampling steps from the 4×4 stem to image_size
|
||||
def _n_upsamples(image_size: int) -> int:
|
||||
"""Number of 2x upsampling steps from 4x4 to image_size."""
|
||||
if image_size < 8 or image_size & (image_size - 1):
|
||||
raise ValueError(f"image_size must be a power of two ≥ 8, got {image_size}")
|
||||
return int(math.log2(image_size)) - 2 # 64 -> 4, 128 -> 5
|
||||
|
||||
|
||||
# Maps (latent_dim × 1 × 1) → (3 × image_size × image_size) in [-1, 1]
|
||||
class DCGANGenerator(nn.Module):
|
||||
"""Maps (latent_dim x 1 x 1) -> (3 x image_size x image_size) in [-1, 1]."""
|
||||
|
||||
def __init__(self, latent_dim: int = 100, ngf: int = 64, image_size: int = 64):
|
||||
super().__init__()
|
||||
n_up = _n_upsamples(image_size) # 64 -> 4 upsamples after the 1->4 init
|
||||
@@ -69,9 +59,8 @@ class DCGANGenerator(nn.Module):
|
||||
return self.net(z)
|
||||
|
||||
|
||||
# Maps (3 × image_size × image_size) → scalar logit (no sigmoid)
|
||||
class DCGANDiscriminator(nn.Module):
|
||||
"""Maps (3 x image_size x image_size) -> scalar logit (no sigmoid)."""
|
||||
|
||||
def __init__(self, ndf: int = 64, image_size: int = 64):
|
||||
super().__init__()
|
||||
n_down = _n_upsamples(image_size)
|
||||
@@ -97,6 +86,7 @@ class DCGANDiscriminator(nn.Module):
|
||||
return self.net(x).view(x.size(0))
|
||||
|
||||
|
||||
# Builds a (DCGANGenerator, DCGANDiscriminator) pair from cfg
|
||||
def _build(cfg: dict):
|
||||
image_size = cfg.get("image_size", 64)
|
||||
return (
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
"""PatchGAN discriminator for Phase 3.3 (VQGAN-lite adversarial training).
|
||||
|
||||
Outputs a spatial patch map instead of a single scalar — each patch
|
||||
predicts real/fake independently. Loss is the mean over all patches.
|
||||
|
||||
Not registered in the model registry; instantiated inside train_vae
|
||||
when lambda_adversarial > 0.
|
||||
"""
|
||||
# PatchGAN discriminator for Phase 3.3 adversarial training: outputs a spatial patch logit map, not a scalar.
|
||||
# Not in the model registry — instantiated directly inside train_vae when lambda_adversarial > 0.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -17,14 +11,9 @@ def _init_weights(m):
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
# Stride-2 + stride-1 conv chain → spatial patch logit map; supports image_size ∈ {64, 128}
|
||||
# 64×64 → 6×6 map (70×70 receptive field); 128×128 adds an extra stride-2 layer. InstanceNorm except first layer.
|
||||
class PatchGANDiscriminator(nn.Module):
|
||||
"""Stride-2 + stride-1 convolution chain → spatial patch logit map.
|
||||
|
||||
Supports image_size ∈ {64, 128}. For 64×64 input the final map is 6×6
|
||||
(70×70 receptive field). For 128×128 an extra stride-2 layer is added.
|
||||
InstanceNorm everywhere except the first layer.
|
||||
"""
|
||||
|
||||
def __init__(self, ndf: int = 64, image_size: int = 64):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = [
|
||||
@@ -57,13 +46,13 @@ class PatchGANDiscriminator(nn.Module):
|
||||
return self.net(x) # (B, 1, H', W') — patch logit map
|
||||
|
||||
|
||||
# Hinge loss for the discriminator (Lim & Ye, 2017)
|
||||
def hinge_d_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Hinge loss for the discriminator (Lim & Ye, 2017)."""
|
||||
loss_real = torch.mean(torch.relu(1.0 - real_logits))
|
||||
loss_fake = torch.mean(torch.relu(1.0 + fake_logits))
|
||||
return 0.5 * (loss_real + loss_fake)
|
||||
|
||||
|
||||
# Generator hinge loss — maximises D(fake)
|
||||
def hinge_g_loss(fake_logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Generator hinge loss — maximise D(fake)."""
|
||||
return -torch.mean(fake_logits)
|
||||
|
||||
@@ -1,13 +1,5 @@
|
||||
"""Time-conditioned U-Net for DDPM (Phase 4).
|
||||
|
||||
Architecture follows Ho et al. (2020) with options from Nichol & Dhariwal (2021):
|
||||
- Sinusoidal time embedding → MLP → added to every ResBlock
|
||||
- GroupNorm (32 groups) + SiLU activations throughout
|
||||
- Self-attention at configurable spatial resolutions
|
||||
- Upsample(nearest) + Conv in the decoder — no checkerboard artefacts
|
||||
|
||||
Registered as kind="ddpm".
|
||||
"""
|
||||
# Time-conditioned U-Net for DDPM (Phase 4), following Ho et al. (2020) with Nichol & Dhariwal (2021) options.
|
||||
# Sinusoidal time embedding → MLP → injected additively into every ResBlock; GroupNorm(32) + SiLU throughout.
|
||||
import math
|
||||
|
||||
import torch
|
||||
@@ -37,9 +29,8 @@ class SinusoidalPosEmb(nn.Module):
|
||||
|
||||
# ── Core building blocks ──────────────────────────────────────────────────────
|
||||
|
||||
# ResNet block with additive time-embedding injection after the first conv
|
||||
class ResBlock(nn.Module):
|
||||
"""ResNet block with time-embedding injection (additive, after first conv)."""
|
||||
|
||||
def __init__(self, in_ch: int, out_ch: int, t_emb_dim: int, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.norm1 = nn.GroupNorm(_GN, in_ch)
|
||||
@@ -57,9 +48,8 @@ class ResBlock(nn.Module):
|
||||
return h + self.skip(x)
|
||||
|
||||
|
||||
# Single-head self-attention with GroupNorm pre-norm and residual
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Single-head self-attention with GroupNorm pre-norm and residual."""
|
||||
|
||||
def __init__(self, ch: int):
|
||||
super().__init__()
|
||||
self.norm = nn.GroupNorm(_GN, ch)
|
||||
@@ -155,17 +145,9 @@ class UpBlock(nn.Module):
|
||||
|
||||
# ── U-Net ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Time-conditioned U-Net; image_size must be a power-of-two (64 or 128 recommended)
|
||||
# ch_mult sets channel multipliers per level; attn_resolutions controls where attention is inserted
|
||||
class UNet(nn.Module):
|
||||
"""Time-conditioned U-Net.
|
||||
|
||||
image_size — must be a power-of-two; 64 or 128 recommended.
|
||||
base_ch — base channel count (128 for phases 4.1–4.3, 192 for 4.4).
|
||||
ch_mult — channel multipliers per resolution level.
|
||||
attn_resolutions — spatial resolutions at which attention is inserted.
|
||||
num_res_blocks — ResBlocks per level (in both down and up paths).
|
||||
dropout — applied inside every ResBlock.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int = 64,
|
||||
@@ -265,6 +247,7 @@ class UNet(nn.Module):
|
||||
return self.out_conv(F.silu(self.out_norm(x)))
|
||||
|
||||
|
||||
# Builds a UNet from cfg for DDPM training
|
||||
def _build(cfg: dict):
|
||||
return UNet(
|
||||
image_size = cfg.get("image_size", 64),
|
||||
|
||||
@@ -1,13 +1,5 @@
|
||||
"""Convolutional VAE for Phase 3.
|
||||
|
||||
Encoder uses stride-2 Conv → flatten → linear (μ, log σ²).
|
||||
Decoder uses Linear → Upsample(nearest) + Conv to avoid ConvTranspose2d
|
||||
checkerboard artefacts.
|
||||
|
||||
Registered as kind="vae". The run.py dispatcher passes the model to
|
||||
train_vae(), which internally builds perceptual loss and PatchGAN when
|
||||
the corresponding lambdas are non-zero.
|
||||
"""
|
||||
# Convolutional VAE for Phase 3. Encoder: stride-2 Conv → flatten → linear (μ, log σ²).
|
||||
# Decoder: Linear → Upsample(nearest) + Conv — avoids ConvTranspose2d checkerboard artefacts.
|
||||
import math
|
||||
|
||||
import torch
|
||||
@@ -31,8 +23,8 @@ def _norm(channels: int) -> nn.GroupNorm:
|
||||
return nn.GroupNorm(8, channels)
|
||||
|
||||
|
||||
# Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard artefacts
|
||||
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
|
||||
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode="nearest"),
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
||||
@@ -41,20 +33,15 @@ def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
|
||||
)
|
||||
|
||||
|
||||
# Convolutional VAE; image_size must be a power-of-two ≥ 32
|
||||
# Spatial bottleneck is always at 4×4 — encoder and decoder scale stride-2 steps accordingly
|
||||
class VAE(nn.Module):
|
||||
"""Convolutional VAE. image_size must be a power-of-two ≥ 32.
|
||||
|
||||
Spatial bottleneck is always at 4×4 regardless of image_size —
|
||||
the encoder and decoder scale the number of stride-2 steps accordingly.
|
||||
"""
|
||||
|
||||
def __init__(self, latent_dim: int = 256, ngf: int = 64, image_size: int = 64):
|
||||
super().__init__()
|
||||
if image_size < 32 or (image_size & (image_size - 1)):
|
||||
raise ValueError(f"image_size must be a power-of-two ≥ 32, got {image_size}")
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.image_size = image_size
|
||||
|
||||
n_down = int(math.log2(image_size)) - 2 # steps from image_size to 4×4
|
||||
# 64 → n_down=4: 64→32→16→8→4
|
||||
@@ -115,19 +102,20 @@ class VAE(nn.Module):
|
||||
h = self.fc_dec(z).view(z.size(0), self._dec_ch, 4, 4)
|
||||
return self.decoder(h)
|
||||
|
||||
# Returns (reconstruction, mu, log_var)
|
||||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Returns (reconstruction, mu, log_var)."""
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
return self.decode(z), mu, log_var
|
||||
|
||||
# Samples n images by drawing z ~ N(0, I) and decoding
|
||||
@torch.no_grad()
|
||||
def sample(self, n: int, device) -> torch.Tensor:
|
||||
"""Sample n images by drawing z ~ N(0, I)."""
|
||||
z = torch.randn(n, self.latent_dim, device=device)
|
||||
return self.decode(z)
|
||||
|
||||
|
||||
# Builds a VAE from cfg
|
||||
def _build(cfg: dict):
|
||||
return VAE(
|
||||
latent_dim=cfg.get("latent_dim", 256),
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""WGAN-GP variants.
|
||||
|
||||
wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only.
|
||||
wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic.
|
||||
"""
|
||||
# WGAN-GP variants.
|
||||
# wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only.
|
||||
# wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic.
|
||||
import math
|
||||
|
||||
import torch
|
||||
@@ -23,9 +21,8 @@ def _sn(module):
|
||||
return nn.utils.spectral_norm(module)
|
||||
|
||||
|
||||
# SAGAN-style self-attention: gamma-gated residual, dot-products scaled by mid^-0.5
|
||||
class SelfAttention(nn.Module):
|
||||
"""SAGAN-style self-attention."""
|
||||
|
||||
def __init__(self, in_ch: int):
|
||||
super().__init__()
|
||||
mid = max(in_ch // 8, 1)
|
||||
@@ -48,13 +45,8 @@ class SelfAttention(nn.Module):
|
||||
# Phase 2.2 — basic WGAN-GP (BatchNorm in G, InstanceNorm in D, 64×64 only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Maps (latent_dim, 1, 1) → (3, 64, 64) in [-1, 1]; BatchNorm in G is safe under WGAN-GP (constraint targets the critic)
|
||||
class WGANBasicGenerator(nn.Module):
|
||||
"""Maps (latent_dim, 1, 1) -> (3, 64, 64) in [-1, 1].
|
||||
|
||||
Same channel structure as DCGAN. BatchNorm in generator is fine because
|
||||
WGAN-GP's constraint targets the critic, not the generator.
|
||||
"""
|
||||
|
||||
def __init__(self, latent_dim: int = 128, ngf: int = 64):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@@ -80,11 +72,8 @@ class WGANBasicGenerator(nn.Module):
|
||||
return self.net(z)
|
||||
|
||||
|
||||
# WGAN-GP critic (64×64); InstanceNorm instead of BatchNorm — BatchNorm breaks the per-sample Lipschitz constraint
|
||||
class WGANBasicCritic(nn.Module):
|
||||
"""WGAN-GP critic (64×64). InstanceNorm instead of BatchNorm — BatchNorm
|
||||
breaks the per-sample Lipschitz constraint the gradient penalty enforces.
|
||||
"""
|
||||
|
||||
def __init__(self, ndf: int = 64):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@@ -116,21 +105,14 @@ class WGANBasicCritic(nn.Module):
|
||||
# Phase 2.3 / 2.4 — advanced WGAN-GP (GroupNorm, SpectralNorm, attention)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# GroupNorm generator with SAGAN self-attention; supports image_size ∈ {64, 128}
|
||||
# Stem always 1×1 → 4×4 → 8×8 → 16×16; attention at 16×16, and at 32×32 for 128×128
|
||||
class WGANGenerator(nn.Module):
|
||||
"""GroupNorm generator with SAGAN self-attention.
|
||||
|
||||
Supports image_size ∈ {64, 128}.
|
||||
Stem is always 1×1 → 4×4 → 8×8 → 16×16 (ngf×8 → ngf×4 → ngf×2 channels).
|
||||
Attention at 16×16 always; additional attention at 32×32 for 128×128.
|
||||
"""
|
||||
|
||||
def __init__(self, latent_dim: int = 128, ngf: int = 128, image_size: int = 64):
|
||||
super().__init__()
|
||||
if image_size not in (64, 128):
|
||||
raise ValueError(f"WGANGenerator supports image_size 64 or 128, got {image_size}")
|
||||
|
||||
self._image_size = image_size
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
# 1×1 → 4×4
|
||||
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
|
||||
@@ -179,20 +161,14 @@ class WGANGenerator(nn.Module):
|
||||
return self.tail(h)
|
||||
|
||||
|
||||
# SpectralNorm critic with SAGAN self-attention; supports image_size ∈ {64, 128}
|
||||
# Attention at 16×16 always; additional attention at 32×32 for 128×128
|
||||
class WGANCritic(nn.Module):
|
||||
"""SpectralNorm critic with SAGAN self-attention.
|
||||
|
||||
Supports image_size ∈ {64, 128}.
|
||||
Attention at 16×16 always; additional attention at 32×32 for 128×128.
|
||||
"""
|
||||
|
||||
def __init__(self, ndf: int = 128, image_size: int = 64):
|
||||
super().__init__()
|
||||
if image_size not in (64, 128):
|
||||
raise ValueError(f"WGANCritic supports image_size 64 or 128, got {image_size}")
|
||||
|
||||
self._image_size = image_size
|
||||
|
||||
if image_size == 64:
|
||||
# Head: 64→32 (ndf//2)
|
||||
self.head = nn.Sequential(
|
||||
@@ -242,6 +218,7 @@ class WGANCritic(nn.Module):
|
||||
return self.tail(h).view(x.size(0))
|
||||
|
||||
|
||||
# Builds a (WGANBasicGenerator, WGANBasicCritic) pair from cfg
|
||||
def _build_basic(cfg: dict):
|
||||
return (
|
||||
WGANBasicGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 64)),
|
||||
@@ -249,6 +226,7 @@ def _build_basic(cfg: dict):
|
||||
)
|
||||
|
||||
|
||||
# Builds a (WGANGenerator, WGANCritic) pair from cfg
|
||||
def _build(cfg: dict):
|
||||
image_size = cfg.get("image_size", 64)
|
||||
return (
|
||||
|
||||
@@ -1,13 +1,5 @@
|
||||
"""Gaussian diffusion utilities for Phase 4 (DDPM).
|
||||
|
||||
Provides noise schedules, the forward (noising) process, training loss,
|
||||
and DDIM deterministic sampling (Song et al., 2020).
|
||||
|
||||
Convention: alpha_bars is a 1-D tensor of length T, where alpha_bars[t]
|
||||
= ᾱ_{t+1} in 1-indexed notation. Timestep t used in the training loop
|
||||
is a 0-indexed integer in [0, T). At t=0 the image is almost clean
|
||||
(ᾱ ≈ 1 − β_1); at t=T−1 the image is almost pure noise (ᾱ ≈ 0).
|
||||
"""
|
||||
# Gaussian diffusion utilities for Phase 4 (DDPM): noise schedules, forward process, training loss, DDIM sampling.
|
||||
# alpha_bars[t] = ᾱ_{t+1} (0-indexed); t=0 is near-clean (ᾱ ≈ 1−β₁), t=T−1 is near-pure-noise (ᾱ ≈ 0).
|
||||
import math
|
||||
|
||||
import torch
|
||||
@@ -16,13 +8,13 @@ import torch.nn.functional as F
|
||||
|
||||
# ── Noise schedules ──────────────────────────────────────────────────────────
|
||||
|
||||
# Ho et al. (2020) linear schedule
|
||||
def linear_betas(T: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor:
|
||||
"""Ho et al. (2020) linear schedule."""
|
||||
return torch.linspace(beta_start, beta_end, T)
|
||||
|
||||
|
||||
# Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t
|
||||
def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor:
|
||||
"""Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t."""
|
||||
t = torch.linspace(0, T, T + 1)
|
||||
f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
|
||||
alpha_bar = f / f[0]
|
||||
@@ -30,20 +22,20 @@ def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor:
|
||||
return betas.clamp(max=0.999)
|
||||
|
||||
|
||||
# Cumulative product of (1 − β), shape (T,)
|
||||
def make_alpha_bars(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""Cumulative product of (1 − β), shape (T,)."""
|
||||
return (1.0 - betas).cumprod(0)
|
||||
|
||||
|
||||
# ── Forward process ──────────────────────────────────────────────────────────
|
||||
|
||||
# Adds noise to x0 at timestep t; returns (x_t, noise)
|
||||
def q_sample(
|
||||
x0: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
alpha_bars: torch.Tensor,
|
||||
noise: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Add noise to x0 at timestep t. Returns (x_t, noise)."""
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
ab = alpha_bars[t].to(x0.device)[:, None, None, None]
|
||||
@@ -53,6 +45,7 @@ def q_sample(
|
||||
|
||||
# ── Training loss ────────────────────────────────────────────────────────────
|
||||
|
||||
# MSE between model prediction and target; pred_type="eps" targets noise ε, pred_type="v" targets v = √ᾱ·ε − √(1−ᾱ)·x0
|
||||
def diffusion_loss(
|
||||
model,
|
||||
x0: torch.Tensor,
|
||||
@@ -60,11 +53,6 @@ def diffusion_loss(
|
||||
alpha_bars: torch.Tensor,
|
||||
pred_type: str = "eps",
|
||||
) -> torch.Tensor:
|
||||
"""MSE on the model's prediction vs the true target.
|
||||
|
||||
pred_type="eps" → target is the added noise ε (Ho et al.)
|
||||
pred_type="v" → target is v = √ᾱ·ε − √(1−ᾱ)·x0 (Salimans & Ho)
|
||||
"""
|
||||
x_t, noise = q_sample(x0, t, alpha_bars)
|
||||
pred = model(x_t, t)
|
||||
|
||||
@@ -80,6 +68,8 @@ def diffusion_loss(
|
||||
# ── DDIM deterministic sampling ───────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
# Generates n images via DDIM (eta=0, deterministic); batches internally to avoid OOM
|
||||
# Returns tensor shape (n, 3, image_size, image_size) in [-1, 1]
|
||||
def ddim_sample(
|
||||
model,
|
||||
n: int,
|
||||
@@ -90,11 +80,6 @@ def ddim_sample(
|
||||
device: str = "cuda",
|
||||
batch_size: int = 32,
|
||||
) -> torch.Tensor:
|
||||
"""Generate n images via DDIM (eta=0, deterministic).
|
||||
|
||||
Batches internally to avoid OOM when n is large.
|
||||
Returns tensor shape (n, 3, image_size, image_size) in [-1, 1].
|
||||
"""
|
||||
model.eval()
|
||||
T = len(alpha_bars)
|
||||
|
||||
|
||||
@@ -3,13 +3,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Exponential moving average of model weights; shadow copy updated after each optimizer step
|
||||
# Sample from ema.model at eval time, never from the training model directly
|
||||
class EMA:
|
||||
"""Exponential moving average of model weights.
|
||||
|
||||
Maintains a shadow copy of the model. Call update() after each
|
||||
optimizer step. Sample from ema.model, never from the training model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, decay: float = 0.9999):
|
||||
self.decay = decay
|
||||
self.model = copy.deepcopy(model).eval()
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
"""FID evaluation helper.
|
||||
|
||||
Computes Fréchet Inception Distance between a fixed set of real images
|
||||
and a batch of generated images. Real images are stored as a tensor on CPU
|
||||
and moved to device only during evaluation — this avoids re-reading disk
|
||||
every call while keeping GPU memory free between evaluations.
|
||||
"""
|
||||
# FID evaluation helper: caches real images as a CPU tensor to avoid re-reading disk on every eval call.
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
|
||||
|
||||
# Computes FID between cached real images and a batch of generated images
|
||||
class FIDEvaluator:
|
||||
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
|
||||
num_workers: int = 2):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
self.n_real = n_real
|
||||
# Inception network loaded once here; compute() calls reset() between evaluations
|
||||
self._fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device)
|
||||
|
||||
# Cache real images as a CPU tensor ([-1, 1] range)
|
||||
imgs_list = []
|
||||
@@ -27,24 +24,20 @@ class FIDEvaluator:
|
||||
real = torch.cat(imgs_list)[:n_real]
|
||||
self._real = real # stored on CPU, shape (N, 3, H, W) in [-1, 1]
|
||||
|
||||
# Computes FID score; fake_imgs should be float in [-1, 1], shape (N, 3, H, W), N ≥ 2048 for reliability
|
||||
@torch.no_grad()
|
||||
def compute(self, fake_imgs: torch.Tensor) -> float:
|
||||
"""Compute FID score.
|
||||
|
||||
fake_imgs: float tensor in [-1, 1], shape (N, 3, H, W).
|
||||
N should be at least 2048 for a reliable score.
|
||||
"""
|
||||
fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device)
|
||||
self._fid.reset()
|
||||
|
||||
# Feed real images in batches
|
||||
for i in range(0, self._real.size(0), 256):
|
||||
batch = (self._real[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
|
||||
fid.update(batch, real=True)
|
||||
self._fid.update(batch, real=True)
|
||||
|
||||
# Feed fake images in batches
|
||||
fake = fake_imgs.cpu()
|
||||
for i in range(0, fake.size(0), 256):
|
||||
batch = (fake[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
|
||||
fid.update(batch, real=False)
|
||||
self._fid.update(batch, real=False)
|
||||
|
||||
return float(fid.compute())
|
||||
return float(self._fid.compute())
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
"""Extended generation quality metrics for Phase 5 cross-family comparison.
|
||||
|
||||
IS — Inception Score (Salimans et al., 2016): measures sample quality × diversity.
|
||||
LPIPS — average pairwise learned perceptual distance: measures sample diversity alone.
|
||||
|
||||
Both functions accept float tensors in [-1, 1].
|
||||
"""
|
||||
# Extended generation quality metrics for Phase 5 cross-family comparison.
|
||||
# IS (Salimans et al., 2016): quality × diversity. LPIPS diversity: pairwise perceptual distance.
|
||||
# Both functions accept float tensors in [-1, 1].
|
||||
import torch
|
||||
from torchmetrics.image.inception import InceptionScore
|
||||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||
|
||||
|
||||
# Inception Score (mean ± std) over 10 splits; imgs in [-1, 1], N ≥ 2048 for reliability; returns (mean, std)
|
||||
def compute_is(
|
||||
imgs: torch.Tensor,
|
||||
device: str = "cuda",
|
||||
batch_size: int = 64,
|
||||
) -> tuple[float, float]:
|
||||
"""Inception Score (mean ± std) over 10 splits.
|
||||
|
||||
imgs: (N, 3, H, W) in [-1, 1]. N ≥ 2 048 for a reliable estimate.
|
||||
Returns (is_mean, is_std).
|
||||
"""
|
||||
metric = InceptionScore(normalize=True).to(device)
|
||||
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
|
||||
for i in range(0, len(imgs_01), batch_size):
|
||||
@@ -28,16 +20,13 @@ def compute_is(
|
||||
return float(mean), float(std)
|
||||
|
||||
|
||||
# Average pairwise LPIPS distance over n_pairs random (i ≠ j) pairs; higher = more diverse
|
||||
def compute_lpips_diversity(
|
||||
imgs: torch.Tensor,
|
||||
n_pairs: int = 200,
|
||||
device: str = "cuda",
|
||||
batch_size: int = 16,
|
||||
) -> float:
|
||||
"""Average pairwise LPIPS distance — higher means more diverse samples.
|
||||
|
||||
imgs: (N, 3, H, W) in [-1, 1]. Samples n_pairs random (i, j) pairs with i ≠ j.
|
||||
"""
|
||||
metric = LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True).to(device)
|
||||
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
|
||||
N = len(imgs_01)
|
||||
|
||||
@@ -1,26 +1,13 @@
|
||||
"""VGG-16 perceptual loss for Phase 3.2 and 3.3.
|
||||
|
||||
Extracts features at relu1_2, relu2_2, relu3_3 and returns the
|
||||
L1 distance in feature space. VGG weights are frozen.
|
||||
|
||||
Input convention: images in [-1, 1] — the loss converts internally to
|
||||
[0, 1] and then applies ImageNet normalisation before passing to VGG.
|
||||
"""
|
||||
# VGG-16 perceptual loss (Phase 3.2/3.3): L1 feature distance at relu1_2, relu2_2, relu3_3.
|
||||
# Expects images in [-1, 1]; converts internally to [0, 1] then ImageNet-normalises before VGG. Weights frozen.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as tv_models
|
||||
|
||||
|
||||
# L1 feature-matching loss at relu1_2 ([:4]), relu2_2 ([4:9]), relu3_3 ([9:16]) of VGG-16
|
||||
class PerceptualLoss(nn.Module):
|
||||
"""L1 feature-matching loss at three VGG-16 layers.
|
||||
|
||||
VGG-16 feature indices:
|
||||
relu1_2: features[:4] (before first maxpool)
|
||||
relu2_2: features[4:9] (before second maxpool)
|
||||
relu3_3: features[9:16] (before third maxpool)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
vgg = tv_models.vgg16(weights=tv_models.VGG16_Weights.IMAGENET1K_V1)
|
||||
@@ -39,13 +26,13 @@ class PerceptualLoss(nn.Module):
|
||||
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||
)
|
||||
|
||||
# Converts [-1, 1] → [0, 1] then applies ImageNet normalisation
|
||||
def _normalise(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert [-1, 1] → ImageNet-normalised [0, 1]."""
|
||||
x = x * 0.5 + 0.5 # → [0, 1]
|
||||
return (x - self.mean) / self.std
|
||||
|
||||
# L1 feature distance; real gradients are stopped — only fake trains
|
||||
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
|
||||
"""L1 feature distance. real gradients are stopped — only fake trains."""
|
||||
f = self._normalise(fake)
|
||||
r = self._normalise(real)
|
||||
|
||||
|
||||
@@ -21,6 +21,31 @@ else:
|
||||
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
|
||||
|
||||
|
||||
# Linear decay from decay_start to epochs; returns a multiplier clamped to [0, 1]
|
||||
def _linear_decay(decay_start: int, epochs: int):
|
||||
return lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))
|
||||
|
||||
|
||||
# Returns (new_best_fid, new_no_improve, should_stop); improved = new_best_fid < best_fid
|
||||
def _check_fid_plateau(fid_score: float, best_fid: float, no_improve: int, patience) -> tuple[float, int, bool]:
|
||||
if fid_score < best_fid:
|
||||
return fid_score, 0, False
|
||||
no_improve += 1
|
||||
if patience and no_improve >= patience:
|
||||
print(f"FID plateau — {no_improve} evals without improvement, stopping")
|
||||
return best_fid, no_improve, True
|
||||
return best_fid, no_improve, False
|
||||
|
||||
|
||||
# Generates n GAN fake images in batches of 64 using the EMA model
|
||||
@torch.no_grad()
|
||||
def _generate_gan_fakes(ema_model, n: int, latent_dim: int, device) -> torch.Tensor:
|
||||
return torch.cat([
|
||||
ema_model(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
for _ in range(n // 64 + 1)
|
||||
])[:n]
|
||||
|
||||
|
||||
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None:
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
with torch.no_grad():
|
||||
@@ -29,6 +54,7 @@ def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise:
|
||||
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||
|
||||
|
||||
# Vanilla DCGAN training loop (Phase 1): BCE loss, single G/D step per batch, no gradient penalty
|
||||
def train_dcgan(
|
||||
generator,
|
||||
discriminator,
|
||||
@@ -39,11 +65,6 @@ def train_dcgan(
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""Vanilla DCGAN training loop with BCE loss (Radford et al., 2015).
|
||||
|
||||
Used as the Phase 1 baseline for cheap pipeline ablations. No gradient
|
||||
penalty, no n_critic, single G/D step per batch.
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
generator = generator.to(device)
|
||||
discriminator = discriminator.to(device)
|
||||
@@ -79,6 +100,10 @@ def train_dcgan(
|
||||
|
||||
ema = EMA(generator, decay=ema_decay)
|
||||
|
||||
if hasattr(torch, "compile"):
|
||||
generator = torch.compile(generator)
|
||||
discriminator = torch.compile(discriminator)
|
||||
|
||||
# Fixed noise for consistent sample tracking across epochs
|
||||
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
|
||||
@@ -91,18 +116,22 @@ def train_dcgan(
|
||||
|
||||
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
|
||||
print(f"Device: {device} AMP: {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)}")
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched_g = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
sched_g = torch.optim.lr_scheduler.LambdaLR(opt_g, _linear_decay(decay_start, epochs))
|
||||
sched_d = torch.optim.lr_scheduler.LambdaLR(opt_d, _linear_decay(decay_start, epochs))
|
||||
|
||||
t_start = time.time()
|
||||
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
|
||||
fid_patience = cfg.get("fid_patience", None)
|
||||
no_improve = 0
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
if max_train_secs and time.time() - t_start >= max_train_secs:
|
||||
print(f"Time limit reached — stopping after epoch {epoch - 1}")
|
||||
break
|
||||
generator.train()
|
||||
discriminator.train()
|
||||
g_sum = d_sum = real_sum = fake_sum = 0.0
|
||||
@@ -160,19 +189,18 @@ def train_dcgan(
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fake_imgs = _generate_gan_fakes(ema.model, fid_n_real, latent_dim, device)
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
|
||||
if new_best < best_fid:
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
best_fid = new_best
|
||||
if plateau:
|
||||
break
|
||||
|
||||
sched_g.step()
|
||||
sched_d.step()
|
||||
@@ -184,8 +212,8 @@ def train_dcgan(
|
||||
return history
|
||||
|
||||
|
||||
# Two-sided gradient penalty (Gulrajani et al., 2017)
|
||||
def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor:
|
||||
"""Two-sided gradient penalty (Gulrajani et al., 2017)."""
|
||||
bsz = real.size(0)
|
||||
eps = torch.rand(bsz, 1, 1, 1, device=device)
|
||||
interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True)
|
||||
@@ -200,6 +228,8 @@ def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) ->
|
||||
return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean()
|
||||
|
||||
|
||||
# WGAN-GP training loop (Phase 2.2–2.4): gradient penalty replaces weight clipping
|
||||
# Critic runs in float32 for GP stability; AMP used only for the generator forward/backward
|
||||
def train_wgan(
|
||||
generator,
|
||||
critic,
|
||||
@@ -210,12 +240,6 @@ def train_wgan(
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""WGAN-GP training loop (Gulrajani et al., 2017).
|
||||
|
||||
Used for Phase 2.2–2.4. Gradient penalty replaces weight clipping.
|
||||
The critic runs in float32 to keep GP gradient computation numerically
|
||||
stable; AMP is used only for the generator forward/backward.
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
generator = generator.to(device)
|
||||
critic = critic.to(device)
|
||||
@@ -251,6 +275,10 @@ def train_wgan(
|
||||
|
||||
ema = EMA(generator, decay=ema_decay)
|
||||
|
||||
if hasattr(torch, "compile"):
|
||||
generator = torch.compile(generator)
|
||||
critic = torch.compile(critic)
|
||||
|
||||
# Fixed noise for consistent sample tracking across epochs
|
||||
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
|
||||
@@ -263,18 +291,22 @@ def train_wgan(
|
||||
|
||||
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
print(f"Device: {device} AMP (G only): {use_amp} Batches/epoch: {len(loader)} n_critic: {n_critic}")
|
||||
print(f"Device: {device} AMP (G only): {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)} n_critic: {n_critic}")
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched_g = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
sched_c = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_c, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
sched_g = torch.optim.lr_scheduler.LambdaLR(opt_g, _linear_decay(decay_start, epochs))
|
||||
sched_c = torch.optim.lr_scheduler.LambdaLR(opt_c, _linear_decay(decay_start, epochs))
|
||||
|
||||
t_start = time.time()
|
||||
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
|
||||
fid_patience = cfg.get("fid_patience", None)
|
||||
no_improve = 0
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
if max_train_secs and time.time() - t_start >= max_train_secs:
|
||||
print(f"Time limit reached — stopping after epoch {epoch - 1}")
|
||||
break
|
||||
generator.train()
|
||||
critic.train()
|
||||
g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0
|
||||
@@ -292,11 +324,11 @@ def train_wgan(
|
||||
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
|
||||
|
||||
real_f32 = real.float()
|
||||
fake_f32 = fake.float().detach()
|
||||
fake_f32 = fake.float() # already no_grad from torch.no_grad() context above
|
||||
|
||||
d_real = critic(real_f32)
|
||||
d_fake = critic(fake_f32)
|
||||
gp = _gradient_penalty(critic, real_f32, fake_f32.detach(), device)
|
||||
gp = _gradient_penalty(critic, real_f32, fake_f32, device)
|
||||
c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp
|
||||
c_loss.backward()
|
||||
opt_c.step()
|
||||
@@ -342,19 +374,18 @@ def train_wgan(
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fake_imgs = _generate_gan_fakes(ema.model, fid_n_real, latent_dim, device)
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
|
||||
if new_best < best_fid:
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
best_fid = new_best
|
||||
if plateau:
|
||||
break
|
||||
|
||||
sched_g.step()
|
||||
sched_c.step()
|
||||
@@ -370,6 +401,7 @@ def train_wgan(
|
||||
# Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN)
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Saves prior samples and a real-vs-reconstruction interleaved grid for the epoch
|
||||
def _save_vae_samples(
|
||||
vae,
|
||||
samples_dir: Path,
|
||||
@@ -379,7 +411,6 @@ def _save_vae_samples(
|
||||
fixed_real: torch.Tensor,
|
||||
device,
|
||||
) -> None:
|
||||
"""Save prior samples and a real-vs-reconstruction grid side by side."""
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
@@ -395,6 +426,8 @@ def _save_vae_samples(
|
||||
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
|
||||
|
||||
|
||||
# VAE training loop (Phase 3.1–3.3): lambda_perceptual > 0 adds VGG loss, lambda_adversarial > 0 adds PatchGAN
|
||||
# AMP disabled — float16 overflows on KL spikes causing unrecoverable NaN cascades; everything runs in float32
|
||||
def train_vae(
|
||||
vae,
|
||||
train_dataset,
|
||||
@@ -404,22 +437,6 @@ def train_vae(
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""VAE training loop covering Phase 3.1 – 3.3 and Phase 5.
|
||||
|
||||
Config toggles:
|
||||
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
|
||||
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
|
||||
|
||||
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
|
||||
|
||||
KL is computed as mean over latent dimensions (scale-invariant), so
|
||||
beta_kl is comparable across different latent_dim values.
|
||||
|
||||
AMP is intentionally disabled for VAE training — mixed-precision float16
|
||||
overflows when the KL divergence spikes, producing NaN cascades that
|
||||
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
|
||||
computation runs in float32.
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
vae = vae.to(device)
|
||||
|
||||
@@ -450,16 +467,13 @@ def train_vae(
|
||||
)
|
||||
|
||||
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
|
||||
# AMP disabled — float16 overflows on KL spikes, causing NaN cascades
|
||||
use_amp = False
|
||||
|
||||
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
|
||||
kl_warmup_epochs = max(1, epochs // 5)
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched_vae = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_vae, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||
sched_vae = torch.optim.lr_scheduler.LambdaLR(opt_vae, _linear_decay(decay_start, epochs))
|
||||
sched_d = None # set below if adversarial
|
||||
|
||||
# ── Optional components ───────────────────────────────────────────────
|
||||
@@ -479,12 +493,11 @@ def train_vae(
|
||||
image_size=cfg.get("image_size", 64),
|
||||
).to(device).float()
|
||||
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
||||
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||
sched_d = torch.optim.lr_scheduler.LambdaLR(opt_d, _linear_decay(decay_start, epochs))
|
||||
n_d = sum(p.numel() for p in patchgan.parameters())
|
||||
print(f"PatchGAN: {n_d:,} params")
|
||||
else:
|
||||
hinge_d_loss = hinge_g_loss = None # never called
|
||||
if hasattr(torch, "compile"):
|
||||
patchgan = torch.compile(patchgan)
|
||||
|
||||
# ── Fixed seeds for consistent visualisation ──────────────────────────
|
||||
fixed_z = torch.randn(16, latent_dim, device=device)
|
||||
@@ -494,6 +507,9 @@ def train_vae(
|
||||
|
||||
ema = EMA(vae, decay=ema_decay)
|
||||
|
||||
if hasattr(torch, "compile"):
|
||||
vae = torch.compile(vae)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
@@ -514,8 +530,14 @@ def train_vae(
|
||||
)
|
||||
|
||||
t_start = time.time()
|
||||
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
|
||||
fid_patience = cfg.get("fid_patience", None)
|
||||
no_improve = 0
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
if max_train_secs and time.time() - t_start >= max_train_secs:
|
||||
print(f"Time limit reached — stopping after epoch {epoch - 1}")
|
||||
break
|
||||
vae.train()
|
||||
if patchgan is not None:
|
||||
patchgan.train()
|
||||
@@ -614,10 +636,13 @@ def train_vae(
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
|
||||
if new_best < best_fid:
|
||||
torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
best_fid = new_best
|
||||
if plateau:
|
||||
break
|
||||
|
||||
sched_vae.step()
|
||||
if sched_d is not None:
|
||||
@@ -636,6 +661,7 @@ def train_vae(
|
||||
# Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider)
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# DDPM training loop (Phase 4.1–4.4): noise_schedule="linear"/"cosine", pred_type="eps"/"v"
|
||||
def train_ddpm(
|
||||
model,
|
||||
train_dataset,
|
||||
@@ -645,15 +671,6 @@ def train_ddpm(
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""DDPM training loop (Ho et al., 2020) covering Phase 4.1 – 4.4.
|
||||
|
||||
Config keys:
|
||||
noise_schedule — "linear" (4.1) or "cosine" (4.2+)
|
||||
pred_type — "eps" (4.1–4.2) or "v" (4.3+)
|
||||
T — diffusion timesteps (default 1000)
|
||||
base_ch / ch_mult / attn_resolutions — U-Net capacity (see unet.py)
|
||||
ddim_steps — DDIM steps for FID evaluation (default 100)
|
||||
"""
|
||||
from src.training.diffusion import (
|
||||
linear_betas, cosine_betas, make_alpha_bars,
|
||||
diffusion_loss, ddim_sample,
|
||||
@@ -694,8 +711,8 @@ def train_ddpm(
|
||||
|
||||
ema = EMA(model, decay=ema_decay)
|
||||
|
||||
# Fixed noise for sample visualisation (same latents across epochs)
|
||||
fixed_noise = torch.randn(16, 3, image_size, image_size, device=device)
|
||||
if hasattr(torch, "compile"):
|
||||
model = torch.compile(model)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -707,18 +724,23 @@ def train_ddpm(
|
||||
history = {"loss": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
print(
|
||||
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
|
||||
f"Device: {device} AMP: {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)}"
|
||||
f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}"
|
||||
)
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||
sched = torch.optim.lr_scheduler.LambdaLR(opt, _linear_decay(decay_start, epochs))
|
||||
|
||||
t_start = time.time()
|
||||
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
|
||||
fid_patience = cfg.get("fid_patience", None)
|
||||
no_improve = 0
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
if max_train_secs and time.time() - t_start >= max_train_secs:
|
||||
print(f"Time limit reached — stopping after epoch {epoch - 1}")
|
||||
break
|
||||
model.train()
|
||||
loss_sum = 0.0
|
||||
n_batches = 0
|
||||
@@ -749,7 +771,6 @@ def train_ddpm(
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
# Quick visualisation: denoise fixed_noise via DDIM
|
||||
imgs = ddim_sample(
|
||||
ema.model, 16, image_size, alpha_bars,
|
||||
n_steps=50, pred_type=pred_type, device=str(device), batch_size=16,
|
||||
@@ -768,10 +789,13 @@ def train_ddpm(
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
|
||||
if new_best < best_fid:
|
||||
torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
best_fid = new_best
|
||||
if plateau:
|
||||
break
|
||||
|
||||
sched.step()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user