Clean state

This commit is contained in:
Johnny Fernandes
2026-04-30 01:25:39 +01:00
commit bb3dfb92d5
266 changed files with 37043 additions and 0 deletions
+13
View File
@@ -0,0 +1,13 @@
{
"model": "dcgan",
"image_size": 64,
"latent_dim": 100,
"ngf": 64,
"ndf": 64,
"epochs": 50,
"lr_g": 2e-4,
"lr_d": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"augment": false
}
@@ -0,0 +1,5 @@
{
"extends": "_base_dcgan.json",
"run_name": "p1a_dcgan_128",
"image_size": 128
}
@@ -0,0 +1,5 @@
{
"extends": "_base_dcgan.json",
"run_name": "p1a_dcgan_64",
"image_size": 64
}
@@ -0,0 +1,5 @@
{
"extends": "_base_dcgan.json",
"run_name": "p1b_dcgan_aligned",
"data_dir": "cropped/generator"
}
@@ -0,0 +1,5 @@
{
"extends": "_base_dcgan.json",
"run_name": "p1b_dcgan_full",
"data_dir": "data"
}
@@ -0,0 +1,6 @@
{
"extends": "_base_dcgan.json",
"run_name": "p1c_dcgan_full_aug",
"data_dir": "cropped/generator",
"augment": true
}
@@ -0,0 +1,6 @@
{
"extends": "_base_dcgan.json",
"run_name": "p1c_dcgan_hflip",
"data_dir": "cropped/generator",
"augment": false
}
@@ -0,0 +1,5 @@
{
"extends": "_base_dcgan.json",
"run_name": "p1d_dcgan_combined",
"data_dir": ["data", "cropped/generator"]
}
+10
View File
@@ -0,0 +1,10 @@
{
"batch_size": 64,
"ema_decay": 0.9999,
"data_dir": "data",
"sources": ["wiki"],
"subsample": 1.0,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000
}
+81
View File
@@ -0,0 +1,81 @@
"""
Train a generative model from a config file.
Usage:
python run.py <config.json>
python run.py <config.json> --data-dir /path/to/data --output-root generator/outputs
"""
import argparse
import json
import sys
import warnings
from pathlib import Path
# Allow running from project root (python3 generator/run.py ...) or from inside generator/
_here = Path(__file__).resolve().parent
if str(_here) not in sys.path:
sys.path.insert(0, str(_here))
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
def parse_args(argv=None):
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("config_path")
parser.add_argument("--data-dir", default=None)
parser.add_argument("--output-root", default="generator/outputs")
parser.add_argument("--use-gpu", action="store_true", help="Accepted for pipeline compatibility (GPU auto-detected).")
return parser.parse_args(argv)
def main(config_path, *, data_dir_override=None, output_root="generator/outputs"):
import torch
from src.data import GeneratorDataset, get_transform
from src.models import get_model
from src.training import train_dcgan
from src.utils import load_config
cfg = load_config(config_path)
run_name = cfg["run_name"]
device = "cuda" if torch.cuda.is_available() else "cpu"
data_dir = data_dir_override or cfg.get("data_dir", "data")
output_root = Path(output_root)
models_dir = output_root / "models"
logs_dir = output_root / "logs"
print(f"Run: {run_name}")
print(f"Config: {cfg}")
print(f"Device: {device} Data: {data_dir}")
model, kind = get_model(cfg)
augment = cfg.get("augment", True)
transform = get_transform(cfg.get("image_size", 128), augment=augment)
dataset = GeneratorDataset(
data_dir,
sources=cfg.get("sources", ["wiki"]),
subsample=cfg.get("subsample", 1.0),
transform=transform,
)
print(f"Dataset size: {len(dataset)}")
if kind == "dcgan":
generator, discriminator = model
history = train_dcgan(
generator, discriminator, dataset, cfg,
save_dir=models_dir, run_name=run_name, device=device,
)
else:
raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase")
logs_dir.mkdir(parents=True, exist_ok=True)
out = logs_dir / f"{run_name}.json"
with open(out, "w") as f:
json.dump({"run_name": run_name, "config": cfg, "history": history}, f, indent=2)
print(f"\nSaved log to {out}")
if __name__ == "__main__":
args = parse_args(sys.argv[1:])
main(args.config_path, data_dir_override=args.data_dir, output_root=args.output_root)
+3
View File
@@ -0,0 +1,3 @@
from src.data.dataset import GeneratorDataset, get_transform
__all__ = ["GeneratorDataset", "get_transform"]
+74
View File
@@ -0,0 +1,74 @@
import random
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
from torch.utils.data import Dataset
class GeneratorDataset(Dataset):
"""Unlabeled image dataset for generative model training.
Loads images from source subdirectories and returns tensors only —
no labels, since generation is unsupervised.
"""
def __init__(self, data_dir, sources=None, subsample=1.0, transform=None, seed=42):
self.transform = transform
self.samples = []
# Accept either a single root or a list of roots (used by 1D to mix
# raw + aligned crops in one dataset).
roots = [data_dir] if isinstance(data_dir, (str, Path)) else list(data_dir)
if sources is None:
sources = ["wiki"]
for root in roots:
root = Path(root)
if not root.exists():
raise FileNotFoundError(f"Dataset root not found: {root}")
for source in sources:
source_dir = root / source
if not source_dir.exists():
raise FileNotFoundError(f"Missing source directory: {source_dir}")
for subdir in sorted(source_dir.iterdir()):
if subdir.is_dir():
for img_path in sorted(subdir.glob("*.jpg")):
self.samples.append(img_path)
if subsample < 1.0:
rng = random.Random(seed)
n = max(1, int(len(self.samples) * subsample))
self.samples = rng.sample(self.samples, n)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img = Image.open(self.samples[idx]).convert("RGB")
if self.transform:
img = self.transform(img)
return img
def get_transform(image_size: int, augment: bool = False) -> T.Compose:
"""Build transform for generator training. Output is in [-1, 1].
augment=True adds horizontal flip + mild rotation + mild color jitter.
Use augment=False for validation / FID real-image sets.
"""
ops = [
T.Resize(image_size),
T.CenterCrop(image_size),
]
if augment:
ops += [
T.RandomHorizontalFlip(p=0.5),
T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),
]
ops += [
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # -> [-1, 1]
]
return T.Compose(ops)
+26
View File
@@ -0,0 +1,26 @@
from typing import Callable
import torch.nn as nn
_REGISTRY: dict[str, tuple[Callable, str]] = {}
def register(name: str, builder: Callable, *, kind: str) -> None:
_REGISTRY[name] = (builder, kind)
def get_model(cfg: dict) -> tuple:
"""Return (model_or_pair, kind).
kind="dcgan" -> (generator, discriminator)
"""
name = cfg.get("model")
entry = _REGISTRY.get(name)
if entry is None:
available = ", ".join(sorted(_REGISTRY))
raise ValueError(f"Unknown model: {name!r}. Available: {available}")
builder, kind = entry
return builder(cfg), kind
from src.models import dcgan # noqa: E402, F401
+115
View File
@@ -0,0 +1,115 @@
"""Vanilla DCGAN (Radford et al., 2015).
Used as the Phase 1 baseline for cheap pipeline ablations. Architecture is
intentionally minimal — BatchNorm in both networks, no spectral norm, no
attention, no gradient penalty. The whole point is to be the cheapest GAN
we can run, so 1A1D pipeline deltas show up in FID quickly.
Depth scales with image_size: each step doubles the spatial dimension,
starting from 4×4 after the first transposed conv.
64 -> 5 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64)
128 -> 6 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64 -> 128)
"""
import math
import torch
import torch.nn as nn
from src.models import register
def _init_weights(m):
classname = m.__class__.__name__
if "Conv" in classname:
nn.init.normal_(m.weight, 0.0, 0.02)
elif "BatchNorm" in classname:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
def _n_upsamples(image_size: int) -> int:
"""Number of 2x upsampling steps from 4x4 to image_size."""
if image_size < 8 or image_size & (image_size - 1):
raise ValueError(f"image_size must be a power of two ≥ 8, got {image_size}")
return int(math.log2(image_size)) - 2 # 64 -> 4, 128 -> 5
class DCGANGenerator(nn.Module):
"""Maps (latent_dim x 1 x 1) -> (3 x image_size x image_size) in [-1, 1]."""
def __init__(self, latent_dim: int = 100, ngf: int = 64, image_size: int = 64):
super().__init__()
n_up = _n_upsamples(image_size) # 64 -> 4 upsamples after the 1->4 init
max_mult = 2 ** (n_up - 1) # channel multiplier at the 4x4 stage
layers: list[nn.Module] = [
# 1x1 -> 4x4
nn.ConvTranspose2d(latent_dim, ngf * max_mult, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * max_mult),
nn.ReLU(inplace=True),
]
# Each step halves the channel multiplier and doubles spatial size.
mult = max_mult
for _ in range(n_up - 1):
layers += [
nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * mult // 2),
nn.ReLU(inplace=True),
]
mult //= 2
# Final layer to 3 channels, no BN, Tanh.
layers += [
nn.ConvTranspose2d(ngf * mult, 3, 4, 2, 1, bias=False),
nn.Tanh(),
]
self.net = nn.Sequential(*layers)
self.apply(_init_weights)
def forward(self, z: torch.Tensor) -> torch.Tensor:
return self.net(z)
class DCGANDiscriminator(nn.Module):
"""Maps (3 x image_size x image_size) -> scalar logit (no sigmoid)."""
def __init__(self, ndf: int = 64, image_size: int = 64):
super().__init__()
n_down = _n_upsamples(image_size)
layers: list[nn.Module] = [
# First layer: no BN
nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
]
mult = 1
for _ in range(n_down - 1):
layers += [
nn.Conv2d(ndf * mult, ndf * mult * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * mult * 2),
nn.LeakyReLU(0.2, inplace=True),
]
mult *= 2
# 4x4 -> 1x1, scalar logit
layers += [nn.Conv2d(ndf * mult, 1, 4, 1, 0, bias=False)]
self.net = nn.Sequential(*layers)
self.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x).view(x.size(0))
def _build(cfg: dict):
image_size = cfg.get("image_size", 64)
return (
DCGANGenerator(
latent_dim=cfg.get("latent_dim", 100),
ngf=cfg.get("ngf", 64),
image_size=image_size,
),
DCGANDiscriminator(
ndf=cfg.get("ndf", 64),
image_size=image_size,
),
)
register("dcgan", _build, kind="dcgan")
+133
View File
@@ -0,0 +1,133 @@
"""WGAN-GP with spectral normalization, self-attention, and GroupNorm.
Improvements over the original:
- Generator: BatchNorm -> GroupNorm (no batch-size coupling, stable with varied content)
- Critic: InstanceNorm -> spectral normalization (principled Lipschitz constraint)
- Both: one SAGAN-style self-attention block at the 32x32 feature map
- Larger capacity: ngf=128, ndf=128
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models import register
def _init_weights(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.GroupNorm) and m.weight is not None:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
class SelfAttention(nn.Module):
def __init__(self, in_ch: int):
super().__init__()
mid = max(in_ch // 8, 1)
self.q = nn.Conv2d(in_ch, mid, 1, bias=False)
self.k = nn.Conv2d(in_ch, mid, 1, bias=False)
self.v = nn.Conv2d(in_ch, in_ch, 1, bias=False)
self.gamma = nn.Parameter(torch.zeros(1))
self._mid = mid
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
q = self.q(x).view(b, self._mid, -1).transpose(-2, -1) # (b, hw, mid)
k = self.k(x).view(b, self._mid, -1) # (b, mid, hw)
v = self.v(x).view(b, c, -1) # (b, c, hw)
attn = torch.softmax(q @ k * self._mid ** -0.5, dim=-1) # (b, hw, hw)
out = (v @ attn.transpose(-2, -1)).view(b, c, h, w)
return x + self.gamma * out
def _sn(module):
"""Apply spectral normalization to a conv layer."""
return nn.utils.spectral_norm(module)
class WGANGenerator(nn.Module):
"""Maps (latent_dim x 1 x 1) -> (3 x 128 x 128) in [-1, 1].
Upsampling path: 1 -> 4 -> 8 -> 16 (+attn) -> 32 -> 64 -> 128
Self-attention sits at 16x16 (attention matrix 256x256 vs 1024x1024 at 32x32).
"""
def __init__(self, latent_dim: int = 128, ngf: int = 64):
super().__init__()
self.net = nn.Sequential(
# 1x1 -> 4x4
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
nn.GroupNorm(8, ngf * 8), nn.ReLU(True),
# 4x4 -> 8x8
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf * 4), nn.ReLU(True),
# 8x8 -> 16x16
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf * 2), nn.ReLU(True),
)
self.attn = SelfAttention(ngf * 2) # applied at 16x16
self.out = nn.Sequential(
# 16x16 -> 32x32
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf), nn.ReLU(True),
# 32x32 -> 64x64
nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1, bias=False),
nn.GroupNorm(8, ngf // 2), nn.ReLU(True),
# 64x64 -> 128x128
nn.ConvTranspose2d(ngf // 2, 3, 4, 2, 1, bias=False),
nn.Tanh(),
)
self.apply(_init_weights)
def forward(self, z: torch.Tensor) -> torch.Tensor:
h = self.net(z)
h = self.attn(h)
return self.out(h)
class WGANCritic(nn.Module):
"""Critic (no sigmoid) for WGAN-GP. All conv layers are spectrally normalized.
Downsampling path: 128 -> 64 -> 32 -> 16 (+attn) -> 8 -> 4 -> score
"""
def __init__(self, ndf: int = 64):
super().__init__()
self.down = nn.Sequential(
# 128x128 -> 64x64 (no norm on first layer)
_sn(nn.Conv2d(3, ndf // 2, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
# 64x64 -> 32x32
_sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
# 32x32 -> 16x16
_sn(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
)
self.attn = SelfAttention(ndf * 2) # applied at 16x16
self.tail = nn.Sequential(
# 16x16 -> 8x8
_sn(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
# 8x8 -> 4x4
_sn(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)),
nn.LeakyReLU(0.2, True),
# 4x4 -> 1x1
_sn(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)),
)
self.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.down(x)
h = self.attn(h)
return self.tail(h).view(x.size(0))
def _build(cfg: dict):
return (
WGANGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 128)),
WGANCritic(ndf=cfg.get("ndf", 128)),
)
register("wgan", _build, kind="wgan")
+3
View File
@@ -0,0 +1,3 @@
from src.training.trainer import train_dcgan
__all__ = ["train_dcgan"]
+22
View File
@@ -0,0 +1,22 @@
import copy
import torch
import torch.nn as nn
class EMA:
"""Exponential moving average of model weights.
Maintains a shadow copy of the model. Call update() after each
optimizer step. Sample from ema.model, never from the training model.
"""
def __init__(self, model: nn.Module, decay: float = 0.9999):
self.decay = decay
self.model = copy.deepcopy(model).eval()
for p in self.model.parameters():
p.requires_grad_(False)
@torch.no_grad()
def update(self, model: nn.Module) -> None:
for p_ema, p in zip(self.model.parameters(), model.parameters()):
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
+49
View File
@@ -0,0 +1,49 @@
"""FID evaluation helper.
Computes Fréchet Inception Distance between a fixed set of real images
and a batch of generated images. Real images are stored as a tensor on CPU
and moved to device only during evaluation — this avoids re-reading disk
every call while keeping GPU memory free between evaluations.
"""
import torch
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
class FIDEvaluator:
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.n_real = n_real
# Cache real images as a CPU tensor ([-1, 1] range)
imgs_list = []
loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
num_workers=4, drop_last=False)
for batch in loader:
imgs_list.append(batch.cpu())
if sum(x.size(0) for x in imgs_list) >= n_real:
break
real = torch.cat(imgs_list)[:n_real]
self._real = real # stored on CPU, shape (N, 3, H, W) in [-1, 1]
@torch.no_grad()
def compute(self, fake_imgs: torch.Tensor) -> float:
"""Compute FID score.
fake_imgs: float tensor in [-1, 1], shape (N, 3, H, W).
N should be at least 2048 for a reliable score.
"""
fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device)
# Feed real images in batches
for i in range(0, self._real.size(0), 256):
batch = (self._real[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
fid.update(batch, real=True)
# Feed fake images in batches
fake = fake_imgs.cpu()
for i in range(0, fake.size(0), 256):
batch = (fake[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
fid.update(batch, real=False)
return float(fid.compute())
+166
View File
@@ -0,0 +1,166 @@
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
from src.training.ema import EMA
from src.training.fid import FIDEvaluator
if hasattr(torch.amp, "GradScaler"):
_GradScaler = torch.amp.GradScaler
_autocast = torch.amp.autocast
else:
from torch.cuda.amp import GradScaler as _GS, autocast as _AC
_GradScaler = lambda device="", enabled=True, **kw: _GS(**kw)
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, latent_dim: int, device) -> None:
samples_dir.mkdir(parents=True, exist_ok=True)
with torch.no_grad():
noise = torch.randn(16, latent_dim, 1, 1, device=device)
imgs = generator_ema.model(noise) # EMA model, [-1, 1]
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
def train_dcgan(
generator,
discriminator,
train_dataset,
cfg: dict,
*,
save_dir,
run_name: str,
device: str = "cuda",
) -> dict:
"""Vanilla DCGAN training loop with BCE loss (Radford et al., 2015).
Used as the Phase 1 baseline for cheap pipeline ablations. No gradient
penalty, no n_critic, single G/D step per batch.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)
n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
n_d = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
print(f"Generator: {n_g:,} params Discriminator: {n_d:,} params")
epochs = cfg["epochs"]
batch_size = cfg["batch_size"]
lr_g = cfg.get("lr_g", 2e-4)
lr_d = cfg.get("lr_d", 2e-4)
beta1 = cfg.get("beta1", 0.5)
beta2 = cfg.get("beta2", 0.999)
latent_dim = cfg.get("latent_dim", 100)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
fid_n_real = cfg.get("fid_n_real", 5000)
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
bce = nn.BCEWithLogitsLoss()
use_amp = device.type == "cuda"
scaler_g = _GradScaler("cuda", enabled=use_amp)
scaler_d = _GradScaler("cuda", enabled=use_amp)
ema = EMA(generator, decay=ema_decay)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf")
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
for epoch in range(1, epochs + 1):
generator.train()
discriminator.train()
g_sum = d_sum = real_sum = fake_sum = 0.0
n_batches = 0
for imgs in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
imgs = imgs.to(device)
bsz = imgs.size(0)
real_labels = torch.ones(bsz, device=device)
fake_labels = torch.zeros(bsz, device=device)
# ── Discriminator step ────────────────────────────────────────
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
with _autocast("cuda", enabled=use_amp):
fake = generator(noise).detach()
d_real = discriminator(imgs)
d_fake = discriminator(fake)
d_loss = bce(d_real, real_labels) + bce(d_fake, fake_labels)
opt_d.zero_grad()
scaler_d.scale(d_loss).backward()
scaler_d.step(opt_d)
scaler_d.update()
# ── Generator step ────────────────────────────────────────────
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
with _autocast("cuda", enabled=use_amp):
g_loss = bce(discriminator(generator(noise)), real_labels)
opt_g.zero_grad()
scaler_g.scale(g_loss).backward()
scaler_g.step(opt_g)
scaler_g.update()
ema.update(generator)
g_sum += g_loss.item()
d_sum += d_loss.item()
real_sum += d_real.mean().item()
fake_sum += d_fake.mean().item()
n_batches += 1
avg_g = g_sum / n_batches
avg_d = d_sum / n_batches
avg_r = real_sum / n_batches
avg_f = fake_sum / n_batches
history["g_loss"].append(avg_g)
history["d_loss"].append(avg_d)
history["d_real"].append(avg_r)
history["d_fake"].append(avg_f)
print(
f"[{epoch:03d}/{epochs}] "
f"G: {avg_g:.4f} D: {avg_d:.4f} D(real): {avg_r:.4f} D(fake): {avg_f:.4f}"
)
if epoch % sample_interval == 0:
_save_samples(ema, samples_dir, epoch, latent_dim=latent_dim, device=device)
if epoch % fid_interval == 0:
generator.eval()
with torch.no_grad():
fake_imgs = torch.cat([
generator(torch.randn(64, latent_dim, 1, 1, device=device))
for _ in range(fid_n_real // 64 + 1)
])[:fid_n_real]
fid_score = fid_eval.compute(fake_imgs)
history["fid"][epoch] = fid_score
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
if fid_score < best_fid:
best_fid = fid_score
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt")
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
return history
+3
View File
@@ -0,0 +1,3 @@
from src.utils.config import load_config
__all__ = ["load_config"]
+58
View File
@@ -0,0 +1,58 @@
import json
from pathlib import Path
from typing import Any, Dict, Optional
# Resolves the extends chain first, then overlays shared.json underneath so
# experiment-level keys always win over shared defaults.
def load_config(config_path: str, shared_path: Optional[str] = None) -> Dict[str, Any]:
config_path = Path(config_path)
cfg = _load_extends(config_path)
if shared_path is None:
shared_path = config_path.parent.parent / "shared.json"
else:
shared_path = Path(shared_path)
if shared_path.exists():
with open(shared_path) as f:
shared_cfg = json.load(f)
cfg = _deep_merge(shared_cfg, cfg)
return cfg
# Pops the "extends" key and recursively merges the parent config underneath;
# the seen set catches circular inheritance before it recurses infinitely.
def _load_extends(config_path: Path, seen: Optional[set[Path]] = None) -> Dict[str, Any]:
if seen is None:
seen = set()
resolved_path = config_path.resolve()
if resolved_path in seen:
chain = " -> ".join(str(p) for p in [*seen, resolved_path])
raise ValueError(f"Circular config inheritance detected: {chain}")
seen.add(resolved_path)
with open(config_path) as f:
cfg = json.load(f)
base_ref = cfg.pop("extends", None)
if not base_ref:
seen.remove(resolved_path)
return cfg
base_path = (config_path.parent / base_ref).resolve()
base_cfg = _load_extends(base_path, seen=seen)
seen.remove(resolved_path)
return _deep_merge(base_cfg, cfg)
# Override always wins; nested dicts are merged recursively rather than replaced.
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
+268
View File
@@ -0,0 +1,268 @@
#!/usr/bin/env python3
"""
Pre-align face images using MTCNN landmarks + similarity transform.
This is the generator-side counterpart to classifier/tools/facecrop.py.
Difference: classifier uses bbox crop+resize; the generator wants landmark-based
alignment so the eyes, nose, and mouth land at fixed pixel positions in every
training image — structurally consistent training data for the generator.
Output mirrors the source layout exactly:
data/wiki/14/37591914.jpg -> cropped/generator/wiki/14/37591914.jpg
Resumable: already-aligned images are skipped by default.
Usage:
python generator/tools/facecrop.py
python generator/tools/facecrop.py --data-dir data --output-dir cropped/generator
python generator/tools/facecrop.py --sources wiki --device cpu
python generator/tools/facecrop.py --size 128
python generator/tools/facecrop.py --no-skip-existing # reprocess everything
"""
import argparse
import sys
import warnings
from pathlib import Path
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
# Generator trains on real images only (wiki). The other sources are AI-generated
# and aren't used as training targets for the generator, so we don't align them
# by default. Pass --sources to override.
SOURCES = ["wiki"]
ALL_SOURCES = ["wiki", "inpainting", "text2img", "insight"]
# Reference landmark positions for a 128px aligned face.
# Source: standard FFHQ-style alignment template (eyes at y=51, nose at y=71, mouth at y=95).
# Scaled at runtime to match --size.
REF_LANDMARKS_128 = [
(38.0, 51.0), # left eye
(90.0, 51.0), # right eye
(64.0, 71.0), # nose tip
(45.0, 95.0), # left mouth
(83.0, 95.0), # right mouth
]
_DETECTORS: dict[str, object] = {}
def parse_args():
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
p.add_argument("--output-dir", default="cropped/generator", help="Output root (default: cropped/generator)")
p.add_argument("--size", type=int, default=128, help="Output image size in px, square (default: 128)")
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE",
help=f"Sources to process. Default: {', '.join(SOURCES)} (real images only). "
f"All available: {', '.join(ALL_SOURCES)}")
p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True,
help="Skip images already present in output-dir (default: on, resumable)")
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
help="Re-process all images even if already aligned")
return p.parse_args()
# ── alignment helpers ─────────────────────────────────────────────────────────
def _ref_landmarks(size: int):
"""Reference landmarks scaled from the 128px template to `size`."""
import numpy as np
scale = size / 128.0
return np.asarray(
[(x * scale, y * scale) for x, y in REF_LANDMARKS_128],
dtype=np.float32,
)
def _align_from_landmarks(img, landmarks, size: int):
"""Apply similarity transform so detected landmarks map to reference positions."""
import numpy as np
from PIL import Image as PILImage
from skimage.transform import SimilarityTransform, warp
src = np.asarray(landmarks, dtype=np.float32) # 5x2 detected
dst = _ref_landmarks(size) # 5x2 reference
try:
tform = SimilarityTransform.from_estimate(src, dst)
except Exception:
return None
aligned = warp(
np.asarray(img),
tform.inverse,
output_shape=(size, size),
order=3,
preserve_range=True,
).astype(np.uint8)
return PILImage.fromarray(aligned)
def _center_crop(img, size: int):
from PIL import Image as PILImage
w, h = img.size
side = min(w, h)
left, top = (w - side) // 2, (h - side) // 2
return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR)
def _get_detectors(device: str):
"""Return (standard, relaxed) MTCNN detectors with landmarks enabled."""
if device in _DETECTORS:
return _DETECTORS[device]
from facenet_pytorch import MTCNN
standard = MTCNN(
keep_all=False, select_largest=True,
min_face_size=15,
device=device, post_process=False,
)
relaxed = MTCNN(
keep_all=False, select_largest=True,
min_face_size=10,
thresholds=[0.5, 0.6, 0.6],
device=device, post_process=False,
)
_DETECTORS[device] = (standard, relaxed)
return standard, relaxed
# ── main ──────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
import torch
from PIL import Image
from tqdm import tqdm
data_dir = Path(args.data_dir)
output_dir = Path(args.output_dir)
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
sources = args.sources or SOURCES
if not data_dir.exists():
print(f"Error: data directory not found: {data_dir}")
sys.exit(1)
for src in sources:
if not (data_dir / src).exists():
print(f"Error: source directory not found: {data_dir / src}")
sys.exit(1)
try:
import facenet_pytorch # noqa: F401
import skimage # noqa: F401
except ImportError as exc:
print(f"Error: missing dependency ({exc}).")
print(" Run: pip install facenet-pytorch scikit-image")
sys.exit(1)
print(f"Data dir: {data_dir.resolve()}")
print(f"Output dir: {output_dir.resolve()}")
print(f"Sources: {', '.join(sources)}")
print(f"Device: {device}")
print(f"Size: {args.size}px")
print(f"Skip exist: {args.skip_existing}")
standard, relaxed = _get_detectors(device)
all_paths: list[Path] = []
for src in sources:
for subdir in sorted((data_dir / src).iterdir()):
if subdir.is_dir():
all_paths.extend(sorted(subdir.glob("*.jpg")))
print(f"\nTotal images: {len(all_paths):,}\n")
n_processed = n_skipped = n_error = 0
src_stats: dict[str, dict] = {
s: {"aligned": 0, "retry": 0, "fallback": 0} for s in sources
}
for img_path in tqdm(all_paths, desc="Aligning", unit="img"):
rel = img_path.relative_to(data_dir)
out_path = output_dir / rel
src_name = img_path.parent.parent.name
if args.skip_existing and out_path.exists():
n_skipped += 1
continue
out_path.parent.mkdir(parents=True, exist_ok=True)
try:
img = Image.open(img_path).convert("RGB")
except Exception as exc:
tqdm.write(f"[WARN] Cannot open {img_path.name}: {exc}")
n_error += 1
continue
aligned = None
try:
# Pass 1: detect landmarks on the original image
_, _, landmarks = standard.detect(img, landmarks=True)
if landmarks is not None and len(landmarks) > 0:
aligned = _align_from_landmarks(img, landmarks[0], args.size)
if aligned is not None:
src_stats[src_name]["aligned"] += 1
# Pass 2: upscale 2x and retry with relaxed thresholds
if aligned is None:
w, h = img.size
img2x = img.resize((w * 2, h * 2), Image.BILINEAR)
_, _, landmarks2 = relaxed.detect(img2x, landmarks=True)
if landmarks2 is not None and len(landmarks2) > 0:
lm_orig = [(x / 2, y / 2) for x, y in landmarks2[0]]
aligned = _align_from_landmarks(img, lm_orig, args.size)
if aligned is not None:
src_stats[src_name]["retry"] += 1
if aligned is None:
aligned = _center_crop(img, args.size)
src_stats[src_name]["fallback"] += 1
except Exception as exc:
tqdm.write(f"[WARN] Detection failed for {img_path.name}: {exc}")
aligned = _center_crop(img, args.size)
src_stats[src_name]["fallback"] += 1
aligned.save(out_path, quality=95)
n_processed += 1
total = n_processed + n_skipped
n_aligned = sum(s["aligned"] for s in src_stats.values())
n_retry = sum(s["retry"] for s in src_stats.values())
n_fallback = sum(s["fallback"] for s in src_stats.values())
denom = max(n_processed, 1)
print(f"\n{'' * 55}")
print(f" Total images : {total:>8,}")
print(f" Processed : {n_processed:>8,}")
print(f" Skipped (existed) : {n_skipped:>8,}")
print(f" Errors : {n_error:>8,}")
print(f" Pass-1 aligned : {n_aligned:>8,} ({n_aligned / denom:.1%})")
print(f" Pass-2 aligned : {n_retry:>8,} ({n_retry / denom:.1%}) ← 2x upscale retry")
print(f" Centre fallback : {n_fallback:>8,} ({n_fallback / denom:.1%})")
print()
print(f" {'Source':<12} {'pass-1':>8} {'pass-2':>8} {'fallback':>8} {'fallback%':>10}")
print(f" {''*12} {''*8} {''*8} {''*8} {''*10}")
for src in sources:
s = src_stats[src]
total_src = s["aligned"] + s["retry"] + s["fallback"]
fb_pct = s["fallback"] / max(total_src, 1)
print(f" {src:<12} {s['aligned']:>8,} {s['retry']:>8,} {s['fallback']:>8,} {fb_pct:>9.1%}")
print(f"{'' * 55}")
print(f" Output: {output_dir.resolve()}")
print()
print("Next step — point your config at the aligned dataset:")
print(f' "data_dir": "{output_dir}"')
if __name__ == "__main__":
main()