Clean state
This commit is contained in:
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pre-crop face images using MTCNN and save to a new directory.
|
||||
|
||||
Runs face detection once over the dataset and saves cropped images to disk.
|
||||
Training configs can then point at the pre-cropped directory — no per-epoch
|
||||
MTCNN overhead during training.
|
||||
|
||||
The output mirrors the source structure exactly:
|
||||
data/wiki/14/37591914.jpg -> cropped/classifier/wiki/14/37591914.jpg
|
||||
|
||||
Resumable: already-cropped images are skipped by default.
|
||||
|
||||
Usage:
|
||||
python tools/facecrop.py
|
||||
python tools/facecrop.py --data-dir data --output-dir cropped/classifier
|
||||
python tools/facecrop.py --sources wiki inpainting --device cpu
|
||||
python tools/facecrop.py --no-skip-existing # reprocess everything
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
# Suppress facenet_pytorch's torch.load FutureWarning — not fixable externally.
|
||||
warnings.filterwarnings("ignore", message=".*weights_only.*", category=FutureWarning)
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
SOURCES = ["wiki", "inpainting", "text2img", "insight"]
|
||||
_DETECTORS: dict[tuple[str, str], object] = {}
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
|
||||
p.add_argument("--output-dir", default="cropped/classifier", help="Output root (default: cropped/classifier)")
|
||||
p.add_argument("--margin", type=float, default=0.6, help="Face box margin as fraction of box size (default: 0.6)")
|
||||
p.add_argument("--size", type=int, default=224, help="Output image size in px, square (default: 224)")
|
||||
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
|
||||
p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE",
|
||||
help=f"Only process these sources. Default: all ({', '.join(SOURCES)})")
|
||||
p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True,
|
||||
help="Skip images already present in output-dir (default: on, resumable)")
|
||||
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
|
||||
help="Re-process all images even if already cropped")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
# ── crop helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _crop_face(img, box, margin: float, size: int):
|
||||
from PIL import Image as PILImage
|
||||
x1, y1, x2, y2 = [float(v) for v in box]
|
||||
bw, bh = x2 - x1, y2 - y1
|
||||
mx, my = bw * margin / 2, bh * margin / 2
|
||||
x1 -= mx; y1 -= my; x2 += mx; y2 += my
|
||||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||
side = max(x2 - x1, y2 - y1)
|
||||
x1, y1 = cx - side / 2, cy - side / 2
|
||||
x2, y2 = x1 + side, y1 + side
|
||||
w, h = img.size
|
||||
x1, y1 = max(0, x1), max(0, y1)
|
||||
x2, y2 = min(w, x2), min(h, y2)
|
||||
return img.crop((int(x1), int(y1), int(x2), int(y2))).resize((size, size), PILImage.BILINEAR)
|
||||
|
||||
|
||||
def _center_crop(img, size: int):
|
||||
from PIL import Image as PILImage
|
||||
w, h = img.size
|
||||
side = min(w, h)
|
||||
left, top = (w - side) // 2, (h - side) // 2
|
||||
return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR)
|
||||
|
||||
|
||||
def _get_detectors(device: str):
|
||||
key_std = ("std", device)
|
||||
key_relaxed = ("relaxed", device)
|
||||
if key_std in _DETECTORS and key_relaxed in _DETECTORS:
|
||||
return _DETECTORS[key_std], _DETECTORS[key_relaxed]
|
||||
|
||||
from facenet_pytorch import MTCNN
|
||||
|
||||
detector = MTCNN(
|
||||
keep_all=False, select_largest=True,
|
||||
min_face_size=15,
|
||||
device=device, post_process=False,
|
||||
)
|
||||
detector_relaxed = MTCNN(
|
||||
keep_all=False, select_largest=True,
|
||||
min_face_size=10,
|
||||
thresholds=[0.5, 0.6, 0.6],
|
||||
device=device, post_process=False,
|
||||
)
|
||||
_DETECTORS[key_std] = detector
|
||||
_DETECTORS[key_relaxed] = detector_relaxed
|
||||
return detector, detector_relaxed
|
||||
|
||||
|
||||
class FaceCropper:
|
||||
"""Reusable face cropper for notebooks/tools (not training pipeline)."""
|
||||
|
||||
def __init__(self, margin: float = 0.6, size: int = 224, device: str | None = None):
|
||||
import torch
|
||||
|
||||
self.margin = margin
|
||||
self.size = size
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def __call__(self, img):
|
||||
from PIL import Image as PILImage
|
||||
|
||||
detector, detector_relaxed = _get_detectors(self.device)
|
||||
|
||||
boxes, _ = detector.detect(img)
|
||||
if boxes is not None and len(boxes) > 0:
|
||||
return _crop_face(img, boxes[0], self.margin, self.size)
|
||||
|
||||
w, h = img.size
|
||||
img2x = img.resize((w * 2, h * 2), PILImage.BILINEAR)
|
||||
boxes2, _ = detector_relaxed.detect(img2x)
|
||||
if boxes2 is not None and len(boxes2) > 0:
|
||||
box_orig = [v / 2 for v in boxes2[0]]
|
||||
return _crop_face(img, box_orig, self.margin, self.size)
|
||||
return _center_crop(img, self.size)
|
||||
|
||||
|
||||
# ── main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
data_dir = Path(args.data_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
sources = args.sources or SOURCES
|
||||
|
||||
if not data_dir.exists():
|
||||
print(f"Error: data directory not found: {data_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate requested sources
|
||||
for src in sources:
|
||||
if not (data_dir / src).exists():
|
||||
print(f"Error: source directory not found: {data_dir / src}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import facenet_pytorch # noqa: F401
|
||||
except ImportError:
|
||||
print("Error: facenet_pytorch not installed.")
|
||||
print(" Run: pip install facenet-pytorch")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Data dir: {data_dir.resolve()}")
|
||||
print(f"Output dir: {output_dir.resolve()}")
|
||||
print(f"Sources: {', '.join(sources)}")
|
||||
print(f"Device: {device}")
|
||||
print(f"Margin: {args.margin} | Size: {args.size}px")
|
||||
print(f"Skip exist: {args.skip_existing}")
|
||||
|
||||
detector, detector_relaxed = _get_detectors(device)
|
||||
|
||||
# Collect all image paths, grouped by source for per-source stats
|
||||
all_paths: list[Path] = []
|
||||
for src in sources:
|
||||
for subdir in sorted((data_dir / src).iterdir()):
|
||||
if subdir.is_dir():
|
||||
all_paths.extend(sorted(subdir.glob("*.jpg")))
|
||||
|
||||
print(f"\nTotal images: {len(all_paths):,}\n")
|
||||
|
||||
n_processed = n_skipped = n_error = 0
|
||||
# track per-source: detected / retry_detected / fallback
|
||||
src_stats: dict[str, dict] = {s: {"detected": 0, "retry": 0, "fallback": 0} for s in sources}
|
||||
|
||||
for img_path in tqdm(all_paths, desc="Pre-cropping", unit="img"):
|
||||
rel = img_path.relative_to(data_dir)
|
||||
out_path = output_dir / rel
|
||||
src_name = img_path.parent.parent.name # data/wiki/14/file.jpg -> wiki
|
||||
|
||||
if args.skip_existing and out_path.exists():
|
||||
n_skipped += 1
|
||||
continue
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
except Exception as exc:
|
||||
tqdm.write(f"[WARN] Cannot open {img_path.name}: {exc}")
|
||||
n_error += 1
|
||||
continue
|
||||
|
||||
cropped = None
|
||||
try:
|
||||
# Pass 1: detect on original image
|
||||
boxes, _ = detector.detect(img)
|
||||
if boxes is not None and len(boxes) > 0:
|
||||
cropped = _crop_face(img, boxes[0], args.margin, args.size)
|
||||
src_stats[src_name]["detected"] += 1
|
||||
else:
|
||||
# Pass 2: upscale 2x and retry with relaxed thresholds
|
||||
w, h = img.size
|
||||
img2x = img.resize((w * 2, h * 2), Image.BILINEAR)
|
||||
boxes2, _ = detector_relaxed.detect(img2x)
|
||||
if boxes2 is not None and len(boxes2) > 0:
|
||||
# boxes are in upscaled coords — divide by 2 to get original coords
|
||||
box_orig = [v / 2 for v in boxes2[0]]
|
||||
cropped = _crop_face(img, box_orig, args.margin, args.size)
|
||||
src_stats[src_name]["retry"] += 1
|
||||
else:
|
||||
cropped = _center_crop(img, args.size)
|
||||
src_stats[src_name]["fallback"] += 1
|
||||
except Exception as exc:
|
||||
tqdm.write(f"[WARN] Detection failed for {img_path.name}: {exc}")
|
||||
cropped = _center_crop(img, args.size)
|
||||
src_stats[src_name]["fallback"] += 1
|
||||
|
||||
cropped.save(out_path, quality=95)
|
||||
n_processed += 1
|
||||
|
||||
total = n_processed + n_skipped
|
||||
n_detected = sum(s["detected"] for s in src_stats.values())
|
||||
n_retry = sum(s["retry"] for s in src_stats.values())
|
||||
n_fallback = sum(s["fallback"] for s in src_stats.values())
|
||||
denom = max(n_processed, 1)
|
||||
|
||||
print(f"\n{'─' * 55}")
|
||||
print(f" Total images : {total:>8,}")
|
||||
print(f" Processed : {n_processed:>8,}")
|
||||
print(f" Skipped (existed) : {n_skipped:>8,}")
|
||||
print(f" Errors : {n_error:>8,}")
|
||||
print(f" Pass-1 detected : {n_detected:>8,} ({n_detected / denom:.1%})")
|
||||
print(f" Pass-2 detected : {n_retry:>8,} ({n_retry / denom:.1%}) ← 2x upscale retry")
|
||||
print(f" Centre fallback : {n_fallback:>8,} ({n_fallback / denom:.1%})")
|
||||
print()
|
||||
print(f" {'Source':<12} {'pass-1':>8} {'pass-2':>8} {'fallback':>8} {'fallback%':>10}")
|
||||
print(f" {'─'*12} {'─'*8} {'─'*8} {'─'*8} {'─'*10}")
|
||||
for src in sources:
|
||||
s = src_stats[src]
|
||||
total_src = s["detected"] + s["retry"] + s["fallback"]
|
||||
fb_pct = s["fallback"] / max(total_src, 1)
|
||||
print(f" {src:<12} {s['detected']:>8,} {s['retry']:>8,} {s['fallback']:>8,} {fb_pct:>9.1%}")
|
||||
print(f"{'─' * 55}")
|
||||
print(f" Output: {output_dir.resolve()}")
|
||||
print()
|
||||
print("Next step — update your config:")
|
||||
print(f' "data_dir": "{output_dir}"')
|
||||
print(f' remove "face_crop": true (images are already cropped)')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user