Phase 5 preparation
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"run_name": "p5_ddpm",
|
"run_name": "p5_ddpm",
|
||||||
"model": "ddpm",
|
"model": "ddpm",
|
||||||
"epochs": 200,
|
"epochs": 400,
|
||||||
"data_dir": "cropped/generator",
|
"data_dir": "cropped/generator",
|
||||||
"sources": ["wiki"],
|
"sources": ["wiki"],
|
||||||
"augment": "hflip",
|
"augment": "hflip",
|
||||||
@@ -18,5 +18,7 @@
|
|||||||
"ddim_steps": 100,
|
"ddim_steps": 100,
|
||||||
"sample_interval": 10,
|
"sample_interval": 10,
|
||||||
"fid_interval": 25,
|
"fid_interval": 25,
|
||||||
"fid_n_real": 5000
|
"fid_n_real": 5000,
|
||||||
|
"max_train_hours": 24,
|
||||||
|
"fid_patience": 2
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
{
|
{
|
||||||
"run_name": "p5_gan",
|
"run_name": "p5_gan",
|
||||||
"model": "wgan",
|
"model": "wgan",
|
||||||
"epochs": 200,
|
"epochs": 400,
|
||||||
"data_dir": "cropped/generator",
|
"data_dir": "cropped/generator",
|
||||||
"sources": ["wiki"],
|
"sources": ["wiki"],
|
||||||
"augment": true,
|
"augment": true,
|
||||||
"image_size": 128,
|
"image_size": 64,
|
||||||
"latent_dim": 128,
|
"latent_dim": 128,
|
||||||
"ngf": 128,
|
"ngf": 128,
|
||||||
"ndf": 128,
|
"ndf": 128,
|
||||||
@@ -17,5 +17,7 @@
|
|||||||
"gp_lambda": 10,
|
"gp_lambda": 10,
|
||||||
"sample_interval": 10,
|
"sample_interval": 10,
|
||||||
"fid_interval": 25,
|
"fid_interval": 25,
|
||||||
"fid_n_real": 5000
|
"fid_n_real": 5000,
|
||||||
|
"max_train_hours": 24,
|
||||||
|
"fid_patience": 2
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"run_name": "p5_vae",
|
"run_name": "p5_vae",
|
||||||
"model": "vae",
|
"model": "vae",
|
||||||
"epochs": 200,
|
"epochs": 400,
|
||||||
"data_dir": "cropped/generator",
|
"data_dir": "cropped/generator",
|
||||||
"sources": ["wiki"],
|
"sources": ["wiki"],
|
||||||
"augment": "hflip",
|
"augment": "hflip",
|
||||||
@@ -10,11 +10,13 @@
|
|||||||
"ngf": 64,
|
"ngf": 64,
|
||||||
"lr": 1e-3,
|
"lr": 1e-3,
|
||||||
"lr_d": 1e-4,
|
"lr_d": 1e-4,
|
||||||
"beta_kl": 0.0001,
|
"beta_kl": 0.25,
|
||||||
"lambda_perceptual": 0.1,
|
"lambda_perceptual": 0.1,
|
||||||
"lambda_adversarial": 0.1,
|
"lambda_adversarial": 0.1,
|
||||||
"ndf_patch": 64,
|
"ndf_patch": 64,
|
||||||
"sample_interval": 10,
|
"sample_interval": 10,
|
||||||
"fid_interval": 25,
|
"fid_interval": 25,
|
||||||
"fid_n_real": 5000
|
"fid_n_real": 5000,
|
||||||
|
"max_train_hours": 24,
|
||||||
|
"fid_patience": 2
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,13 +6,8 @@ import torchvision.transforms as T
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
# Unlabeled image dataset for generative model training; returns tensors only, no labels
|
||||||
class GeneratorDataset(Dataset):
|
class GeneratorDataset(Dataset):
|
||||||
"""Unlabeled image dataset for generative model training.
|
|
||||||
|
|
||||||
Loads images from source subdirectories and returns tensors only —
|
|
||||||
no labels, since generation is unsupervised.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data_dir, sources=None, subsample=1.0, transform=None, seed=42):
|
def __init__(self, data_dir, sources=None, subsample=1.0, transform=None, seed=42):
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.samples = []
|
self.samples = []
|
||||||
@@ -51,13 +46,8 @@ class GeneratorDataset(Dataset):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
# Builds transform pipeline outputting [-1, 1]; augment=False (none), "hflip" (flip only), True (flip+rot+jitter)
|
||||||
def get_transform(image_size: int, augment=False) -> T.Compose:
|
def get_transform(image_size: int, augment=False) -> T.Compose:
|
||||||
"""Build transform for generator training. Output is in [-1, 1].
|
|
||||||
|
|
||||||
augment=False — no augmentation (for FID real-image sets)
|
|
||||||
augment="hflip" — horizontal flip only (recommended for VAE/DDPM)
|
|
||||||
augment=True — H-flip + rotation ±5° + mild color jitter (for GAN)
|
|
||||||
"""
|
|
||||||
ops = [
|
ops = [
|
||||||
T.Resize(image_size),
|
T.Resize(image_size),
|
||||||
T.CenterCrop(image_size),
|
T.CenterCrop(image_size),
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Maps model name → (builder, kind); populated by each model module at import time
|
||||||
_REGISTRY: dict[str, tuple[Callable, str]] = {}
|
_REGISTRY: dict[str, tuple[Callable, str]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# Called by each model module to advertise its name and kind to get_model
|
||||||
def register(name: str, builder: Callable, *, kind: str) -> None:
|
def register(name: str, builder: Callable, *, kind: str) -> None:
|
||||||
_REGISTRY[name] = (builder, kind)
|
_REGISTRY[name] = (builder, kind)
|
||||||
|
|
||||||
|
|
||||||
|
# Returns (model_or_pair, kind) for the model named in cfg["model"]
|
||||||
|
# kind in {"dcgan", "wgan"} → (generator, critic) | kind in {"vae", "ddpm"} → model
|
||||||
def get_model(cfg: dict) -> tuple:
|
def get_model(cfg: dict) -> tuple:
|
||||||
"""Return (model_or_pair, kind).
|
|
||||||
|
|
||||||
kind="dcgan" -> (generator, discriminator)
|
|
||||||
"""
|
|
||||||
name = cfg.get("model")
|
name = cfg.get("model")
|
||||||
entry = _REGISTRY.get(name)
|
entry = _REGISTRY.get(name)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
@@ -22,6 +22,7 @@ def get_model(cfg: dict) -> tuple:
|
|||||||
return builder(cfg), kind
|
return builder(cfg), kind
|
||||||
|
|
||||||
|
|
||||||
|
# Importing each module triggers its register() calls
|
||||||
from src.models import dcgan # noqa: E402, F401
|
from src.models import dcgan # noqa: E402, F401
|
||||||
from src.models import wgan # noqa: E402, F401
|
from src.models import wgan # noqa: E402, F401
|
||||||
from src.models import vae # noqa: E402, F401
|
from src.models import vae # noqa: E402, F401
|
||||||
|
|||||||
@@ -1,15 +1,6 @@
|
|||||||
"""Vanilla DCGAN (Radford et al., 2015).
|
# Vanilla DCGAN (Radford et al., 2015) — Phase 1 baseline for cheap pipeline ablations.
|
||||||
|
# Depth scales with image_size; each step doubles spatial size from a 4×4 stem.
|
||||||
Used as the Phase 1 baseline for cheap pipeline ablations. Architecture is
|
# 64 → 5 upsample steps | 128 → 6 upsample steps
|
||||||
intentionally minimal — BatchNorm in both networks, no spectral norm, no
|
|
||||||
attention, no gradient penalty. The whole point is to be the cheapest GAN
|
|
||||||
we can run, so 1A–1D pipeline deltas show up in FID quickly.
|
|
||||||
|
|
||||||
Depth scales with image_size: each step doubles the spatial dimension,
|
|
||||||
starting from 4×4 after the first transposed conv.
|
|
||||||
64 -> 5 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64)
|
|
||||||
128 -> 6 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64 -> 128)
|
|
||||||
"""
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -27,16 +18,15 @@ def _init_weights(m):
|
|||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
# Number of 2× upsampling steps from the 4×4 stem to image_size
|
||||||
def _n_upsamples(image_size: int) -> int:
|
def _n_upsamples(image_size: int) -> int:
|
||||||
"""Number of 2x upsampling steps from 4x4 to image_size."""
|
|
||||||
if image_size < 8 or image_size & (image_size - 1):
|
if image_size < 8 or image_size & (image_size - 1):
|
||||||
raise ValueError(f"image_size must be a power of two ≥ 8, got {image_size}")
|
raise ValueError(f"image_size must be a power of two ≥ 8, got {image_size}")
|
||||||
return int(math.log2(image_size)) - 2 # 64 -> 4, 128 -> 5
|
return int(math.log2(image_size)) - 2 # 64 -> 4, 128 -> 5
|
||||||
|
|
||||||
|
|
||||||
|
# Maps (latent_dim × 1 × 1) → (3 × image_size × image_size) in [-1, 1]
|
||||||
class DCGANGenerator(nn.Module):
|
class DCGANGenerator(nn.Module):
|
||||||
"""Maps (latent_dim x 1 x 1) -> (3 x image_size x image_size) in [-1, 1]."""
|
|
||||||
|
|
||||||
def __init__(self, latent_dim: int = 100, ngf: int = 64, image_size: int = 64):
|
def __init__(self, latent_dim: int = 100, ngf: int = 64, image_size: int = 64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
n_up = _n_upsamples(image_size) # 64 -> 4 upsamples after the 1->4 init
|
n_up = _n_upsamples(image_size) # 64 -> 4 upsamples after the 1->4 init
|
||||||
@@ -69,9 +59,8 @@ class DCGANGenerator(nn.Module):
|
|||||||
return self.net(z)
|
return self.net(z)
|
||||||
|
|
||||||
|
|
||||||
|
# Maps (3 × image_size × image_size) → scalar logit (no sigmoid)
|
||||||
class DCGANDiscriminator(nn.Module):
|
class DCGANDiscriminator(nn.Module):
|
||||||
"""Maps (3 x image_size x image_size) -> scalar logit (no sigmoid)."""
|
|
||||||
|
|
||||||
def __init__(self, ndf: int = 64, image_size: int = 64):
|
def __init__(self, ndf: int = 64, image_size: int = 64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
n_down = _n_upsamples(image_size)
|
n_down = _n_upsamples(image_size)
|
||||||
@@ -97,6 +86,7 @@ class DCGANDiscriminator(nn.Module):
|
|||||||
return self.net(x).view(x.size(0))
|
return self.net(x).view(x.size(0))
|
||||||
|
|
||||||
|
|
||||||
|
# Builds a (DCGANGenerator, DCGANDiscriminator) pair from cfg
|
||||||
def _build(cfg: dict):
|
def _build(cfg: dict):
|
||||||
image_size = cfg.get("image_size", 64)
|
image_size = cfg.get("image_size", 64)
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -1,11 +1,5 @@
|
|||||||
"""PatchGAN discriminator for Phase 3.3 (VQGAN-lite adversarial training).
|
# PatchGAN discriminator for Phase 3.3 adversarial training: outputs a spatial patch logit map, not a scalar.
|
||||||
|
# Not in the model registry — instantiated directly inside train_vae when lambda_adversarial > 0.
|
||||||
Outputs a spatial patch map instead of a single scalar — each patch
|
|
||||||
predicts real/fake independently. Loss is the mean over all patches.
|
|
||||||
|
|
||||||
Not registered in the model registry; instantiated inside train_vae
|
|
||||||
when lambda_adversarial > 0.
|
|
||||||
"""
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@@ -17,14 +11,9 @@ def _init_weights(m):
|
|||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
# Stride-2 + stride-1 conv chain → spatial patch logit map; supports image_size ∈ {64, 128}
|
||||||
|
# 64×64 → 6×6 map (70×70 receptive field); 128×128 adds an extra stride-2 layer. InstanceNorm except first layer.
|
||||||
class PatchGANDiscriminator(nn.Module):
|
class PatchGANDiscriminator(nn.Module):
|
||||||
"""Stride-2 + stride-1 convolution chain → spatial patch logit map.
|
|
||||||
|
|
||||||
Supports image_size ∈ {64, 128}. For 64×64 input the final map is 6×6
|
|
||||||
(70×70 receptive field). For 128×128 an extra stride-2 layer is added.
|
|
||||||
InstanceNorm everywhere except the first layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, ndf: int = 64, image_size: int = 64):
|
def __init__(self, ndf: int = 64, image_size: int = 64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
layers: list[nn.Module] = [
|
layers: list[nn.Module] = [
|
||||||
@@ -57,13 +46,13 @@ class PatchGANDiscriminator(nn.Module):
|
|||||||
return self.net(x) # (B, 1, H', W') — patch logit map
|
return self.net(x) # (B, 1, H', W') — patch logit map
|
||||||
|
|
||||||
|
|
||||||
|
# Hinge loss for the discriminator (Lim & Ye, 2017)
|
||||||
def hinge_d_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor:
|
def hinge_d_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor:
|
||||||
"""Hinge loss for the discriminator (Lim & Ye, 2017)."""
|
|
||||||
loss_real = torch.mean(torch.relu(1.0 - real_logits))
|
loss_real = torch.mean(torch.relu(1.0 - real_logits))
|
||||||
loss_fake = torch.mean(torch.relu(1.0 + fake_logits))
|
loss_fake = torch.mean(torch.relu(1.0 + fake_logits))
|
||||||
return 0.5 * (loss_real + loss_fake)
|
return 0.5 * (loss_real + loss_fake)
|
||||||
|
|
||||||
|
|
||||||
|
# Generator hinge loss — maximises D(fake)
|
||||||
def hinge_g_loss(fake_logits: torch.Tensor) -> torch.Tensor:
|
def hinge_g_loss(fake_logits: torch.Tensor) -> torch.Tensor:
|
||||||
"""Generator hinge loss — maximise D(fake)."""
|
|
||||||
return -torch.mean(fake_logits)
|
return -torch.mean(fake_logits)
|
||||||
|
|||||||
@@ -1,13 +1,5 @@
|
|||||||
"""Time-conditioned U-Net for DDPM (Phase 4).
|
# Time-conditioned U-Net for DDPM (Phase 4), following Ho et al. (2020) with Nichol & Dhariwal (2021) options.
|
||||||
|
# Sinusoidal time embedding → MLP → injected additively into every ResBlock; GroupNorm(32) + SiLU throughout.
|
||||||
Architecture follows Ho et al. (2020) with options from Nichol & Dhariwal (2021):
|
|
||||||
- Sinusoidal time embedding → MLP → added to every ResBlock
|
|
||||||
- GroupNorm (32 groups) + SiLU activations throughout
|
|
||||||
- Self-attention at configurable spatial resolutions
|
|
||||||
- Upsample(nearest) + Conv in the decoder — no checkerboard artefacts
|
|
||||||
|
|
||||||
Registered as kind="ddpm".
|
|
||||||
"""
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -37,9 +29,8 @@ class SinusoidalPosEmb(nn.Module):
|
|||||||
|
|
||||||
# ── Core building blocks ──────────────────────────────────────────────────────
|
# ── Core building blocks ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# ResNet block with additive time-embedding injection after the first conv
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
"""ResNet block with time-embedding injection (additive, after first conv)."""
|
|
||||||
|
|
||||||
def __init__(self, in_ch: int, out_ch: int, t_emb_dim: int, dropout: float = 0.1):
|
def __init__(self, in_ch: int, out_ch: int, t_emb_dim: int, dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm1 = nn.GroupNorm(_GN, in_ch)
|
self.norm1 = nn.GroupNorm(_GN, in_ch)
|
||||||
@@ -57,9 +48,8 @@ class ResBlock(nn.Module):
|
|||||||
return h + self.skip(x)
|
return h + self.skip(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Single-head self-attention with GroupNorm pre-norm and residual
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
"""Single-head self-attention with GroupNorm pre-norm and residual."""
|
|
||||||
|
|
||||||
def __init__(self, ch: int):
|
def __init__(self, ch: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = nn.GroupNorm(_GN, ch)
|
self.norm = nn.GroupNorm(_GN, ch)
|
||||||
@@ -155,17 +145,9 @@ class UpBlock(nn.Module):
|
|||||||
|
|
||||||
# ── U-Net ─────────────────────────────────────────────────────────────────────
|
# ── U-Net ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Time-conditioned U-Net; image_size must be a power-of-two (64 or 128 recommended)
|
||||||
|
# ch_mult sets channel multipliers per level; attn_resolutions controls where attention is inserted
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
"""Time-conditioned U-Net.
|
|
||||||
|
|
||||||
image_size — must be a power-of-two; 64 or 128 recommended.
|
|
||||||
base_ch — base channel count (128 for phases 4.1–4.3, 192 for 4.4).
|
|
||||||
ch_mult — channel multipliers per resolution level.
|
|
||||||
attn_resolutions — spatial resolutions at which attention is inserted.
|
|
||||||
num_res_blocks — ResBlocks per level (in both down and up paths).
|
|
||||||
dropout — applied inside every ResBlock.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_size: int = 64,
|
image_size: int = 64,
|
||||||
@@ -265,6 +247,7 @@ class UNet(nn.Module):
|
|||||||
return self.out_conv(F.silu(self.out_norm(x)))
|
return self.out_conv(F.silu(self.out_norm(x)))
|
||||||
|
|
||||||
|
|
||||||
|
# Builds a UNet from cfg for DDPM training
|
||||||
def _build(cfg: dict):
|
def _build(cfg: dict):
|
||||||
return UNet(
|
return UNet(
|
||||||
image_size = cfg.get("image_size", 64),
|
image_size = cfg.get("image_size", 64),
|
||||||
|
|||||||
@@ -1,13 +1,5 @@
|
|||||||
"""Convolutional VAE for Phase 3.
|
# Convolutional VAE for Phase 3. Encoder: stride-2 Conv → flatten → linear (μ, log σ²).
|
||||||
|
# Decoder: Linear → Upsample(nearest) + Conv — avoids ConvTranspose2d checkerboard artefacts.
|
||||||
Encoder uses stride-2 Conv → flatten → linear (μ, log σ²).
|
|
||||||
Decoder uses Linear → Upsample(nearest) + Conv to avoid ConvTranspose2d
|
|
||||||
checkerboard artefacts.
|
|
||||||
|
|
||||||
Registered as kind="vae". The run.py dispatcher passes the model to
|
|
||||||
train_vae(), which internally builds perceptual loss and PatchGAN when
|
|
||||||
the corresponding lambdas are non-zero.
|
|
||||||
"""
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -31,8 +23,8 @@ def _norm(channels: int) -> nn.GroupNorm:
|
|||||||
return nn.GroupNorm(8, channels)
|
return nn.GroupNorm(8, channels)
|
||||||
|
|
||||||
|
|
||||||
|
# Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard artefacts
|
||||||
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
|
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
|
||||||
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
|
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Upsample(scale_factor=2, mode="nearest"),
|
nn.Upsample(scale_factor=2, mode="nearest"),
|
||||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
||||||
@@ -41,20 +33,15 @@ def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Convolutional VAE; image_size must be a power-of-two ≥ 32
|
||||||
|
# Spatial bottleneck is always at 4×4 — encoder and decoder scale stride-2 steps accordingly
|
||||||
class VAE(nn.Module):
|
class VAE(nn.Module):
|
||||||
"""Convolutional VAE. image_size must be a power-of-two ≥ 32.
|
|
||||||
|
|
||||||
Spatial bottleneck is always at 4×4 regardless of image_size —
|
|
||||||
the encoder and decoder scale the number of stride-2 steps accordingly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, latent_dim: int = 256, ngf: int = 64, image_size: int = 64):
|
def __init__(self, latent_dim: int = 256, ngf: int = 64, image_size: int = 64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if image_size < 32 or (image_size & (image_size - 1)):
|
if image_size < 32 or (image_size & (image_size - 1)):
|
||||||
raise ValueError(f"image_size must be a power-of-two ≥ 32, got {image_size}")
|
raise ValueError(f"image_size must be a power-of-two ≥ 32, got {image_size}")
|
||||||
|
|
||||||
self.latent_dim = latent_dim
|
self.latent_dim = latent_dim
|
||||||
self.image_size = image_size
|
|
||||||
|
|
||||||
n_down = int(math.log2(image_size)) - 2 # steps from image_size to 4×4
|
n_down = int(math.log2(image_size)) - 2 # steps from image_size to 4×4
|
||||||
# 64 → n_down=4: 64→32→16→8→4
|
# 64 → n_down=4: 64→32→16→8→4
|
||||||
@@ -115,19 +102,20 @@ class VAE(nn.Module):
|
|||||||
h = self.fc_dec(z).view(z.size(0), self._dec_ch, 4, 4)
|
h = self.fc_dec(z).view(z.size(0), self._dec_ch, 4, 4)
|
||||||
return self.decoder(h)
|
return self.decoder(h)
|
||||||
|
|
||||||
|
# Returns (reconstruction, mu, log_var)
|
||||||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Returns (reconstruction, mu, log_var)."""
|
|
||||||
mu, log_var = self.encode(x)
|
mu, log_var = self.encode(x)
|
||||||
z = self.reparameterize(mu, log_var)
|
z = self.reparameterize(mu, log_var)
|
||||||
return self.decode(z), mu, log_var
|
return self.decode(z), mu, log_var
|
||||||
|
|
||||||
|
# Samples n images by drawing z ~ N(0, I) and decoding
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(self, n: int, device) -> torch.Tensor:
|
def sample(self, n: int, device) -> torch.Tensor:
|
||||||
"""Sample n images by drawing z ~ N(0, I)."""
|
|
||||||
z = torch.randn(n, self.latent_dim, device=device)
|
z = torch.randn(n, self.latent_dim, device=device)
|
||||||
return self.decode(z)
|
return self.decode(z)
|
||||||
|
|
||||||
|
|
||||||
|
# Builds a VAE from cfg
|
||||||
def _build(cfg: dict):
|
def _build(cfg: dict):
|
||||||
return VAE(
|
return VAE(
|
||||||
latent_dim=cfg.get("latent_dim", 256),
|
latent_dim=cfg.get("latent_dim", 256),
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
"""WGAN-GP variants.
|
# WGAN-GP variants.
|
||||||
|
# wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only.
|
||||||
wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only.
|
# wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic.
|
||||||
wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic.
|
|
||||||
"""
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -23,9 +21,8 @@ def _sn(module):
|
|||||||
return nn.utils.spectral_norm(module)
|
return nn.utils.spectral_norm(module)
|
||||||
|
|
||||||
|
|
||||||
|
# SAGAN-style self-attention: gamma-gated residual, dot-products scaled by mid^-0.5
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
"""SAGAN-style self-attention."""
|
|
||||||
|
|
||||||
def __init__(self, in_ch: int):
|
def __init__(self, in_ch: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mid = max(in_ch // 8, 1)
|
mid = max(in_ch // 8, 1)
|
||||||
@@ -48,13 +45,8 @@ class SelfAttention(nn.Module):
|
|||||||
# Phase 2.2 — basic WGAN-GP (BatchNorm in G, InstanceNorm in D, 64×64 only)
|
# Phase 2.2 — basic WGAN-GP (BatchNorm in G, InstanceNorm in D, 64×64 only)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Maps (latent_dim, 1, 1) → (3, 64, 64) in [-1, 1]; BatchNorm in G is safe under WGAN-GP (constraint targets the critic)
|
||||||
class WGANBasicGenerator(nn.Module):
|
class WGANBasicGenerator(nn.Module):
|
||||||
"""Maps (latent_dim, 1, 1) -> (3, 64, 64) in [-1, 1].
|
|
||||||
|
|
||||||
Same channel structure as DCGAN. BatchNorm in generator is fine because
|
|
||||||
WGAN-GP's constraint targets the critic, not the generator.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, latent_dim: int = 128, ngf: int = 64):
|
def __init__(self, latent_dim: int = 128, ngf: int = 64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
@@ -80,11 +72,8 @@ class WGANBasicGenerator(nn.Module):
|
|||||||
return self.net(z)
|
return self.net(z)
|
||||||
|
|
||||||
|
|
||||||
|
# WGAN-GP critic (64×64); InstanceNorm instead of BatchNorm — BatchNorm breaks the per-sample Lipschitz constraint
|
||||||
class WGANBasicCritic(nn.Module):
|
class WGANBasicCritic(nn.Module):
|
||||||
"""WGAN-GP critic (64×64). InstanceNorm instead of BatchNorm — BatchNorm
|
|
||||||
breaks the per-sample Lipschitz constraint the gradient penalty enforces.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, ndf: int = 64):
|
def __init__(self, ndf: int = 64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
@@ -116,21 +105,14 @@ class WGANBasicCritic(nn.Module):
|
|||||||
# Phase 2.3 / 2.4 — advanced WGAN-GP (GroupNorm, SpectralNorm, attention)
|
# Phase 2.3 / 2.4 — advanced WGAN-GP (GroupNorm, SpectralNorm, attention)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# GroupNorm generator with SAGAN self-attention; supports image_size ∈ {64, 128}
|
||||||
|
# Stem always 1×1 → 4×4 → 8×8 → 16×16; attention at 16×16, and at 32×32 for 128×128
|
||||||
class WGANGenerator(nn.Module):
|
class WGANGenerator(nn.Module):
|
||||||
"""GroupNorm generator with SAGAN self-attention.
|
|
||||||
|
|
||||||
Supports image_size ∈ {64, 128}.
|
|
||||||
Stem is always 1×1 → 4×4 → 8×8 → 16×16 (ngf×8 → ngf×4 → ngf×2 channels).
|
|
||||||
Attention at 16×16 always; additional attention at 32×32 for 128×128.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, latent_dim: int = 128, ngf: int = 128, image_size: int = 64):
|
def __init__(self, latent_dim: int = 128, ngf: int = 128, image_size: int = 64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if image_size not in (64, 128):
|
if image_size not in (64, 128):
|
||||||
raise ValueError(f"WGANGenerator supports image_size 64 or 128, got {image_size}")
|
raise ValueError(f"WGANGenerator supports image_size 64 or 128, got {image_size}")
|
||||||
|
|
||||||
self._image_size = image_size
|
|
||||||
|
|
||||||
self.stem = nn.Sequential(
|
self.stem = nn.Sequential(
|
||||||
# 1×1 → 4×4
|
# 1×1 → 4×4
|
||||||
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
|
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
|
||||||
@@ -179,20 +161,14 @@ class WGANGenerator(nn.Module):
|
|||||||
return self.tail(h)
|
return self.tail(h)
|
||||||
|
|
||||||
|
|
||||||
|
# SpectralNorm critic with SAGAN self-attention; supports image_size ∈ {64, 128}
|
||||||
|
# Attention at 16×16 always; additional attention at 32×32 for 128×128
|
||||||
class WGANCritic(nn.Module):
|
class WGANCritic(nn.Module):
|
||||||
"""SpectralNorm critic with SAGAN self-attention.
|
|
||||||
|
|
||||||
Supports image_size ∈ {64, 128}.
|
|
||||||
Attention at 16×16 always; additional attention at 32×32 for 128×128.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, ndf: int = 128, image_size: int = 64):
|
def __init__(self, ndf: int = 128, image_size: int = 64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if image_size not in (64, 128):
|
if image_size not in (64, 128):
|
||||||
raise ValueError(f"WGANCritic supports image_size 64 or 128, got {image_size}")
|
raise ValueError(f"WGANCritic supports image_size 64 or 128, got {image_size}")
|
||||||
|
|
||||||
self._image_size = image_size
|
|
||||||
|
|
||||||
if image_size == 64:
|
if image_size == 64:
|
||||||
# Head: 64→32 (ndf//2)
|
# Head: 64→32 (ndf//2)
|
||||||
self.head = nn.Sequential(
|
self.head = nn.Sequential(
|
||||||
@@ -242,6 +218,7 @@ class WGANCritic(nn.Module):
|
|||||||
return self.tail(h).view(x.size(0))
|
return self.tail(h).view(x.size(0))
|
||||||
|
|
||||||
|
|
||||||
|
# Builds a (WGANBasicGenerator, WGANBasicCritic) pair from cfg
|
||||||
def _build_basic(cfg: dict):
|
def _build_basic(cfg: dict):
|
||||||
return (
|
return (
|
||||||
WGANBasicGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 64)),
|
WGANBasicGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 64)),
|
||||||
@@ -249,6 +226,7 @@ def _build_basic(cfg: dict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Builds a (WGANGenerator, WGANCritic) pair from cfg
|
||||||
def _build(cfg: dict):
|
def _build(cfg: dict):
|
||||||
image_size = cfg.get("image_size", 64)
|
image_size = cfg.get("image_size", 64)
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -1,13 +1,5 @@
|
|||||||
"""Gaussian diffusion utilities for Phase 4 (DDPM).
|
# Gaussian diffusion utilities for Phase 4 (DDPM): noise schedules, forward process, training loss, DDIM sampling.
|
||||||
|
# alpha_bars[t] = ᾱ_{t+1} (0-indexed); t=0 is near-clean (ᾱ ≈ 1−β₁), t=T−1 is near-pure-noise (ᾱ ≈ 0).
|
||||||
Provides noise schedules, the forward (noising) process, training loss,
|
|
||||||
and DDIM deterministic sampling (Song et al., 2020).
|
|
||||||
|
|
||||||
Convention: alpha_bars is a 1-D tensor of length T, where alpha_bars[t]
|
|
||||||
= ᾱ_{t+1} in 1-indexed notation. Timestep t used in the training loop
|
|
||||||
is a 0-indexed integer in [0, T). At t=0 the image is almost clean
|
|
||||||
(ᾱ ≈ 1 − β_1); at t=T−1 the image is almost pure noise (ᾱ ≈ 0).
|
|
||||||
"""
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -16,13 +8,13 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
# ── Noise schedules ──────────────────────────────────────────────────────────
|
# ── Noise schedules ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Ho et al. (2020) linear schedule
|
||||||
def linear_betas(T: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor:
|
def linear_betas(T: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor:
|
||||||
"""Ho et al. (2020) linear schedule."""
|
|
||||||
return torch.linspace(beta_start, beta_end, T)
|
return torch.linspace(beta_start, beta_end, T)
|
||||||
|
|
||||||
|
|
||||||
|
# Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t
|
||||||
def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor:
|
def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor:
|
||||||
"""Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t."""
|
|
||||||
t = torch.linspace(0, T, T + 1)
|
t = torch.linspace(0, T, T + 1)
|
||||||
f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
|
f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
|
||||||
alpha_bar = f / f[0]
|
alpha_bar = f / f[0]
|
||||||
@@ -30,20 +22,20 @@ def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor:
|
|||||||
return betas.clamp(max=0.999)
|
return betas.clamp(max=0.999)
|
||||||
|
|
||||||
|
|
||||||
|
# Cumulative product of (1 − β), shape (T,)
|
||||||
def make_alpha_bars(betas: torch.Tensor) -> torch.Tensor:
|
def make_alpha_bars(betas: torch.Tensor) -> torch.Tensor:
|
||||||
"""Cumulative product of (1 − β), shape (T,)."""
|
|
||||||
return (1.0 - betas).cumprod(0)
|
return (1.0 - betas).cumprod(0)
|
||||||
|
|
||||||
|
|
||||||
# ── Forward process ──────────────────────────────────────────────────────────
|
# ── Forward process ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Adds noise to x0 at timestep t; returns (x_t, noise)
|
||||||
def q_sample(
|
def q_sample(
|
||||||
x0: torch.Tensor,
|
x0: torch.Tensor,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
alpha_bars: torch.Tensor,
|
alpha_bars: torch.Tensor,
|
||||||
noise: torch.Tensor | None = None,
|
noise: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Add noise to x0 at timestep t. Returns (x_t, noise)."""
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = torch.randn_like(x0)
|
noise = torch.randn_like(x0)
|
||||||
ab = alpha_bars[t].to(x0.device)[:, None, None, None]
|
ab = alpha_bars[t].to(x0.device)[:, None, None, None]
|
||||||
@@ -53,6 +45,7 @@ def q_sample(
|
|||||||
|
|
||||||
# ── Training loss ────────────────────────────────────────────────────────────
|
# ── Training loss ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# MSE between model prediction and target; pred_type="eps" targets noise ε, pred_type="v" targets v = √ᾱ·ε − √(1−ᾱ)·x0
|
||||||
def diffusion_loss(
|
def diffusion_loss(
|
||||||
model,
|
model,
|
||||||
x0: torch.Tensor,
|
x0: torch.Tensor,
|
||||||
@@ -60,11 +53,6 @@ def diffusion_loss(
|
|||||||
alpha_bars: torch.Tensor,
|
alpha_bars: torch.Tensor,
|
||||||
pred_type: str = "eps",
|
pred_type: str = "eps",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""MSE on the model's prediction vs the true target.
|
|
||||||
|
|
||||||
pred_type="eps" → target is the added noise ε (Ho et al.)
|
|
||||||
pred_type="v" → target is v = √ᾱ·ε − √(1−ᾱ)·x0 (Salimans & Ho)
|
|
||||||
"""
|
|
||||||
x_t, noise = q_sample(x0, t, alpha_bars)
|
x_t, noise = q_sample(x0, t, alpha_bars)
|
||||||
pred = model(x_t, t)
|
pred = model(x_t, t)
|
||||||
|
|
||||||
@@ -80,6 +68,8 @@ def diffusion_loss(
|
|||||||
# ── DDIM deterministic sampling ───────────────────────────────────────────────
|
# ── DDIM deterministic sampling ───────────────────────────────────────────────
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
# Generates n images via DDIM (eta=0, deterministic); batches internally to avoid OOM
|
||||||
|
# Returns tensor shape (n, 3, image_size, image_size) in [-1, 1]
|
||||||
def ddim_sample(
|
def ddim_sample(
|
||||||
model,
|
model,
|
||||||
n: int,
|
n: int,
|
||||||
@@ -90,11 +80,6 @@ def ddim_sample(
|
|||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Generate n images via DDIM (eta=0, deterministic).
|
|
||||||
|
|
||||||
Batches internally to avoid OOM when n is large.
|
|
||||||
Returns tensor shape (n, 3, image_size, image_size) in [-1, 1].
|
|
||||||
"""
|
|
||||||
model.eval()
|
model.eval()
|
||||||
T = len(alpha_bars)
|
T = len(alpha_bars)
|
||||||
|
|
||||||
|
|||||||
@@ -3,13 +3,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# Exponential moving average of model weights; shadow copy updated after each optimizer step
|
||||||
|
# Sample from ema.model at eval time, never from the training model directly
|
||||||
class EMA:
|
class EMA:
|
||||||
"""Exponential moving average of model weights.
|
|
||||||
|
|
||||||
Maintains a shadow copy of the model. Call update() after each
|
|
||||||
optimizer step. Sample from ema.model, never from the training model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, decay: float = 0.9999):
|
def __init__(self, model: nn.Module, decay: float = 0.9999):
|
||||||
self.decay = decay
|
self.decay = decay
|
||||||
self.model = copy.deepcopy(model).eval()
|
self.model = copy.deepcopy(model).eval()
|
||||||
|
|||||||
@@ -1,20 +1,17 @@
|
|||||||
"""FID evaluation helper.
|
# FID evaluation helper: caches real images as a CPU tensor to avoid re-reading disk on every eval call.
|
||||||
|
|
||||||
Computes Fréchet Inception Distance between a fixed set of real images
|
|
||||||
and a batch of generated images. Real images are stored as a tensor on CPU
|
|
||||||
and moved to device only during evaluation — this avoids re-reading disk
|
|
||||||
every call while keeping GPU memory free between evaluations.
|
|
||||||
"""
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||||
|
|
||||||
|
|
||||||
|
# Computes FID between cached real images and a batch of generated images
|
||||||
class FIDEvaluator:
|
class FIDEvaluator:
|
||||||
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
|
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
|
||||||
num_workers: int = 2):
|
num_workers: int = 2):
|
||||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||||
self.n_real = n_real
|
self.n_real = n_real
|
||||||
|
# Inception network loaded once here; compute() calls reset() between evaluations
|
||||||
|
self._fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device)
|
||||||
|
|
||||||
# Cache real images as a CPU tensor ([-1, 1] range)
|
# Cache real images as a CPU tensor ([-1, 1] range)
|
||||||
imgs_list = []
|
imgs_list = []
|
||||||
@@ -27,24 +24,20 @@ class FIDEvaluator:
|
|||||||
real = torch.cat(imgs_list)[:n_real]
|
real = torch.cat(imgs_list)[:n_real]
|
||||||
self._real = real # stored on CPU, shape (N, 3, H, W) in [-1, 1]
|
self._real = real # stored on CPU, shape (N, 3, H, W) in [-1, 1]
|
||||||
|
|
||||||
|
# Computes FID score; fake_imgs should be float in [-1, 1], shape (N, 3, H, W), N ≥ 2048 for reliability
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def compute(self, fake_imgs: torch.Tensor) -> float:
|
def compute(self, fake_imgs: torch.Tensor) -> float:
|
||||||
"""Compute FID score.
|
self._fid.reset()
|
||||||
|
|
||||||
fake_imgs: float tensor in [-1, 1], shape (N, 3, H, W).
|
|
||||||
N should be at least 2048 for a reliable score.
|
|
||||||
"""
|
|
||||||
fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device)
|
|
||||||
|
|
||||||
# Feed real images in batches
|
# Feed real images in batches
|
||||||
for i in range(0, self._real.size(0), 256):
|
for i in range(0, self._real.size(0), 256):
|
||||||
batch = (self._real[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
|
batch = (self._real[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
|
||||||
fid.update(batch, real=True)
|
self._fid.update(batch, real=True)
|
||||||
|
|
||||||
# Feed fake images in batches
|
# Feed fake images in batches
|
||||||
fake = fake_imgs.cpu()
|
fake = fake_imgs.cpu()
|
||||||
for i in range(0, fake.size(0), 256):
|
for i in range(0, fake.size(0), 256):
|
||||||
batch = (fake[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
|
batch = (fake[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
|
||||||
fid.update(batch, real=False)
|
self._fid.update(batch, real=False)
|
||||||
|
|
||||||
return float(fid.compute())
|
return float(self._fid.compute())
|
||||||
|
|||||||
@@ -1,25 +1,17 @@
|
|||||||
"""Extended generation quality metrics for Phase 5 cross-family comparison.
|
# Extended generation quality metrics for Phase 5 cross-family comparison.
|
||||||
|
# IS (Salimans et al., 2016): quality × diversity. LPIPS diversity: pairwise perceptual distance.
|
||||||
IS — Inception Score (Salimans et al., 2016): measures sample quality × diversity.
|
# Both functions accept float tensors in [-1, 1].
|
||||||
LPIPS — average pairwise learned perceptual distance: measures sample diversity alone.
|
|
||||||
|
|
||||||
Both functions accept float tensors in [-1, 1].
|
|
||||||
"""
|
|
||||||
import torch
|
import torch
|
||||||
from torchmetrics.image.inception import InceptionScore
|
from torchmetrics.image.inception import InceptionScore
|
||||||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||||
|
|
||||||
|
|
||||||
|
# Inception Score (mean ± std) over 10 splits; imgs in [-1, 1], N ≥ 2048 for reliability; returns (mean, std)
|
||||||
def compute_is(
|
def compute_is(
|
||||||
imgs: torch.Tensor,
|
imgs: torch.Tensor,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
batch_size: int = 64,
|
batch_size: int = 64,
|
||||||
) -> tuple[float, float]:
|
) -> tuple[float, float]:
|
||||||
"""Inception Score (mean ± std) over 10 splits.
|
|
||||||
|
|
||||||
imgs: (N, 3, H, W) in [-1, 1]. N ≥ 2 048 for a reliable estimate.
|
|
||||||
Returns (is_mean, is_std).
|
|
||||||
"""
|
|
||||||
metric = InceptionScore(normalize=True).to(device)
|
metric = InceptionScore(normalize=True).to(device)
|
||||||
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
|
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
|
||||||
for i in range(0, len(imgs_01), batch_size):
|
for i in range(0, len(imgs_01), batch_size):
|
||||||
@@ -28,16 +20,13 @@ def compute_is(
|
|||||||
return float(mean), float(std)
|
return float(mean), float(std)
|
||||||
|
|
||||||
|
|
||||||
|
# Average pairwise LPIPS distance over n_pairs random (i ≠ j) pairs; higher = more diverse
|
||||||
def compute_lpips_diversity(
|
def compute_lpips_diversity(
|
||||||
imgs: torch.Tensor,
|
imgs: torch.Tensor,
|
||||||
n_pairs: int = 200,
|
n_pairs: int = 200,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Average pairwise LPIPS distance — higher means more diverse samples.
|
|
||||||
|
|
||||||
imgs: (N, 3, H, W) in [-1, 1]. Samples n_pairs random (i, j) pairs with i ≠ j.
|
|
||||||
"""
|
|
||||||
metric = LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True).to(device)
|
metric = LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True).to(device)
|
||||||
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
|
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
|
||||||
N = len(imgs_01)
|
N = len(imgs_01)
|
||||||
|
|||||||
@@ -1,26 +1,13 @@
|
|||||||
"""VGG-16 perceptual loss for Phase 3.2 and 3.3.
|
# VGG-16 perceptual loss (Phase 3.2/3.3): L1 feature distance at relu1_2, relu2_2, relu3_3.
|
||||||
|
# Expects images in [-1, 1]; converts internally to [0, 1] then ImageNet-normalises before VGG. Weights frozen.
|
||||||
Extracts features at relu1_2, relu2_2, relu3_3 and returns the
|
|
||||||
L1 distance in feature space. VGG weights are frozen.
|
|
||||||
|
|
||||||
Input convention: images in [-1, 1] — the loss converts internally to
|
|
||||||
[0, 1] and then applies ImageNet normalisation before passing to VGG.
|
|
||||||
"""
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.models as tv_models
|
import torchvision.models as tv_models
|
||||||
|
|
||||||
|
|
||||||
|
# L1 feature-matching loss at relu1_2 ([:4]), relu2_2 ([4:9]), relu3_3 ([9:16]) of VGG-16
|
||||||
class PerceptualLoss(nn.Module):
|
class PerceptualLoss(nn.Module):
|
||||||
"""L1 feature-matching loss at three VGG-16 layers.
|
|
||||||
|
|
||||||
VGG-16 feature indices:
|
|
||||||
relu1_2: features[:4] (before first maxpool)
|
|
||||||
relu2_2: features[4:9] (before second maxpool)
|
|
||||||
relu3_3: features[9:16] (before third maxpool)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
vgg = tv_models.vgg16(weights=tv_models.VGG16_Weights.IMAGENET1K_V1)
|
vgg = tv_models.vgg16(weights=tv_models.VGG16_Weights.IMAGENET1K_V1)
|
||||||
@@ -39,13 +26,13 @@ class PerceptualLoss(nn.Module):
|
|||||||
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Converts [-1, 1] → [0, 1] then applies ImageNet normalisation
|
||||||
def _normalise(self, x: torch.Tensor) -> torch.Tensor:
|
def _normalise(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Convert [-1, 1] → ImageNet-normalised [0, 1]."""
|
|
||||||
x = x * 0.5 + 0.5 # → [0, 1]
|
x = x * 0.5 + 0.5 # → [0, 1]
|
||||||
return (x - self.mean) / self.std
|
return (x - self.mean) / self.std
|
||||||
|
|
||||||
|
# L1 feature distance; real gradients are stopped — only fake trains
|
||||||
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
|
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
|
||||||
"""L1 feature distance. real gradients are stopped — only fake trains."""
|
|
||||||
f = self._normalise(fake)
|
f = self._normalise(fake)
|
||||||
r = self._normalise(real)
|
r = self._normalise(real)
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,31 @@ else:
|
|||||||
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
|
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
# Linear decay from decay_start to epochs; returns a multiplier clamped to [0, 1]
|
||||||
|
def _linear_decay(decay_start: int, epochs: int):
|
||||||
|
return lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))
|
||||||
|
|
||||||
|
|
||||||
|
# Returns (new_best_fid, new_no_improve, should_stop); improved = new_best_fid < best_fid
|
||||||
|
def _check_fid_plateau(fid_score: float, best_fid: float, no_improve: int, patience) -> tuple[float, int, bool]:
|
||||||
|
if fid_score < best_fid:
|
||||||
|
return fid_score, 0, False
|
||||||
|
no_improve += 1
|
||||||
|
if patience and no_improve >= patience:
|
||||||
|
print(f"FID plateau — {no_improve} evals without improvement, stopping")
|
||||||
|
return best_fid, no_improve, True
|
||||||
|
return best_fid, no_improve, False
|
||||||
|
|
||||||
|
|
||||||
|
# Generates n GAN fake images in batches of 64 using the EMA model
|
||||||
|
@torch.no_grad()
|
||||||
|
def _generate_gan_fakes(ema_model, n: int, latent_dim: int, device) -> torch.Tensor:
|
||||||
|
return torch.cat([
|
||||||
|
ema_model(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||||
|
for _ in range(n // 64 + 1)
|
||||||
|
])[:n]
|
||||||
|
|
||||||
|
|
||||||
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None:
|
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None:
|
||||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -29,6 +54,7 @@ def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise:
|
|||||||
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||||
|
|
||||||
|
|
||||||
|
# Vanilla DCGAN training loop (Phase 1): BCE loss, single G/D step per batch, no gradient penalty
|
||||||
def train_dcgan(
|
def train_dcgan(
|
||||||
generator,
|
generator,
|
||||||
discriminator,
|
discriminator,
|
||||||
@@ -39,11 +65,6 @@ def train_dcgan(
|
|||||||
run_name: str,
|
run_name: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Vanilla DCGAN training loop with BCE loss (Radford et al., 2015).
|
|
||||||
|
|
||||||
Used as the Phase 1 baseline for cheap pipeline ablations. No gradient
|
|
||||||
penalty, no n_critic, single G/D step per batch.
|
|
||||||
"""
|
|
||||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||||
generator = generator.to(device)
|
generator = generator.to(device)
|
||||||
discriminator = discriminator.to(device)
|
discriminator = discriminator.to(device)
|
||||||
@@ -79,6 +100,10 @@ def train_dcgan(
|
|||||||
|
|
||||||
ema = EMA(generator, decay=ema_decay)
|
ema = EMA(generator, decay=ema_decay)
|
||||||
|
|
||||||
|
if hasattr(torch, "compile"):
|
||||||
|
generator = torch.compile(generator)
|
||||||
|
discriminator = torch.compile(discriminator)
|
||||||
|
|
||||||
# Fixed noise for consistent sample tracking across epochs
|
# Fixed noise for consistent sample tracking across epochs
|
||||||
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||||
|
|
||||||
@@ -91,18 +116,22 @@ def train_dcgan(
|
|||||||
|
|
||||||
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||||
best_fid = float("inf")
|
best_fid = float("inf")
|
||||||
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
|
print(f"Device: {device} AMP: {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)}")
|
||||||
|
|
||||||
# Linear LR decay from epoch epochs//2 to epochs
|
# Linear LR decay from epoch epochs//2 to epochs
|
||||||
decay_start = epochs // 2
|
decay_start = epochs // 2
|
||||||
sched_g = torch.optim.lr_scheduler.LambdaLR(
|
sched_g = torch.optim.lr_scheduler.LambdaLR(opt_g, _linear_decay(decay_start, epochs))
|
||||||
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
sched_d = torch.optim.lr_scheduler.LambdaLR(opt_d, _linear_decay(decay_start, epochs))
|
||||||
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
|
||||||
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
|
||||||
|
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
|
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
|
||||||
|
fid_patience = cfg.get("fid_patience", None)
|
||||||
|
no_improve = 0
|
||||||
|
|
||||||
for epoch in range(1, epochs + 1):
|
for epoch in range(1, epochs + 1):
|
||||||
|
if max_train_secs and time.time() - t_start >= max_train_secs:
|
||||||
|
print(f"Time limit reached — stopping after epoch {epoch - 1}")
|
||||||
|
break
|
||||||
generator.train()
|
generator.train()
|
||||||
discriminator.train()
|
discriminator.train()
|
||||||
g_sum = d_sum = real_sum = fake_sum = 0.0
|
g_sum = d_sum = real_sum = fake_sum = 0.0
|
||||||
@@ -160,19 +189,18 @@ def train_dcgan(
|
|||||||
|
|
||||||
if epoch % fid_interval == 0:
|
if epoch % fid_interval == 0:
|
||||||
ema.model.eval()
|
ema.model.eval()
|
||||||
with torch.no_grad():
|
fake_imgs = _generate_gan_fakes(ema.model, fid_n_real, latent_dim, device)
|
||||||
fake_imgs = torch.cat([
|
|
||||||
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
|
|
||||||
for _ in range(fid_n_real // 64 + 1)
|
|
||||||
])[:fid_n_real]
|
|
||||||
fid_score = fid_eval.compute(fake_imgs)
|
fid_score = fid_eval.compute(fake_imgs)
|
||||||
history["fid"][epoch] = fid_score
|
history["fid"][epoch] = fid_score
|
||||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||||
|
|
||||||
if fid_score < best_fid:
|
new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
|
||||||
best_fid = fid_score
|
if new_best < best_fid:
|
||||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||||
|
best_fid = new_best
|
||||||
|
if plateau:
|
||||||
|
break
|
||||||
|
|
||||||
sched_g.step()
|
sched_g.step()
|
||||||
sched_d.step()
|
sched_d.step()
|
||||||
@@ -184,8 +212,8 @@ def train_dcgan(
|
|||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
# Two-sided gradient penalty (Gulrajani et al., 2017)
|
||||||
def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor:
|
def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor:
|
||||||
"""Two-sided gradient penalty (Gulrajani et al., 2017)."""
|
|
||||||
bsz = real.size(0)
|
bsz = real.size(0)
|
||||||
eps = torch.rand(bsz, 1, 1, 1, device=device)
|
eps = torch.rand(bsz, 1, 1, 1, device=device)
|
||||||
interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True)
|
interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True)
|
||||||
@@ -200,6 +228,8 @@ def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) ->
|
|||||||
return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean()
|
return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean()
|
||||||
|
|
||||||
|
|
||||||
|
# WGAN-GP training loop (Phase 2.2–2.4): gradient penalty replaces weight clipping
|
||||||
|
# Critic runs in float32 for GP stability; AMP used only for the generator forward/backward
|
||||||
def train_wgan(
|
def train_wgan(
|
||||||
generator,
|
generator,
|
||||||
critic,
|
critic,
|
||||||
@@ -210,12 +240,6 @@ def train_wgan(
|
|||||||
run_name: str,
|
run_name: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""WGAN-GP training loop (Gulrajani et al., 2017).
|
|
||||||
|
|
||||||
Used for Phase 2.2–2.4. Gradient penalty replaces weight clipping.
|
|
||||||
The critic runs in float32 to keep GP gradient computation numerically
|
|
||||||
stable; AMP is used only for the generator forward/backward.
|
|
||||||
"""
|
|
||||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||||
generator = generator.to(device)
|
generator = generator.to(device)
|
||||||
critic = critic.to(device)
|
critic = critic.to(device)
|
||||||
@@ -251,6 +275,10 @@ def train_wgan(
|
|||||||
|
|
||||||
ema = EMA(generator, decay=ema_decay)
|
ema = EMA(generator, decay=ema_decay)
|
||||||
|
|
||||||
|
if hasattr(torch, "compile"):
|
||||||
|
generator = torch.compile(generator)
|
||||||
|
critic = torch.compile(critic)
|
||||||
|
|
||||||
# Fixed noise for consistent sample tracking across epochs
|
# Fixed noise for consistent sample tracking across epochs
|
||||||
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||||
|
|
||||||
@@ -263,18 +291,22 @@ def train_wgan(
|
|||||||
|
|
||||||
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||||
best_fid = float("inf")
|
best_fid = float("inf")
|
||||||
print(f"Device: {device} AMP (G only): {use_amp} Batches/epoch: {len(loader)} n_critic: {n_critic}")
|
print(f"Device: {device} AMP (G only): {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)} n_critic: {n_critic}")
|
||||||
|
|
||||||
# Linear LR decay from epoch epochs//2 to epochs
|
# Linear LR decay from epoch epochs//2 to epochs
|
||||||
decay_start = epochs // 2
|
decay_start = epochs // 2
|
||||||
sched_g = torch.optim.lr_scheduler.LambdaLR(
|
sched_g = torch.optim.lr_scheduler.LambdaLR(opt_g, _linear_decay(decay_start, epochs))
|
||||||
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
sched_c = torch.optim.lr_scheduler.LambdaLR(opt_c, _linear_decay(decay_start, epochs))
|
||||||
sched_c = torch.optim.lr_scheduler.LambdaLR(
|
|
||||||
opt_c, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
|
||||||
|
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
|
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
|
||||||
|
fid_patience = cfg.get("fid_patience", None)
|
||||||
|
no_improve = 0
|
||||||
|
|
||||||
for epoch in range(1, epochs + 1):
|
for epoch in range(1, epochs + 1):
|
||||||
|
if max_train_secs and time.time() - t_start >= max_train_secs:
|
||||||
|
print(f"Time limit reached — stopping after epoch {epoch - 1}")
|
||||||
|
break
|
||||||
generator.train()
|
generator.train()
|
||||||
critic.train()
|
critic.train()
|
||||||
g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0
|
g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0
|
||||||
@@ -292,11 +324,11 @@ def train_wgan(
|
|||||||
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
|
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
|
||||||
|
|
||||||
real_f32 = real.float()
|
real_f32 = real.float()
|
||||||
fake_f32 = fake.float().detach()
|
fake_f32 = fake.float() # already no_grad from torch.no_grad() context above
|
||||||
|
|
||||||
d_real = critic(real_f32)
|
d_real = critic(real_f32)
|
||||||
d_fake = critic(fake_f32)
|
d_fake = critic(fake_f32)
|
||||||
gp = _gradient_penalty(critic, real_f32, fake_f32.detach(), device)
|
gp = _gradient_penalty(critic, real_f32, fake_f32, device)
|
||||||
c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp
|
c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp
|
||||||
c_loss.backward()
|
c_loss.backward()
|
||||||
opt_c.step()
|
opt_c.step()
|
||||||
@@ -342,19 +374,18 @@ def train_wgan(
|
|||||||
|
|
||||||
if epoch % fid_interval == 0:
|
if epoch % fid_interval == 0:
|
||||||
ema.model.eval()
|
ema.model.eval()
|
||||||
with torch.no_grad():
|
fake_imgs = _generate_gan_fakes(ema.model, fid_n_real, latent_dim, device)
|
||||||
fake_imgs = torch.cat([
|
|
||||||
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
|
|
||||||
for _ in range(fid_n_real // 64 + 1)
|
|
||||||
])[:fid_n_real]
|
|
||||||
fid_score = fid_eval.compute(fake_imgs)
|
fid_score = fid_eval.compute(fake_imgs)
|
||||||
history["fid"][epoch] = fid_score
|
history["fid"][epoch] = fid_score
|
||||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||||
|
|
||||||
if fid_score < best_fid:
|
new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
|
||||||
best_fid = fid_score
|
if new_best < best_fid:
|
||||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||||
|
best_fid = new_best
|
||||||
|
if plateau:
|
||||||
|
break
|
||||||
|
|
||||||
sched_g.step()
|
sched_g.step()
|
||||||
sched_c.step()
|
sched_c.step()
|
||||||
@@ -370,6 +401,7 @@ def train_wgan(
|
|||||||
# Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN)
|
# Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN)
|
||||||
# ────────────────────────────────────────────────────────────────────────────
|
# ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Saves prior samples and a real-vs-reconstruction interleaved grid for the epoch
|
||||||
def _save_vae_samples(
|
def _save_vae_samples(
|
||||||
vae,
|
vae,
|
||||||
samples_dir: Path,
|
samples_dir: Path,
|
||||||
@@ -379,7 +411,6 @@ def _save_vae_samples(
|
|||||||
fixed_real: torch.Tensor,
|
fixed_real: torch.Tensor,
|
||||||
device,
|
device,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save prior samples and a real-vs-reconstruction grid side by side."""
|
|
||||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -395,6 +426,8 @@ def _save_vae_samples(
|
|||||||
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
|
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
|
||||||
|
|
||||||
|
|
||||||
|
# VAE training loop (Phase 3.1–3.3): lambda_perceptual > 0 adds VGG loss, lambda_adversarial > 0 adds PatchGAN
|
||||||
|
# AMP disabled — float16 overflows on KL spikes causing unrecoverable NaN cascades; everything runs in float32
|
||||||
def train_vae(
|
def train_vae(
|
||||||
vae,
|
vae,
|
||||||
train_dataset,
|
train_dataset,
|
||||||
@@ -404,22 +437,6 @@ def train_vae(
|
|||||||
run_name: str,
|
run_name: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""VAE training loop covering Phase 3.1 – 3.3 and Phase 5.
|
|
||||||
|
|
||||||
Config toggles:
|
|
||||||
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
|
|
||||||
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
|
|
||||||
|
|
||||||
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
|
|
||||||
|
|
||||||
KL is computed as mean over latent dimensions (scale-invariant), so
|
|
||||||
beta_kl is comparable across different latent_dim values.
|
|
||||||
|
|
||||||
AMP is intentionally disabled for VAE training — mixed-precision float16
|
|
||||||
overflows when the KL divergence spikes, producing NaN cascades that
|
|
||||||
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
|
|
||||||
computation runs in float32.
|
|
||||||
"""
|
|
||||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||||
vae = vae.to(device)
|
vae = vae.to(device)
|
||||||
|
|
||||||
@@ -450,16 +467,13 @@ def train_vae(
|
|||||||
)
|
)
|
||||||
|
|
||||||
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
|
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
|
||||||
# AMP disabled — float16 overflows on KL spikes, causing NaN cascades
|
|
||||||
use_amp = False
|
|
||||||
|
|
||||||
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
|
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
|
||||||
kl_warmup_epochs = max(1, epochs // 5)
|
kl_warmup_epochs = max(1, epochs // 5)
|
||||||
|
|
||||||
# Linear LR decay from epoch epochs//2 to epochs
|
# Linear LR decay from epoch epochs//2 to epochs
|
||||||
decay_start = epochs // 2
|
decay_start = epochs // 2
|
||||||
sched_vae = torch.optim.lr_scheduler.LambdaLR(
|
sched_vae = torch.optim.lr_scheduler.LambdaLR(opt_vae, _linear_decay(decay_start, epochs))
|
||||||
opt_vae, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
|
||||||
sched_d = None # set below if adversarial
|
sched_d = None # set below if adversarial
|
||||||
|
|
||||||
# ── Optional components ───────────────────────────────────────────────
|
# ── Optional components ───────────────────────────────────────────────
|
||||||
@@ -479,12 +493,11 @@ def train_vae(
|
|||||||
image_size=cfg.get("image_size", 64),
|
image_size=cfg.get("image_size", 64),
|
||||||
).to(device).float()
|
).to(device).float()
|
||||||
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
||||||
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
sched_d = torch.optim.lr_scheduler.LambdaLR(opt_d, _linear_decay(decay_start, epochs))
|
||||||
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
|
||||||
n_d = sum(p.numel() for p in patchgan.parameters())
|
n_d = sum(p.numel() for p in patchgan.parameters())
|
||||||
print(f"PatchGAN: {n_d:,} params")
|
print(f"PatchGAN: {n_d:,} params")
|
||||||
else:
|
if hasattr(torch, "compile"):
|
||||||
hinge_d_loss = hinge_g_loss = None # never called
|
patchgan = torch.compile(patchgan)
|
||||||
|
|
||||||
# ── Fixed seeds for consistent visualisation ──────────────────────────
|
# ── Fixed seeds for consistent visualisation ──────────────────────────
|
||||||
fixed_z = torch.randn(16, latent_dim, device=device)
|
fixed_z = torch.randn(16, latent_dim, device=device)
|
||||||
@@ -494,6 +507,9 @@ def train_vae(
|
|||||||
|
|
||||||
ema = EMA(vae, decay=ema_decay)
|
ema = EMA(vae, decay=ema_decay)
|
||||||
|
|
||||||
|
if hasattr(torch, "compile"):
|
||||||
|
vae = torch.compile(vae)
|
||||||
|
|
||||||
save_dir = Path(save_dir)
|
save_dir = Path(save_dir)
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
samples_dir = save_dir.parent / "samples" / run_name
|
samples_dir = save_dir.parent / "samples" / run_name
|
||||||
@@ -514,8 +530,14 @@ def train_vae(
|
|||||||
)
|
)
|
||||||
|
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
|
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
|
||||||
|
fid_patience = cfg.get("fid_patience", None)
|
||||||
|
no_improve = 0
|
||||||
|
|
||||||
for epoch in range(1, epochs + 1):
|
for epoch in range(1, epochs + 1):
|
||||||
|
if max_train_secs and time.time() - t_start >= max_train_secs:
|
||||||
|
print(f"Time limit reached — stopping after epoch {epoch - 1}")
|
||||||
|
break
|
||||||
vae.train()
|
vae.train()
|
||||||
if patchgan is not None:
|
if patchgan is not None:
|
||||||
patchgan.train()
|
patchgan.train()
|
||||||
@@ -614,10 +636,13 @@ def train_vae(
|
|||||||
history["fid"][epoch] = fid_score
|
history["fid"][epoch] = fid_score
|
||||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||||
|
|
||||||
if fid_score < best_fid:
|
new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
|
||||||
best_fid = fid_score
|
if new_best < best_fid:
|
||||||
torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt")
|
torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt")
|
||||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||||
|
best_fid = new_best
|
||||||
|
if plateau:
|
||||||
|
break
|
||||||
|
|
||||||
sched_vae.step()
|
sched_vae.step()
|
||||||
if sched_d is not None:
|
if sched_d is not None:
|
||||||
@@ -636,6 +661,7 @@ def train_vae(
|
|||||||
# Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider)
|
# Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider)
|
||||||
# ────────────────────────────────────────────────────────────────────────────
|
# ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# DDPM training loop (Phase 4.1–4.4): noise_schedule="linear"/"cosine", pred_type="eps"/"v"
|
||||||
def train_ddpm(
|
def train_ddpm(
|
||||||
model,
|
model,
|
||||||
train_dataset,
|
train_dataset,
|
||||||
@@ -645,15 +671,6 @@ def train_ddpm(
|
|||||||
run_name: str,
|
run_name: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""DDPM training loop (Ho et al., 2020) covering Phase 4.1 – 4.4.
|
|
||||||
|
|
||||||
Config keys:
|
|
||||||
noise_schedule — "linear" (4.1) or "cosine" (4.2+)
|
|
||||||
pred_type — "eps" (4.1–4.2) or "v" (4.3+)
|
|
||||||
T — diffusion timesteps (default 1000)
|
|
||||||
base_ch / ch_mult / attn_resolutions — U-Net capacity (see unet.py)
|
|
||||||
ddim_steps — DDIM steps for FID evaluation (default 100)
|
|
||||||
"""
|
|
||||||
from src.training.diffusion import (
|
from src.training.diffusion import (
|
||||||
linear_betas, cosine_betas, make_alpha_bars,
|
linear_betas, cosine_betas, make_alpha_bars,
|
||||||
diffusion_loss, ddim_sample,
|
diffusion_loss, ddim_sample,
|
||||||
@@ -694,8 +711,8 @@ def train_ddpm(
|
|||||||
|
|
||||||
ema = EMA(model, decay=ema_decay)
|
ema = EMA(model, decay=ema_decay)
|
||||||
|
|
||||||
# Fixed noise for sample visualisation (same latents across epochs)
|
if hasattr(torch, "compile"):
|
||||||
fixed_noise = torch.randn(16, 3, image_size, image_size, device=device)
|
model = torch.compile(model)
|
||||||
|
|
||||||
save_dir = Path(save_dir)
|
save_dir = Path(save_dir)
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -707,18 +724,23 @@ def train_ddpm(
|
|||||||
history = {"loss": [], "fid": {}}
|
history = {"loss": [], "fid": {}}
|
||||||
best_fid = float("inf")
|
best_fid = float("inf")
|
||||||
print(
|
print(
|
||||||
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
|
f"Device: {device} AMP: {use_amp} compile: {hasattr(torch, 'compile')} Batches/epoch: {len(loader)}"
|
||||||
f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}"
|
f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Linear LR decay from epoch epochs//2 to epochs
|
# Linear LR decay from epoch epochs//2 to epochs
|
||||||
decay_start = epochs // 2
|
decay_start = epochs // 2
|
||||||
sched = torch.optim.lr_scheduler.LambdaLR(
|
sched = torch.optim.lr_scheduler.LambdaLR(opt, _linear_decay(decay_start, epochs))
|
||||||
opt, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
|
||||||
|
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
|
max_train_secs = cfg.get("max_train_hours", 0) * 3600 or None
|
||||||
|
fid_patience = cfg.get("fid_patience", None)
|
||||||
|
no_improve = 0
|
||||||
|
|
||||||
for epoch in range(1, epochs + 1):
|
for epoch in range(1, epochs + 1):
|
||||||
|
if max_train_secs and time.time() - t_start >= max_train_secs:
|
||||||
|
print(f"Time limit reached — stopping after epoch {epoch - 1}")
|
||||||
|
break
|
||||||
model.train()
|
model.train()
|
||||||
loss_sum = 0.0
|
loss_sum = 0.0
|
||||||
n_batches = 0
|
n_batches = 0
|
||||||
@@ -749,7 +771,6 @@ def train_ddpm(
|
|||||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||||
ema.model.eval()
|
ema.model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Quick visualisation: denoise fixed_noise via DDIM
|
|
||||||
imgs = ddim_sample(
|
imgs = ddim_sample(
|
||||||
ema.model, 16, image_size, alpha_bars,
|
ema.model, 16, image_size, alpha_bars,
|
||||||
n_steps=50, pred_type=pred_type, device=str(device), batch_size=16,
|
n_steps=50, pred_type=pred_type, device=str(device), batch_size=16,
|
||||||
@@ -768,10 +789,13 @@ def train_ddpm(
|
|||||||
history["fid"][epoch] = fid_score
|
history["fid"][epoch] = fid_score
|
||||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||||
|
|
||||||
if fid_score < best_fid:
|
new_best, no_improve, plateau = _check_fid_plateau(fid_score, best_fid, no_improve, fid_patience)
|
||||||
best_fid = fid_score
|
if new_best < best_fid:
|
||||||
torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt")
|
torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt")
|
||||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||||
|
best_fid = new_best
|
||||||
|
if plateau:
|
||||||
|
break
|
||||||
|
|
||||||
sched.step()
|
sched.step()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user