Implementing inference sampling and quick tool polish

This commit is contained in:
Johnny Fernandes
2026-05-04 21:05:06 +01:00
parent a750d177fa
commit a235dfd5f8
65 changed files with 151 additions and 64 deletions
Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 106 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

+31 -64
View File
@@ -1,24 +1,7 @@
#!/usr/bin/env python3 # Pre-align face images using MTCNN landmarks + similarity transform.
""" # Generator-side counterpart to classifier/tools/facecrop.py — uses landmark-based alignment
Pre-align face images using MTCNN landmarks + similarity transform. # (not bbox crop) so eyes, nose, and mouth land at fixed pixel positions in every image.
# Usage: python generator/tools/facecrop.py [--data-dir data] [--output-dir cropped/generator] [--size 128]
This is the generator-side counterpart to classifier/tools/facecrop.py.
Difference: classifier uses bbox crop+resize; the generator wants landmark-based
alignment so the eyes, nose, and mouth land at fixed pixel positions in every
training image — structurally consistent training data for the generator.
Output mirrors the source layout exactly:
data/wiki/14/37591914.jpg -> cropped/generator/wiki/14/37591914.jpg
Resumable: already-aligned images are skipped by default.
Usage:
python generator/tools/facecrop.py
python generator/tools/facecrop.py --data-dir data --output-dir cropped/generator
python generator/tools/facecrop.py --sources wiki --device cpu
python generator/tools/facecrop.py --size 128
python generator/tools/facecrop.py --no-skip-existing # reprocess everything
"""
import argparse import argparse
import sys import sys
import warnings import warnings
@@ -32,7 +15,7 @@ sys.path.insert(0, str(ROOT))
# Generator trains on real images only (wiki). The other sources are AI-generated # Generator trains on real images only (wiki). The other sources are AI-generated
# and aren't used as training targets for the generator, so we don't align them # and aren't used as training targets for the generator, so we don't align them
# by default. Pass --sources to override. # by default. Pass --sources to override.
SOURCES = ["wiki"] SOURCES = ["wiki"]
ALL_SOURCES = ["wiki", "inpainting", "text2img", "insight"] ALL_SOURCES = ["wiki", "inpainting", "text2img", "insight"]
# Reference landmark positions for a 128px aligned face. # Reference landmark positions for a 128px aligned face.
@@ -50,44 +33,37 @@ _DETECTORS: dict[str, object] = {}
def parse_args(): def parse_args():
p = argparse.ArgumentParser( p = argparse.ArgumentParser()
description=__doc__, p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
p.add_argument("--output-dir", default="cropped/generator", help="Output root (default: cropped/generator)") p.add_argument("--output-dir", default="cropped/generator", help="Output root (default: cropped/generator)")
p.add_argument("--size", type=int, default=128, help="Output image size in px, square (default: 128)") p.add_argument("--size", type=int, default=128, help="Output image size in px, square (default: 128)")
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect") p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE", p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE",
help=f"Sources to process. Default: {', '.join(SOURCES)} (real images only). " help=f"Sources to process. Default: {', '.join(SOURCES)}. All: {', '.join(ALL_SOURCES)}")
f"All available: {', '.join(ALL_SOURCES)}")
p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True, p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True,
help="Skip images already present in output-dir (default: on, resumable)") help="Skip images already present in output-dir (default: on)")
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false", p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
help="Re-process all images even if already aligned") help="Re-process all images even if already aligned")
return p.parse_args() return p.parse_args()
# ── alignment helpers ───────────────────────────────────────────────────────── # ── Alignment helpers ─────────────────────────────────────────────────────────
# Scale the 128px reference template to match the target size
def _ref_landmarks(size: int): def _ref_landmarks(size: int):
"""Reference landmarks scaled from the 128px template to `size`."""
import numpy as np import numpy as np
scale = size / 128.0 scale = size / 128.0
return np.asarray( return np.asarray([(x * scale, y * scale) for x, y in REF_LANDMARKS_128], dtype=np.float32)
[(x * scale, y * scale) for x, y in REF_LANDMARKS_128],
dtype=np.float32,
)
# Apply similarity transform so detected landmarks map to reference positions
def _align_from_landmarks(img, landmarks, size: int): def _align_from_landmarks(img, landmarks, size: int):
"""Apply similarity transform so detected landmarks map to reference positions."""
import numpy as np import numpy as np
from PIL import Image as PILImage from PIL import Image as PILImage
from skimage.transform import SimilarityTransform, warp from skimage.transform import SimilarityTransform, warp
src = np.asarray(landmarks, dtype=np.float32) # 5x2 detected src = np.asarray(landmarks, dtype=np.float32)
dst = _ref_landmarks(size) # 5x2 reference dst = _ref_landmarks(size)
try: try:
tform = SimilarityTransform.from_estimate(src, dst) tform = SimilarityTransform.from_estimate(src, dst)
@@ -105,35 +81,28 @@ def _align_from_landmarks(img, landmarks, size: int):
def _center_crop(img, size: int): def _center_crop(img, size: int):
from PIL import Image as PILImage from PIL import Image as PILImage
w, h = img.size w, h = img.size
side = min(w, h) side = min(w, h)
left, top = (w - side) // 2, (h - side) // 2 left, top = (w - side) // 2, (h - side) // 2
return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR) return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR)
# Returns (standard, relaxed) MTCNN detectors; cached per device
def _get_detectors(device: str): def _get_detectors(device: str):
"""Return (standard, relaxed) MTCNN detectors with landmarks enabled."""
if device in _DETECTORS: if device in _DETECTORS:
return _DETECTORS[device] return _DETECTORS[device]
from facenet_pytorch import MTCNN from facenet_pytorch import MTCNN
standard = MTCNN( standard = MTCNN(keep_all=False, select_largest=True, min_face_size=15,
keep_all=False, select_largest=True, device=device, post_process=False)
min_face_size=15, relaxed = MTCNN(keep_all=False, select_largest=True, min_face_size=10,
device=device, post_process=False, thresholds=[0.5, 0.6, 0.6], device=device, post_process=False)
)
relaxed = MTCNN(
keep_all=False, select_largest=True,
min_face_size=10,
thresholds=[0.5, 0.6, 0.6],
device=device, post_process=False,
)
_DETECTORS[device] = (standard, relaxed) _DETECTORS[device] = (standard, relaxed)
return standard, relaxed return standard, relaxed
# ── main ────────────────────────────────────────────────────────────────────── # ── Main ──────────────────────────────────────────────────────────────────────
def main(): def main():
args = parse_args() args = parse_args()
@@ -158,7 +127,7 @@ def main():
try: try:
import facenet_pytorch # noqa: F401 import facenet_pytorch # noqa: F401
import skimage # noqa: F401 import skimage # noqa: F401
except ImportError as exc: except ImportError as exc:
print(f"Error: missing dependency ({exc}).") print(f"Error: missing dependency ({exc}).")
print(" Run: pip install facenet-pytorch scikit-image") print(" Run: pip install facenet-pytorch scikit-image")
@@ -182,9 +151,7 @@ def main():
print(f"\nTotal images: {len(all_paths):,}\n") print(f"\nTotal images: {len(all_paths):,}\n")
n_processed = n_skipped = n_error = 0 n_processed = n_skipped = n_error = 0
src_stats: dict[str, dict] = { src_stats: dict[str, dict] = {s: {"aligned": 0, "retry": 0, "fallback": 0} for s in sources}
s: {"aligned": 0, "retry": 0, "fallback": 0} for s in sources
}
for img_path in tqdm(all_paths, desc="Aligning", unit="img"): for img_path in tqdm(all_paths, desc="Aligning", unit="img"):
rel = img_path.relative_to(data_dir) rel = img_path.relative_to(data_dir)
@@ -215,8 +182,8 @@ def main():
# Pass 2: upscale 2x and retry with relaxed thresholds # Pass 2: upscale 2x and retry with relaxed thresholds
if aligned is None: if aligned is None:
w, h = img.size w, h = img.size
img2x = img.resize((w * 2, h * 2), Image.BILINEAR) img2x = img.resize((w * 2, h * 2), Image.BILINEAR)
_, _, landmarks2 = relaxed.detect(img2x, landmarks=True) _, _, landmarks2 = relaxed.detect(img2x, landmarks=True)
if landmarks2 is not None and len(landmarks2) > 0: if landmarks2 is not None and len(landmarks2) > 0:
lm_orig = [(x / 2, y / 2) for x, y in landmarks2[0]] lm_orig = [(x / 2, y / 2) for x, y in landmarks2[0]]
@@ -253,9 +220,9 @@ def main():
print(f" {'Source':<12} {'pass-1':>8} {'pass-2':>8} {'fallback':>8} {'fallback%':>10}") print(f" {'Source':<12} {'pass-1':>8} {'pass-2':>8} {'fallback':>8} {'fallback%':>10}")
print(f" {''*12} {''*8} {''*8} {''*8} {''*10}") print(f" {''*12} {''*8} {''*8} {''*8} {''*10}")
for src in sources: for src in sources:
s = src_stats[src] s = src_stats[src]
total_src = s["aligned"] + s["retry"] + s["fallback"] total_src = s["aligned"] + s["retry"] + s["fallback"]
fb_pct = s["fallback"] / max(total_src, 1) fb_pct = s["fallback"] / max(total_src, 1)
print(f" {src:<12} {s['aligned']:>8,} {s['retry']:>8,} {s['fallback']:>8,} {fb_pct:>9.1%}") print(f" {src:<12} {s['aligned']:>8,} {s['retry']:>8,} {s['fallback']:>8,} {fb_pct:>9.1%}")
print(f"{'' * 55}") print(f"{'' * 55}")
print(f" Output: {output_dir.resolve()}") print(f" Output: {output_dir.resolve()}")
+120
View File
@@ -0,0 +1,120 @@
# Generate 4x4 sample grids from Phase 5 EMA checkpoints, matching training visualization.
# Usage: python generator/tools/sampling.py [--samples N] [--models p5_gan p5_vae p5_ddpm]
import argparse
import json
import sys
from pathlib import Path
import torch
from torchvision.utils import save_image
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
MODELS_DIR = ROOT / "outputs" / "models"
CFG_DIR = ROOT / "configs" / "phase5"
# final_ema first: best_ema is the lowest-FID snapshot, which for slowly-converging models
# (e.g. DDPM) can be saved while the EMA shadow is still close to random init.
PHASE5_RUNS = {
"p5_gan": {"type": "gan", "config": "p5_gan.json", "checkpoints": ["p5_gan_final_ema.pt", "p5_gan_best_ema.pt"]},
"p5_vae": {"type": "vae", "config": "p5_vae.json", "checkpoints": ["p5_vae_final_ema.pt", "p5_vae_best_ema.pt"]},
"p5_ddpm": {"type": "ddpm", "config": "p5_ddpm.json", "checkpoints": ["p5_ddpm_final_ema.pt", "p5_ddpm_best_ema.pt"]},
}
def _load_cfg(name: str) -> dict:
with open(CFG_DIR / PHASE5_RUNS[name]["config"]) as f:
return json.load(f)
def _load_model(name: str, cfg: dict, device: torch.device):
from src.models import get_model
result, _ = get_model(cfg)
# GAN returns (generator, critic); all others return the model directly
model = result[0] if isinstance(result, tuple) else result
for ckpt_name in PHASE5_RUNS[name]["checkpoints"]:
ckpt_path = MODELS_DIR / ckpt_name
if not ckpt_path.exists():
print(f" [{name}] {ckpt_name} not found, trying next")
continue
state_dict = torch.load(ckpt_path, map_location=device, weights_only=True)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
print(f" [{name}] {ckpt_name}: missing={len(missing)} unexpected={len(unexpected)} — trying next")
continue
print(f" [{name}] Loaded {ckpt_name}")
return model.to(device).eval()
raise FileNotFoundError(f"No usable EMA checkpoint found for {name}")
# Returns a (16, C, H, W) tensor in [0, 1] ready for save_image with nrow=4
@torch.no_grad()
def _generate_grid(
name: str, model, cfg: dict, device: torch.device,
*, truncation: float | None = None,
) -> torch.Tensor:
kind = PHASE5_RUNS[name]["type"]
image_size = cfg.get("image_size", 64)
if kind == "gan":
z = torch.randn(16, cfg.get("latent_dim", 128), 1, 1, device=device)
if truncation is not None and truncation > 0:
z = z.clamp(-truncation, truncation)
imgs = model(z)
elif kind == "vae":
imgs = model.sample(16, device)
elif kind == "ddpm":
from src.training.diffusion import cosine_betas, linear_betas, make_alpha_bars, ddim_sample
T = cfg.get("T", 1000)
betas = (cosine_betas(T) if cfg.get("noise_schedule", "cosine") == "cosine" else linear_betas(T)).to(device)
alpha_bars = make_alpha_bars(betas)
# n_steps=50 mirrors the training-time preview; cfg's ddim_steps is for FID only
imgs = ddim_sample(
model, 16, image_size, alpha_bars,
n_steps=50, pred_type=cfg.get("pred_type", "eps"),
device=str(device), batch_size=16,
)
else:
raise ValueError(f"Unknown model type: {kind}")
return (imgs.clamp(-1, 1) + 1.0) / 2.0
def main():
p = argparse.ArgumentParser()
p.add_argument("--samples", type=int, default=10, help="Number of 4x4 grids per model")
p.add_argument("--output-dir", type=Path,
default=ROOT / "outputs" / "samples" / "final_comparison")
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--truncation", type=float, default=None,
help="Optional GAN latent truncation (lower=less diversity but sharper)")
p.add_argument("--models", nargs="+", choices=list(PHASE5_RUNS.keys()),
default=list(PHASE5_RUNS.keys()))
args = p.parse_args()
device = torch.device(args.device)
args.output_dir.mkdir(parents=True, exist_ok=True)
for name in args.models:
print(f"[{name}] Loading checkpoint...")
cfg = _load_cfg(name)
model = _load_model(name, cfg, device)
out_dir = args.output_dir / name
out_dir.mkdir(exist_ok=True)
print(f"[{name}] Generating {args.samples} grids -> {out_dir}")
for i in range(1, args.samples + 1):
grid = _generate_grid(name, model, cfg, device, truncation=args.truncation)
save_image(grid, out_dir / f"grid_{i:04d}.png", nrow=4)
print("Done.")
if __name__ == "__main__":
main()