Implementing inference sampling and quick tool polish

This commit is contained in:
Johnny Fernandes
2026-05-04 21:05:06 +01:00
parent a750d177fa
commit a235dfd5f8
65 changed files with 151 additions and 64 deletions
+31 -64
View File
@@ -1,24 +1,7 @@
#!/usr/bin/env python3
"""
Pre-align face images using MTCNN landmarks + similarity transform.
This is the generator-side counterpart to classifier/tools/facecrop.py.
Difference: classifier uses bbox crop+resize; the generator wants landmark-based
alignment so the eyes, nose, and mouth land at fixed pixel positions in every
training image — structurally consistent training data for the generator.
Output mirrors the source layout exactly:
data/wiki/14/37591914.jpg -> cropped/generator/wiki/14/37591914.jpg
Resumable: already-aligned images are skipped by default.
Usage:
python generator/tools/facecrop.py
python generator/tools/facecrop.py --data-dir data --output-dir cropped/generator
python generator/tools/facecrop.py --sources wiki --device cpu
python generator/tools/facecrop.py --size 128
python generator/tools/facecrop.py --no-skip-existing # reprocess everything
"""
# Pre-align face images using MTCNN landmarks + similarity transform.
# Generator-side counterpart to classifier/tools/facecrop.py — uses landmark-based alignment
# (not bbox crop) so eyes, nose, and mouth land at fixed pixel positions in every image.
# Usage: python generator/tools/facecrop.py [--data-dir data] [--output-dir cropped/generator] [--size 128]
import argparse
import sys
import warnings
@@ -32,7 +15,7 @@ sys.path.insert(0, str(ROOT))
# Generator trains on real images only (wiki). The other sources are AI-generated
# and aren't used as training targets for the generator, so we don't align them
# by default. Pass --sources to override.
SOURCES = ["wiki"]
SOURCES = ["wiki"]
ALL_SOURCES = ["wiki", "inpainting", "text2img", "insight"]
# Reference landmark positions for a 128px aligned face.
@@ -50,44 +33,37 @@ _DETECTORS: dict[str, object] = {}
def parse_args():
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
p = argparse.ArgumentParser()
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
p.add_argument("--output-dir", default="cropped/generator", help="Output root (default: cropped/generator)")
p.add_argument("--size", type=int, default=128, help="Output image size in px, square (default: 128)")
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
p.add_argument("--size", type=int, default=128, help="Output image size in px, square (default: 128)")
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE",
help=f"Sources to process. Default: {', '.join(SOURCES)} (real images only). "
f"All available: {', '.join(ALL_SOURCES)}")
help=f"Sources to process. Default: {', '.join(SOURCES)}. All: {', '.join(ALL_SOURCES)}")
p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True,
help="Skip images already present in output-dir (default: on, resumable)")
help="Skip images already present in output-dir (default: on)")
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
help="Re-process all images even if already aligned")
return p.parse_args()
# ── alignment helpers ─────────────────────────────────────────────────────────
# ── Alignment helpers ─────────────────────────────────────────────────────────
# Scale the 128px reference template to match the target size
def _ref_landmarks(size: int):
"""Reference landmarks scaled from the 128px template to `size`."""
import numpy as np
scale = size / 128.0
return np.asarray(
[(x * scale, y * scale) for x, y in REF_LANDMARKS_128],
dtype=np.float32,
)
return np.asarray([(x * scale, y * scale) for x, y in REF_LANDMARKS_128], dtype=np.float32)
# Apply similarity transform so detected landmarks map to reference positions
def _align_from_landmarks(img, landmarks, size: int):
"""Apply similarity transform so detected landmarks map to reference positions."""
import numpy as np
from PIL import Image as PILImage
from skimage.transform import SimilarityTransform, warp
src = np.asarray(landmarks, dtype=np.float32) # 5x2 detected
dst = _ref_landmarks(size) # 5x2 reference
src = np.asarray(landmarks, dtype=np.float32)
dst = _ref_landmarks(size)
try:
tform = SimilarityTransform.from_estimate(src, dst)
@@ -105,35 +81,28 @@ def _align_from_landmarks(img, landmarks, size: int):
def _center_crop(img, size: int):
from PIL import Image as PILImage
w, h = img.size
side = min(w, h)
w, h = img.size
side = min(w, h)
left, top = (w - side) // 2, (h - side) // 2
return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR)
# Returns (standard, relaxed) MTCNN detectors; cached per device
def _get_detectors(device: str):
"""Return (standard, relaxed) MTCNN detectors with landmarks enabled."""
if device in _DETECTORS:
return _DETECTORS[device]
from facenet_pytorch import MTCNN
standard = MTCNN(
keep_all=False, select_largest=True,
min_face_size=15,
device=device, post_process=False,
)
relaxed = MTCNN(
keep_all=False, select_largest=True,
min_face_size=10,
thresholds=[0.5, 0.6, 0.6],
device=device, post_process=False,
)
standard = MTCNN(keep_all=False, select_largest=True, min_face_size=15,
device=device, post_process=False)
relaxed = MTCNN(keep_all=False, select_largest=True, min_face_size=10,
thresholds=[0.5, 0.6, 0.6], device=device, post_process=False)
_DETECTORS[device] = (standard, relaxed)
return standard, relaxed
# ── main ──────────────────────────────────────────────────────────────────────
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
@@ -158,7 +127,7 @@ def main():
try:
import facenet_pytorch # noqa: F401
import skimage # noqa: F401
import skimage # noqa: F401
except ImportError as exc:
print(f"Error: missing dependency ({exc}).")
print(" Run: pip install facenet-pytorch scikit-image")
@@ -182,9 +151,7 @@ def main():
print(f"\nTotal images: {len(all_paths):,}\n")
n_processed = n_skipped = n_error = 0
src_stats: dict[str, dict] = {
s: {"aligned": 0, "retry": 0, "fallback": 0} for s in sources
}
src_stats: dict[str, dict] = {s: {"aligned": 0, "retry": 0, "fallback": 0} for s in sources}
for img_path in tqdm(all_paths, desc="Aligning", unit="img"):
rel = img_path.relative_to(data_dir)
@@ -215,8 +182,8 @@ def main():
# Pass 2: upscale 2x and retry with relaxed thresholds
if aligned is None:
w, h = img.size
img2x = img.resize((w * 2, h * 2), Image.BILINEAR)
w, h = img.size
img2x = img.resize((w * 2, h * 2), Image.BILINEAR)
_, _, landmarks2 = relaxed.detect(img2x, landmarks=True)
if landmarks2 is not None and len(landmarks2) > 0:
lm_orig = [(x / 2, y / 2) for x, y in landmarks2[0]]
@@ -253,9 +220,9 @@ def main():
print(f" {'Source':<12} {'pass-1':>8} {'pass-2':>8} {'fallback':>8} {'fallback%':>10}")
print(f" {''*12} {''*8} {''*8} {''*8} {''*10}")
for src in sources:
s = src_stats[src]
s = src_stats[src]
total_src = s["aligned"] + s["retry"] + s["fallback"]
fb_pct = s["fallback"] / max(total_src, 1)
fb_pct = s["fallback"] / max(total_src, 1)
print(f" {src:<12} {s['aligned']:>8,} {s['retry']:>8,} {s['fallback']:>8,} {fb_pct:>9.1%}")
print(f"{'' * 55}")
print(f" Output: {output_dir.resolve()}")