Preview of phase 2-5 implementation; needs a full check
This commit is contained in:
@@ -51,17 +51,20 @@ class GeneratorDataset(Dataset):
|
||||
return img
|
||||
|
||||
|
||||
def get_transform(image_size: int, augment: bool = False) -> T.Compose:
|
||||
def get_transform(image_size: int, augment=False) -> T.Compose:
|
||||
"""Build transform for generator training. Output is in [-1, 1].
|
||||
|
||||
augment=True adds horizontal flip + mild rotation + mild color jitter.
|
||||
Use augment=False for validation / FID real-image sets.
|
||||
augment=False — no augmentation (for FID real-image sets)
|
||||
augment="hflip" — horizontal flip only (recommended for VAE/DDPM)
|
||||
augment=True — H-flip + rotation ±5° + mild color jitter (for GAN)
|
||||
"""
|
||||
ops = [
|
||||
T.Resize(image_size),
|
||||
T.CenterCrop(image_size),
|
||||
]
|
||||
if augment:
|
||||
if augment == "hflip":
|
||||
ops.append(T.RandomHorizontalFlip(p=0.5))
|
||||
elif augment:
|
||||
ops += [
|
||||
T.RandomHorizontalFlip(p=0.5),
|
||||
T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR),
|
||||
|
||||
@@ -23,4 +23,7 @@ def get_model(cfg: dict) -> tuple:
|
||||
|
||||
|
||||
from src.models import dcgan # noqa: E402, F401
|
||||
from src.models import wgan # noqa: E402, F401
|
||||
from src.models import vae # noqa: E402, F401
|
||||
from src.models import unet # noqa: E402, F401
|
||||
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
"""PatchGAN discriminator for Phase 3.3 (VQGAN-lite adversarial training).
|
||||
|
||||
Outputs a spatial patch map instead of a single scalar — each patch
|
||||
predicts real/fake independently. Loss is the mean over all patches.
|
||||
|
||||
Not registered in the model registry; instantiated inside train_vae
|
||||
when lambda_adversarial > 0.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.normal_(m.weight, 0.0, 0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
class PatchGANDiscriminator(nn.Module):
|
||||
"""Stride-2 + stride-1 convolution chain → spatial patch logit map.
|
||||
|
||||
Supports image_size ∈ {64, 128}. For 64×64 input the final map is 6×6
|
||||
(70×70 receptive field). For 128×128 an extra stride-2 layer is added.
|
||||
InstanceNorm everywhere except the first layer.
|
||||
"""
|
||||
|
||||
def __init__(self, ndf: int = 64, image_size: int = 64):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = [
|
||||
# First layer: no norm
|
||||
nn.Conv2d(3, ndf, 4, stride=2, padding=1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
if image_size >= 128:
|
||||
layers += [
|
||||
nn.Conv2d(ndf, ndf, 4, stride=2, padding=1, bias=False),
|
||||
nn.InstanceNorm2d(ndf, affine=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
layers += [
|
||||
nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1, bias=False),
|
||||
nn.InstanceNorm2d(ndf * 2, affine=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1, bias=False),
|
||||
nn.InstanceNorm2d(ndf * 4, affine=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(ndf * 4, ndf * 8, 4, stride=1, padding=1, bias=False),
|
||||
nn.InstanceNorm2d(ndf * 8, affine=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(ndf * 8, 1, 4, stride=1, padding=1, bias=True),
|
||||
]
|
||||
self.net = nn.Sequential(*layers)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x) # (B, 1, H', W') — patch logit map
|
||||
|
||||
|
||||
def hinge_d_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Hinge loss for the discriminator (Lim & Ye, 2017)."""
|
||||
loss_real = torch.mean(torch.relu(1.0 - real_logits))
|
||||
loss_fake = torch.mean(torch.relu(1.0 + fake_logits))
|
||||
return 0.5 * (loss_real + loss_fake)
|
||||
|
||||
|
||||
def hinge_g_loss(fake_logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Generator hinge loss — maximise D(fake)."""
|
||||
return -torch.mean(fake_logits)
|
||||
@@ -0,0 +1,279 @@
|
||||
"""Time-conditioned U-Net for DDPM (Phase 4).
|
||||
|
||||
Architecture follows Ho et al. (2020) with options from Nichol & Dhariwal (2021):
|
||||
- Sinusoidal time embedding → MLP → added to every ResBlock
|
||||
- GroupNorm (32 groups) + SiLU activations throughout
|
||||
- Self-attention at configurable spatial resolutions
|
||||
- Upsample(nearest) + Conv in the decoder — no checkerboard artefacts
|
||||
|
||||
Registered as kind="ddpm".
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from src.models import register
|
||||
|
||||
_GN = 32 # GroupNorm groups — all channel counts used here are multiples of 32
|
||||
|
||||
|
||||
# ── Time embedding ────────────────────────────────────────────────────────────
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||
half = self.dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(10000) * torch.arange(half, device=t.device, dtype=torch.float) / half
|
||||
)
|
||||
angles = t[:, None].float() * freqs[None] # (B, half)
|
||||
return torch.cat([angles.sin(), angles.cos()], dim=-1) # (B, dim)
|
||||
|
||||
|
||||
# ── Core building blocks ──────────────────────────────────────────────────────
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""ResNet block with time-embedding injection (additive, after first conv)."""
|
||||
|
||||
def __init__(self, in_ch: int, out_ch: int, t_emb_dim: int, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.norm1 = nn.GroupNorm(_GN, in_ch)
|
||||
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
|
||||
self.t_proj = nn.Linear(t_emb_dim, out_ch)
|
||||
self.norm2 = nn.GroupNorm(_GN, out_ch)
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
|
||||
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
|
||||
h = self.conv1(F.silu(self.norm1(x)))
|
||||
h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
|
||||
h = self.conv2(self.drop(F.silu(self.norm2(h))))
|
||||
return h + self.skip(x)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Single-head self-attention with GroupNorm pre-norm and residual."""
|
||||
|
||||
def __init__(self, ch: int):
|
||||
super().__init__()
|
||||
self.norm = nn.GroupNorm(_GN, ch)
|
||||
self.qkv = nn.Conv2d(ch, ch * 3, 1, bias=False)
|
||||
self.proj = nn.Conv2d(ch, ch, 1)
|
||||
self._scale = ch ** -0.5
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, c, h, w = x.shape
|
||||
n = h * w
|
||||
qkv = self.norm(x)
|
||||
q, k, v = self.qkv(qkv).reshape(b, 3, c, n).unbind(1) # each (b, c, n)
|
||||
attn = torch.softmax(q.transpose(-2, -1) @ k * self._scale, dim=-1) # (b, n, n)
|
||||
out = (v @ attn.transpose(-2, -1)).reshape(b, c, h, w)
|
||||
return x + self.proj(out)
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, ch: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(ch, ch, 4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, ch: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(ch, ch, 3, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(F.interpolate(x, scale_factor=2, mode="nearest"))
|
||||
|
||||
|
||||
# ── Down / Up blocks ──────────────────────────────────────────────────────────
|
||||
|
||||
class DownBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_ch: int,
|
||||
out_ch: int,
|
||||
t_emb_dim: int,
|
||||
num_res_blocks: int,
|
||||
with_attn: bool,
|
||||
dropout: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList(
|
||||
ResBlock(in_ch if j == 0 else out_ch, out_ch, t_emb_dim, dropout)
|
||||
for j in range(num_res_blocks)
|
||||
)
|
||||
self.attn = AttentionBlock(out_ch) if with_attn else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
|
||||
for res in self.resnets:
|
||||
x = res(x, t_emb)
|
||||
return self.attn(x)
|
||||
|
||||
|
||||
class UpBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_ch: int,
|
||||
skip_ch: int,
|
||||
out_ch: int,
|
||||
t_emb_dim: int,
|
||||
num_res_blocks: int,
|
||||
with_attn: bool,
|
||||
dropout: float,
|
||||
):
|
||||
super().__init__()
|
||||
# First ResBlock absorbs the skip-connection channels via concat
|
||||
self.resnets = nn.ModuleList(
|
||||
ResBlock(
|
||||
(in_ch + skip_ch) if j == 0 else out_ch,
|
||||
out_ch,
|
||||
t_emb_dim,
|
||||
dropout,
|
||||
)
|
||||
for j in range(num_res_blocks + 1) # +1 to consume the concat
|
||||
)
|
||||
self.attn = AttentionBlock(out_ch) if with_attn else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, skip: torch.Tensor, t_emb: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
for res in self.resnets:
|
||||
x = res(x, t_emb)
|
||||
return self.attn(x)
|
||||
|
||||
|
||||
# ── U-Net ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
class UNet(nn.Module):
|
||||
"""Time-conditioned U-Net.
|
||||
|
||||
image_size — must be a power-of-two; 64 or 128 recommended.
|
||||
base_ch — base channel count (128 for phases 4.1–4.3, 192 for 4.4).
|
||||
ch_mult — channel multipliers per resolution level.
|
||||
attn_resolutions — spatial resolutions at which attention is inserted.
|
||||
num_res_blocks — ResBlocks per level (in both down and up paths).
|
||||
dropout — applied inside every ResBlock.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int = 64,
|
||||
base_ch: int = 128,
|
||||
ch_mult: tuple = (1, 2, 2, 2),
|
||||
attn_resolutions: tuple = (16, 8),
|
||||
num_res_blocks: int = 2,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
n_levels = len(ch_mult)
|
||||
chs = [base_ch * m for m in ch_mult]
|
||||
t_emb_dim = base_ch * 4
|
||||
|
||||
# ── Time embedding ────────────────────────────────────────────────
|
||||
self.time_embed = nn.Sequential(
|
||||
SinusoidalPosEmb(base_ch),
|
||||
nn.Linear(base_ch, t_emb_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(t_emb_dim, t_emb_dim),
|
||||
)
|
||||
|
||||
# ── Input projection ──────────────────────────────────────────────
|
||||
self.in_conv = nn.Conv2d(3, chs[0], 3, padding=1)
|
||||
|
||||
# ── Down path ─────────────────────────────────────────────────────
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.downsamples = nn.ModuleList()
|
||||
cur_res = image_size
|
||||
prev_ch = chs[0]
|
||||
for i, ch in enumerate(chs):
|
||||
with_attn = (cur_res in attn_resolutions)
|
||||
self.down_blocks.append(
|
||||
DownBlock(prev_ch, ch, t_emb_dim, num_res_blocks, with_attn, dropout)
|
||||
)
|
||||
if i < n_levels - 1:
|
||||
self.downsamples.append(Downsample(ch))
|
||||
cur_res //= 2
|
||||
prev_ch = ch
|
||||
|
||||
# ── Middle ────────────────────────────────────────────────────────
|
||||
mid_ch = chs[-1]
|
||||
self.mid_res1 = ResBlock(mid_ch, mid_ch, t_emb_dim, dropout)
|
||||
self.mid_attn = AttentionBlock(mid_ch)
|
||||
self.mid_res2 = ResBlock(mid_ch, mid_ch, t_emb_dim, dropout)
|
||||
|
||||
# ── Up path ───────────────────────────────────────────────────────
|
||||
# Mirrors down path: iterate chs in reverse; skip_ch = chs[n_levels-1-i].
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.upsamples = nn.ModuleList()
|
||||
in_ch = mid_ch
|
||||
for i in range(n_levels):
|
||||
level = n_levels - 1 - i # index from deep (n-1) to shallow (0)
|
||||
skip_ch = chs[level]
|
||||
out_ch = chs[level - 1] if level > 0 else chs[0]
|
||||
with_attn = (cur_res in attn_resolutions)
|
||||
self.up_blocks.append(
|
||||
UpBlock(in_ch, skip_ch, out_ch, t_emb_dim, num_res_blocks, with_attn, dropout)
|
||||
)
|
||||
if level > 0:
|
||||
self.upsamples.append(Upsample(out_ch))
|
||||
cur_res *= 2
|
||||
in_ch = out_ch
|
||||
|
||||
# ── Output ────────────────────────────────────────────────────────
|
||||
self.out_norm = nn.GroupNorm(_GN, chs[0])
|
||||
self.out_conv = nn.Conv2d(chs[0], 3, 3, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
t_emb = self.time_embed(t)
|
||||
|
||||
x = self.in_conv(x)
|
||||
|
||||
# Down
|
||||
skips = []
|
||||
ds_idx = 0
|
||||
for i, block in enumerate(self.down_blocks):
|
||||
x = block(x, t_emb)
|
||||
skips.append(x)
|
||||
if ds_idx < len(self.downsamples):
|
||||
x = self.downsamples[ds_idx](x)
|
||||
ds_idx += 1
|
||||
|
||||
# Middle
|
||||
x = self.mid_res1(x, t_emb)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_res2(x, t_emb)
|
||||
|
||||
# Up
|
||||
us_idx = 0
|
||||
for i, block in enumerate(self.up_blocks):
|
||||
x = block(x, skips[-(i + 1)], t_emb)
|
||||
if us_idx < len(self.upsamples):
|
||||
x = self.upsamples[us_idx](x)
|
||||
us_idx += 1
|
||||
|
||||
return self.out_conv(F.silu(self.out_norm(x)))
|
||||
|
||||
|
||||
def _build(cfg: dict):
|
||||
return UNet(
|
||||
image_size = cfg.get("image_size", 64),
|
||||
base_ch = cfg.get("base_ch", 128),
|
||||
ch_mult = tuple(cfg.get("ch_mult", [1, 2, 2, 2])),
|
||||
attn_resolutions = tuple(cfg.get("attn_resolutions", [16, 8])),
|
||||
num_res_blocks = cfg.get("num_res_blocks", 2),
|
||||
dropout = cfg.get("dropout", 0.1),
|
||||
)
|
||||
|
||||
|
||||
register("ddpm", _build, kind="ddpm")
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Convolutional VAE for Phase 3.
|
||||
|
||||
Encoder uses stride-2 Conv → flatten → linear (μ, log σ²).
|
||||
Decoder uses Linear → Upsample(nearest) + Conv to avoid ConvTranspose2d
|
||||
checkerboard artefacts.
|
||||
|
||||
Registered as kind="vae". The run.py dispatcher passes the model to
|
||||
train_vae(), which internally builds perceptual loss and PatchGAN when
|
||||
the corresponding lambdas are non-zero.
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from src.models import register
|
||||
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
|
||||
nn.init.normal_(m.weight, 0.0, 0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d) and m.weight is not None:
|
||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
|
||||
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode="nearest"),
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
|
||||
class VAE(nn.Module):
|
||||
"""Convolutional VAE. image_size must be a power-of-two ≥ 32.
|
||||
|
||||
Spatial bottleneck is always at 4×4 regardless of image_size —
|
||||
the encoder and decoder scale the number of stride-2 steps accordingly.
|
||||
"""
|
||||
|
||||
def __init__(self, latent_dim: int = 256, ngf: int = 64, image_size: int = 64):
|
||||
super().__init__()
|
||||
if image_size < 32 or (image_size & (image_size - 1)):
|
||||
raise ValueError(f"image_size must be a power-of-two ≥ 32, got {image_size}")
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.image_size = image_size
|
||||
|
||||
n_down = int(math.log2(image_size)) - 2 # steps from image_size to 4×4
|
||||
# 64 → n_down=4: 64→32→16→8→4
|
||||
# 128 → n_down=5: 128→64→32→16→8→4
|
||||
|
||||
# ── Encoder ──────────────────────────────────────────────────────────
|
||||
enc_layers: list[nn.Module] = [
|
||||
nn.Conv2d(3, ngf, 4, stride=2, padding=1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
ch = ngf
|
||||
for _ in range(n_down - 1):
|
||||
enc_layers += [
|
||||
nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ch * 2),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
ch *= 2
|
||||
# ch = ngf * 2^(n_down-1); spatial = 4×4
|
||||
self.encoder = nn.Sequential(*enc_layers)
|
||||
|
||||
flat = ch * 4 * 4
|
||||
self.fc_mu = nn.Linear(flat, latent_dim)
|
||||
self.fc_lv = nn.Linear(flat, latent_dim)
|
||||
|
||||
# ── Decoder ──────────────────────────────────────────────────────────
|
||||
self.fc_dec = nn.Linear(latent_dim, flat)
|
||||
self._dec_ch = ch # channels at the 4×4 bottleneck
|
||||
|
||||
dec_layers: list[nn.Module] = []
|
||||
for _ in range(n_down - 1):
|
||||
dec_layers.append(_upsample_block(ch, ch // 2))
|
||||
ch //= 2
|
||||
# Final upsample to image_size, output 3 channels, no BN, Tanh
|
||||
dec_layers += [
|
||||
nn.Upsample(scale_factor=2, mode="nearest"),
|
||||
nn.Conv2d(ch, 3, 3, padding=1, bias=True),
|
||||
nn.Tanh(),
|
||||
]
|
||||
self.decoder = nn.Sequential(*dec_layers)
|
||||
|
||||
self.apply(_init_weights)
|
||||
|
||||
# ── Interface ────────────────────────────────────────────────────────────
|
||||
|
||||
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
h = self.encoder(x).flatten(1)
|
||||
return self.fc_mu(h), self.fc_lv(h)
|
||||
|
||||
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
|
||||
std = torch.exp(0.5 * log_var)
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
h = self.fc_dec(z).view(z.size(0), self._dec_ch, 4, 4)
|
||||
return self.decoder(h)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Returns (reconstruction, mu, log_var)."""
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
return self.decode(z), mu, log_var
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, n: int, device) -> torch.Tensor:
|
||||
"""Sample n images by drawing z ~ N(0, I)."""
|
||||
z = torch.randn(n, self.latent_dim, device=device)
|
||||
return self.decode(z)
|
||||
|
||||
|
||||
def _build(cfg: dict):
|
||||
return VAE(
|
||||
latent_dim=cfg.get("latent_dim", 256),
|
||||
ngf=cfg.get("ngf", 64),
|
||||
image_size=cfg.get("image_size", 64),
|
||||
)
|
||||
|
||||
|
||||
register("vae", _build, kind="vae")
|
||||
+191
-63
@@ -1,26 +1,31 @@
|
||||
"""WGAN-GP with spectral normalization, self-attention, and GroupNorm.
|
||||
"""WGAN-GP variants.
|
||||
|
||||
Improvements over the original:
|
||||
- Generator: BatchNorm -> GroupNorm (no batch-size coupling, stable with varied content)
|
||||
- Critic: InstanceNorm -> spectral normalization (principled Lipschitz constraint)
|
||||
- Both: one SAGAN-style self-attention block at the 32x32 feature map
|
||||
- Larger capacity: ngf=128, ndf=128
|
||||
wgan_basic — Phase 2.2: BatchNorm/InstanceNorm, no attention, 64×64 only.
|
||||
wgan — Phase 2.3/2.4: GroupNorm/SpectralNorm + self-attention, size-agnostic.
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from src.models import register
|
||||
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
nn.init.normal_(m.weight, 0.0, 0.02)
|
||||
elif isinstance(m, nn.GroupNorm) and m.weight is not None:
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and m.weight is not None:
|
||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
def _sn(module):
|
||||
return nn.utils.spectral_norm(module)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""SAGAN-style self-attention."""
|
||||
|
||||
def __init__(self, in_ch: int):
|
||||
super().__init__()
|
||||
mid = max(in_ch // 8, 1)
|
||||
@@ -36,98 +41,221 @@ class SelfAttention(nn.Module):
|
||||
k = self.k(x).view(b, self._mid, -1) # (b, mid, hw)
|
||||
v = self.v(x).view(b, c, -1) # (b, c, hw)
|
||||
attn = torch.softmax(q @ k * self._mid ** -0.5, dim=-1) # (b, hw, hw)
|
||||
out = (v @ attn.transpose(-2, -1)).view(b, c, h, w)
|
||||
return x + self.gamma * out
|
||||
return x + self.gamma * (v @ attn.transpose(-2, -1)).view(b, c, h, w)
|
||||
|
||||
|
||||
def _sn(module):
|
||||
"""Apply spectral normalization to a conv layer."""
|
||||
return nn.utils.spectral_norm(module)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2.2 — basic WGAN-GP (BatchNorm in G, InstanceNorm in D, 64×64 only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WGANBasicGenerator(nn.Module):
|
||||
"""Maps (latent_dim, 1, 1) -> (3, 64, 64) in [-1, 1].
|
||||
|
||||
class WGANGenerator(nn.Module):
|
||||
"""Maps (latent_dim x 1 x 1) -> (3 x 128 x 128) in [-1, 1].
|
||||
|
||||
Upsampling path: 1 -> 4 -> 8 -> 16 (+attn) -> 32 -> 64 -> 128
|
||||
Self-attention sits at 16x16 (attention matrix 256x256 vs 1024x1024 at 32x32).
|
||||
Same channel structure as DCGAN. BatchNorm in generator is fine because
|
||||
WGAN-GP's constraint targets the critic, not the generator.
|
||||
"""
|
||||
|
||||
def __init__(self, latent_dim: int = 128, ngf: int = 64):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
# 1x1 -> 4x4
|
||||
# 1×1 → 4×4
|
||||
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
|
||||
nn.GroupNorm(8, ngf * 8), nn.ReLU(True),
|
||||
# 4x4 -> 8x8
|
||||
nn.BatchNorm2d(ngf * 8), nn.ReLU(True),
|
||||
# 4×4 → 8×8
|
||||
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf * 4), nn.ReLU(True),
|
||||
# 8x8 -> 16x16
|
||||
nn.BatchNorm2d(ngf * 4), nn.ReLU(True),
|
||||
# 8×8 → 16×16
|
||||
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf * 2), nn.ReLU(True),
|
||||
)
|
||||
self.attn = SelfAttention(ngf * 2) # applied at 16x16
|
||||
self.out = nn.Sequential(
|
||||
# 16x16 -> 32x32
|
||||
nn.BatchNorm2d(ngf * 2), nn.ReLU(True),
|
||||
# 16×16 → 32×32
|
||||
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf), nn.ReLU(True),
|
||||
# 32x32 -> 64x64
|
||||
nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf // 2), nn.ReLU(True),
|
||||
# 64x64 -> 128x128
|
||||
nn.ConvTranspose2d(ngf // 2, 3, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf), nn.ReLU(True),
|
||||
# 32×32 → 64×64
|
||||
nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
||||
h = self.net(z)
|
||||
h = self.attn(h)
|
||||
return self.out(h)
|
||||
return self.net(z)
|
||||
|
||||
|
||||
class WGANCritic(nn.Module):
|
||||
"""Critic (no sigmoid) for WGAN-GP. All conv layers are spectrally normalized.
|
||||
|
||||
Downsampling path: 128 -> 64 -> 32 -> 16 (+attn) -> 8 -> 4 -> score
|
||||
class WGANBasicCritic(nn.Module):
|
||||
"""WGAN-GP critic (64×64). InstanceNorm instead of BatchNorm — BatchNorm
|
||||
breaks the per-sample Lipschitz constraint the gradient penalty enforces.
|
||||
"""
|
||||
|
||||
def __init__(self, ndf: int = 64):
|
||||
super().__init__()
|
||||
self.down = nn.Sequential(
|
||||
# 128x128 -> 64x64 (no norm on first layer)
|
||||
_sn(nn.Conv2d(3, ndf // 2, 4, 2, 1, bias=False)),
|
||||
self.net = nn.Sequential(
|
||||
# 64×64 → 32×32 (no norm on first layer)
|
||||
nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 64x64 -> 32x32
|
||||
_sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)),
|
||||
# 32×32 → 16×16
|
||||
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
|
||||
nn.InstanceNorm2d(ndf * 2, affine=True),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 32x32 -> 16x16
|
||||
_sn(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)),
|
||||
# 16×16 → 8×8
|
||||
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
||||
nn.InstanceNorm2d(ndf * 4, affine=True),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
self.attn = SelfAttention(ndf * 2) # applied at 16x16
|
||||
self.tail = nn.Sequential(
|
||||
# 16x16 -> 8x8
|
||||
_sn(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)),
|
||||
# 8×8 → 4×4
|
||||
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
|
||||
nn.InstanceNorm2d(ndf * 8, affine=True),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 8x8 -> 4x4
|
||||
_sn(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 4x4 -> 1x1
|
||||
_sn(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)),
|
||||
# 4×4 → 1×1 (score, no sigmoid)
|
||||
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
|
||||
)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.down(x)
|
||||
h = self.attn(h)
|
||||
return self.net(x).view(x.size(0))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2.3 / 2.4 — advanced WGAN-GP (GroupNorm, SpectralNorm, attention)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WGANGenerator(nn.Module):
|
||||
"""GroupNorm generator with SAGAN self-attention.
|
||||
|
||||
Supports image_size ∈ {64, 128}.
|
||||
Stem is always 1×1 → 4×4 → 8×8 → 16×16 (ngf×8 → ngf×4 → ngf×2 channels).
|
||||
Attention at 16×16 always; additional attention at 32×32 for 128×128.
|
||||
"""
|
||||
|
||||
def __init__(self, latent_dim: int = 128, ngf: int = 128, image_size: int = 64):
|
||||
super().__init__()
|
||||
if image_size not in (64, 128):
|
||||
raise ValueError(f"WGANGenerator supports image_size 64 or 128, got {image_size}")
|
||||
|
||||
self._image_size = image_size
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
# 1×1 → 4×4
|
||||
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
|
||||
nn.GroupNorm(8, ngf * 8), nn.ReLU(True),
|
||||
# 4×4 → 8×8
|
||||
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf * 4), nn.ReLU(True),
|
||||
# 8×8 → 16×16
|
||||
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf * 2), nn.ReLU(True),
|
||||
) # output: (ngf×2, 16, 16)
|
||||
self.attn16 = SelfAttention(ngf * 2)
|
||||
|
||||
if image_size == 64:
|
||||
self.mid = None
|
||||
self.attn32 = None
|
||||
self.tail = nn.Sequential(
|
||||
# 16×16 → 32×32
|
||||
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf), nn.ReLU(True),
|
||||
# 32×32 → 64×64
|
||||
nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
|
||||
nn.Tanh(),
|
||||
)
|
||||
else: # 128
|
||||
self.mid = nn.Sequential(
|
||||
# 16×16 → 32×32
|
||||
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf), nn.ReLU(True),
|
||||
)
|
||||
self.attn32 = SelfAttention(ngf)
|
||||
self.tail = nn.Sequential(
|
||||
# 32×32 → 64×64
|
||||
nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf // 2), nn.ReLU(True),
|
||||
# 64×64 → 128×128
|
||||
nn.ConvTranspose2d(ngf // 2, 3, 4, 2, 1, bias=False),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
||||
h = self.attn16(self.stem(z))
|
||||
if self.mid is not None:
|
||||
h = self.attn32(self.mid(h))
|
||||
return self.tail(h)
|
||||
|
||||
|
||||
class WGANCritic(nn.Module):
|
||||
"""SpectralNorm critic with SAGAN self-attention.
|
||||
|
||||
Supports image_size ∈ {64, 128}.
|
||||
Attention at 16×16 always; additional attention at 32×32 for 128×128.
|
||||
"""
|
||||
|
||||
def __init__(self, ndf: int = 128, image_size: int = 64):
|
||||
super().__init__()
|
||||
if image_size not in (64, 128):
|
||||
raise ValueError(f"WGANCritic supports image_size 64 or 128, got {image_size}")
|
||||
|
||||
self._image_size = image_size
|
||||
|
||||
if image_size == 64:
|
||||
# Head: 64→32 (ndf//2)
|
||||
self.head = nn.Sequential(
|
||||
_sn(nn.Conv2d(3, ndf // 2, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
self.attn32 = None
|
||||
# 32→16 (ndf)
|
||||
self.mid = nn.Sequential(
|
||||
_sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
attn_ch = ndf
|
||||
else: # 128
|
||||
# Head: 128→64 (ndf//4), 64→32 (ndf//2)
|
||||
self.head = nn.Sequential(
|
||||
_sn(nn.Conv2d(3, ndf // 4, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
_sn(nn.Conv2d(ndf // 4, ndf // 2, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
self.attn32 = SelfAttention(ndf // 2)
|
||||
# 32→16 (ndf)
|
||||
self.mid = nn.Sequential(
|
||||
_sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
attn_ch = ndf
|
||||
|
||||
self.attn16 = SelfAttention(attn_ch)
|
||||
|
||||
# Tail: 16×16 → 8×8 → 4×4 → score
|
||||
self.tail = nn.Sequential(
|
||||
_sn(nn.Conv2d(attn_ch, attn_ch * 2, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
_sn(nn.Conv2d(attn_ch * 2, attn_ch * 4, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
_sn(nn.Conv2d(attn_ch * 4, 1, 4, 1, 0, bias=False)),
|
||||
)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.head(x)
|
||||
if self.attn32 is not None:
|
||||
h = self.attn32(h)
|
||||
h = self.attn16(self.mid(h))
|
||||
return self.tail(h).view(x.size(0))
|
||||
|
||||
|
||||
def _build(cfg: dict):
|
||||
def _build_basic(cfg: dict):
|
||||
return (
|
||||
WGANGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 128)),
|
||||
WGANCritic(ndf=cfg.get("ndf", 128)),
|
||||
WGANBasicGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 64)),
|
||||
WGANBasicCritic(ndf=cfg.get("ndf", 64)),
|
||||
)
|
||||
|
||||
|
||||
def _build(cfg: dict):
|
||||
image_size = cfg.get("image_size", 64)
|
||||
return (
|
||||
WGANGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 128), image_size=image_size),
|
||||
WGANCritic(ndf=cfg.get("ndf", 128), image_size=image_size),
|
||||
)
|
||||
|
||||
|
||||
register("wgan_basic", _build_basic, kind="wgan")
|
||||
register("wgan", _build, kind="wgan")
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from src.training.trainer import train_dcgan
|
||||
from src.training.trainer import train_dcgan, train_wgan, train_vae, train_ddpm
|
||||
|
||||
__all__ = ["train_dcgan"]
|
||||
__all__ = ["train_dcgan", "train_wgan", "train_vae", "train_ddpm"]
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Gaussian diffusion utilities for Phase 4 (DDPM).
|
||||
|
||||
Provides noise schedules, the forward (noising) process, training loss,
|
||||
and DDIM deterministic sampling (Song et al., 2020).
|
||||
|
||||
Convention: alpha_bars is a 1-D tensor of length T, where alpha_bars[t]
|
||||
= ᾱ_{t+1} in 1-indexed notation. Timestep t used in the training loop
|
||||
is a 0-indexed integer in [0, T). At t=0 the image is almost clean
|
||||
(ᾱ ≈ 1 − β_1); at t=T−1 the image is almost pure noise (ᾱ ≈ 0).
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ── Noise schedules ──────────────────────────────────────────────────────────
|
||||
|
||||
def linear_betas(T: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor:
|
||||
"""Ho et al. (2020) linear schedule."""
|
||||
return torch.linspace(beta_start, beta_end, T)
|
||||
|
||||
|
||||
def cosine_betas(T: int, s: float = 0.008) -> torch.Tensor:
|
||||
"""Nichol & Dhariwal (2021) cosine schedule — avoids over-denoising at low t."""
|
||||
t = torch.linspace(0, T, T + 1)
|
||||
f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
|
||||
alpha_bar = f / f[0]
|
||||
betas = 1 - alpha_bar[1:] / alpha_bar[:-1]
|
||||
return betas.clamp(max=0.999)
|
||||
|
||||
|
||||
def make_alpha_bars(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""Cumulative product of (1 − β), shape (T,)."""
|
||||
return (1.0 - betas).cumprod(0)
|
||||
|
||||
|
||||
# ── Forward process ──────────────────────────────────────────────────────────
|
||||
|
||||
def q_sample(
|
||||
x0: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
alpha_bars: torch.Tensor,
|
||||
noise: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Add noise to x0 at timestep t. Returns (x_t, noise)."""
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
ab = alpha_bars[t].to(x0.device)[:, None, None, None]
|
||||
x_t = ab.sqrt() * x0 + (1 - ab).sqrt() * noise
|
||||
return x_t, noise
|
||||
|
||||
|
||||
# ── Training loss ────────────────────────────────────────────────────────────
|
||||
|
||||
def diffusion_loss(
|
||||
model,
|
||||
x0: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
alpha_bars: torch.Tensor,
|
||||
pred_type: str = "eps",
|
||||
) -> torch.Tensor:
|
||||
"""MSE on the model's prediction vs the true target.
|
||||
|
||||
pred_type="eps" → target is the added noise ε (Ho et al.)
|
||||
pred_type="v" → target is v = √ᾱ·ε − √(1−ᾱ)·x0 (Salimans & Ho)
|
||||
"""
|
||||
x_t, noise = q_sample(x0, t, alpha_bars)
|
||||
pred = model(x_t, t)
|
||||
|
||||
if pred_type == "eps":
|
||||
target = noise
|
||||
else: # v
|
||||
ab = alpha_bars[t].to(x0.device)[:, None, None, None]
|
||||
target = ab.sqrt() * noise - (1 - ab).sqrt() * x0
|
||||
|
||||
return F.mse_loss(pred, target)
|
||||
|
||||
|
||||
# ── DDIM deterministic sampling ───────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample(
|
||||
model,
|
||||
n: int,
|
||||
image_size: int,
|
||||
alpha_bars: torch.Tensor,
|
||||
n_steps: int = 100,
|
||||
pred_type: str = "eps",
|
||||
device: str = "cuda",
|
||||
batch_size: int = 32,
|
||||
) -> torch.Tensor:
|
||||
"""Generate n images via DDIM (eta=0, deterministic).
|
||||
|
||||
Batches internally to avoid OOM when n is large.
|
||||
Returns tensor shape (n, 3, image_size, image_size) in [-1, 1].
|
||||
"""
|
||||
model.eval()
|
||||
T = len(alpha_bars)
|
||||
|
||||
# Build reversed subsequence: [T-1, T-1-step, ..., 0]
|
||||
step = max(T // n_steps, 1)
|
||||
ts = list(range(T - 1, -1, -step))[:n_steps]
|
||||
if ts[-1] != 0:
|
||||
ts.append(0)
|
||||
|
||||
results = []
|
||||
remaining = n
|
||||
while remaining > 0:
|
||||
bsz = min(batch_size, remaining)
|
||||
x = torch.randn(bsz, 3, image_size, image_size, device=device)
|
||||
|
||||
for i, t_cur in enumerate(ts):
|
||||
t_prev = ts[i + 1] if i + 1 < len(ts) else -1
|
||||
|
||||
t_batch = torch.full((bsz,), t_cur, device=device, dtype=torch.long)
|
||||
ab_t = alpha_bars[t_cur].to(device)
|
||||
ab_prev = alpha_bars[t_prev].to(device) if t_prev >= 0 else torch.ones(1, device=device)
|
||||
|
||||
pred = model(x, t_batch)
|
||||
|
||||
# Reconstruct x0 from prediction
|
||||
if pred_type == "eps":
|
||||
x0_hat = (x - (1 - ab_t).sqrt() * pred) / ab_t.sqrt()
|
||||
else: # v
|
||||
x0_hat = ab_t.sqrt() * x - (1 - ab_t).sqrt() * pred
|
||||
x0_hat = x0_hat.clamp(-1, 1)
|
||||
|
||||
# DDIM step
|
||||
eps_hat = (x - ab_t.sqrt() * x0_hat) / (1 - ab_t).sqrt()
|
||||
x = ab_prev.sqrt() * x0_hat + (1 - ab_prev).sqrt() * eps_hat
|
||||
|
||||
results.append(x.cpu())
|
||||
remaining -= bsz
|
||||
|
||||
return torch.cat(results)[:n]
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Extended generation quality metrics for Phase 5 cross-family comparison.
|
||||
|
||||
IS — Inception Score (Salimans et al., 2016): measures sample quality × diversity.
|
||||
LPIPS — average pairwise learned perceptual distance: measures sample diversity alone.
|
||||
|
||||
Both functions accept float tensors in [-1, 1].
|
||||
"""
|
||||
import torch
|
||||
from torchmetrics.image.inception import InceptionScore
|
||||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||
|
||||
|
||||
def compute_is(
|
||||
imgs: torch.Tensor,
|
||||
device: str = "cuda",
|
||||
batch_size: int = 64,
|
||||
) -> tuple[float, float]:
|
||||
"""Inception Score (mean ± std) over 10 splits.
|
||||
|
||||
imgs: (N, 3, H, W) in [-1, 1]. N ≥ 2 048 for a reliable estimate.
|
||||
Returns (is_mean, is_std).
|
||||
"""
|
||||
metric = InceptionScore(normalize=True).to(device)
|
||||
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
|
||||
for i in range(0, len(imgs_01), batch_size):
|
||||
metric.update(imgs_01[i : i + batch_size].to(device))
|
||||
mean, std = metric.compute()
|
||||
return float(mean), float(std)
|
||||
|
||||
|
||||
def compute_lpips_diversity(
|
||||
imgs: torch.Tensor,
|
||||
n_pairs: int = 200,
|
||||
device: str = "cuda",
|
||||
batch_size: int = 16,
|
||||
) -> float:
|
||||
"""Average pairwise LPIPS distance — higher means more diverse samples.
|
||||
|
||||
imgs: (N, 3, H, W) in [-1, 1]. Samples n_pairs random (i, j) pairs with i ≠ j.
|
||||
"""
|
||||
metric = LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True).to(device)
|
||||
imgs_01 = (imgs.clamp(-1, 1) * 0.5 + 0.5)
|
||||
N = len(imgs_01)
|
||||
|
||||
# Sample random pairs (ensure i ≠ j by rejection)
|
||||
idx = torch.randperm(N * 2)[:n_pairs * 2].view(n_pairs, 2) % N
|
||||
same = idx[:, 0] == idx[:, 1]
|
||||
idx[same, 1] = (idx[same, 1] + 1) % N # shift duplicate indices
|
||||
|
||||
for start in range(0, n_pairs, batch_size):
|
||||
end = min(start + batch_size, n_pairs)
|
||||
i_batch = idx[start:end, 0]
|
||||
j_batch = idx[start:end, 1]
|
||||
metric.update(imgs_01[i_batch].to(device), imgs_01[j_batch].to(device))
|
||||
|
||||
return float(metric.compute())
|
||||
@@ -0,0 +1,57 @@
|
||||
"""VGG-16 perceptual loss for Phase 3.2 and 3.3.
|
||||
|
||||
Extracts features at relu1_2, relu2_2, relu3_3 and returns the
|
||||
L1 distance in feature space. VGG weights are frozen.
|
||||
|
||||
Input convention: images in [-1, 1] — the loss converts internally to
|
||||
[0, 1] and then applies ImageNet normalisation before passing to VGG.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as tv_models
|
||||
|
||||
|
||||
class PerceptualLoss(nn.Module):
|
||||
"""L1 feature-matching loss at three VGG-16 layers.
|
||||
|
||||
VGG-16 feature indices:
|
||||
relu1_2: features[:4] (before first maxpool)
|
||||
relu2_2: features[4:9] (before second maxpool)
|
||||
relu3_3: features[9:16] (before third maxpool)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
vgg = tv_models.vgg16(weights=tv_models.VGG16_Weights.IMAGENET1K_V1)
|
||||
feats = vgg.features
|
||||
self.slice1 = nn.Sequential(*list(feats[:4])) # relu1_2
|
||||
self.slice2 = nn.Sequential(*list(feats[4:9])) # relu2_2
|
||||
self.slice3 = nn.Sequential(*list(feats[9:16])) # relu3_3
|
||||
|
||||
for p in self.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.register_buffer(
|
||||
"mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||
)
|
||||
self.register_buffer(
|
||||
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||
)
|
||||
|
||||
def _normalise(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert [-1, 1] → ImageNet-normalised [0, 1]."""
|
||||
x = x * 0.5 + 0.5 # → [0, 1]
|
||||
return (x - self.mean) / self.std
|
||||
|
||||
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
|
||||
"""L1 feature distance. real gradients are stopped — only fake trains."""
|
||||
f = self._normalise(fake)
|
||||
r = self._normalise(real)
|
||||
|
||||
loss = torch.tensor(0.0, device=fake.device)
|
||||
for layer in (self.slice1, self.slice2, self.slice3):
|
||||
f = layer(f)
|
||||
r = layer(r)
|
||||
loss = loss + F.l1_loss(f, r.detach())
|
||||
return loss
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.utils import save_image
|
||||
from tqdm import tqdm
|
||||
@@ -19,12 +21,11 @@ else:
|
||||
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
|
||||
|
||||
|
||||
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, latent_dim: int, device) -> None:
|
||||
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, fixed_noise: torch.Tensor, device) -> None:
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
with torch.no_grad():
|
||||
noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
imgs = generator_ema.model(noise) # EMA model, [-1, 1]
|
||||
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
|
||||
imgs = generator_ema.model(fixed_noise.to(device)) # EMA model, [-1, 1]
|
||||
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
|
||||
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||
|
||||
|
||||
@@ -78,6 +79,9 @@ def train_dcgan(
|
||||
|
||||
ema = EMA(generator, decay=ema_decay)
|
||||
|
||||
# Fixed noise for consistent sample tracking across epochs
|
||||
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
@@ -88,6 +92,15 @@ def train_dcgan(
|
||||
best_fid = float("inf")
|
||||
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched_g = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
generator.train()
|
||||
discriminator.train()
|
||||
@@ -142,13 +155,13 @@ def train_dcgan(
|
||||
)
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
_save_samples(ema, samples_dir, epoch, latent_dim=latent_dim, device=device)
|
||||
_save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
generator.eval()
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
generator(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
@@ -160,7 +173,586 @@ def train_dcgan(
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
sched_g.step()
|
||||
sched_d.step()
|
||||
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
|
||||
torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
history["train_time_s"] = time.time() - t_start
|
||||
return history
|
||||
|
||||
|
||||
def _gradient_penalty(critic, real: torch.Tensor, fake: torch.Tensor, device) -> torch.Tensor:
|
||||
"""Two-sided gradient penalty (Gulrajani et al., 2017)."""
|
||||
bsz = real.size(0)
|
||||
eps = torch.rand(bsz, 1, 1, 1, device=device)
|
||||
interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True)
|
||||
d_interp = critic(interp)
|
||||
grad = torch.autograd.grad(
|
||||
outputs=d_interp,
|
||||
inputs=interp,
|
||||
grad_outputs=torch.ones_like(d_interp),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
)[0]
|
||||
return ((grad.norm(2, dim=[1, 2, 3]) - 1) ** 2).mean()
|
||||
|
||||
|
||||
def train_wgan(
|
||||
generator,
|
||||
critic,
|
||||
train_dataset,
|
||||
cfg: dict,
|
||||
*,
|
||||
save_dir,
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""WGAN-GP training loop (Gulrajani et al., 2017).
|
||||
|
||||
Used for Phase 2.2–2.4. Gradient penalty replaces weight clipping.
|
||||
The critic runs in float32 to keep GP gradient computation numerically
|
||||
stable; AMP is used only for the generator forward/backward.
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
generator = generator.to(device)
|
||||
critic = critic.to(device)
|
||||
|
||||
n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
|
||||
n_c = sum(p.numel() for p in critic.parameters() if p.requires_grad)
|
||||
print(f"Generator: {n_g:,} params Critic: {n_c:,} params")
|
||||
|
||||
epochs = cfg["epochs"]
|
||||
batch_size = cfg["batch_size"]
|
||||
lr_g = cfg.get("lr_g", 1e-4)
|
||||
lr_d = cfg.get("lr_d", 1e-4)
|
||||
beta1 = cfg.get("beta1", 0.0)
|
||||
beta2 = cfg.get("beta2", 0.9)
|
||||
latent_dim = cfg.get("latent_dim", 128)
|
||||
n_critic = cfg.get("n_critic", 5)
|
||||
gp_lambda = cfg.get("gp_lambda", 10)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
fid_interval = cfg.get("fid_interval", 25)
|
||||
fid_n_real = cfg.get("fid_n_real", 5000)
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=min(4, os.cpu_count() or 1),
|
||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||
)
|
||||
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
||||
opt_c = torch.optim.Adam(critic.parameters(), lr=lr_d, betas=(beta1, beta2))
|
||||
|
||||
use_amp = device.type == "cuda"
|
||||
scaler_g = _GradScaler("cuda", enabled=use_amp)
|
||||
|
||||
ema = EMA(generator, decay=ema_decay)
|
||||
|
||||
# Fixed noise for consistent sample tracking across epochs
|
||||
fixed_noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
|
||||
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
print(f"Device: {device} AMP (G only): {use_amp} Batches/epoch: {len(loader)} n_critic: {n_critic}")
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched_g = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_g, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
sched_c = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_c, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / (epochs - decay_start)))
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
generator.train()
|
||||
critic.train()
|
||||
g_sum = w_sum = gp_sum = real_sum = fake_sum = 0.0
|
||||
n_c_steps = n_g_steps = 0
|
||||
|
||||
for batch_idx, real in enumerate(tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False)):
|
||||
real = real.to(device)
|
||||
bsz = real.size(0)
|
||||
|
||||
# ── Critic step (every batch) ─────────────────────────────────
|
||||
# Run critic in float32 — GP requires double-precision gradients
|
||||
# and AMP can degrade stability here.
|
||||
opt_c.zero_grad()
|
||||
with torch.no_grad():
|
||||
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
|
||||
|
||||
real_f32 = real.float()
|
||||
fake_f32 = fake.float().detach()
|
||||
|
||||
d_real = critic(real_f32)
|
||||
d_fake = critic(fake_f32)
|
||||
gp = _gradient_penalty(critic, real_f32, fake_f32.detach(), device)
|
||||
c_loss = d_fake.mean() - d_real.mean() + gp_lambda * gp
|
||||
c_loss.backward()
|
||||
opt_c.step()
|
||||
|
||||
w_dist = (d_real.mean() - d_fake.mean()).item()
|
||||
w_sum += w_dist
|
||||
gp_sum += gp.item()
|
||||
real_sum += d_real.mean().item()
|
||||
fake_sum += d_fake.mean().item()
|
||||
n_c_steps += 1
|
||||
|
||||
# ── Generator step (every n_critic batches) ───────────────────
|
||||
if (batch_idx + 1) % n_critic == 0:
|
||||
opt_g.zero_grad()
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
fake = generator(torch.randn(bsz, latent_dim, 1, 1, device=device))
|
||||
g_loss = -critic(fake.float()).mean()
|
||||
scaler_g.scale(g_loss).backward()
|
||||
scaler_g.step(opt_g)
|
||||
scaler_g.update()
|
||||
ema.update(generator)
|
||||
g_sum += g_loss.item()
|
||||
n_g_steps += 1
|
||||
|
||||
avg_w = w_sum / max(n_c_steps, 1)
|
||||
avg_gp = gp_sum / max(n_c_steps, 1)
|
||||
avg_g = g_sum / max(n_g_steps, 1)
|
||||
avg_r = real_sum / max(n_c_steps, 1)
|
||||
avg_f = fake_sum / max(n_c_steps, 1)
|
||||
history["g_loss"].append(avg_g)
|
||||
history["w_dist"].append(avg_w)
|
||||
history["gp"].append(avg_gp)
|
||||
history["d_real"].append(avg_r)
|
||||
history["d_fake"].append(avg_f)
|
||||
print(
|
||||
f"[{epoch:03d}/{epochs}] "
|
||||
f"G: {avg_g:.4f} W-dist: {avg_w:.4f} GP: {avg_gp:.4f} "
|
||||
f"C(real): {avg_r:.4f} C(fake): {avg_f:.4f}"
|
||||
)
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
_save_samples(ema, samples_dir, epoch, fixed_noise=fixed_noise, device=device)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
ema.model(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
sched_g.step()
|
||||
sched_c.step()
|
||||
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
|
||||
torch.save(critic.state_dict(), save_dir / f"{run_name}_final_d.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
history["train_time_s"] = time.time() - t_start
|
||||
return history
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
# Phase 3 — VAE (3.1 MSE+KL · 3.2 +Perceptual · 3.3 +PatchGAN)
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _save_vae_samples(
|
||||
vae,
|
||||
samples_dir: Path,
|
||||
epoch: int,
|
||||
*,
|
||||
fixed_z: torch.Tensor,
|
||||
fixed_real: torch.Tensor,
|
||||
device,
|
||||
) -> None:
|
||||
"""Save prior samples and a real-vs-reconstruction grid side by side."""
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
prior = vae.decode(fixed_z.to(device))
|
||||
prior = (prior.clamp(-1, 1) + 1.0) / 2.0
|
||||
save_image(prior, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||
|
||||
recon, _, _ = vae(fixed_real.to(device))
|
||||
recon = (recon.clamp(-1, 1) + 1.0) / 2.0
|
||||
real = (fixed_real.to(device) + 1.0) / 2.0
|
||||
# Interleave real / reconstruction pairs
|
||||
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
|
||||
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
|
||||
vae.train()
|
||||
|
||||
|
||||
def train_vae(
|
||||
vae,
|
||||
train_dataset,
|
||||
cfg: dict,
|
||||
*,
|
||||
save_dir,
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""VAE training loop covering Phase 3.1 – 3.3.
|
||||
|
||||
Config toggles:
|
||||
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
|
||||
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
|
||||
|
||||
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
vae = vae.to(device)
|
||||
|
||||
n_vae = sum(p.numel() for p in vae.parameters() if p.requires_grad)
|
||||
print(f"VAE: {n_vae:,} params")
|
||||
|
||||
epochs = cfg["epochs"]
|
||||
batch_size = cfg["batch_size"]
|
||||
lr = cfg.get("lr", 1e-3)
|
||||
latent_dim = cfg.get("latent_dim", 256)
|
||||
beta_kl = cfg.get("beta_kl", 1.0)
|
||||
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
|
||||
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
|
||||
lr_d = cfg.get("lr_d", 1e-4)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
fid_interval = cfg.get("fid_interval", 25)
|
||||
fid_n_real = cfg.get("fid_n_real", 5000)
|
||||
|
||||
use_perceptual = lambda_perceptual > 0
|
||||
use_adversarial = lambda_adversarial > 0
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=min(4, os.cpu_count() or 1),
|
||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||
)
|
||||
|
||||
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
|
||||
use_amp = device.type == "cuda"
|
||||
scaler = _GradScaler("cuda", enabled=use_amp)
|
||||
|
||||
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
|
||||
kl_warmup_epochs = max(1, epochs // 5)
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched_vae = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_vae, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||
sched_d = None # set below if adversarial
|
||||
|
||||
# ── Optional components ───────────────────────────────────────────────
|
||||
perc_fn = None
|
||||
patchgan = None
|
||||
opt_d = None
|
||||
scaler_d = None
|
||||
|
||||
if use_perceptual:
|
||||
from src.training.perceptual import PerceptualLoss
|
||||
perc_fn = PerceptualLoss().to(device)
|
||||
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
|
||||
|
||||
if use_adversarial:
|
||||
from src.models.patchgan import PatchGANDiscriminator, hinge_d_loss, hinge_g_loss
|
||||
patchgan = PatchGANDiscriminator(
|
||||
ndf=cfg.get("ndf_patch", 64),
|
||||
image_size=cfg.get("image_size", 64),
|
||||
).to(device)
|
||||
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
||||
scaler_d = _GradScaler("cuda", enabled=use_amp)
|
||||
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||
n_d = sum(p.numel() for p in patchgan.parameters())
|
||||
print(f"PatchGAN: {n_d:,} params")
|
||||
else:
|
||||
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called
|
||||
|
||||
# ── Fixed seeds for consistent visualisation ──────────────────────────
|
||||
fixed_z = torch.randn(16, latent_dim, device=device)
|
||||
# Grab first 16 real images from the loader for reconstruction tracking
|
||||
_it = iter(loader)
|
||||
fixed_real = next(_it)[:16].cpu()
|
||||
|
||||
ema = EMA(vae, decay=ema_decay)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
|
||||
history = {
|
||||
"recon_loss": [], "kl_loss": [], "perc_loss": [],
|
||||
"adv_g_loss": [], "adv_d_loss": [], "fid": {},
|
||||
}
|
||||
best_fid = float("inf")
|
||||
print(
|
||||
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
|
||||
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}"
|
||||
)
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
vae.train()
|
||||
if patchgan is not None:
|
||||
patchgan.train()
|
||||
|
||||
recon_sum = kl_sum = perc_sum = adv_g_sum = adv_d_sum = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
|
||||
real = real.to(device)
|
||||
|
||||
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
|
||||
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
|
||||
|
||||
# ── VAE forward ───────────────────────────────────────────────
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
recon, mu, log_var = vae(real)
|
||||
mse = F.mse_loss(recon, real)
|
||||
kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean()
|
||||
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
|
||||
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
|
||||
|
||||
# ── PatchGAN discriminator step ───────────────────────────────
|
||||
adv_d = real.new_zeros(1).squeeze()
|
||||
if use_adversarial:
|
||||
opt_d.zero_grad()
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
d_real = patchgan(real)
|
||||
d_fake = patchgan(recon.detach())
|
||||
adv_d = hinge_d_loss(d_real, d_fake)
|
||||
scaler_d.scale(adv_d).backward()
|
||||
scaler_d.step(opt_d)
|
||||
scaler_d.update()
|
||||
|
||||
# ── PatchGAN generator adversarial loss ───────────────────────
|
||||
adv_g = real.new_zeros(1).squeeze()
|
||||
if use_adversarial:
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
adv_g = hinge_g_loss(patchgan(recon))
|
||||
vae_loss = vae_loss + lambda_adversarial * adv_g
|
||||
|
||||
# ── VAE backward ──────────────────────────────────────────────
|
||||
opt_vae.zero_grad()
|
||||
scaler.scale(vae_loss).backward()
|
||||
scaler.step(opt_vae)
|
||||
scaler.update()
|
||||
ema.update(vae)
|
||||
|
||||
recon_sum += mse.item()
|
||||
kl_sum += kl.item()
|
||||
perc_sum += perc.item()
|
||||
adv_g_sum += adv_g.item()
|
||||
adv_d_sum += adv_d.item()
|
||||
n_batches += 1
|
||||
|
||||
avg_r = recon_sum / n_batches
|
||||
avg_k = kl_sum / n_batches
|
||||
avg_p = perc_sum / n_batches
|
||||
avg_g = adv_g_sum / n_batches
|
||||
avg_d = adv_d_sum / n_batches
|
||||
history["recon_loss"].append(avg_r)
|
||||
history["kl_loss"].append(avg_k)
|
||||
history["perc_loss"].append(avg_p)
|
||||
history["adv_g_loss"].append(avg_g)
|
||||
history["adv_d_loss"].append(avg_d)
|
||||
|
||||
print(
|
||||
f"[{epoch:03d}/{epochs}] "
|
||||
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
|
||||
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
|
||||
)
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
_save_vae_samples(
|
||||
ema.model, samples_dir, epoch,
|
||||
fixed_z=fixed_z, fixed_real=fixed_real, device=device,
|
||||
)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
ema.model.sample(64, device)
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
torch.save(vae.state_dict(), save_dir / f"{run_name}_best_vae.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
sched_vae.step()
|
||||
if sched_d is not None:
|
||||
sched_d.step()
|
||||
|
||||
torch.save(vae.state_dict(), save_dir / f"{run_name}_final_vae.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
if patchgan is not None:
|
||||
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
|
||||
history["train_time_s"] = time.time() - t_start
|
||||
return history
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
# Phase 4 — DDPM (4.1 linear·ε · 4.2 cosine·ε · 4.3 cosine·v · 4.4 wider)
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def train_ddpm(
|
||||
model,
|
||||
train_dataset,
|
||||
cfg: dict,
|
||||
*,
|
||||
save_dir,
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""DDPM training loop (Ho et al., 2020) covering Phase 4.1 – 4.4.
|
||||
|
||||
Config keys:
|
||||
noise_schedule — "linear" (4.1) or "cosine" (4.2+)
|
||||
pred_type — "eps" (4.1–4.2) or "v" (4.3+)
|
||||
T — diffusion timesteps (default 1000)
|
||||
base_ch / ch_mult / attn_resolutions — U-Net capacity (see unet.py)
|
||||
ddim_steps — DDIM steps for FID evaluation (default 100)
|
||||
"""
|
||||
from src.training.diffusion import (
|
||||
linear_betas, cosine_betas, make_alpha_bars,
|
||||
diffusion_loss, ddim_sample,
|
||||
)
|
||||
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"U-Net: {n_params:,} params")
|
||||
|
||||
epochs = cfg["epochs"]
|
||||
batch_size = cfg["batch_size"]
|
||||
lr = cfg.get("lr", 2e-4)
|
||||
T = cfg.get("T", 1000)
|
||||
noise_schedule = cfg.get("noise_schedule", "linear")
|
||||
pred_type = cfg.get("pred_type", "eps")
|
||||
ddim_steps = cfg.get("ddim_steps", 100)
|
||||
image_size = cfg.get("image_size", 64)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
fid_interval = cfg.get("fid_interval", 25)
|
||||
fid_n_real = cfg.get("fid_n_real", 5000)
|
||||
|
||||
# Build noise schedule and register on device
|
||||
betas = (cosine_betas(T) if noise_schedule == "cosine" else linear_betas(T)).to(device)
|
||||
alpha_bars = make_alpha_bars(betas) # on device
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=min(4, os.cpu_count() or 1),
|
||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||
)
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
|
||||
use_amp = device.type == "cuda"
|
||||
scaler = _GradScaler("cuda", enabled=use_amp)
|
||||
|
||||
ema = EMA(model, decay=ema_decay)
|
||||
|
||||
# Fixed noise for sample visualisation (same latents across epochs)
|
||||
fixed_noise = torch.randn(16, 3, image_size, image_size, device=device)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
|
||||
history = {"loss": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
print(
|
||||
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
|
||||
f" T={T} schedule={noise_schedule} pred={pred_type} ddim_steps={ddim_steps}"
|
||||
)
|
||||
|
||||
# Linear LR decay from epoch epochs//2 to epochs
|
||||
decay_start = epochs // 2
|
||||
sched = torch.optim.lr_scheduler.LambdaLR(
|
||||
opt, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
model.train()
|
||||
loss_sum = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for x0 in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
|
||||
x0 = x0.to(device)
|
||||
t = torch.randint(0, T, (x0.size(0),), device=device)
|
||||
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
loss = diffusion_loss(model, x0, t, alpha_bars, pred_type)
|
||||
|
||||
opt.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
ema.update(model)
|
||||
|
||||
loss_sum += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
avg_loss = loss_sum / n_batches
|
||||
history["loss"].append(avg_loss)
|
||||
print(f"[{epoch:03d}/{epochs}] Loss: {avg_loss:.5f}")
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
ema.model.eval()
|
||||
with torch.no_grad():
|
||||
# Quick visualisation: denoise fixed_noise via DDIM
|
||||
imgs = ddim_sample(
|
||||
ema.model, 16, image_size, alpha_bars,
|
||||
n_steps=50, pred_type=pred_type, device=str(device), batch_size=16,
|
||||
)
|
||||
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0
|
||||
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
ema.model.eval()
|
||||
fake_imgs = ddim_sample(
|
||||
ema.model, fid_n_real, image_size, alpha_bars,
|
||||
n_steps=ddim_steps, pred_type=pred_type,
|
||||
device=str(device), batch_size=32,
|
||||
)
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
torch.save(model.state_dict(), save_dir / f"{run_name}_best_unet.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
sched.step()
|
||||
|
||||
torch.save(model.state_dict(), save_dir / f"{run_name}_final_unet.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
history["train_time_s"] = time.time() - t_start
|
||||
return history
|
||||
|
||||
Reference in New Issue
Block a user