diff --git a/generator/configs/phase5/p5_ddpm.json b/generator/configs/phase5/p5_ddpm.json index 600dbe9..b12169d 100644 --- a/generator/configs/phase5/p5_ddpm.json +++ b/generator/configs/phase5/p5_ddpm.json @@ -1,7 +1,7 @@ { "run_name": "p5_ddpm", "model": "ddpm", - "epochs": 200, + "epochs": 400, "data_dir": "cropped/generator", "sources": ["wiki"], "augment": "hflip", @@ -18,5 +18,7 @@ "ddim_steps": 100, "sample_interval": 10, "fid_interval": 25, - "fid_n_real": 5000 + "fid_n_real": 5000, + "max_train_hours": 24, + "fid_patience": 2 } diff --git a/generator/configs/phase5/p5_gan.json b/generator/configs/phase5/p5_gan.json index 68b35ea..ba78dd5 100644 --- a/generator/configs/phase5/p5_gan.json +++ b/generator/configs/phase5/p5_gan.json @@ -1,11 +1,11 @@ { "run_name": "p5_gan", "model": "wgan", - "epochs": 200, + "epochs": 400, "data_dir": "cropped/generator", "sources": ["wiki"], "augment": true, - "image_size": 128, + "image_size": 64, "latent_dim": 128, "ngf": 128, "ndf": 128, @@ -17,5 +17,7 @@ "gp_lambda": 10, "sample_interval": 10, "fid_interval": 25, - "fid_n_real": 5000 + "fid_n_real": 5000, + "max_train_hours": 24, + "fid_patience": 2 } diff --git a/generator/configs/phase5/p5_vae.json b/generator/configs/phase5/p5_vae.json index 4eaf91d..a62acd2 100644 --- a/generator/configs/phase5/p5_vae.json +++ b/generator/configs/phase5/p5_vae.json @@ -1,7 +1,7 @@ { "run_name": "p5_vae", "model": "vae", - "epochs": 200, + "epochs": 400, "data_dir": "cropped/generator", "sources": ["wiki"], "augment": "hflip", @@ -10,11 +10,13 @@ "ngf": 64, "lr": 1e-3, "lr_d": 1e-4, - "beta_kl": 0.0001, + "beta_kl": 0.25, "lambda_perceptual": 0.1, "lambda_adversarial": 0.1, "ndf_patch": 64, "sample_interval": 10, "fid_interval": 25, - "fid_n_real": 5000 + "fid_n_real": 5000, + "max_train_hours": 24, + "fid_patience": 2 } diff --git a/generator/src/data/dataset.py b/generator/src/data/dataset.py index ed88f62..c03f8e9 100644 --- a/generator/src/data/dataset.py +++ b/generator/src/data/dataset.py @@ -6,13 +6,8 @@ import torchvision.transforms as T from torch.utils.data import Dataset +# Unlabeled image dataset for generative model training; returns tensors only, no labels class GeneratorDataset(Dataset): - """Unlabeled image dataset for generative model training. - - Loads images from source subdirectories and returns tensors only — - no labels, since generation is unsupervised. - """ - def __init__(self, data_dir, sources=None, subsample=1.0, transform=None, seed=42): self.transform = transform self.samples = [] @@ -51,13 +46,8 @@ class GeneratorDataset(Dataset): return img +# Builds transform pipeline outputting [-1, 1]; augment=False (none), "hflip" (flip only), True (flip+rot+jitter) def get_transform(image_size: int, augment=False) -> T.Compose: - """Build transform for generator training. Output is in [-1, 1]. - - augment=False — no augmentation (for FID real-image sets) - augment="hflip" — horizontal flip only (recommended for VAE/DDPM) - augment=True — H-flip + rotation ±5° + mild color jitter (for GAN) - """ ops = [ T.Resize(image_size), T.CenterCrop(image_size), diff --git a/generator/src/models/__init__.py b/generator/src/models/__init__.py index cf8f908..4c4239a 100644 --- a/generator/src/models/__init__.py +++ b/generator/src/models/__init__.py @@ -1,18 +1,18 @@ from typing import Callable import torch.nn as nn +# Maps model name → (builder, kind); populated by each model module at import time _REGISTRY: dict[str, tuple[Callable, str]] = {} +# Called by each model module to advertise its name and kind to get_model def register(name: str, builder: Callable, *, kind: str) -> None: _REGISTRY[name] = (builder, kind) +# Returns (model_or_pair, kind) for the model named in cfg["model"] +# kind in {"dcgan", "wgan"} → (generator, critic) | kind in {"vae", "ddpm"} → model def get_model(cfg: dict) -> tuple: - """Return (model_or_pair, kind). - - kind="dcgan" -> (generator, discriminator) - """ name = cfg.get("model") entry = _REGISTRY.get(name) if entry is None: @@ -22,6 +22,7 @@ def get_model(cfg: dict) -> tuple: return builder(cfg), kind +# Importing each module triggers its register() calls from src.models import dcgan # noqa: E402, F401 from src.models import wgan # noqa: E402, F401 from src.models import vae # noqa: E402, F401 diff --git a/generator/src/models/dcgan.py b/generator/src/models/dcgan.py index 84a951b..b80213f 100644 --- a/generator/src/models/dcgan.py +++ b/generator/src/models/dcgan.py @@ -1,15 +1,6 @@ -"""Vanilla DCGAN (Radford et al., 2015). - -Used as the Phase 1 baseline for cheap pipeline ablations. Architecture is -intentionally minimal — BatchNorm in both networks, no spectral norm, no -attention, no gradient penalty. The whole point is to be the cheapest GAN -we can run, so 1A–1D pipeline deltas show up in FID quickly. - -Depth scales with image_size: each step doubles the spatial dimension, -starting from 4×4 after the first transposed conv. - 64 -> 5 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64) - 128 -> 6 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64 -> 128) -""" +# Vanilla DCGAN (Radford et al., 2015) — Phase 1 baseline for cheap pipeline ablations. +# Depth scales with image_size; each step doubles spatial size from a 4×4 stem. +# 64 → 5 upsample steps | 128 → 6 upsample steps import math import torch @@ -27,16 +18,15 @@ def _init_weights(m): nn.init.zeros_(m.bias) +# Number of 2× upsampling steps from the 4×4 stem to image_size def _n_upsamples(image_size: int) -> int: - """Number of 2x upsampling steps from 4x4 to image_size.""" if image_size < 8 or image_size & (image_size - 1): raise ValueError(f"image_size must be a power of two ≥ 8, got {image_size}") return int(math.log2(image_size)) - 2 # 64 -> 4, 128 -> 5 +# Maps (latent_dim × 1 × 1) → (3 × image_size × image_size) in [-1, 1] class DCGANGenerator(nn.Module): - """Maps (latent_dim x 1 x 1) -> (3 x image_size x image_size) in [-1, 1].""" - def __init__(self, latent_dim: int = 100, ngf: int = 64, image_size: int = 64): super().__init__() n_up = _n_upsamples(image_size) # 64 -> 4 upsamples after the 1->4 init @@ -69,9 +59,8 @@ class DCGANGenerator(nn.Module): return self.net(z) +# Maps (3 × image_size × image_size) → scalar logit (no sigmoid) class DCGANDiscriminator(nn.Module): - """Maps (3 x image_size x image_size) -> scalar logit (no sigmoid).""" - def __init__(self, ndf: int = 64, image_size: int = 64): super().__init__() n_down = _n_upsamples(image_size) @@ -97,6 +86,7 @@ class DCGANDiscriminator(nn.Module): return self.net(x).view(x.size(0)) +# Builds a (DCGANGenerator, DCGANDiscriminator) pair from cfg def _build(cfg: dict): image_size = cfg.get("image_size", 64) return ( diff --git a/generator/src/models/patchgan.py b/generator/src/models/patchgan.py index 9245416..4cf217c 100644 --- a/generator/src/models/patchgan.py +++ b/generator/src/models/patchgan.py @@ -1,11 +1,5 @@ -"""PatchGAN discriminator for Phase 3.3 (VQGAN-lite adversarial training). - -Outputs a spatial patch map instead of a single scalar — each patch -predicts real/fake independently. Loss is the mean over all patches. - -Not registered in the model registry; instantiated inside train_vae -when lambda_adversarial > 0. -""" +# PatchGAN discriminator for Phase 3.3 adversarial training: outputs a spatial patch logit map, not a scalar. +# Not in the model registry — instantiated directly inside train_vae when lambda_adversarial > 0. import torch import torch.nn as nn @@ -17,14 +11,9 @@ def _init_weights(m): nn.init.zeros_(m.bias) +# Stride-2 + stride-1 conv chain → spatial patch logit map; supports image_size ∈ {64, 128} +# 64×64 → 6×6 map (70×70 receptive field); 128×128 adds an extra stride-2 layer. InstanceNorm except first layer. class PatchGANDiscriminator(nn.Module): - """Stride-2 + stride-1 convolution chain → spatial patch logit map. - - Supports image_size ∈ {64, 128}. For 64×64 input the final map is 6×6 - (70×70 receptive field). For 128×128 an extra stride-2 layer is added. - InstanceNorm everywhere except the first layer. - """ - def __init__(self, ndf: int = 64, image_size: int = 64): super().__init__() layers: list[nn.Module] = [ @@ -57,13 +46,13 @@ class PatchGANDiscriminator(nn.Module): return self.net(x) # (B, 1, H', W') — patch logit map +# Hinge loss for the discriminator (Lim & Ye, 2017) def hinge_d_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor: - """Hinge loss for the discriminator (Lim & Ye, 2017).""" loss_real = torch.mean(torch.relu(1.0 - real_logits)) loss_fake = torch.mean(torch.relu(1.0 + fake_logits)) return 0.5 * (loss_real + loss_fake) +# Generator hinge loss — maximises D(fake) def hinge_g_loss(fake_logits: torch.Tensor) -> torch.Tensor: - """Generator hinge loss — maximise D(fake).""" return -torch.mean(fake_logits) diff --git a/generator/src/models/unet.py b/generator/src/models/unet.py index 8250bd8..9e21b58 100644 --- a/generator/src/models/unet.py +++ b/generator/src/models/unet.py @@ -1,13 +1,5 @@ -"""Time-conditioned U-Net for DDPM (Phase 4). - -Architecture follows Ho et al. (2020) with options from Nichol & Dhariwal (2021): - - Sinusoidal time embedding → MLP → added to every ResBlock - - GroupNorm (32 groups) + SiLU activations throughout - - Self-attention at configurable spatial resolutions - - Upsample(nearest) + Conv in the decoder — no checkerboard artefacts - -Registered as kind="ddpm". -""" +# Time-conditioned U-Net for DDPM (Phase 4), following Ho et al. (2020) with Nichol & Dhariwal (2021) options. +# Sinusoidal time embedding → MLP → injected additively into every ResBlock; GroupNorm(32) + SiLU throughout. import math import torch @@ -37,9 +29,8 @@ class SinusoidalPosEmb(nn.Module): # ── Core building blocks ────────────────────────────────────────────────────── +# ResNet block with additive time-embedding injection after the first conv class ResBlock(nn.Module): - """ResNet block with time-embedding injection (additive, after first conv).""" - def __init__(self, in_ch: int, out_ch: int, t_emb_dim: int, dropout: float = 0.1): super().__init__() self.norm1 = nn.GroupNorm(_GN, in_ch) @@ -57,9 +48,8 @@ class ResBlock(nn.Module): return h + self.skip(x) +# Single-head self-attention with GroupNorm pre-norm and residual class AttentionBlock(nn.Module): - """Single-head self-attention with GroupNorm pre-norm and residual.""" - def __init__(self, ch: int): super().__init__() self.norm = nn.GroupNorm(_GN, ch) @@ -155,17 +145,9 @@ class UpBlock(nn.Module): # ── U-Net ───────────────────────────────────────────────────────────────────── +# Time-conditioned U-Net; image_size must be a power-of-two (64 or 128 recommended) +# ch_mult sets channel multipliers per level; attn_resolutions controls where attention is inserted class UNet(nn.Module): - """Time-conditioned U-Net. - - image_size — must be a power-of-two; 64 or 128 recommended. - base_ch — base channel count (128 for phases 4.1–4.3, 192 for 4.4). - ch_mult — channel multipliers per resolution level. - attn_resolutions — spatial resolutions at which attention is inserted. - num_res_blocks — ResBlocks per level (in both down and up paths). - dropout — applied inside every ResBlock. - """ - def __init__( self, image_size: int = 64, @@ -265,6 +247,7 @@ class UNet(nn.Module): return self.out_conv(F.silu(self.out_norm(x))) +# Builds a UNet from cfg for DDPM training def _build(cfg: dict): return UNet( image_size = cfg.get("image_size", 64), diff --git a/generator/src/models/vae.py b/generator/src/models/vae.py index 126a08b..14e1788 100644 --- a/generator/src/models/vae.py +++ b/generator/src/models/vae.py @@ -1,13 +1,5 @@ -"""Convolutional VAE for Phase 3. - -Encoder uses stride-2 Conv → flatten → linear (μ, log σ²). -Decoder uses Linear → Upsample(nearest) + Conv to avoid ConvTranspose2d -checkerboard artefacts. - -Registered as kind="vae". The run.py dispatcher passes the model to -train_vae(), which internally builds perceptual loss and PatchGAN when -the corresponding lambdas are non-zero. -""" +# Convolutional VAE for Phase 3. Encoder: stride-2 Conv → flatten → linear (μ, log σ²). +# Decoder: Linear → Upsample(nearest) + Conv — avoids ConvTranspose2d checkerboard artefacts. import math import torch @@ -31,8 +23,8 @@ def _norm(channels: int) -> nn.GroupNorm: return nn.GroupNorm(8, channels) +# Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard artefacts def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential: - """Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard.""" return nn.Sequential( nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), @@ -41,20 +33,15 @@ def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential: ) +# Convolutional VAE; image_size must be a power-of-two ≥ 32 +# Spatial bottleneck is always at 4×4 — encoder and decoder scale stride-2 steps accordingly class VAE(nn.Module): - """Convolutional VAE. image_size must be a power-of-two ≥ 32. - - Spatial bottleneck is always at 4×4 regardless of image_size — - the encoder and decoder scale the number of stride-2 steps accordingly. - """ - def __init__(self, latent_dim: int = 256, ngf: int = 64, image_size: int = 64): super().__init__() if image_size < 32 or (image_size & (image_size - 1)): raise ValueError(f"image_size must be a power-of-two ≥ 32, got {image_size}") self.latent_dim = latent_dim - self.image_size = image_size n_down = int(math.log2(image_size)) - 2 # steps from image_size to 4×4 # 64 → n_down=4: 64→32→16→8→4 @@ -115,19 +102,20 @@ class VAE(nn.Module): h = self.fc_dec(z).view(z.size(0), self._dec_ch, 4, 4) return self.decoder(h) + # Returns (reconstruction, mu, log_var) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Returns (reconstruction, mu, log_var).""" mu, log_var = self.encode(x) z = self.reparameterize(mu, log_var) return self.decode(z), mu, log_var + # Samples n images by drawing z ~ N(0, I) and decoding @torch.no_grad() def sample(self, n: int, device) -> torch.Tensor: - """Sample n images by drawing z ~ N(0, I).""" z = torch.randn(n, self.latent_dim, device=device) return self.decode(z) +# Builds a VAE from cfg def _build(cfg: dict): return VAE( latent_dim=cfg.get("latent_dim", 256), diff --git a/generator/src/models/wgan.py b/generator/src/models/wgan.py index 2c61e22..d27059d 100644 --- a/generator/src/models/wgan.py +++ b/generator/src/models/wgan.py @@ -1,8 +1,6 @@ -"""WGAN-GP variants. - -wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only. -wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic. -""" +# WGAN-GP variants. +# wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only. +# wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic. import math import torch @@ -23,9 +21,8 @@ def _sn(module): return nn.utils.spectral_norm(module) +# SAGAN-style self-attention: gamma-gated residual, dot-products scaled by mid^-0.5 class SelfAttention(nn.Module): - """SAGAN-style self-attention.""" - def __init__(self, in_ch: int): super().__init__() mid = max(in_ch // 8, 1) @@ -48,13 +45,8 @@ class SelfAttention(nn.Module): # Phase 2.2 — basic WGAN-GP (BatchNorm in G, InstanceNorm in D, 64×64 only) # --------------------------------------------------------------------------- +# Maps (latent_dim, 1, 1) → (3, 64, 64) in [-1, 1]; BatchNorm in G is safe under WGAN-GP (constraint targets the critic) class WGANBasicGenerator(nn.Module): - """Maps (latent_dim, 1, 1) -> (3, 64, 64) in [-1, 1]. - - Same channel structure as DCGAN. BatchNorm in generator is fine because - WGAN-GP's constraint targets the critic, not the generator. - """ - def __init__(self, latent_dim: int = 128, ngf: int = 64): super().__init__() self.net = nn.Sequential( @@ -80,11 +72,8 @@ class WGANBasicGenerator(nn.Module): return self.net(z) +# WGAN-GP critic (64×64); InstanceNorm instead of BatchNorm — BatchNorm breaks the per-sample Lipschitz constraint class WGANBasicCritic(nn.Module): - """WGAN-GP critic (64×64). InstanceNorm instead of BatchNorm — BatchNorm - breaks the per-sample Lipschitz constraint the gradient penalty enforces. - """ - def __init__(self, ndf: int = 64): super().__init__() self.net = nn.Sequential( @@ -116,21 +105,14 @@ class WGANBasicCritic(nn.Module): # Phase 2.3 / 2.4 — advanced WGAN-GP (GroupNorm, SpectralNorm, attention) # --------------------------------------------------------------------------- +# GroupNorm generator with SAGAN self-attention; supports image_size ∈ {64, 128} +# Stem always 1×1 → 4×4 → 8×8 → 16×16; attention at 16×16, and at 32×32 for 128×128 class WGANGenerator(nn.Module): - """GroupNorm generator with SAGAN self-attention. - - Supports image_size ∈ {64, 128}. - Stem is always 1×1 → 4×4 → 8×8 → 16×16 (ngf×8 → ngf×4 → ngf×2 channels). - Attention at 16×16 always; additional attention at 32×32 for 128×128. - """ - def __init__(self, latent_dim: int = 128, ngf: int = 128, image_size: int = 64): super().__init__() if image_size not in (64, 128): raise ValueError(f"WGANGenerator supports image_size 64 or 128, got {image_size}") - self._image_size = image_size - self.stem = nn.Sequential( # 1×1 → 4×4 nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False), @@ -179,20 +161,14 @@ class WGANGenerator(nn.Module): return self.tail(h) +# SpectralNorm critic with SAGAN self-attention; supports image_size ∈ {64, 128} +# Attention at 16×16 always; additional attention at 32×32 for 128×128 class WGANCritic(nn.Module): - """SpectralNorm critic with SAGAN self-attention. - - Supports image_size ∈ {64, 128}. - Attention at 16×16 always; additional attention at 32×32 for 128×128. - """ - def __init__(self, ndf: int = 128, image_size: int = 64): super().__init__() if image_size not in (64, 128): raise ValueError(f"WGANCritic supports image_size 64 or 128, got {image_size}") - self._image_size = image_size - if image_size == 64: # Head: 64→32 (ndf//2) self.head = nn.Sequential( @@ -242,6 +218,7 @@ class WGANCritic(nn.Module): return self.tail(h).view(x.size(0)) +# Builds a (WGANBasicGenerator, WGANBasicCritic) pair from cfg def _build_basic(cfg: dict): return ( WGANBasicGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 64)), @@ -249,6 +226,7 @@ def _build_basic(cfg: dict): ) +# Builds a (WGANGenerator, WGANCritic) pair from cfg def _build(cfg: dict): image_size = cfg.get("image_size", 64) return ( diff --git a/generator/src/training/diffusion.py b/generator/src/training/diffusion.py index b737d6e..32f22a1 100644 --- a/generator/src/training/diffusion.py +++ b/generator/src/training/diffusion.py @@ -1,13 +1,5 @@ -"""Gaussian diffusion utilities for Phase 4 (DDPM). - -Provides noise schedules, the forward (noising) process, training loss, -and DDIM deterministic sampling (Song et al., 2020). - -Convention: alpha_bars is a 1-D tensor of length T, where alpha_bars[t] -= ᾱ_{t+1} in 1-indexed notation. Timestep t used in the training loop -is a 0-indexed integer in [0, T). At t=0 the image is almost clean -(ᾱ ≈ 1 − β_1); at t=T−1 the image is almost pure noise (ᾱ ≈ 0). -""" +# Gaussian diffusion utilities for Phase 4 (DDPM): noise schedules, forward process, training loss, DDIM sampling. +# alpha_bars[t] = ᾱ_{t+1} (0-indexed); t=0 is near-clean (ᾱ ≈ 1−β₁), t=T−1 is near-pure-noise (ᾱ ≈ 0). import math import torch @@ -16,13 +8,13 @@ import torch.nn.functional as F # ── Noise schedules ────────────────────────────────────────────────────────── +# Ho et al. (2020) linear schedule def linear_betas(T: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor: - """Ho et al. (2020) linear schedule.""" return torch.linspace(beta_start, beta_end, T) +# Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor: - """Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t.""" t = torch.linspace(0, T, T + 1) f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2 alpha_bar = f / f[0] @@ -30,20 +22,20 @@ def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor: return betas.clamp(max=0.999) +# Cumulative product of (1 − β), shape (T,) def make_alpha_bars(betas: torch.Tensor) -> torch.Tensor: - """Cumulative product of (1 − β), shape (T,).""" return (1.0 - betas).cumprod(0) # ── Forward process ────────────────────────────────────────────────────────── +# Adds noise to x0 at timestep t; returns (x_t, noise) def q_sample( x0: torch.Tensor, t: torch.Tensor, alpha_bars: torch.Tensor, noise: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Add noise to x0 at timestep t. Returns (x_t, noise).""" if noise is None: noise = torch.randn_like(x0) ab = alpha_bars[t].to(x0.device)[:, None, None, None] @@ -53,6 +45,7 @@ def q_sample( # ── Training loss ──────────────────────────────────────────────────────────── +# MSE between model prediction and target; pred_type="eps" targets noise ε, pred_type="v" targets v = √ᾱ·ε − √(1−ᾱ)·x0 def diffusion_loss( model, x0: torch.Tensor, @@ -60,11 +53,6 @@ def diffusion_loss( alpha_bars: torch.Tensor, pred_type: str = "eps", ) -> torch.Tensor: - """MSE on the model's prediction vs the true target. - - pred_type="eps" → target is the added noise ε (Ho et al.) - pred_type="v" → target is v = √ᾱ·ε − √(1−ᾱ)·x0 (Salimans & Ho) - """ x_t, noise = q_sample(x0, t, alpha_bars) pred = model(x_t, t) @@ -80,6 +68,8 @@ def diffusion_loss( # ── DDIM deterministic sampling ─────────────────────────────────────────────── @torch.no_grad() +# Generates n images via DDIM (eta=0, deterministic); batches internally to avoid OOM +# Returns tensor shape (n, 3, image_size, image_size) in [-1, 1] def ddim_sample( model, n: int, @@ -90,11 +80,6 @@ def ddim_sample( device: str = "cuda", batch_size: int = 32, ) -> torch.Tensor: - """Generate n images via DDIM (eta=0, deterministic). - - Batches internally to avoid OOM when n is large. - Returns tensor shape (n, 3, image_size, image_size) in [-1, 1]. - """ model.eval() T = len(alpha_bars) diff --git a/generator/src/training/ema.py b/generator/src/training/ema.py index 97f5bd8..804048b 100644 --- a/generator/src/training/ema.py +++ b/generator/src/training/ema.py @@ -3,13 +3,9 @@ import torch import torch.nn as nn +# Exponential moving average of model weights; shadow copy updated after each optimizer step +# Sample from ema.model at eval time, never from the training model directly class EMA: - """Exponential moving average of model weights. - - Maintains a shadow copy of the model. Call update() after each - optimizer step. Sample from ema.model, never from the training model. - """ - def __init__(self, model: nn.Module, decay: float = 0.9999): self.decay = decay self.model = copy.deepcopy(model).eval() diff --git a/generator/src/training/fid.py b/generator/src/training/fid.py index 206f0b7..a7c9b8a 100644 --- a/generator/src/training/fid.py +++ b/generator/src/training/fid.py @@ -1,20 +1,17 @@ -"""FID evaluation helper. - -Computes Fréchet Inception Distance between a fixed set of real images -and a batch of generated images. Real images are stored as a tensor on CPU -and moved to device only during evaluation — this avoids re-reading disk -every call while keeping GPU memory free between evaluations. -""" +# FID evaluation helper: caches real images as a CPU tensor to avoid re-reading disk on every eval call. import torch from torch.utils.data import DataLoader from torchmetrics.image.fid import FrechetInceptionDistance +# Computes FID between cached real images and a batch of generated images class FIDEvaluator: def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda", num_workers: int = 2): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.n_real = n_real + # Inception network loaded once here; compute() calls reset() between evaluations + self._fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device) # Cache real images as a CPU tensor ([-1, 1] range) imgs_list = [] @@ -27,24 +24,20 @@ class FIDEvaluator: real = torch.cat(imgs_list)[:n_real] self._real = real # stored on CPU, shape (N, 3, H, W) in [-1, 1] + # Computes FID score; fake_imgs should be float in [-1, 1], shape (N, 3, H, W), N ≥ 2048 for reliability @torch.no_grad() def compute(self, fake_imgs: torch.Tensor) -> float: - """Compute FID score. - - fake_imgs: float tensor in [-1, 1], shape (N, 3, H, W). - N should be at least 2048 for a reliable score. - """ - fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device) + self._fid.reset() # Feed real images in batches for i in range(0, self._real.size(0), 256): batch = (self._real[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device) - fid.update(batch, real=True) + self._fid.update(batch, real=True) # Feed fake images in batches fake = fake_imgs.cpu() for i in range(0, fake.size(0), 256): batch = (fake[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device) - fid.update(batch, real=False) + self._fid.update(batch, real=False) - return float(fid.compute()) + return float(self._fid.compute()) diff --git a/generator/src/training/metrics.py b/generator/src/training/metrics.py index 1bca7df..dcd8912 100644 --- a/generator/src/training/metrics.py +++ b/generator/src/training/metrics.py @@ -1,25 +1,17 @@ -"""Extended generation quality metrics for Phase 5 cross-family comparison. - -IS — Inception Score (Salimans et al., 2016): measures sample quality × diversity. -LPIPS — average pairwise learned perceptual distance: measures sample diversity alone. - -Both functions accept float tensors in [-1, 1]. -""" +# Extended generation quality metrics for Phase 5 cross-family comparison. +# IS (Salimans et al., 2016): quality × diversity. LPIPS diversity: pairwise perceptual distance. +# Both functions accept float tensors in [-1, 1]. import torch from torchmetrics.image.inception import InceptionScore from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +# Inception Score (mean ± std) over 10 splits; imgs in [-1, 1], N ≥ 2048 for reliability; returns (mean, std) def compute_is( imgs: torch.Tensor, device: str = "cuda", batch_size: int = 64, ) -> tuple[float, float]: - """Inception Score (mean ± std) over 10 splits. - - imgs: (N, 3, H, W) in [-1, 1]. N ≥ 2 048 for a reliable estimate. - Returns (is_mean, is_std). - """ metric = InceptionScore(normalize=True).to(device) imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5) for i in range(0, len(imgs_01), batch_size): @@ -28,16 +20,13 @@ def compute_is( return float(mean), float(std) +# Average pairwise LPIPS distance over n_pairs random (i ≠ j) pairs; higher = more diverse def compute_lpips_diversity( imgs: torch.Tensor, n_pairs: int = 200, device: str = "cuda", batch_size: int = 16, ) -> float: - """Average pairwise LPIPS distance — higher means more diverse samples. - - imgs: (N, 3, H, W) in [-1, 1]. Samples n_pairs random (i, j) pairs with i ≠ j. - """ metric = LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True).to(device) imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5) N = len(imgs_01) diff --git a/generator/src/training/perceptual.py b/generator/src/training/perceptual.py index 51ab7a0..eef469f 100644 --- a/generator/src/training/perceptual.py +++ b/generator/src/training/perceptual.py @@ -1,26 +1,13 @@ -"""VGG-16 perceptual loss for Phase 3.2 and 3.3. - -Extracts features at relu1_2, relu2_2, relu3_3 and returns the -L1 distance in feature space. VGG weights are frozen. - -Input convention: images in [-1, 1] — the loss converts internally to -[0, 1] and then applies ImageNet normalisation before passing to VGG. -""" +# VGG-16 perceptual loss (Phase 3.2/3.3): L1 feature distance at relu1_2, relu2_2, relu3_3. +# Expects images in [-1, 1]; converts internally to [0, 1] then ImageNet-normalises before VGG. Weights frozen. import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as tv_models +# L1 feature-matching loss at relu1_2 ([:4]), relu2_2 ([4:9]), relu3_3 ([9:16]) of VGG-16 class PerceptualLoss(nn.Module): - """L1 feature-matching loss at three VGG-16 layers. - - VGG-16 feature indices: - relu1_2: features[:4] (before first maxpool) - relu2_2: features[4:9] (before second maxpool) - relu3_3: features[9:16] (before third maxpool) - """ - def __init__(self): super().__init__() vgg = tv_models.vgg16(weights=tv_models.VGG16_Weights.IMAGENET1K_V1) @@ -39,13 +26,13 @@ class PerceptualLoss(nn.Module): "std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) + # Converts [-1, 1] → [0, 1] then applies ImageNet normalisation def _normalise(self, x: torch.Tensor) -> torch.Tensor: - """Convert [-1, 1] → ImageNet-normalised [0, 1].""" x = x * 0.5 + 0.5 # → [0, 1] return (x - self.mean) / self.std + # L1 feature distance; real gradients are stopped — only fake trains def forward(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor: - """L1 feature distance. real gradients are stopped — only fake trains.""" f = self._normalise(fake) r = self._normalise(real) diff --git a/generator/src/training/trainer.py b/generator/src/training/trainer.py index 3f3b7f3..10ad7fb 100644 --- a/generator/src/training/trainer.py +++ b/generator/src/training/trainer.py @@ -21,6 +21,31 @@ else: _autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw) +# Linear decay from decay_start to epochs; returns a multiplier clamped to [0, 1] +def _linear_decay(decay_start: int, epochs: int): + return lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)) + + +# Returns (new_best_fid, new_no_improve, should_stop); improved = new_best_fid < best_fid +def _check_fid_plateau(fid_score: float, best_fid: float, no_improve: int, patience) -> tuple[float, int, bool]: + if fid_score < best_fid: + return fid_score, 0, False + no_improve += 1 + if patience and no_improve >= patience: + print(f"FID plateau — {no_improve} evals without improvement, stopping") + return best_fid, no_improve, True + return best_fid, no_improve, False + + +# Generates n GAN fake images in batches of 64 using the EMA model +@torch.no_grad() +def _generate_gan_fakes(ema_model, n: int, latent_dim: int, device) -> torch.Tensor: + return torch.cat([ + ema_model(torch.randn(64, latent_dim, 1, 1, device=device)) + for _ in range(n // 64 + 1) + ])[:n] + + def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None: samples_dir.mkdir(parents=True, exist_ok=True) with torch.no_grad(): @@ -29,6 +54,7 @@ def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4) +# Vanilla DCGAN training loop (Phase 1): BCE loss, single G/D step per batch, no gradient penalty def train_dcgan( generator, discriminator, @@ -39,11 +65,6 @@ def train_dcgan( run_name: str, device: str = "cuda", ) -> dict: - """Vanilla DCGAN training loop with BCE loss (Radford et al., 2015). - - Used as the Phase 1 baseline for cheap pipeline ablations. No gradient - penalty, no n_critic, single G/D step per batch. - """ device = torch.device(device if torch.cuda.is_available() else "cpu") generator = generator.to(device) discriminator = discriminator.to(device) @@ -79,6 +100,10 @@ def train_dcgan( ema = EMA(generator, decay=ema_decay) + if hasattr(torch, "compile"): + generator = torch.compile(generator) + discriminator = torch.compile(discriminator) + # Fixed noise for consistent sample tracking across epochs fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device) @@ -91,18 +116,22 @@ def train_dcgan( history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}} best_fid = float("inf") - print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}") + print(f"Device: {device} AMP: {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)}") # Linear LR decay from epoch epochs//2 to epochs decay_start = epochs // 2 - sched_g = torch.optim.lr_scheduler.LambdaLR( - opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) - sched_d = torch.optim.lr_scheduler.LambdaLR( - opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) + sched_g = torch.optim.lr_scheduler.LambdaLR(opt_g, _linear_decay(decay_start, epochs)) + sched_d = torch.optim.lr_scheduler.LambdaLR(opt_d, _linear_decay(decay_start, epochs)) t_start = time.time() + max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None + fid_patience = cfg.get("fid_patience", None) + no_improve = 0 for epoch in range(1, epochs + 1): + if max_train_secs and time.time() - t_start >= max_train_secs: + print(f"Time limit reached — stopping after epoch {epoch - 1}") + break generator.train() discriminator.train() g_sum = d_sum = real_sum = fake_sum = 0.0 @@ -160,19 +189,18 @@ def train_dcgan( if epoch % fid_interval == 0: ema.model.eval() - with torch.no_grad(): - fake_imgs = torch.cat([ - ema.model(torch.randn(64, latent_dim, 1, 1, device=device)) - for _ in range(fid_n_real // 64 + 1) - ])[:fid_n_real] + fake_imgs = _generate_gan_fakes(ema.model, fid_n_real, latent_dim, device) fid_score = fid_eval.compute(fake_imgs) history["fid"][epoch] = fid_score print(f" FID @ epoch {epoch}: {fid_score:.2f}") - if fid_score < best_fid: - best_fid = fid_score + new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience) + if new_best < best_fid: torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") + best_fid = new_best + if plateau: + break sched_g.step() sched_d.step() @@ -184,8 +212,8 @@ def train_dcgan( return history +# Two-sided gradient penalty (Gulrajani et al., 2017) def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor: - """Two-sided gradient penalty (Gulrajani et al., 2017).""" bsz = real.size(0) eps = torch.rand(bsz, 1, 1, 1, device=device) interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True) @@ -200,6 +228,8 @@ def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean() +# WGAN-GP training loop (Phase 2.2–2.4): gradient penalty replaces weight clipping +# Critic runs in float32 for GP stability; AMP used only for the generator forward/backward def train_wgan( generator, critic, @@ -210,12 +240,6 @@ def train_wgan( run_name: str, device: str = "cuda", ) -> dict: - """WGAN-GP training loop (Gulrajani et al., 2017). - - Used for Phase 2.2–2.4. Gradient penalty replaces weight clipping. - The critic runs in float32 to keep GP gradient computation numerically - stable; AMP is used only for the generator forward/backward. - """ device = torch.device(device if torch.cuda.is_available() else "cpu") generator = generator.to(device) critic = critic.to(device) @@ -251,6 +275,10 @@ def train_wgan( ema = EMA(generator, decay=ema_decay) + if hasattr(torch, "compile"): + generator = torch.compile(generator) + critic = torch.compile(critic) + # Fixed noise for consistent sample tracking across epochs fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device) @@ -263,18 +291,22 @@ def train_wgan( history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}} best_fid = float("inf") - print(f"Device: {device} AMP (G only): {use_amp} Batches/epoch: {len(loader)} n_critic: {n_critic}") + print(f"Device: {device} AMP (G only): {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)} n_critic: {n_critic}") # Linear LR decay from epoch epochs//2 to epochs decay_start = epochs // 2 - sched_g = torch.optim.lr_scheduler.LambdaLR( - opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) - sched_c = torch.optim.lr_scheduler.LambdaLR( - opt_c, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start))) + sched_g = torch.optim.lr_scheduler.LambdaLR(opt_g, _linear_decay(decay_start, epochs)) + sched_c = torch.optim.lr_scheduler.LambdaLR(opt_c, _linear_decay(decay_start, epochs)) t_start = time.time() + max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None + fid_patience = cfg.get("fid_patience", None) + no_improve = 0 for epoch in range(1, epochs + 1): + if max_train_secs and time.time() - t_start >= max_train_secs: + print(f"Time limit reached — stopping after epoch {epoch - 1}") + break generator.train() critic.train() g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0 @@ -292,11 +324,11 @@ def train_wgan( fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device)) real_f32 = real.float() - fake_f32 = fake.float().detach() + fake_f32 = fake.float() # already no_grad from torch.no_grad() context above d_real = critic(real_f32) d_fake = critic(fake_f32) - gp = _gradient_penalty(critic, real_f32, fake_f32.detach(), device) + gp = _gradient_penalty(critic, real_f32, fake_f32, device) c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp c_loss.backward() opt_c.step() @@ -342,19 +374,18 @@ def train_wgan( if epoch % fid_interval == 0: ema.model.eval() - with torch.no_grad(): - fake_imgs = torch.cat([ - ema.model(torch.randn(64, latent_dim, 1, 1, device=device)) - for _ in range(fid_n_real // 64 + 1) - ])[:fid_n_real] + fake_imgs = _generate_gan_fakes(ema.model, fid_n_real, latent_dim, device) fid_score = fid_eval.compute(fake_imgs) history["fid"][epoch] = fid_score print(f" FID @ epoch {epoch}: {fid_score:.2f}") - if fid_score < best_fid: - best_fid = fid_score + new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience) + if new_best < best_fid: torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") + best_fid = new_best + if plateau: + break sched_g.step() sched_c.step() @@ -370,6 +401,7 @@ def train_wgan( # Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN) # ──────────────────────────────────────────────────────────────────────────── +# Saves prior samples and a real-vs-reconstruction interleaved grid for the epoch def _save_vae_samples( vae, samples_dir: Path, @@ -379,7 +411,6 @@ def _save_vae_samples( fixed_real: torch.Tensor, device, ) -> None: - """Save prior samples and a real-vs-reconstruction grid side by side.""" samples_dir.mkdir(parents=True, exist_ok=True) vae.eval() with torch.no_grad(): @@ -395,6 +426,8 @@ def _save_vae_samples( save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4) +# VAE training loop (Phase 3.1–3.3): lambda_perceptual > 0 adds VGG loss, lambda_adversarial > 0 adds PatchGAN +# AMP disabled — float16 overflows on KL spikes causing unrecoverable NaN cascades; everything runs in float32 def train_vae( vae, train_dataset, @@ -404,22 +437,6 @@ def train_vae( run_name: str, device: str = "cuda", ) -> dict: - """VAE training loop covering Phase 3.1 – 3.3 and Phase 5. - - Config toggles: - lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+) - lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3) - - Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl - - KL is computed as mean over latent dimensions (scale-invariant), so - beta_kl is comparable across different latent_dim values. - - AMP is intentionally disabled for VAE training — mixed-precision float16 - overflows when the KL divergence spikes, producing NaN cascades that - corrupt the model irrecoverably. All VAE + perceptual + PatchGAN - computation runs in float32. - """ device = torch.device(device if torch.cuda.is_available() else "cpu") vae = vae.to(device) @@ -450,16 +467,13 @@ def train_vae( ) opt_vae = torch.optim.Adam(vae.parameters(), lr=lr) - # AMP disabled — float16 overflows on KL spikes, causing NaN cascades - use_amp = False # KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training kl_warmup_epochs = max(1, epochs // 5) # Linear LR decay from epoch epochs//2 to epochs decay_start = epochs // 2 - sched_vae = torch.optim.lr_scheduler.LambdaLR( - opt_vae, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) + sched_vae = torch.optim.lr_scheduler.LambdaLR(opt_vae, _linear_decay(decay_start, epochs)) sched_d = None # set below if adversarial # ── Optional components ─────────────────────────────────────────────── @@ -479,12 +493,11 @@ def train_vae( image_size=cfg.get("image_size", 64), ).to(device).float() opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999)) - sched_d = torch.optim.lr_scheduler.LambdaLR( - opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) + sched_d = torch.optim.lr_scheduler.LambdaLR(opt_d, _linear_decay(decay_start, epochs)) n_d = sum(p.numel() for p in patchgan.parameters()) print(f"PatchGAN: {n_d:,} params") - else: - hinge_d_loss = hinge_g_loss = None # never called + if hasattr(torch, "compile"): + patchgan = torch.compile(patchgan) # ── Fixed seeds for consistent visualisation ────────────────────────── fixed_z = torch.randn(16, latent_dim, device=device) @@ -494,6 +507,9 @@ def train_vae( ema = EMA(vae, decay=ema_decay) + if hasattr(torch, "compile"): + vae = torch.compile(vae) + save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) samples_dir = save_dir.parent / "samples" / run_name @@ -514,8 +530,14 @@ def train_vae( ) t_start = time.time() + max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None + fid_patience = cfg.get("fid_patience", None) + no_improve = 0 for epoch in range(1, epochs + 1): + if max_train_secs and time.time() - t_start >= max_train_secs: + print(f"Time limit reached — stopping after epoch {epoch - 1}") + break vae.train() if patchgan is not None: patchgan.train() @@ -614,10 +636,13 @@ def train_vae( history["fid"][epoch] = fid_score print(f" FID @ epoch {epoch}: {fid_score:.2f}") - if fid_score < best_fid: - best_fid = fid_score + new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience) + if new_best < best_fid: torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") + best_fid = new_best + if plateau: + break sched_vae.step() if sched_d is not None: @@ -636,6 +661,7 @@ def train_vae( # Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider) # ──────────────────────────────────────────────────────────────────────────── +# DDPM training loop (Phase 4.1–4.4): noise_schedule="linear"/"cosine", pred_type="eps"/"v" def train_ddpm( model, train_dataset, @@ -645,15 +671,6 @@ def train_ddpm( run_name: str, device: str = "cuda", ) -> dict: - """DDPM training loop (Ho et al., 2020) covering Phase 4.1 – 4.4. - - Config keys: - noise_schedule — "linear" (4.1) or "cosine" (4.2+) - pred_type — "eps" (4.1–4.2) or "v" (4.3+) - T — diffusion timesteps (default 1000) - base_ch / ch_mult / attn_resolutions — U-Net capacity (see unet.py) - ddim_steps — DDIM steps for FID evaluation (default 100) - """ from src.training.diffusion import ( linear_betas, cosine_betas, make_alpha_bars, diffusion_loss, ddim_sample, @@ -694,8 +711,8 @@ def train_ddpm( ema = EMA(model, decay=ema_decay) - # Fixed noise for sample visualisation (same latents across epochs) - fixed_noise = torch.randn(16, 3, image_size, image_size, device=device) + if hasattr(torch, "compile"): + model = torch.compile(model) save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) @@ -707,18 +724,23 @@ def train_ddpm( history = {"loss": [], "fid": {}} best_fid = float("inf") print( - f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}" + f"Device: {device} AMP: {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)}" f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}" ) # Linear LR decay from epoch epochs//2 to epochs decay_start = epochs // 2 - sched = torch.optim.lr_scheduler.LambdaLR( - opt, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) + sched = torch.optim.lr_scheduler.LambdaLR(opt, _linear_decay(decay_start, epochs)) t_start = time.time() + max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None + fid_patience = cfg.get("fid_patience", None) + no_improve = 0 for epoch in range(1, epochs + 1): + if max_train_secs and time.time() - t_start >= max_train_secs: + print(f"Time limit reached — stopping after epoch {epoch - 1}") + break model.train() loss_sum = 0.0 n_batches = 0 @@ -749,7 +771,6 @@ def train_ddpm( samples_dir.mkdir(parents=True, exist_ok=True) ema.model.eval() with torch.no_grad(): - # Quick visualisation: denoise fixed_noise via DDIM imgs = ddim_sample( ema.model, 16, image_size, alpha_bars, n_steps=50, pred_type=pred_type, device=str(device), batch_size=16, @@ -768,10 +789,13 @@ def train_ddpm( history["fid"][epoch] = fid_score print(f" FID @ epoch {epoch}: {fid_score:.2f}") - if fid_score < best_fid: - best_fid = fid_score + new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience) + if new_best < best_fid: torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt") torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt") + best_fid = new_best + if plateau: + break sched.step()