#!/usr/bin/env python3 """ Pre-align face images using MTCNN landmarks + similarity transform. This is the generator-side counterpart to classifier/tools/facecrop.py. Difference: classifier uses bbox crop+resize; the generator wants landmark-based alignment so the eyes, nose, and mouth land at fixed pixel positions in every training image — structurally consistent training data for the generator. Output mirrors the source layout exactly: data/wiki/14/37591914.jpg -> cropped/generator/wiki/14/37591914.jpg Resumable: already-aligned images are skipped by default. Usage: python generator/tools/facecrop.py python generator/tools/facecrop.py --data-dir data --output-dir cropped/generator python generator/tools/facecrop.py --sources wiki --device cpu python generator/tools/facecrop.py --size 128 python generator/tools/facecrop.py --no-skip-existing # reprocess everything """ import argparse import sys import warnings from pathlib import Path warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning) ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT)) # Generator trains on real images only (wiki). The other sources are AI-generated # and aren't used as training targets for the generator, so we don't align them # by default. Pass --sources to override. SOURCES = ["wiki"] ALL_SOURCES = ["wiki", "inpainting", "text2img", "insight"] # Reference landmark positions for a 128px aligned face. # Source: standard FFHQ-style alignment template (eyes at y=51, nose at y=71, mouth at y=95). # Scaled at runtime to match --size. REF_LANDMARKS_128 = [ (38.0, 51.0), # left eye (90.0, 51.0), # right eye (64.0, 71.0), # nose tip (45.0, 95.0), # left mouth (83.0, 95.0), # right mouth ] _DETECTORS: dict[str, object] = {} def parse_args(): p = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter, ) p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)") p.add_argument("--output-dir", default="cropped/generator", help="Output root (default: cropped/generator)") p.add_argument("--size", type=int, default=128, help="Output image size in px, square (default: 128)") p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect") p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE", help=f"Sources to process. Default: {', '.join(SOURCES)} (real images only). " f"All available: {', '.join(ALL_SOURCES)}") p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True, help="Skip images already present in output-dir (default: on, resumable)") p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false", help="Re-process all images even if already aligned") return p.parse_args() # ── alignment helpers ───────────────────────────────────────────────────────── def _ref_landmarks(size: int): """Reference landmarks scaled from the 128px template to `size`.""" import numpy as np scale = size / 128.0 return np.asarray( [(x * scale, y * scale) for x, y in REF_LANDMARKS_128], dtype=np.float32, ) def _align_from_landmarks(img, landmarks, size: int): """Apply similarity transform so detected landmarks map to reference positions.""" import numpy as np from PIL import Image as PILImage from skimage.transform import SimilarityTransform, warp src = np.asarray(landmarks, dtype=np.float32) # 5x2 detected dst = _ref_landmarks(size) # 5x2 reference try: tform = SimilarityTransform.from_estimate(src, dst) except Exception: return None aligned = warp( np.asarray(img), tform.inverse, output_shape=(size, size), order=3, preserve_range=True, ).astype(np.uint8) return PILImage.fromarray(aligned) def _center_crop(img, size: int): from PIL import Image as PILImage w, h = img.size side = min(w, h) left, top = (w - side) // 2, (h - side) // 2 return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR) def _get_detectors(device: str): """Return (standard, relaxed) MTCNN detectors with landmarks enabled.""" if device in _DETECTORS: return _DETECTORS[device] from facenet_pytorch import MTCNN standard = MTCNN( keep_all=False, select_largest=True, min_face_size=15, device=device, post_process=False, ) relaxed = MTCNN( keep_all=False, select_largest=True, min_face_size=10, thresholds=[0.5, 0.6, 0.6], device=device, post_process=False, ) _DETECTORS[device] = (standard, relaxed) return standard, relaxed # ── main ────────────────────────────────────────────────────────────────────── def main(): args = parse_args() import torch from PIL import Image from tqdm import tqdm data_dir = Path(args.data_dir) output_dir = Path(args.output_dir) device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") sources = args.sources or SOURCES if not data_dir.exists(): print(f"Error: data directory not found: {data_dir}") sys.exit(1) for src in sources: if not (data_dir / src).exists(): print(f"Error: source directory not found: {data_dir / src}") sys.exit(1) try: import facenet_pytorch # noqa: F401 import skimage # noqa: F401 except ImportError as exc: print(f"Error: missing dependency ({exc}).") print(" Run: pip install facenet-pytorch scikit-image") sys.exit(1) print(f"Data dir: {data_dir.resolve()}") print(f"Output dir: {output_dir.resolve()}") print(f"Sources: {', '.join(sources)}") print(f"Device: {device}") print(f"Size: {args.size}px") print(f"Skip exist: {args.skip_existing}") standard, relaxed = _get_detectors(device) all_paths: list[Path] = [] for src in sources: for subdir in sorted((data_dir / src).iterdir()): if subdir.is_dir(): all_paths.extend(sorted(subdir.glob("*.jpg"))) print(f"\nTotal images: {len(all_paths):,}\n") n_processed = n_skipped = n_error = 0 src_stats: dict[str, dict] = { s: {"aligned": 0, "retry": 0, "fallback": 0} for s in sources } for img_path in tqdm(all_paths, desc="Aligning", unit="img"): rel = img_path.relative_to(data_dir) out_path = output_dir / rel src_name = img_path.parent.parent.name if args.skip_existing and out_path.exists(): n_skipped += 1 continue out_path.parent.mkdir(parents=True, exist_ok=True) try: img = Image.open(img_path).convert("RGB") except Exception as exc: tqdm.write(f"[WARN] Cannot open {img_path.name}: {exc}") n_error += 1 continue aligned = None try: # Pass 1: detect landmarks on the original image _, _, landmarks = standard.detect(img, landmarks=True) if landmarks is not None and len(landmarks) > 0: aligned = _align_from_landmarks(img, landmarks[0], args.size) if aligned is not None: src_stats[src_name]["aligned"] += 1 # Pass 2: upscale 2x and retry with relaxed thresholds if aligned is None: w, h = img.size img2x = img.resize((w * 2, h * 2), Image.BILINEAR) _, _, landmarks2 = relaxed.detect(img2x, landmarks=True) if landmarks2 is not None and len(landmarks2) > 0: lm_orig = [(x / 2, y / 2) for x, y in landmarks2[0]] aligned = _align_from_landmarks(img, lm_orig, args.size) if aligned is not None: src_stats[src_name]["retry"] += 1 if aligned is None: aligned = _center_crop(img, args.size) src_stats[src_name]["fallback"] += 1 except Exception as exc: tqdm.write(f"[WARN] Detection failed for {img_path.name}: {exc}") aligned = _center_crop(img, args.size) src_stats[src_name]["fallback"] += 1 aligned.save(out_path, quality=95) n_processed += 1 total = n_processed + n_skipped n_aligned = sum(s["aligned"] for s in src_stats.values()) n_retry = sum(s["retry"] for s in src_stats.values()) n_fallback = sum(s["fallback"] for s in src_stats.values()) denom = max(n_processed, 1) print(f"\n{'─' * 55}") print(f" Total images : {total:>8,}") print(f" Processed : {n_processed:>8,}") print(f" Skipped (existed) : {n_skipped:>8,}") print(f" Errors : {n_error:>8,}") print(f" Pass-1 aligned : {n_aligned:>8,} ({n_aligned / denom:.1%})") print(f" Pass-2 aligned : {n_retry:>8,} ({n_retry / denom:.1%}) ← 2x upscale retry") print(f" Centre fallback : {n_fallback:>8,} ({n_fallback / denom:.1%})") print() print(f" {'Source':<12} {'pass-1':>8} {'pass-2':>8} {'fallback':>8} {'fallback%':>10}") print(f" {'─'*12} {'─'*8} {'─'*8} {'─'*8} {'─'*10}") for src in sources: s = src_stats[src] total_src = s["aligned"] + s["retry"] + s["fallback"] fb_pct = s["fallback"] / max(total_src, 1) print(f" {src:<12} {s['aligned']:>8,} {s['retry']:>8,} {s['fallback']:>8,} {fb_pct:>9.1%}") print(f"{'─' * 55}") print(f" Output: {output_dir.resolve()}") print() print("Next step — point your config at the aligned dataset:") print(f' "data_dir": "{output_dir}"') if __name__ == "__main__": main()