Implementing inference sampling and quick tool polish
This commit is contained in:
@@ -0,0 +1,120 @@
|
||||
# Generate 4x4 sample grids from Phase 5 EMA checkpoints, matching training visualization.
|
||||
# Usage: python generator/tools/sampling.py [--samples N] [--models p5_gan p5_vae p5_ddpm]
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
MODELS_DIR = ROOT / "outputs" / "models"
|
||||
CFG_DIR = ROOT / "configs" / "phase5"
|
||||
|
||||
# final_ema first: best_ema is the lowest-FID snapshot, which for slowly-converging models
|
||||
# (e.g. DDPM) can be saved while the EMA shadow is still close to random init.
|
||||
PHASE5_RUNS = {
|
||||
"p5_gan": {"type": "gan", "config": "p5_gan.json", "checkpoints": ["p5_gan_final_ema.pt", "p5_gan_best_ema.pt"]},
|
||||
"p5_vae": {"type": "vae", "config": "p5_vae.json", "checkpoints": ["p5_vae_final_ema.pt", "p5_vae_best_ema.pt"]},
|
||||
"p5_ddpm": {"type": "ddpm", "config": "p5_ddpm.json", "checkpoints": ["p5_ddpm_final_ema.pt", "p5_ddpm_best_ema.pt"]},
|
||||
}
|
||||
|
||||
|
||||
def _load_cfg(name: str) -> dict:
|
||||
with open(CFG_DIR / PHASE5_RUNS[name]["config"]) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _load_model(name: str, cfg: dict, device: torch.device):
|
||||
from src.models import get_model
|
||||
|
||||
result, _ = get_model(cfg)
|
||||
# GAN returns (generator, critic); all others return the model directly
|
||||
model = result[0] if isinstance(result, tuple) else result
|
||||
|
||||
for ckpt_name in PHASE5_RUNS[name]["checkpoints"]:
|
||||
ckpt_path = MODELS_DIR / ckpt_name
|
||||
if not ckpt_path.exists():
|
||||
print(f" [{name}] {ckpt_name} not found, trying next")
|
||||
continue
|
||||
|
||||
state_dict = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
if missing or unexpected:
|
||||
print(f" [{name}] {ckpt_name}: missing={len(missing)} unexpected={len(unexpected)} — trying next")
|
||||
continue
|
||||
print(f" [{name}] Loaded {ckpt_name}")
|
||||
return model.to(device).eval()
|
||||
|
||||
raise FileNotFoundError(f"No usable EMA checkpoint found for {name}")
|
||||
|
||||
|
||||
# Returns a (16, C, H, W) tensor in [0, 1] ready for save_image with nrow=4
|
||||
@torch.no_grad()
|
||||
def _generate_grid(
|
||||
name: str, model, cfg: dict, device: torch.device,
|
||||
*, truncation: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
kind = PHASE5_RUNS[name]["type"]
|
||||
image_size = cfg.get("image_size", 64)
|
||||
|
||||
if kind == "gan":
|
||||
z = torch.randn(16, cfg.get("latent_dim", 128), 1, 1, device=device)
|
||||
if truncation is not None and truncation > 0:
|
||||
z = z.clamp(-truncation, truncation)
|
||||
imgs = model(z)
|
||||
elif kind == "vae":
|
||||
imgs = model.sample(16, device)
|
||||
elif kind == "ddpm":
|
||||
from src.training.diffusion import cosine_betas, linear_betas, make_alpha_bars, ddim_sample
|
||||
T = cfg.get("T", 1000)
|
||||
betas = (cosine_betas(T) if cfg.get("noise_schedule", "cosine") == "cosine" else linear_betas(T)).to(device)
|
||||
alpha_bars = make_alpha_bars(betas)
|
||||
# n_steps=50 mirrors the training-time preview; cfg's ddim_steps is for FID only
|
||||
imgs = ddim_sample(
|
||||
model, 16, image_size, alpha_bars,
|
||||
n_steps=50, pred_type=cfg.get("pred_type", "eps"),
|
||||
device=str(device), batch_size=16,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {kind}")
|
||||
|
||||
return (imgs.clamp(-1, 1) + 1.0) / 2.0
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--samples", type=int, default=10, help="Number of 4x4 grids per model")
|
||||
p.add_argument("--output-dir", type=Path,
|
||||
default=ROOT / "outputs" / "samples" / "final_comparison")
|
||||
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
p.add_argument("--truncation", type=float, default=None,
|
||||
help="Optional GAN latent truncation (lower=less diversity but sharper)")
|
||||
p.add_argument("--models", nargs="+", choices=list(PHASE5_RUNS.keys()),
|
||||
default=list(PHASE5_RUNS.keys()))
|
||||
args = p.parse_args()
|
||||
device = torch.device(args.device)
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for name in args.models:
|
||||
print(f"[{name}] Loading checkpoint...")
|
||||
cfg = _load_cfg(name)
|
||||
model = _load_model(name, cfg, device)
|
||||
|
||||
out_dir = args.output_dir / name
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
|
||||
print(f"[{name}] Generating {args.samples} grids -> {out_dir}")
|
||||
for i in range(1, args.samples + 1):
|
||||
grid = _generate_grid(name, model, cfg, device, truncation=args.truncation)
|
||||
save_image(grid, out_dir / f"grid_{i:04d}.png", nrow=4)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user