Clean state
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from src.data.dataset import GeneratorDataset, get_transform
|
||||
|
||||
__all__ = ["GeneratorDataset", "get_transform"]
|
||||
@@ -0,0 +1,74 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
import torchvision.transforms as T
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class GeneratorDataset(Dataset):
|
||||
"""Unlabeled image dataset for generative model training.
|
||||
|
||||
Loads images from source subdirectories and returns tensors only —
|
||||
no labels, since generation is unsupervised.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir, sources=None, subsample=1.0, transform=None, seed=42):
|
||||
self.transform = transform
|
||||
self.samples = []
|
||||
|
||||
# Accept either a single root or a list of roots (used by 1D to mix
|
||||
# raw + aligned crops in one dataset).
|
||||
roots = [data_dir] if isinstance(data_dir, (str, Path)) else list(data_dir)
|
||||
if sources is None:
|
||||
sources = ["wiki"]
|
||||
|
||||
for root in roots:
|
||||
root = Path(root)
|
||||
if not root.exists():
|
||||
raise FileNotFoundError(f"Dataset root not found: {root}")
|
||||
for source in sources:
|
||||
source_dir = root / source
|
||||
if not source_dir.exists():
|
||||
raise FileNotFoundError(f"Missing source directory: {source_dir}")
|
||||
for subdir in sorted(source_dir.iterdir()):
|
||||
if subdir.is_dir():
|
||||
for img_path in sorted(subdir.glob("*.jpg")):
|
||||
self.samples.append(img_path)
|
||||
|
||||
if subsample < 1.0:
|
||||
rng = random.Random(seed)
|
||||
n = max(1, int(len(self.samples) * subsample))
|
||||
self.samples = rng.sample(self.samples, n)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = Image.open(self.samples[idx]).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img
|
||||
|
||||
|
||||
def get_transform(image_size: int, augment: bool = False) -> T.Compose:
|
||||
"""Build transform for generator training. Output is in [-1, 1].
|
||||
|
||||
augment=True adds horizontal flip + mild rotation + mild color jitter.
|
||||
Use augment=False for validation / FID real-image sets.
|
||||
"""
|
||||
ops = [
|
||||
T.Resize(image_size),
|
||||
T.CenterCrop(image_size),
|
||||
]
|
||||
if augment:
|
||||
ops += [
|
||||
T.RandomHorizontalFlip(p=0.5),
|
||||
T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR),
|
||||
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),
|
||||
]
|
||||
ops += [
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # -> [-1, 1]
|
||||
]
|
||||
return T.Compose(ops)
|
||||
@@ -0,0 +1,26 @@
|
||||
from typing import Callable
|
||||
import torch.nn as nn
|
||||
|
||||
_REGISTRY: dict[str, tuple[Callable, str]] = {}
|
||||
|
||||
|
||||
def register(name: str, builder: Callable, *, kind: str) -> None:
|
||||
_REGISTRY[name] = (builder, kind)
|
||||
|
||||
|
||||
def get_model(cfg: dict) -> tuple:
|
||||
"""Return (model_or_pair, kind).
|
||||
|
||||
kind="dcgan" -> (generator, discriminator)
|
||||
"""
|
||||
name = cfg.get("model")
|
||||
entry = _REGISTRY.get(name)
|
||||
if entry is None:
|
||||
available = ", ".join(sorted(_REGISTRY))
|
||||
raise ValueError(f"Unknown model: {name!r}. Available: {available}")
|
||||
builder, kind = entry
|
||||
return builder(cfg), kind
|
||||
|
||||
|
||||
from src.models import dcgan # noqa: E402, F401
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Vanilla DCGAN (Radford et al., 2015).
|
||||
|
||||
Used as the Phase 1 baseline for cheap pipeline ablations. Architecture is
|
||||
intentionally minimal — BatchNorm in both networks, no spectral norm, no
|
||||
attention, no gradient penalty. The whole point is to be the cheapest GAN
|
||||
we can run, so 1A–1D pipeline deltas show up in FID quickly.
|
||||
|
||||
Depth scales with image_size: each step doubles the spatial dimension,
|
||||
starting from 4×4 after the first transposed conv.
|
||||
64 -> 5 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64)
|
||||
128 -> 6 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64 -> 128)
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from src.models import register
|
||||
|
||||
|
||||
def _init_weights(m):
|
||||
classname = m.__class__.__name__
|
||||
if "Conv" in classname:
|
||||
nn.init.normal_(m.weight, 0.0, 0.02)
|
||||
elif "BatchNorm" in classname:
|
||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
def _n_upsamples(image_size: int) -> int:
|
||||
"""Number of 2x upsampling steps from 4x4 to image_size."""
|
||||
if image_size < 8 or image_size & (image_size - 1):
|
||||
raise ValueError(f"image_size must be a power of two ≥ 8, got {image_size}")
|
||||
return int(math.log2(image_size)) - 2 # 64 -> 4, 128 -> 5
|
||||
|
||||
|
||||
class DCGANGenerator(nn.Module):
|
||||
"""Maps (latent_dim x 1 x 1) -> (3 x image_size x image_size) in [-1, 1]."""
|
||||
|
||||
def __init__(self, latent_dim: int = 100, ngf: int = 64, image_size: int = 64):
|
||||
super().__init__()
|
||||
n_up = _n_upsamples(image_size) # 64 -> 4 upsamples after the 1->4 init
|
||||
max_mult = 2 ** (n_up - 1) # channel multiplier at the 4x4 stage
|
||||
|
||||
layers: list[nn.Module] = [
|
||||
# 1x1 -> 4x4
|
||||
nn.ConvTranspose2d(latent_dim, ngf * max_mult, 4, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(ngf * max_mult),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
# Each step halves the channel multiplier and doubles spatial size.
|
||||
mult = max_mult
|
||||
for _ in range(n_up - 1):
|
||||
layers += [
|
||||
nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf * mult // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
mult //= 2
|
||||
# Final layer to 3 channels, no BN, Tanh.
|
||||
layers += [
|
||||
nn.ConvTranspose2d(ngf * mult, 3, 4, 2, 1, bias=False),
|
||||
nn.Tanh(),
|
||||
]
|
||||
self.net = nn.Sequential(*layers)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(z)
|
||||
|
||||
|
||||
class DCGANDiscriminator(nn.Module):
|
||||
"""Maps (3 x image_size x image_size) -> scalar logit (no sigmoid)."""
|
||||
|
||||
def __init__(self, ndf: int = 64, image_size: int = 64):
|
||||
super().__init__()
|
||||
n_down = _n_upsamples(image_size)
|
||||
layers: list[nn.Module] = [
|
||||
# First layer: no BN
|
||||
nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
mult = 1
|
||||
for _ in range(n_down - 1):
|
||||
layers += [
|
||||
nn.Conv2d(ndf * mult, ndf * mult * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ndf * mult * 2),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
mult *= 2
|
||||
# 4x4 -> 1x1, scalar logit
|
||||
layers += [nn.Conv2d(ndf * mult, 1, 4, 1, 0, bias=False)]
|
||||
self.net = nn.Sequential(*layers)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x).view(x.size(0))
|
||||
|
||||
|
||||
def _build(cfg: dict):
|
||||
image_size = cfg.get("image_size", 64)
|
||||
return (
|
||||
DCGANGenerator(
|
||||
latent_dim=cfg.get("latent_dim", 100),
|
||||
ngf=cfg.get("ngf", 64),
|
||||
image_size=image_size,
|
||||
),
|
||||
DCGANDiscriminator(
|
||||
ndf=cfg.get("ndf", 64),
|
||||
image_size=image_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register("dcgan", _build, kind="dcgan")
|
||||
@@ -0,0 +1,133 @@
|
||||
"""WGAN-GP with spectral normalization, self-attention, and GroupNorm.
|
||||
|
||||
Improvements over the original:
|
||||
- Generator: BatchNorm -> GroupNorm (no batch-size coupling, stable with varied content)
|
||||
- Critic: InstanceNorm -> spectral normalization (principled Lipschitz constraint)
|
||||
- Both: one SAGAN-style self-attention block at the 32x32 feature map
|
||||
- Larger capacity: ngf=128, ndf=128
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from src.models import register
|
||||
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
nn.init.normal_(m.weight, 0.0, 0.02)
|
||||
elif isinstance(m, nn.GroupNorm) and m.weight is not None:
|
||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, in_ch: int):
|
||||
super().__init__()
|
||||
mid = max(in_ch // 8, 1)
|
||||
self.q = nn.Conv2d(in_ch, mid, 1, bias=False)
|
||||
self.k = nn.Conv2d(in_ch, mid, 1, bias=False)
|
||||
self.v = nn.Conv2d(in_ch, in_ch, 1, bias=False)
|
||||
self.gamma = nn.Parameter(torch.zeros(1))
|
||||
self._mid = mid
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, c, h, w = x.shape
|
||||
q = self.q(x).view(b, self._mid, -1).transpose(-2, -1) # (b, hw, mid)
|
||||
k = self.k(x).view(b, self._mid, -1) # (b, mid, hw)
|
||||
v = self.v(x).view(b, c, -1) # (b, c, hw)
|
||||
attn = torch.softmax(q @ k * self._mid ** -0.5, dim=-1) # (b, hw, hw)
|
||||
out = (v @ attn.transpose(-2, -1)).view(b, c, h, w)
|
||||
return x + self.gamma * out
|
||||
|
||||
|
||||
def _sn(module):
|
||||
"""Apply spectral normalization to a conv layer."""
|
||||
return nn.utils.spectral_norm(module)
|
||||
|
||||
|
||||
class WGANGenerator(nn.Module):
|
||||
"""Maps (latent_dim x 1 x 1) -> (3 x 128 x 128) in [-1, 1].
|
||||
|
||||
Upsampling path: 1 -> 4 -> 8 -> 16 (+attn) -> 32 -> 64 -> 128
|
||||
Self-attention sits at 16x16 (attention matrix 256x256 vs 1024x1024 at 32x32).
|
||||
"""
|
||||
|
||||
def __init__(self, latent_dim: int = 128, ngf: int = 64):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
# 1x1 -> 4x4
|
||||
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
|
||||
nn.GroupNorm(8, ngf * 8), nn.ReLU(True),
|
||||
# 4x4 -> 8x8
|
||||
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf * 4), nn.ReLU(True),
|
||||
# 8x8 -> 16x16
|
||||
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf * 2), nn.ReLU(True),
|
||||
)
|
||||
self.attn = SelfAttention(ngf * 2) # applied at 16x16
|
||||
self.out = nn.Sequential(
|
||||
# 16x16 -> 32x32
|
||||
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf), nn.ReLU(True),
|
||||
# 32x32 -> 64x64
|
||||
nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf // 2), nn.ReLU(True),
|
||||
# 64x64 -> 128x128
|
||||
nn.ConvTranspose2d(ngf // 2, 3, 4, 2, 1, bias=False),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
||||
h = self.net(z)
|
||||
h = self.attn(h)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class WGANCritic(nn.Module):
|
||||
"""Critic (no sigmoid) for WGAN-GP. All conv layers are spectrally normalized.
|
||||
|
||||
Downsampling path: 128 -> 64 -> 32 -> 16 (+attn) -> 8 -> 4 -> score
|
||||
"""
|
||||
|
||||
def __init__(self, ndf: int = 64):
|
||||
super().__init__()
|
||||
self.down = nn.Sequential(
|
||||
# 128x128 -> 64x64 (no norm on first layer)
|
||||
_sn(nn.Conv2d(3, ndf // 2, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 64x64 -> 32x32
|
||||
_sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 32x32 -> 16x16
|
||||
_sn(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
self.attn = SelfAttention(ndf * 2) # applied at 16x16
|
||||
self.tail = nn.Sequential(
|
||||
# 16x16 -> 8x8
|
||||
_sn(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 8x8 -> 4x4
|
||||
_sn(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 4x4 -> 1x1
|
||||
_sn(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)),
|
||||
)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.down(x)
|
||||
h = self.attn(h)
|
||||
return self.tail(h).view(x.size(0))
|
||||
|
||||
|
||||
def _build(cfg: dict):
|
||||
return (
|
||||
WGANGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 128)),
|
||||
WGANCritic(ndf=cfg.get("ndf", 128)),
|
||||
)
|
||||
|
||||
|
||||
register("wgan", _build, kind="wgan")
|
||||
@@ -0,0 +1,3 @@
|
||||
from src.training.trainer import train_dcgan
|
||||
|
||||
__all__ = ["train_dcgan"]
|
||||
@@ -0,0 +1,22 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EMA:
|
||||
"""Exponential moving average of model weights.
|
||||
|
||||
Maintains a shadow copy of the model. Call update() after each
|
||||
optimizer step. Sample from ema.model, never from the training model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, decay: float = 0.9999):
|
||||
self.decay = decay
|
||||
self.model = copy.deepcopy(model).eval()
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self, model: nn.Module) -> None:
|
||||
for p_ema, p in zip(self.model.parameters(), model.parameters()):
|
||||
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
|
||||
@@ -0,0 +1,49 @@
|
||||
"""FID evaluation helper.
|
||||
|
||||
Computes Fréchet Inception Distance between a fixed set of real images
|
||||
and a batch of generated images. Real images are stored as a tensor on CPU
|
||||
and moved to device only during evaluation — this avoids re-reading disk
|
||||
every call while keeping GPU memory free between evaluations.
|
||||
"""
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
|
||||
|
||||
class FIDEvaluator:
|
||||
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
self.n_real = n_real
|
||||
|
||||
# Cache real images as a CPU tensor ([-1, 1] range)
|
||||
imgs_list = []
|
||||
loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
|
||||
num_workers=4, drop_last=False)
|
||||
for batch in loader:
|
||||
imgs_list.append(batch.cpu())
|
||||
if sum(x.size(0) for x in imgs_list) >= n_real:
|
||||
break
|
||||
real = torch.cat(imgs_list)[:n_real]
|
||||
self._real = real # stored on CPU, shape (N, 3, H, W) in [-1, 1]
|
||||
|
||||
@torch.no_grad()
|
||||
def compute(self, fake_imgs: torch.Tensor) -> float:
|
||||
"""Compute FID score.
|
||||
|
||||
fake_imgs: float tensor in [-1, 1], shape (N, 3, H, W).
|
||||
N should be at least 2048 for a reliable score.
|
||||
"""
|
||||
fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device)
|
||||
|
||||
# Feed real images in batches
|
||||
for i in range(0, self._real.size(0), 256):
|
||||
batch = (self._real[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
|
||||
fid.update(batch, real=True)
|
||||
|
||||
# Feed fake images in batches
|
||||
fake = fake_imgs.cpu()
|
||||
for i in range(0, fake.size(0), 256):
|
||||
batch = (fake[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
|
||||
fid.update(batch, real=False)
|
||||
|
||||
return float(fid.compute())
|
||||
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.utils import save_image
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.training.ema import EMA
|
||||
from src.training.fid import FIDEvaluator
|
||||
|
||||
if hasattr(torch.amp, "GradScaler"):
|
||||
_GradScaler = torch.amp.GradScaler
|
||||
_autocast = torch.amp.autocast
|
||||
else:
|
||||
from torch.cuda.amp import GradScaler as _GS, autocast as _AC
|
||||
_GradScaler = lambda device="", enabled=True, **kw: _GS(**kw)
|
||||
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
|
||||
|
||||
|
||||
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, latent_dim: int, device) -> None:
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
with torch.no_grad():
|
||||
noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
imgs = generator_ema.model(noise) # EMA model, [-1, 1]
|
||||
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
|
||||
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||
|
||||
|
||||
def train_dcgan(
|
||||
generator,
|
||||
discriminator,
|
||||
train_dataset,
|
||||
cfg: dict,
|
||||
*,
|
||||
save_dir,
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""Vanilla DCGAN training loop with BCE loss (Radford et al., 2015).
|
||||
|
||||
Used as the Phase 1 baseline for cheap pipeline ablations. No gradient
|
||||
penalty, no n_critic, single G/D step per batch.
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
generator = generator.to(device)
|
||||
discriminator = discriminator.to(device)
|
||||
|
||||
n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
|
||||
n_d = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
|
||||
print(f"Generator: {n_g:,} params Discriminator: {n_d:,} params")
|
||||
|
||||
epochs = cfg["epochs"]
|
||||
batch_size = cfg["batch_size"]
|
||||
lr_g = cfg.get("lr_g", 2e-4)
|
||||
lr_d = cfg.get("lr_d", 2e-4)
|
||||
beta1 = cfg.get("beta1", 0.5)
|
||||
beta2 = cfg.get("beta2", 0.999)
|
||||
latent_dim = cfg.get("latent_dim", 100)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
fid_interval = cfg.get("fid_interval", 25)
|
||||
fid_n_real = cfg.get("fid_n_real", 5000)
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=min(4, os.cpu_count() or 1),
|
||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||
)
|
||||
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
||||
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
|
||||
bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
use_amp = device.type == "cuda"
|
||||
scaler_g = _GradScaler("cuda", enabled=use_amp)
|
||||
scaler_d = _GradScaler("cuda", enabled=use_amp)
|
||||
|
||||
ema = EMA(generator, decay=ema_decay)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
|
||||
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
generator.train()
|
||||
discriminator.train()
|
||||
g_sum = d_sum = real_sum = fake_sum = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for imgs in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
|
||||
imgs = imgs.to(device)
|
||||
bsz = imgs.size(0)
|
||||
real_labels = torch.ones(bsz, device=device)
|
||||
fake_labels = torch.zeros(bsz, device=device)
|
||||
|
||||
# ── Discriminator step ────────────────────────────────────────
|
||||
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
fake = generator(noise).detach()
|
||||
d_real = discriminator(imgs)
|
||||
d_fake = discriminator(fake)
|
||||
d_loss = bce(d_real, real_labels) + bce(d_fake, fake_labels)
|
||||
opt_d.zero_grad()
|
||||
scaler_d.scale(d_loss).backward()
|
||||
scaler_d.step(opt_d)
|
||||
scaler_d.update()
|
||||
|
||||
# ── Generator step ────────────────────────────────────────────
|
||||
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
g_loss = bce(discriminator(generator(noise)), real_labels)
|
||||
opt_g.zero_grad()
|
||||
scaler_g.scale(g_loss).backward()
|
||||
scaler_g.step(opt_g)
|
||||
scaler_g.update()
|
||||
ema.update(generator)
|
||||
|
||||
g_sum += g_loss.item()
|
||||
d_sum += d_loss.item()
|
||||
real_sum += d_real.mean().item()
|
||||
fake_sum += d_fake.mean().item()
|
||||
n_batches += 1
|
||||
|
||||
avg_g = g_sum / n_batches
|
||||
avg_d = d_sum / n_batches
|
||||
avg_r = real_sum / n_batches
|
||||
avg_f = fake_sum / n_batches
|
||||
history["g_loss"].append(avg_g)
|
||||
history["d_loss"].append(avg_d)
|
||||
history["d_real"].append(avg_r)
|
||||
history["d_fake"].append(avg_f)
|
||||
print(
|
||||
f"[{epoch:03d}/{epochs}] "
|
||||
f"G: {avg_g:.4f} D: {avg_d:.4f} D(real): {avg_r:.4f} D(fake): {avg_f:.4f}"
|
||||
)
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
_save_samples(ema, samples_dir, epoch, latent_dim=latent_dim, device=device)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
generator.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
generator(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
|
||||
torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
return history
|
||||
@@ -0,0 +1,3 @@
|
||||
from src.utils.config import load_config
|
||||
|
||||
__all__ = ["load_config"]
|
||||
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
# Resolves the extends chain first, then overlays shared.json underneath so
|
||||
# experiment-level keys always win over shared defaults.
|
||||
def load_config(config_path: str, shared_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
config_path = Path(config_path)
|
||||
cfg = _load_extends(config_path)
|
||||
|
||||
if shared_path is None:
|
||||
shared_path = config_path.parent.parent / "shared.json"
|
||||
else:
|
||||
shared_path = Path(shared_path)
|
||||
|
||||
if shared_path.exists():
|
||||
with open(shared_path) as f:
|
||||
shared_cfg = json.load(f)
|
||||
cfg = _deep_merge(shared_cfg, cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
# Pops the "extends" key and recursively merges the parent config underneath;
|
||||
# the seen set catches circular inheritance before it recurses infinitely.
|
||||
def _load_extends(config_path: Path, seen: Optional[set[Path]] = None) -> Dict[str, Any]:
|
||||
if seen is None:
|
||||
seen = set()
|
||||
resolved_path = config_path.resolve()
|
||||
if resolved_path in seen:
|
||||
chain = " -> ".join(str(p) for p in [*seen, resolved_path])
|
||||
raise ValueError(f"Circular config inheritance detected: {chain}")
|
||||
seen.add(resolved_path)
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
base_ref = cfg.pop("extends", None)
|
||||
if not base_ref:
|
||||
seen.remove(resolved_path)
|
||||
return cfg
|
||||
|
||||
base_path = (config_path.parent / base_ref).resolve()
|
||||
base_cfg = _load_extends(base_path, seen=seen)
|
||||
seen.remove(resolved_path)
|
||||
return _deep_merge(base_cfg, cfg)
|
||||
|
||||
|
||||
# Override always wins; nested dicts are merged recursively rather than replaced.
|
||||
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = base.copy()
|
||||
for key, value in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
Reference in New Issue
Block a user