269 lines
10 KiB
Python
269 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Pre-align face images using MTCNN landmarks + similarity transform.
|
|
|
|
This is the generator-side counterpart to classifier/tools/facecrop.py.
|
|
Difference: classifier uses bbox crop+resize; the generator wants landmark-based
|
|
alignment so the eyes, nose, and mouth land at fixed pixel positions in every
|
|
training image — structurally consistent training data for the generator.
|
|
|
|
Output mirrors the source layout exactly:
|
|
data/wiki/14/37591914.jpg -> cropped/generator/wiki/14/37591914.jpg
|
|
|
|
Resumable: already-aligned images are skipped by default.
|
|
|
|
Usage:
|
|
python generator/tools/facecrop.py
|
|
python generator/tools/facecrop.py --data-dir data --output-dir cropped/generator
|
|
python generator/tools/facecrop.py --sources wiki --device cpu
|
|
python generator/tools/facecrop.py --size 128
|
|
python generator/tools/facecrop.py --no-skip-existing # reprocess everything
|
|
"""
|
|
import argparse
|
|
import sys
|
|
import warnings
|
|
from pathlib import Path
|
|
|
|
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
# Generator trains on real images only (wiki). The other sources are AI-generated
|
|
# and aren't used as training targets for the generator, so we don't align them
|
|
# by default. Pass --sources to override.
|
|
SOURCES = ["wiki"]
|
|
ALL_SOURCES = ["wiki", "inpainting", "text2img", "insight"]
|
|
|
|
# Reference landmark positions for a 128px aligned face.
|
|
# Source: standard FFHQ-style alignment template (eyes at y=51, nose at y=71, mouth at y=95).
|
|
# Scaled at runtime to match --size.
|
|
REF_LANDMARKS_128 = [
|
|
(38.0, 51.0), # left eye
|
|
(90.0, 51.0), # right eye
|
|
(64.0, 71.0), # nose tip
|
|
(45.0, 95.0), # left mouth
|
|
(83.0, 95.0), # right mouth
|
|
]
|
|
|
|
_DETECTORS: dict[str, object] = {}
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser(
|
|
description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
|
|
p.add_argument("--output-dir", default="cropped/generator", help="Output root (default: cropped/generator)")
|
|
p.add_argument("--size", type=int, default=128, help="Output image size in px, square (default: 128)")
|
|
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
|
|
p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE",
|
|
help=f"Sources to process. Default: {', '.join(SOURCES)} (real images only). "
|
|
f"All available: {', '.join(ALL_SOURCES)}")
|
|
p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True,
|
|
help="Skip images already present in output-dir (default: on, resumable)")
|
|
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
|
|
help="Re-process all images even if already aligned")
|
|
return p.parse_args()
|
|
|
|
|
|
# ── alignment helpers ─────────────────────────────────────────────────────────
|
|
|
|
def _ref_landmarks(size: int):
|
|
"""Reference landmarks scaled from the 128px template to `size`."""
|
|
import numpy as np
|
|
scale = size / 128.0
|
|
return np.asarray(
|
|
[(x * scale, y * scale) for x, y in REF_LANDMARKS_128],
|
|
dtype=np.float32,
|
|
)
|
|
|
|
|
|
def _align_from_landmarks(img, landmarks, size: int):
|
|
"""Apply similarity transform so detected landmarks map to reference positions."""
|
|
import numpy as np
|
|
from PIL import Image as PILImage
|
|
from skimage.transform import SimilarityTransform, warp
|
|
|
|
src = np.asarray(landmarks, dtype=np.float32) # 5x2 detected
|
|
dst = _ref_landmarks(size) # 5x2 reference
|
|
|
|
try:
|
|
tform = SimilarityTransform.from_estimate(src, dst)
|
|
except Exception:
|
|
return None
|
|
aligned = warp(
|
|
np.asarray(img),
|
|
tform.inverse,
|
|
output_shape=(size, size),
|
|
order=3,
|
|
preserve_range=True,
|
|
).astype(np.uint8)
|
|
return PILImage.fromarray(aligned)
|
|
|
|
|
|
def _center_crop(img, size: int):
|
|
from PIL import Image as PILImage
|
|
w, h = img.size
|
|
side = min(w, h)
|
|
left, top = (w - side) // 2, (h - side) // 2
|
|
return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR)
|
|
|
|
|
|
def _get_detectors(device: str):
|
|
"""Return (standard, relaxed) MTCNN detectors with landmarks enabled."""
|
|
if device in _DETECTORS:
|
|
return _DETECTORS[device]
|
|
|
|
from facenet_pytorch import MTCNN
|
|
|
|
standard = MTCNN(
|
|
keep_all=False, select_largest=True,
|
|
min_face_size=15,
|
|
device=device, post_process=False,
|
|
)
|
|
relaxed = MTCNN(
|
|
keep_all=False, select_largest=True,
|
|
min_face_size=10,
|
|
thresholds=[0.5, 0.6, 0.6],
|
|
device=device, post_process=False,
|
|
)
|
|
_DETECTORS[device] = (standard, relaxed)
|
|
return standard, relaxed
|
|
|
|
|
|
# ── main ──────────────────────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
data_dir = Path(args.data_dir)
|
|
output_dir = Path(args.output_dir)
|
|
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
sources = args.sources or SOURCES
|
|
|
|
if not data_dir.exists():
|
|
print(f"Error: data directory not found: {data_dir}")
|
|
sys.exit(1)
|
|
|
|
for src in sources:
|
|
if not (data_dir / src).exists():
|
|
print(f"Error: source directory not found: {data_dir / src}")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
import facenet_pytorch # noqa: F401
|
|
import skimage # noqa: F401
|
|
except ImportError as exc:
|
|
print(f"Error: missing dependency ({exc}).")
|
|
print(" Run: pip install facenet-pytorch scikit-image")
|
|
sys.exit(1)
|
|
|
|
print(f"Data dir: {data_dir.resolve()}")
|
|
print(f"Output dir: {output_dir.resolve()}")
|
|
print(f"Sources: {', '.join(sources)}")
|
|
print(f"Device: {device}")
|
|
print(f"Size: {args.size}px")
|
|
print(f"Skip exist: {args.skip_existing}")
|
|
|
|
standard, relaxed = _get_detectors(device)
|
|
|
|
all_paths: list[Path] = []
|
|
for src in sources:
|
|
for subdir in sorted((data_dir / src).iterdir()):
|
|
if subdir.is_dir():
|
|
all_paths.extend(sorted(subdir.glob("*.jpg")))
|
|
|
|
print(f"\nTotal images: {len(all_paths):,}\n")
|
|
|
|
n_processed = n_skipped = n_error = 0
|
|
src_stats: dict[str, dict] = {
|
|
s: {"aligned": 0, "retry": 0, "fallback": 0} for s in sources
|
|
}
|
|
|
|
for img_path in tqdm(all_paths, desc="Aligning", unit="img"):
|
|
rel = img_path.relative_to(data_dir)
|
|
out_path = output_dir / rel
|
|
src_name = img_path.parent.parent.name
|
|
|
|
if args.skip_existing and out_path.exists():
|
|
n_skipped += 1
|
|
continue
|
|
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
img = Image.open(img_path).convert("RGB")
|
|
except Exception as exc:
|
|
tqdm.write(f"[WARN] Cannot open {img_path.name}: {exc}")
|
|
n_error += 1
|
|
continue
|
|
|
|
aligned = None
|
|
try:
|
|
# Pass 1: detect landmarks on the original image
|
|
_, _, landmarks = standard.detect(img, landmarks=True)
|
|
if landmarks is not None and len(landmarks) > 0:
|
|
aligned = _align_from_landmarks(img, landmarks[0], args.size)
|
|
if aligned is not None:
|
|
src_stats[src_name]["aligned"] += 1
|
|
|
|
# Pass 2: upscale 2x and retry with relaxed thresholds
|
|
if aligned is None:
|
|
w, h = img.size
|
|
img2x = img.resize((w * 2, h * 2), Image.BILINEAR)
|
|
_, _, landmarks2 = relaxed.detect(img2x, landmarks=True)
|
|
if landmarks2 is not None and len(landmarks2) > 0:
|
|
lm_orig = [(x / 2, y / 2) for x, y in landmarks2[0]]
|
|
aligned = _align_from_landmarks(img, lm_orig, args.size)
|
|
if aligned is not None:
|
|
src_stats[src_name]["retry"] += 1
|
|
|
|
if aligned is None:
|
|
aligned = _center_crop(img, args.size)
|
|
src_stats[src_name]["fallback"] += 1
|
|
except Exception as exc:
|
|
tqdm.write(f"[WARN] Detection failed for {img_path.name}: {exc}")
|
|
aligned = _center_crop(img, args.size)
|
|
src_stats[src_name]["fallback"] += 1
|
|
|
|
aligned.save(out_path, quality=95)
|
|
n_processed += 1
|
|
|
|
total = n_processed + n_skipped
|
|
n_aligned = sum(s["aligned"] for s in src_stats.values())
|
|
n_retry = sum(s["retry"] for s in src_stats.values())
|
|
n_fallback = sum(s["fallback"] for s in src_stats.values())
|
|
denom = max(n_processed, 1)
|
|
|
|
print(f"\n{'─' * 55}")
|
|
print(f" Total images : {total:>8,}")
|
|
print(f" Processed : {n_processed:>8,}")
|
|
print(f" Skipped (existed) : {n_skipped:>8,}")
|
|
print(f" Errors : {n_error:>8,}")
|
|
print(f" Pass-1 aligned : {n_aligned:>8,} ({n_aligned / denom:.1%})")
|
|
print(f" Pass-2 aligned : {n_retry:>8,} ({n_retry / denom:.1%}) ← 2x upscale retry")
|
|
print(f" Centre fallback : {n_fallback:>8,} ({n_fallback / denom:.1%})")
|
|
print()
|
|
print(f" {'Source':<12} {'pass-1':>8} {'pass-2':>8} {'fallback':>8} {'fallback%':>10}")
|
|
print(f" {'─'*12} {'─'*8} {'─'*8} {'─'*8} {'─'*10}")
|
|
for src in sources:
|
|
s = src_stats[src]
|
|
total_src = s["aligned"] + s["retry"] + s["fallback"]
|
|
fb_pct = s["fallback"] / max(total_src, 1)
|
|
print(f" {src:<12} {s['aligned']:>8,} {s['retry']:>8,} {s['fallback']:>8,} {fb_pct:>9.1%}")
|
|
print(f"{'─' * 55}")
|
|
print(f" Output: {output_dir.resolve()}")
|
|
print()
|
|
print("Next step — point your config at the aligned dataset:")
|
|
print(f' "data_dir": "{output_dir}"')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|