Clean state
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"model": "dcgan",
|
||||
"image_size": 64,
|
||||
"latent_dim": 100,
|
||||
"ngf": 64,
|
||||
"ndf": 64,
|
||||
"epochs": 50,
|
||||
"lr_g": 2e-4,
|
||||
"lr_d": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"augment": false
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "_base_dcgan.json",
|
||||
"run_name": "p1a_dcgan_128",
|
||||
"image_size": 128
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "_base_dcgan.json",
|
||||
"run_name": "p1a_dcgan_64",
|
||||
"image_size": 64
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "_base_dcgan.json",
|
||||
"run_name": "p1b_dcgan_aligned",
|
||||
"data_dir": "cropped/generator"
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "_base_dcgan.json",
|
||||
"run_name": "p1b_dcgan_full",
|
||||
"data_dir": "data"
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"extends": "_base_dcgan.json",
|
||||
"run_name": "p1c_dcgan_full_aug",
|
||||
"data_dir": "cropped/generator",
|
||||
"augment": true
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"extends": "_base_dcgan.json",
|
||||
"run_name": "p1c_dcgan_hflip",
|
||||
"data_dir": "cropped/generator",
|
||||
"augment": false
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "_base_dcgan.json",
|
||||
"run_name": "p1d_dcgan_combined",
|
||||
"data_dir": ["data", "cropped/generator"]
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"batch_size": 64,
|
||||
"ema_decay": 0.9999,
|
||||
"data_dir": "data",
|
||||
"sources": ["wiki"],
|
||||
"subsample": 1.0,
|
||||
"sample_interval": 10,
|
||||
"fid_interval": 25,
|
||||
"fid_n_real": 5000
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Train a generative model from a config file.
|
||||
|
||||
Usage:
|
||||
python run.py <config.json>
|
||||
python run.py <config.json> --data-dir /path/to/data --output-root generator/outputs
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
# Allow running from project root (python3 generator/run.py ...) or from inside generator/
|
||||
_here = Path(__file__).resolve().parent
|
||||
if str(_here) not in sys.path:
|
||||
sys.path.insert(0, str(_here))
|
||||
|
||||
warnings.filterwarnings("ignore", message="Corrupt EXIF data", category=UserWarning)
|
||||
|
||||
|
||||
def parse_args(argv=None):
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("config_path")
|
||||
parser.add_argument("--data-dir", default=None)
|
||||
parser.add_argument("--output-root", default="generator/outputs")
|
||||
parser.add_argument("--use-gpu", action="store_true", help="Accepted for pipeline compatibility (GPU auto-detected).")
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
def main(config_path, *, data_dir_override=None, output_root="generator/outputs"):
|
||||
import torch
|
||||
from src.data import GeneratorDataset, get_transform
|
||||
from src.models import get_model
|
||||
from src.training import train_dcgan
|
||||
from src.utils import load_config
|
||||
|
||||
cfg = load_config(config_path)
|
||||
|
||||
run_name = cfg["run_name"]
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
data_dir = data_dir_override or cfg.get("data_dir", "data")
|
||||
output_root = Path(output_root)
|
||||
models_dir = output_root / "models"
|
||||
logs_dir = output_root / "logs"
|
||||
|
||||
print(f"Run: {run_name}")
|
||||
print(f"Config: {cfg}")
|
||||
print(f"Device: {device} Data: {data_dir}")
|
||||
|
||||
model, kind = get_model(cfg)
|
||||
|
||||
augment = cfg.get("augment", True)
|
||||
transform = get_transform(cfg.get("image_size", 128), augment=augment)
|
||||
dataset = GeneratorDataset(
|
||||
data_dir,
|
||||
sources=cfg.get("sources", ["wiki"]),
|
||||
subsample=cfg.get("subsample", 1.0),
|
||||
transform=transform,
|
||||
)
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
if kind == "dcgan":
|
||||
generator, discriminator = model
|
||||
history = train_dcgan(
|
||||
generator, discriminator, dataset, cfg,
|
||||
save_dir=models_dir, run_name=run_name, device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"kind={kind!r} not yet implemented in this phase")
|
||||
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
out = logs_dir / f"{run_name}.json"
|
||||
with open(out, "w") as f:
|
||||
json.dump({"run_name": run_name, "config": cfg, "history": history}, f, indent=2)
|
||||
print(f"\nSaved log to {out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args(sys.argv[1:])
|
||||
main(args.config_path, data_dir_override=args.data_dir, output_root=args.output_root)
|
||||
@@ -0,0 +1,3 @@
|
||||
from src.data.dataset import GeneratorDataset, get_transform
|
||||
|
||||
__all__ = ["GeneratorDataset", "get_transform"]
|
||||
@@ -0,0 +1,74 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
import torchvision.transforms as T
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class GeneratorDataset(Dataset):
|
||||
"""Unlabeled image dataset for generative model training.
|
||||
|
||||
Loads images from source subdirectories and returns tensors only —
|
||||
no labels, since generation is unsupervised.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir, sources=None, subsample=1.0, transform=None, seed=42):
|
||||
self.transform = transform
|
||||
self.samples = []
|
||||
|
||||
# Accept either a single root or a list of roots (used by 1D to mix
|
||||
# raw + aligned crops in one dataset).
|
||||
roots = [data_dir] if isinstance(data_dir, (str, Path)) else list(data_dir)
|
||||
if sources is None:
|
||||
sources = ["wiki"]
|
||||
|
||||
for root in roots:
|
||||
root = Path(root)
|
||||
if not root.exists():
|
||||
raise FileNotFoundError(f"Dataset root not found: {root}")
|
||||
for source in sources:
|
||||
source_dir = root / source
|
||||
if not source_dir.exists():
|
||||
raise FileNotFoundError(f"Missing source directory: {source_dir}")
|
||||
for subdir in sorted(source_dir.iterdir()):
|
||||
if subdir.is_dir():
|
||||
for img_path in sorted(subdir.glob("*.jpg")):
|
||||
self.samples.append(img_path)
|
||||
|
||||
if subsample < 1.0:
|
||||
rng = random.Random(seed)
|
||||
n = max(1, int(len(self.samples) * subsample))
|
||||
self.samples = rng.sample(self.samples, n)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = Image.open(self.samples[idx]).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img
|
||||
|
||||
|
||||
def get_transform(image_size: int, augment: bool = False) -> T.Compose:
|
||||
"""Build transform for generator training. Output is in [-1, 1].
|
||||
|
||||
augment=True adds horizontal flip + mild rotation + mild color jitter.
|
||||
Use augment=False for validation / FID real-image sets.
|
||||
"""
|
||||
ops = [
|
||||
T.Resize(image_size),
|
||||
T.CenterCrop(image_size),
|
||||
]
|
||||
if augment:
|
||||
ops += [
|
||||
T.RandomHorizontalFlip(p=0.5),
|
||||
T.RandomRotation(degrees=5, interpolation=T.InterpolationMode.BILINEAR),
|
||||
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),
|
||||
]
|
||||
ops += [
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # -> [-1, 1]
|
||||
]
|
||||
return T.Compose(ops)
|
||||
@@ -0,0 +1,26 @@
|
||||
from typing import Callable
|
||||
import torch.nn as nn
|
||||
|
||||
_REGISTRY: dict[str, tuple[Callable, str]] = {}
|
||||
|
||||
|
||||
def register(name: str, builder: Callable, *, kind: str) -> None:
|
||||
_REGISTRY[name] = (builder, kind)
|
||||
|
||||
|
||||
def get_model(cfg: dict) -> tuple:
|
||||
"""Return (model_or_pair, kind).
|
||||
|
||||
kind="dcgan" -> (generator, discriminator)
|
||||
"""
|
||||
name = cfg.get("model")
|
||||
entry = _REGISTRY.get(name)
|
||||
if entry is None:
|
||||
available = ", ".join(sorted(_REGISTRY))
|
||||
raise ValueError(f"Unknown model: {name!r}. Available: {available}")
|
||||
builder, kind = entry
|
||||
return builder(cfg), kind
|
||||
|
||||
|
||||
from src.models import dcgan # noqa: E402, F401
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Vanilla DCGAN (Radford et al., 2015).
|
||||
|
||||
Used as the Phase 1 baseline for cheap pipeline ablations. Architecture is
|
||||
intentionally minimal — BatchNorm in both networks, no spectral norm, no
|
||||
attention, no gradient penalty. The whole point is to be the cheapest GAN
|
||||
we can run, so 1A–1D pipeline deltas show up in FID quickly.
|
||||
|
||||
Depth scales with image_size: each step doubles the spatial dimension,
|
||||
starting from 4×4 after the first transposed conv.
|
||||
64 -> 5 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64)
|
||||
128 -> 6 layers (1 -> 4 -> 8 -> 16 -> 32 -> 64 -> 128)
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from src.models import register
|
||||
|
||||
|
||||
def _init_weights(m):
|
||||
classname = m.__class__.__name__
|
||||
if "Conv" in classname:
|
||||
nn.init.normal_(m.weight, 0.0, 0.02)
|
||||
elif "BatchNorm" in classname:
|
||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
def _n_upsamples(image_size: int) -> int:
|
||||
"""Number of 2x upsampling steps from 4x4 to image_size."""
|
||||
if image_size < 8 or image_size & (image_size - 1):
|
||||
raise ValueError(f"image_size must be a power of two ≥ 8, got {image_size}")
|
||||
return int(math.log2(image_size)) - 2 # 64 -> 4, 128 -> 5
|
||||
|
||||
|
||||
class DCGANGenerator(nn.Module):
|
||||
"""Maps (latent_dim x 1 x 1) -> (3 x image_size x image_size) in [-1, 1]."""
|
||||
|
||||
def __init__(self, latent_dim: int = 100, ngf: int = 64, image_size: int = 64):
|
||||
super().__init__()
|
||||
n_up = _n_upsamples(image_size) # 64 -> 4 upsamples after the 1->4 init
|
||||
max_mult = 2 ** (n_up - 1) # channel multiplier at the 4x4 stage
|
||||
|
||||
layers: list[nn.Module] = [
|
||||
# 1x1 -> 4x4
|
||||
nn.ConvTranspose2d(latent_dim, ngf * max_mult, 4, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(ngf * max_mult),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
# Each step halves the channel multiplier and doubles spatial size.
|
||||
mult = max_mult
|
||||
for _ in range(n_up - 1):
|
||||
layers += [
|
||||
nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf * mult // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
mult //= 2
|
||||
# Final layer to 3 channels, no BN, Tanh.
|
||||
layers += [
|
||||
nn.ConvTranspose2d(ngf * mult, 3, 4, 2, 1, bias=False),
|
||||
nn.Tanh(),
|
||||
]
|
||||
self.net = nn.Sequential(*layers)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(z)
|
||||
|
||||
|
||||
class DCGANDiscriminator(nn.Module):
|
||||
"""Maps (3 x image_size x image_size) -> scalar logit (no sigmoid)."""
|
||||
|
||||
def __init__(self, ndf: int = 64, image_size: int = 64):
|
||||
super().__init__()
|
||||
n_down = _n_upsamples(image_size)
|
||||
layers: list[nn.Module] = [
|
||||
# First layer: no BN
|
||||
nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
mult = 1
|
||||
for _ in range(n_down - 1):
|
||||
layers += [
|
||||
nn.Conv2d(ndf * mult, ndf * mult * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ndf * mult * 2),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
mult *= 2
|
||||
# 4x4 -> 1x1, scalar logit
|
||||
layers += [nn.Conv2d(ndf * mult, 1, 4, 1, 0, bias=False)]
|
||||
self.net = nn.Sequential(*layers)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x).view(x.size(0))
|
||||
|
||||
|
||||
def _build(cfg: dict):
|
||||
image_size = cfg.get("image_size", 64)
|
||||
return (
|
||||
DCGANGenerator(
|
||||
latent_dim=cfg.get("latent_dim", 100),
|
||||
ngf=cfg.get("ngf", 64),
|
||||
image_size=image_size,
|
||||
),
|
||||
DCGANDiscriminator(
|
||||
ndf=cfg.get("ndf", 64),
|
||||
image_size=image_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register("dcgan", _build, kind="dcgan")
|
||||
@@ -0,0 +1,133 @@
|
||||
"""WGAN-GP with spectral normalization, self-attention, and GroupNorm.
|
||||
|
||||
Improvements over the original:
|
||||
- Generator: BatchNorm -> GroupNorm (no batch-size coupling, stable with varied content)
|
||||
- Critic: InstanceNorm -> spectral normalization (principled Lipschitz constraint)
|
||||
- Both: one SAGAN-style self-attention block at the 32x32 feature map
|
||||
- Larger capacity: ngf=128, ndf=128
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from src.models import register
|
||||
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
nn.init.normal_(m.weight, 0.0, 0.02)
|
||||
elif isinstance(m, nn.GroupNorm) and m.weight is not None:
|
||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, in_ch: int):
|
||||
super().__init__()
|
||||
mid = max(in_ch // 8, 1)
|
||||
self.q = nn.Conv2d(in_ch, mid, 1, bias=False)
|
||||
self.k = nn.Conv2d(in_ch, mid, 1, bias=False)
|
||||
self.v = nn.Conv2d(in_ch, in_ch, 1, bias=False)
|
||||
self.gamma = nn.Parameter(torch.zeros(1))
|
||||
self._mid = mid
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, c, h, w = x.shape
|
||||
q = self.q(x).view(b, self._mid, -1).transpose(-2, -1) # (b, hw, mid)
|
||||
k = self.k(x).view(b, self._mid, -1) # (b, mid, hw)
|
||||
v = self.v(x).view(b, c, -1) # (b, c, hw)
|
||||
attn = torch.softmax(q @ k * self._mid ** -0.5, dim=-1) # (b, hw, hw)
|
||||
out = (v @ attn.transpose(-2, -1)).view(b, c, h, w)
|
||||
return x + self.gamma * out
|
||||
|
||||
|
||||
def _sn(module):
|
||||
"""Apply spectral normalization to a conv layer."""
|
||||
return nn.utils.spectral_norm(module)
|
||||
|
||||
|
||||
class WGANGenerator(nn.Module):
|
||||
"""Maps (latent_dim x 1 x 1) -> (3 x 128 x 128) in [-1, 1].
|
||||
|
||||
Upsampling path: 1 -> 4 -> 8 -> 16 (+attn) -> 32 -> 64 -> 128
|
||||
Self-attention sits at 16x16 (attention matrix 256x256 vs 1024x1024 at 32x32).
|
||||
"""
|
||||
|
||||
def __init__(self, latent_dim: int = 128, ngf: int = 64):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
# 1x1 -> 4x4
|
||||
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
|
||||
nn.GroupNorm(8, ngf * 8), nn.ReLU(True),
|
||||
# 4x4 -> 8x8
|
||||
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf * 4), nn.ReLU(True),
|
||||
# 8x8 -> 16x16
|
||||
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf * 2), nn.ReLU(True),
|
||||
)
|
||||
self.attn = SelfAttention(ngf * 2) # applied at 16x16
|
||||
self.out = nn.Sequential(
|
||||
# 16x16 -> 32x32
|
||||
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf), nn.ReLU(True),
|
||||
# 32x32 -> 64x64
|
||||
nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1, bias=False),
|
||||
nn.GroupNorm(8, ngf // 2), nn.ReLU(True),
|
||||
# 64x64 -> 128x128
|
||||
nn.ConvTranspose2d(ngf // 2, 3, 4, 2, 1, bias=False),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
||||
h = self.net(z)
|
||||
h = self.attn(h)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class WGANCritic(nn.Module):
|
||||
"""Critic (no sigmoid) for WGAN-GP. All conv layers are spectrally normalized.
|
||||
|
||||
Downsampling path: 128 -> 64 -> 32 -> 16 (+attn) -> 8 -> 4 -> score
|
||||
"""
|
||||
|
||||
def __init__(self, ndf: int = 64):
|
||||
super().__init__()
|
||||
self.down = nn.Sequential(
|
||||
# 128x128 -> 64x64 (no norm on first layer)
|
||||
_sn(nn.Conv2d(3, ndf // 2, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 64x64 -> 32x32
|
||||
_sn(nn.Conv2d(ndf // 2, ndf, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 32x32 -> 16x16
|
||||
_sn(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
self.attn = SelfAttention(ndf * 2) # applied at 16x16
|
||||
self.tail = nn.Sequential(
|
||||
# 16x16 -> 8x8
|
||||
_sn(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 8x8 -> 4x4
|
||||
_sn(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
# 4x4 -> 1x1
|
||||
_sn(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)),
|
||||
)
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.down(x)
|
||||
h = self.attn(h)
|
||||
return self.tail(h).view(x.size(0))
|
||||
|
||||
|
||||
def _build(cfg: dict):
|
||||
return (
|
||||
WGANGenerator(latent_dim=cfg.get("latent_dim", 128), ngf=cfg.get("ngf", 128)),
|
||||
WGANCritic(ndf=cfg.get("ndf", 128)),
|
||||
)
|
||||
|
||||
|
||||
register("wgan", _build, kind="wgan")
|
||||
@@ -0,0 +1,3 @@
|
||||
from src.training.trainer import train_dcgan
|
||||
|
||||
__all__ = ["train_dcgan"]
|
||||
@@ -0,0 +1,22 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EMA:
|
||||
"""Exponential moving average of model weights.
|
||||
|
||||
Maintains a shadow copy of the model. Call update() after each
|
||||
optimizer step. Sample from ema.model, never from the training model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, decay: float = 0.9999):
|
||||
self.decay = decay
|
||||
self.model = copy.deepcopy(model).eval()
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self, model: nn.Module) -> None:
|
||||
for p_ema, p in zip(self.model.parameters(), model.parameters()):
|
||||
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
|
||||
@@ -0,0 +1,49 @@
|
||||
"""FID evaluation helper.
|
||||
|
||||
Computes Fréchet Inception Distance between a fixed set of real images
|
||||
and a batch of generated images. Real images are stored as a tensor on CPU
|
||||
and moved to device only during evaluation — this avoids re-reading disk
|
||||
every call while keeping GPU memory free between evaluations.
|
||||
"""
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
|
||||
|
||||
class FIDEvaluator:
|
||||
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
self.n_real = n_real
|
||||
|
||||
# Cache real images as a CPU tensor ([-1, 1] range)
|
||||
imgs_list = []
|
||||
loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
|
||||
num_workers=4, drop_last=False)
|
||||
for batch in loader:
|
||||
imgs_list.append(batch.cpu())
|
||||
if sum(x.size(0) for x in imgs_list) >= n_real:
|
||||
break
|
||||
real = torch.cat(imgs_list)[:n_real]
|
||||
self._real = real # stored on CPU, shape (N, 3, H, W) in [-1, 1]
|
||||
|
||||
@torch.no_grad()
|
||||
def compute(self, fake_imgs: torch.Tensor) -> float:
|
||||
"""Compute FID score.
|
||||
|
||||
fake_imgs: float tensor in [-1, 1], shape (N, 3, H, W).
|
||||
N should be at least 2048 for a reliable score.
|
||||
"""
|
||||
fid = FrechetInceptionDistance(feature=2048, normalize=True).to(self.device)
|
||||
|
||||
# Feed real images in batches
|
||||
for i in range(0, self._real.size(0), 256):
|
||||
batch = (self._real[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
|
||||
fid.update(batch, real=True)
|
||||
|
||||
# Feed fake images in batches
|
||||
fake = fake_imgs.cpu()
|
||||
for i in range(0, fake.size(0), 256):
|
||||
batch = (fake[i:i + 256] * 0.5 + 0.5).clamp(0, 1).to(self.device)
|
||||
fid.update(batch, real=False)
|
||||
|
||||
return float(fid.compute())
|
||||
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.utils import save_image
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.training.ema import EMA
|
||||
from src.training.fid import FIDEvaluator
|
||||
|
||||
if hasattr(torch.amp, "GradScaler"):
|
||||
_GradScaler = torch.amp.GradScaler
|
||||
_autocast = torch.amp.autocast
|
||||
else:
|
||||
from torch.cuda.amp import GradScaler as _GS, autocast as _AC
|
||||
_GradScaler = lambda device="", enabled=True, **kw: _GS(**kw)
|
||||
_autocast = lambda device_type="", enabled=True, **kw: _AC(enabled=enabled, **kw)
|
||||
|
||||
|
||||
def _save_samples(generator_ema, samples_dir: Path, epoch: int, *, latent_dim: int, device) -> None:
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
with torch.no_grad():
|
||||
noise = torch.randn(16, latent_dim, 1, 1, device=device)
|
||||
imgs = generator_ema.model(noise) # EMA model, [-1, 1]
|
||||
imgs = (imgs.clamp(-1, 1) + 1.0) / 2.0 # -> [0, 1]
|
||||
save_image(imgs, samples_dir / f"epoch_{epoch:04d}.png", nrow=4)
|
||||
|
||||
|
||||
def train_dcgan(
|
||||
generator,
|
||||
discriminator,
|
||||
train_dataset,
|
||||
cfg: dict,
|
||||
*,
|
||||
save_dir,
|
||||
run_name: str,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""Vanilla DCGAN training loop with BCE loss (Radford et al., 2015).
|
||||
|
||||
Used as the Phase 1 baseline for cheap pipeline ablations. No gradient
|
||||
penalty, no n_critic, single G/D step per batch.
|
||||
"""
|
||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
generator = generator.to(device)
|
||||
discriminator = discriminator.to(device)
|
||||
|
||||
n_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
|
||||
n_d = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
|
||||
print(f"Generator: {n_g:,} params Discriminator: {n_d:,} params")
|
||||
|
||||
epochs = cfg["epochs"]
|
||||
batch_size = cfg["batch_size"]
|
||||
lr_g = cfg.get("lr_g", 2e-4)
|
||||
lr_d = cfg.get("lr_d", 2e-4)
|
||||
beta1 = cfg.get("beta1", 0.5)
|
||||
beta2 = cfg.get("beta2", 0.999)
|
||||
latent_dim = cfg.get("latent_dim", 100)
|
||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||
sample_interval = cfg.get("sample_interval", 10)
|
||||
fid_interval = cfg.get("fid_interval", 25)
|
||||
fid_n_real = cfg.get("fid_n_real", 5000)
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=min(4, os.cpu_count() or 1),
|
||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||
)
|
||||
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
||||
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
|
||||
bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
use_amp = device.type == "cuda"
|
||||
scaler_g = _GradScaler("cuda", enabled=use_amp)
|
||||
scaler_d = _GradScaler("cuda", enabled=use_amp)
|
||||
|
||||
ema = EMA(generator, decay=ema_decay)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
samples_dir = save_dir.parent / "samples" / run_name
|
||||
|
||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
||||
|
||||
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||
best_fid = float("inf")
|
||||
print(f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
generator.train()
|
||||
discriminator.train()
|
||||
g_sum = d_sum = real_sum = fake_sum = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for imgs in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
|
||||
imgs = imgs.to(device)
|
||||
bsz = imgs.size(0)
|
||||
real_labels = torch.ones(bsz, device=device)
|
||||
fake_labels = torch.zeros(bsz, device=device)
|
||||
|
||||
# ── Discriminator step ────────────────────────────────────────
|
||||
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
fake = generator(noise).detach()
|
||||
d_real = discriminator(imgs)
|
||||
d_fake = discriminator(fake)
|
||||
d_loss = bce(d_real, real_labels) + bce(d_fake, fake_labels)
|
||||
opt_d.zero_grad()
|
||||
scaler_d.scale(d_loss).backward()
|
||||
scaler_d.step(opt_d)
|
||||
scaler_d.update()
|
||||
|
||||
# ── Generator step ────────────────────────────────────────────
|
||||
noise = torch.randn(bsz, latent_dim, 1, 1, device=device)
|
||||
with _autocast("cuda", enabled=use_amp):
|
||||
g_loss = bce(discriminator(generator(noise)), real_labels)
|
||||
opt_g.zero_grad()
|
||||
scaler_g.scale(g_loss).backward()
|
||||
scaler_g.step(opt_g)
|
||||
scaler_g.update()
|
||||
ema.update(generator)
|
||||
|
||||
g_sum += g_loss.item()
|
||||
d_sum += d_loss.item()
|
||||
real_sum += d_real.mean().item()
|
||||
fake_sum += d_fake.mean().item()
|
||||
n_batches += 1
|
||||
|
||||
avg_g = g_sum / n_batches
|
||||
avg_d = d_sum / n_batches
|
||||
avg_r = real_sum / n_batches
|
||||
avg_f = fake_sum / n_batches
|
||||
history["g_loss"].append(avg_g)
|
||||
history["d_loss"].append(avg_d)
|
||||
history["d_real"].append(avg_r)
|
||||
history["d_fake"].append(avg_f)
|
||||
print(
|
||||
f"[{epoch:03d}/{epochs}] "
|
||||
f"G: {avg_g:.4f} D: {avg_d:.4f} D(real): {avg_r:.4f} D(fake): {avg_f:.4f}"
|
||||
)
|
||||
|
||||
if epoch % sample_interval == 0:
|
||||
_save_samples(ema, samples_dir, epoch, latent_dim=latent_dim, device=device)
|
||||
|
||||
if epoch % fid_interval == 0:
|
||||
generator.eval()
|
||||
with torch.no_grad():
|
||||
fake_imgs = torch.cat([
|
||||
generator(torch.randn(64, latent_dim, 1, 1, device=device))
|
||||
for _ in range(fid_n_real // 64 + 1)
|
||||
])[:fid_n_real]
|
||||
fid_score = fid_eval.compute(fake_imgs)
|
||||
history["fid"][epoch] = fid_score
|
||||
print(f" FID @ epoch {epoch}: {fid_score:.2f}")
|
||||
|
||||
if fid_score < best_fid:
|
||||
best_fid = fid_score
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_best_g.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_best_ema.pt")
|
||||
|
||||
torch.save(generator.state_dict(), save_dir / f"{run_name}_final_g.pt")
|
||||
torch.save(discriminator.state_dict(), save_dir / f"{run_name}_final_d.pt")
|
||||
torch.save(ema.model.state_dict(), save_dir / f"{run_name}_final_ema.pt")
|
||||
return history
|
||||
@@ -0,0 +1,3 @@
|
||||
from src.utils.config import load_config
|
||||
|
||||
__all__ = ["load_config"]
|
||||
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
# Resolves the extends chain first, then overlays shared.json underneath so
|
||||
# experiment-level keys always win over shared defaults.
|
||||
def load_config(config_path: str, shared_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
config_path = Path(config_path)
|
||||
cfg = _load_extends(config_path)
|
||||
|
||||
if shared_path is None:
|
||||
shared_path = config_path.parent.parent / "shared.json"
|
||||
else:
|
||||
shared_path = Path(shared_path)
|
||||
|
||||
if shared_path.exists():
|
||||
with open(shared_path) as f:
|
||||
shared_cfg = json.load(f)
|
||||
cfg = _deep_merge(shared_cfg, cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
# Pops the "extends" key and recursively merges the parent config underneath;
|
||||
# the seen set catches circular inheritance before it recurses infinitely.
|
||||
def _load_extends(config_path: Path, seen: Optional[set[Path]] = None) -> Dict[str, Any]:
|
||||
if seen is None:
|
||||
seen = set()
|
||||
resolved_path = config_path.resolve()
|
||||
if resolved_path in seen:
|
||||
chain = " -> ".join(str(p) for p in [*seen, resolved_path])
|
||||
raise ValueError(f"Circular config inheritance detected: {chain}")
|
||||
seen.add(resolved_path)
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
base_ref = cfg.pop("extends", None)
|
||||
if not base_ref:
|
||||
seen.remove(resolved_path)
|
||||
return cfg
|
||||
|
||||
base_path = (config_path.parent / base_ref).resolve()
|
||||
base_cfg = _load_extends(base_path, seen=seen)
|
||||
seen.remove(resolved_path)
|
||||
return _deep_merge(base_cfg, cfg)
|
||||
|
||||
|
||||
# Override always wins; nested dicts are merged recursively rather than replaced.
|
||||
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = base.copy()
|
||||
for key, value in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
@@ -0,0 +1,268 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pre-align face images using MTCNN landmarks + similarity transform.
|
||||
|
||||
This is the generator-side counterpart to classifier/tools/facecrop.py.
|
||||
Difference: classifier uses bbox crop+resize; the generator wants landmark-based
|
||||
alignment so the eyes, nose, and mouth land at fixed pixel positions in every
|
||||
training image — structurally consistent training data for the generator.
|
||||
|
||||
Output mirrors the source layout exactly:
|
||||
data/wiki/14/37591914.jpg -> cropped/generator/wiki/14/37591914.jpg
|
||||
|
||||
Resumable: already-aligned images are skipped by default.
|
||||
|
||||
Usage:
|
||||
python generator/tools/facecrop.py
|
||||
python generator/tools/facecrop.py --data-dir data --output-dir cropped/generator
|
||||
python generator/tools/facecrop.py --sources wiki --device cpu
|
||||
python generator/tools/facecrop.py --size 128
|
||||
python generator/tools/facecrop.py --no-skip-existing # reprocess everything
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
# Generator trains on real images only (wiki). The other sources are AI-generated
|
||||
# and aren't used as training targets for the generator, so we don't align them
|
||||
# by default. Pass --sources to override.
|
||||
SOURCES = ["wiki"]
|
||||
ALL_SOURCES = ["wiki", "inpainting", "text2img", "insight"]
|
||||
|
||||
# Reference landmark positions for a 128px aligned face.
|
||||
# Source: standard FFHQ-style alignment template (eyes at y=51, nose at y=71, mouth at y=95).
|
||||
# Scaled at runtime to match --size.
|
||||
REF_LANDMARKS_128 = [
|
||||
(38.0, 51.0), # left eye
|
||||
(90.0, 51.0), # right eye
|
||||
(64.0, 71.0), # nose tip
|
||||
(45.0, 95.0), # left mouth
|
||||
(83.0, 95.0), # right mouth
|
||||
]
|
||||
|
||||
_DETECTORS: dict[str, object] = {}
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
|
||||
p.add_argument("--output-dir", default="cropped/generator", help="Output root (default: cropped/generator)")
|
||||
p.add_argument("--size", type=int, default=128, help="Output image size in px, square (default: 128)")
|
||||
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
|
||||
p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE",
|
||||
help=f"Sources to process. Default: {', '.join(SOURCES)} (real images only). "
|
||||
f"All available: {', '.join(ALL_SOURCES)}")
|
||||
p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True,
|
||||
help="Skip images already present in output-dir (default: on, resumable)")
|
||||
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
|
||||
help="Re-process all images even if already aligned")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
# ── alignment helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
def _ref_landmarks(size: int):
|
||||
"""Reference landmarks scaled from the 128px template to `size`."""
|
||||
import numpy as np
|
||||
scale = size / 128.0
|
||||
return np.asarray(
|
||||
[(x * scale, y * scale) for x, y in REF_LANDMARKS_128],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
def _align_from_landmarks(img, landmarks, size: int):
|
||||
"""Apply similarity transform so detected landmarks map to reference positions."""
|
||||
import numpy as np
|
||||
from PIL import Image as PILImage
|
||||
from skimage.transform import SimilarityTransform, warp
|
||||
|
||||
src = np.asarray(landmarks, dtype=np.float32) # 5x2 detected
|
||||
dst = _ref_landmarks(size) # 5x2 reference
|
||||
|
||||
try:
|
||||
tform = SimilarityTransform.from_estimate(src, dst)
|
||||
except Exception:
|
||||
return None
|
||||
aligned = warp(
|
||||
np.asarray(img),
|
||||
tform.inverse,
|
||||
output_shape=(size, size),
|
||||
order=3,
|
||||
preserve_range=True,
|
||||
).astype(np.uint8)
|
||||
return PILImage.fromarray(aligned)
|
||||
|
||||
|
||||
def _center_crop(img, size: int):
|
||||
from PIL import Image as PILImage
|
||||
w, h = img.size
|
||||
side = min(w, h)
|
||||
left, top = (w - side) // 2, (h - side) // 2
|
||||
return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR)
|
||||
|
||||
|
||||
def _get_detectors(device: str):
|
||||
"""Return (standard, relaxed) MTCNN detectors with landmarks enabled."""
|
||||
if device in _DETECTORS:
|
||||
return _DETECTORS[device]
|
||||
|
||||
from facenet_pytorch import MTCNN
|
||||
|
||||
standard = MTCNN(
|
||||
keep_all=False, select_largest=True,
|
||||
min_face_size=15,
|
||||
device=device, post_process=False,
|
||||
)
|
||||
relaxed = MTCNN(
|
||||
keep_all=False, select_largest=True,
|
||||
min_face_size=10,
|
||||
thresholds=[0.5, 0.6, 0.6],
|
||||
device=device, post_process=False,
|
||||
)
|
||||
_DETECTORS[device] = (standard, relaxed)
|
||||
return standard, relaxed
|
||||
|
||||
|
||||
# ── main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
data_dir = Path(args.data_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
sources = args.sources or SOURCES
|
||||
|
||||
if not data_dir.exists():
|
||||
print(f"Error: data directory not found: {data_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
for src in sources:
|
||||
if not (data_dir / src).exists():
|
||||
print(f"Error: source directory not found: {data_dir / src}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import facenet_pytorch # noqa: F401
|
||||
import skimage # noqa: F401
|
||||
except ImportError as exc:
|
||||
print(f"Error: missing dependency ({exc}).")
|
||||
print(" Run: pip install facenet-pytorch scikit-image")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Data dir: {data_dir.resolve()}")
|
||||
print(f"Output dir: {output_dir.resolve()}")
|
||||
print(f"Sources: {', '.join(sources)}")
|
||||
print(f"Device: {device}")
|
||||
print(f"Size: {args.size}px")
|
||||
print(f"Skip exist: {args.skip_existing}")
|
||||
|
||||
standard, relaxed = _get_detectors(device)
|
||||
|
||||
all_paths: list[Path] = []
|
||||
for src in sources:
|
||||
for subdir in sorted((data_dir / src).iterdir()):
|
||||
if subdir.is_dir():
|
||||
all_paths.extend(sorted(subdir.glob("*.jpg")))
|
||||
|
||||
print(f"\nTotal images: {len(all_paths):,}\n")
|
||||
|
||||
n_processed = n_skipped = n_error = 0
|
||||
src_stats: dict[str, dict] = {
|
||||
s: {"aligned": 0, "retry": 0, "fallback": 0} for s in sources
|
||||
}
|
||||
|
||||
for img_path in tqdm(all_paths, desc="Aligning", unit="img"):
|
||||
rel = img_path.relative_to(data_dir)
|
||||
out_path = output_dir / rel
|
||||
src_name = img_path.parent.parent.name
|
||||
|
||||
if args.skip_existing and out_path.exists():
|
||||
n_skipped += 1
|
||||
continue
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
except Exception as exc:
|
||||
tqdm.write(f"[WARN] Cannot open {img_path.name}: {exc}")
|
||||
n_error += 1
|
||||
continue
|
||||
|
||||
aligned = None
|
||||
try:
|
||||
# Pass 1: detect landmarks on the original image
|
||||
_, _, landmarks = standard.detect(img, landmarks=True)
|
||||
if landmarks is not None and len(landmarks) > 0:
|
||||
aligned = _align_from_landmarks(img, landmarks[0], args.size)
|
||||
if aligned is not None:
|
||||
src_stats[src_name]["aligned"] += 1
|
||||
|
||||
# Pass 2: upscale 2x and retry with relaxed thresholds
|
||||
if aligned is None:
|
||||
w, h = img.size
|
||||
img2x = img.resize((w * 2, h * 2), Image.BILINEAR)
|
||||
_, _, landmarks2 = relaxed.detect(img2x, landmarks=True)
|
||||
if landmarks2 is not None and len(landmarks2) > 0:
|
||||
lm_orig = [(x / 2, y / 2) for x, y in landmarks2[0]]
|
||||
aligned = _align_from_landmarks(img, lm_orig, args.size)
|
||||
if aligned is not None:
|
||||
src_stats[src_name]["retry"] += 1
|
||||
|
||||
if aligned is None:
|
||||
aligned = _center_crop(img, args.size)
|
||||
src_stats[src_name]["fallback"] += 1
|
||||
except Exception as exc:
|
||||
tqdm.write(f"[WARN] Detection failed for {img_path.name}: {exc}")
|
||||
aligned = _center_crop(img, args.size)
|
||||
src_stats[src_name]["fallback"] += 1
|
||||
|
||||
aligned.save(out_path, quality=95)
|
||||
n_processed += 1
|
||||
|
||||
total = n_processed + n_skipped
|
||||
n_aligned = sum(s["aligned"] for s in src_stats.values())
|
||||
n_retry = sum(s["retry"] for s in src_stats.values())
|
||||
n_fallback = sum(s["fallback"] for s in src_stats.values())
|
||||
denom = max(n_processed, 1)
|
||||
|
||||
print(f"\n{'─' * 55}")
|
||||
print(f" Total images : {total:>8,}")
|
||||
print(f" Processed : {n_processed:>8,}")
|
||||
print(f" Skipped (existed) : {n_skipped:>8,}")
|
||||
print(f" Errors : {n_error:>8,}")
|
||||
print(f" Pass-1 aligned : {n_aligned:>8,} ({n_aligned / denom:.1%})")
|
||||
print(f" Pass-2 aligned : {n_retry:>8,} ({n_retry / denom:.1%}) ← 2x upscale retry")
|
||||
print(f" Centre fallback : {n_fallback:>8,} ({n_fallback / denom:.1%})")
|
||||
print()
|
||||
print(f" {'Source':<12} {'pass-1':>8} {'pass-2':>8} {'fallback':>8} {'fallback%':>10}")
|
||||
print(f" {'─'*12} {'─'*8} {'─'*8} {'─'*8} {'─'*10}")
|
||||
for src in sources:
|
||||
s = src_stats[src]
|
||||
total_src = s["aligned"] + s["retry"] + s["fallback"]
|
||||
fb_pct = s["fallback"] / max(total_src, 1)
|
||||
print(f" {src:<12} {s['aligned']:>8,} {s['retry']:>8,} {s['fallback']:>8,} {fb_pct:>9.1%}")
|
||||
print(f"{'─' * 55}")
|
||||
print(f" Output: {output_dir.resolve()}")
|
||||
print()
|
||||
print("Next step — point your config at the aligned dataset:")
|
||||
print(f' "data_dir": "{output_dir}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user