Implementing inference sampling and quick tool polish
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 109 KiB |
|
After Width: | Height: | Size: 108 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 111 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 111 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 108 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 108 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 109 KiB |
|
After Width: | Height: | Size: 106 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 111 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 113 KiB |
|
After Width: | Height: | Size: 111 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 107 KiB |
|
After Width: | Height: | Size: 113 KiB |
|
After Width: | Height: | Size: 113 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 114 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 114 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 114 KiB |
|
After Width: | Height: | Size: 111 KiB |
|
After Width: | Height: | Size: 111 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 96 KiB |
|
After Width: | Height: | Size: 98 KiB |
|
After Width: | Height: | Size: 101 KiB |
|
After Width: | Height: | Size: 97 KiB |
|
After Width: | Height: | Size: 102 KiB |
|
After Width: | Height: | Size: 101 KiB |
|
After Width: | Height: | Size: 100 KiB |
|
After Width: | Height: | Size: 101 KiB |
|
After Width: | Height: | Size: 100 KiB |
|
After Width: | Height: | Size: 104 KiB |
|
After Width: | Height: | Size: 96 KiB |
|
After Width: | Height: | Size: 103 KiB |
|
After Width: | Height: | Size: 100 KiB |
|
After Width: | Height: | Size: 97 KiB |
|
After Width: | Height: | Size: 99 KiB |
|
After Width: | Height: | Size: 101 KiB |
|
After Width: | Height: | Size: 101 KiB |
|
After Width: | Height: | Size: 94 KiB |
|
After Width: | Height: | Size: 97 KiB |
|
After Width: | Height: | Size: 100 KiB |
|
After Width: | Height: | Size: 97 KiB |
@@ -1,24 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
# Pre-align face images using MTCNN landmarks + similarity transform.
|
||||||
"""
|
# Generator-side counterpart to classifier/tools/facecrop.py — uses landmark-based alignment
|
||||||
Pre-align face images using MTCNN landmarks + similarity transform.
|
# (not bbox crop) so eyes, nose, and mouth land at fixed pixel positions in every image.
|
||||||
|
# Usage: python generator/tools/facecrop.py [--data-dir data] [--output-dir cropped/generator] [--size 128]
|
||||||
This is the generator-side counterpart to classifier/tools/facecrop.py.
|
|
||||||
Difference: classifier uses bbox crop+resize; the generator wants landmark-based
|
|
||||||
alignment so the eyes, nose, and mouth land at fixed pixel positions in every
|
|
||||||
training image — structurally consistent training data for the generator.
|
|
||||||
|
|
||||||
Output mirrors the source layout exactly:
|
|
||||||
data/wiki/14/37591914.jpg -> cropped/generator/wiki/14/37591914.jpg
|
|
||||||
|
|
||||||
Resumable: already-aligned images are skipped by default.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python generator/tools/facecrop.py
|
|
||||||
python generator/tools/facecrop.py --data-dir data --output-dir cropped/generator
|
|
||||||
python generator/tools/facecrop.py --sources wiki --device cpu
|
|
||||||
python generator/tools/facecrop.py --size 128
|
|
||||||
python generator/tools/facecrop.py --no-skip-existing # reprocess everything
|
|
||||||
"""
|
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
@@ -50,44 +33,37 @@ _DETECTORS: dict[str, object] = {}
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
p = argparse.ArgumentParser(
|
p = argparse.ArgumentParser()
|
||||||
description=__doc__,
|
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
||||||
)
|
|
||||||
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
|
p.add_argument("--data-dir", default="data", help="Source dataset root (default: data)")
|
||||||
p.add_argument("--output-dir", default="cropped/generator", help="Output root (default: cropped/generator)")
|
p.add_argument("--output-dir", default="cropped/generator", help="Output root (default: cropped/generator)")
|
||||||
p.add_argument("--size", type=int, default=128, help="Output image size in px, square (default: 128)")
|
p.add_argument("--size", type=int, default=128, help="Output image size in px, square (default: 128)")
|
||||||
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
|
p.add_argument("--device", default=None, help="'cpu' or 'cuda'. Default: auto-detect")
|
||||||
p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE",
|
p.add_argument("--sources", nargs="+", default=None, metavar="SOURCE",
|
||||||
help=f"Sources to process. Default: {', '.join(SOURCES)} (real images only). "
|
help=f"Sources to process. Default: {', '.join(SOURCES)}. All: {', '.join(ALL_SOURCES)}")
|
||||||
f"All available: {', '.join(ALL_SOURCES)}")
|
|
||||||
p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True,
|
p.add_argument("--skip-existing", dest="skip_existing", action="store_true", default=True,
|
||||||
help="Skip images already present in output-dir (default: on, resumable)")
|
help="Skip images already present in output-dir (default: on)")
|
||||||
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
|
p.add_argument("--no-skip-existing", dest="skip_existing", action="store_false",
|
||||||
help="Re-process all images even if already aligned")
|
help="Re-process all images even if already aligned")
|
||||||
return p.parse_args()
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
# ── alignment helpers ─────────────────────────────────────────────────────────
|
# ── Alignment helpers ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Scale the 128px reference template to match the target size
|
||||||
def _ref_landmarks(size: int):
|
def _ref_landmarks(size: int):
|
||||||
"""Reference landmarks scaled from the 128px template to `size`."""
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
scale = size / 128.0
|
scale = size / 128.0
|
||||||
return np.asarray(
|
return np.asarray([(x * scale, y * scale) for x, y in REF_LANDMARKS_128], dtype=np.float32)
|
||||||
[(x * scale, y * scale) for x, y in REF_LANDMARKS_128],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
# Apply similarity transform so detected landmarks map to reference positions
|
||||||
def _align_from_landmarks(img, landmarks, size: int):
|
def _align_from_landmarks(img, landmarks, size: int):
|
||||||
"""Apply similarity transform so detected landmarks map to reference positions."""
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from skimage.transform import SimilarityTransform, warp
|
from skimage.transform import SimilarityTransform, warp
|
||||||
|
|
||||||
src = np.asarray(landmarks, dtype=np.float32) # 5x2 detected
|
src = np.asarray(landmarks, dtype=np.float32)
|
||||||
dst = _ref_landmarks(size) # 5x2 reference
|
dst = _ref_landmarks(size)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tform = SimilarityTransform.from_estimate(src, dst)
|
tform = SimilarityTransform.from_estimate(src, dst)
|
||||||
@@ -111,29 +87,22 @@ def _center_crop(img, size: int):
|
|||||||
return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR)
|
return img.crop((left, top, left + side, top + side)).resize((size, size), PILImage.BILINEAR)
|
||||||
|
|
||||||
|
|
||||||
|
# Returns (standard, relaxed) MTCNN detectors; cached per device
|
||||||
def _get_detectors(device: str):
|
def _get_detectors(device: str):
|
||||||
"""Return (standard, relaxed) MTCNN detectors with landmarks enabled."""
|
|
||||||
if device in _DETECTORS:
|
if device in _DETECTORS:
|
||||||
return _DETECTORS[device]
|
return _DETECTORS[device]
|
||||||
|
|
||||||
from facenet_pytorch import MTCNN
|
from facenet_pytorch import MTCNN
|
||||||
|
|
||||||
standard = MTCNN(
|
standard = MTCNN(keep_all=False, select_largest=True, min_face_size=15,
|
||||||
keep_all=False, select_largest=True,
|
device=device, post_process=False)
|
||||||
min_face_size=15,
|
relaxed = MTCNN(keep_all=False, select_largest=True, min_face_size=10,
|
||||||
device=device, post_process=False,
|
thresholds=[0.5, 0.6, 0.6], device=device, post_process=False)
|
||||||
)
|
|
||||||
relaxed = MTCNN(
|
|
||||||
keep_all=False, select_largest=True,
|
|
||||||
min_face_size=10,
|
|
||||||
thresholds=[0.5, 0.6, 0.6],
|
|
||||||
device=device, post_process=False,
|
|
||||||
)
|
|
||||||
_DETECTORS[device] = (standard, relaxed)
|
_DETECTORS[device] = (standard, relaxed)
|
||||||
return standard, relaxed
|
return standard, relaxed
|
||||||
|
|
||||||
|
|
||||||
# ── main ──────────────────────────────────────────────────────────────────────
|
# ── Main ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
@@ -182,9 +151,7 @@ def main():
|
|||||||
print(f"\nTotal images: {len(all_paths):,}\n")
|
print(f"\nTotal images: {len(all_paths):,}\n")
|
||||||
|
|
||||||
n_processed = n_skipped = n_error = 0
|
n_processed = n_skipped = n_error = 0
|
||||||
src_stats: dict[str, dict] = {
|
src_stats: dict[str, dict] = {s: {"aligned": 0, "retry": 0, "fallback": 0} for s in sources}
|
||||||
s: {"aligned": 0, "retry": 0, "fallback": 0} for s in sources
|
|
||||||
}
|
|
||||||
|
|
||||||
for img_path in tqdm(all_paths, desc="Aligning", unit="img"):
|
for img_path in tqdm(all_paths, desc="Aligning", unit="img"):
|
||||||
rel = img_path.relative_to(data_dir)
|
rel = img_path.relative_to(data_dir)
|
||||||
|
|||||||
@@ -0,0 +1,120 @@
|
|||||||
|
# Generate 4x4 sample grids from Phase 5 EMA checkpoints, matching training visualization.
|
||||||
|
# Usage: python generator/tools/sampling.py [--samples N] [--models p5_gan p5_vae p5_ddpm]
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
MODELS_DIR = ROOT / "outputs" / "models"
|
||||||
|
CFG_DIR = ROOT / "configs" / "phase5"
|
||||||
|
|
||||||
|
# final_ema first: best_ema is the lowest-FID snapshot, which for slowly-converging models
|
||||||
|
# (e.g. DDPM) can be saved while the EMA shadow is still close to random init.
|
||||||
|
PHASE5_RUNS = {
|
||||||
|
"p5_gan": {"type": "gan", "config": "p5_gan.json", "checkpoints": ["p5_gan_final_ema.pt", "p5_gan_best_ema.pt"]},
|
||||||
|
"p5_vae": {"type": "vae", "config": "p5_vae.json", "checkpoints": ["p5_vae_final_ema.pt", "p5_vae_best_ema.pt"]},
|
||||||
|
"p5_ddpm": {"type": "ddpm", "config": "p5_ddpm.json", "checkpoints": ["p5_ddpm_final_ema.pt", "p5_ddpm_best_ema.pt"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cfg(name: str) -> dict:
|
||||||
|
with open(CFG_DIR / PHASE5_RUNS[name]["config"]) as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_model(name: str, cfg: dict, device: torch.device):
|
||||||
|
from src.models import get_model
|
||||||
|
|
||||||
|
result, _ = get_model(cfg)
|
||||||
|
# GAN returns (generator, critic); all others return the model directly
|
||||||
|
model = result[0] if isinstance(result, tuple) else result
|
||||||
|
|
||||||
|
for ckpt_name in PHASE5_RUNS[name]["checkpoints"]:
|
||||||
|
ckpt_path = MODELS_DIR / ckpt_name
|
||||||
|
if not ckpt_path.exists():
|
||||||
|
print(f" [{name}] {ckpt_name} not found, trying next")
|
||||||
|
continue
|
||||||
|
|
||||||
|
state_dict = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||||
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||||
|
if missing or unexpected:
|
||||||
|
print(f" [{name}] {ckpt_name}: missing={len(missing)} unexpected={len(unexpected)} — trying next")
|
||||||
|
continue
|
||||||
|
print(f" [{name}] Loaded {ckpt_name}")
|
||||||
|
return model.to(device).eval()
|
||||||
|
|
||||||
|
raise FileNotFoundError(f"No usable EMA checkpoint found for {name}")
|
||||||
|
|
||||||
|
|
||||||
|
# Returns a (16, C, H, W) tensor in [0, 1] ready for save_image with nrow=4
|
||||||
|
@torch.no_grad()
|
||||||
|
def _generate_grid(
|
||||||
|
name: str, model, cfg: dict, device: torch.device,
|
||||||
|
*, truncation: float | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
kind = PHASE5_RUNS[name]["type"]
|
||||||
|
image_size = cfg.get("image_size", 64)
|
||||||
|
|
||||||
|
if kind == "gan":
|
||||||
|
z = torch.randn(16, cfg.get("latent_dim", 128), 1, 1, device=device)
|
||||||
|
if truncation is not None and truncation > 0:
|
||||||
|
z = z.clamp(-truncation, truncation)
|
||||||
|
imgs = model(z)
|
||||||
|
elif kind == "vae":
|
||||||
|
imgs = model.sample(16, device)
|
||||||
|
elif kind == "ddpm":
|
||||||
|
from src.training.diffusion import cosine_betas, linear_betas, make_alpha_bars, ddim_sample
|
||||||
|
T = cfg.get("T", 1000)
|
||||||
|
betas = (cosine_betas(T) if cfg.get("noise_schedule", "cosine") == "cosine" else linear_betas(T)).to(device)
|
||||||
|
alpha_bars = make_alpha_bars(betas)
|
||||||
|
# n_steps=50 mirrors the training-time preview; cfg's ddim_steps is for FID only
|
||||||
|
imgs = ddim_sample(
|
||||||
|
model, 16, image_size, alpha_bars,
|
||||||
|
n_steps=50, pred_type=cfg.get("pred_type", "eps"),
|
||||||
|
device=str(device), batch_size=16,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model type: {kind}")
|
||||||
|
|
||||||
|
return (imgs.clamp(-1, 1) + 1.0) / 2.0
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--samples", type=int, default=10, help="Number of 4x4 grids per model")
|
||||||
|
p.add_argument("--output-dir", type=Path,
|
||||||
|
default=ROOT / "outputs" / "samples" / "final_comparison")
|
||||||
|
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
p.add_argument("--truncation", type=float, default=None,
|
||||||
|
help="Optional GAN latent truncation (lower=less diversity but sharper)")
|
||||||
|
p.add_argument("--models", nargs="+", choices=list(PHASE5_RUNS.keys()),
|
||||||
|
default=list(PHASE5_RUNS.keys()))
|
||||||
|
args = p.parse_args()
|
||||||
|
device = torch.device(args.device)
|
||||||
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for name in args.models:
|
||||||
|
print(f"[{name}] Loading checkpoint...")
|
||||||
|
cfg = _load_cfg(name)
|
||||||
|
model = _load_model(name, cfg, device)
|
||||||
|
|
||||||
|
out_dir = args.output_dir / name
|
||||||
|
out_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"[{name}] Generating {args.samples} grids -> {out_dir}")
|
||||||
|
for i in range(1, args.samples + 1):
|
||||||
|
grid = _generate_grid(name, model, cfg, device, truncation=args.truncation)
|
||||||
|
save_image(grid, out_dir / f"grid_{i:04d}.png", nrow=4)
|
||||||
|
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||