# Generate 4x4 sample grids from Phase 5 EMA checkpoints, matching training visualization. # Usage: python generator/tools/sampling.py [--samples N] [--models p5_gan p5_vae p5_ddpm] import argparse import json import sys from pathlib import Path import torch from torchvision.utils import save_image ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) MODELS_DIR = ROOT / "outputs" / "models" CFG_DIR = ROOT / "configs" / "phase5" # final_ema first: best_ema is the lowest-FID snapshot, which for slowly-converging models # (e.g. DDPM) can be saved while the EMA shadow is still close to random init. PHASE5_RUNS = { "p5_gan": {"type": "gan", "config": "p5_gan.json", "checkpoints": ["p5_gan_final_ema.pt", "p5_gan_best_ema.pt"]}, "p5_vae": {"type": "vae", "config": "p5_vae.json", "checkpoints": ["p5_vae_final_ema.pt", "p5_vae_best_ema.pt"]}, "p5_ddpm": {"type": "ddpm", "config": "p5_ddpm.json", "checkpoints": ["p5_ddpm_final_ema.pt", "p5_ddpm_best_ema.pt"]}, } def _load_cfg(name: str) -> dict: with open(CFG_DIR / PHASE5_RUNS[name]["config"]) as f: return json.load(f) def _load_model(name: str, cfg: dict, device: torch.device): from src.models import get_model result, _ = get_model(cfg) # GAN returns (generator, critic); all others return the model directly model = result[0] if isinstance(result, tuple) else result for ckpt_name in PHASE5_RUNS[name]["checkpoints"]: ckpt_path = MODELS_DIR / ckpt_name if not ckpt_path.exists(): print(f" [{name}] {ckpt_name} not found, trying next") continue state_dict = torch.load(ckpt_path, map_location=device, weights_only=True) missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing or unexpected: print(f" [{name}] {ckpt_name}: missing={len(missing)} unexpected={len(unexpected)} — trying next") continue print(f" [{name}] Loaded {ckpt_name}") return model.to(device).eval() raise FileNotFoundError(f"No usable EMA checkpoint found for {name}") # Returns a (16, C, H, W) tensor in [0, 1] ready for save_image with nrow=4 @torch.no_grad() def _generate_grid( name: str, model, cfg: dict, device: torch.device, *, truncation: float | None = None, ) -> torch.Tensor: kind = PHASE5_RUNS[name]["type"] image_size = cfg.get("image_size", 64) if kind == "gan": z = torch.randn(16, cfg.get("latent_dim", 128), 1, 1, device=device) if truncation is not None and truncation > 0: z = z.clamp(-truncation, truncation) imgs = model(z) elif kind == "vae": imgs = model.sample(16, device) elif kind == "ddpm": from src.training.diffusion import cosine_betas, linear_betas, make_alpha_bars, ddim_sample T = cfg.get("T", 1000) betas = (cosine_betas(T) if cfg.get("noise_schedule", "cosine") == "cosine" else linear_betas(T)).to(device) alpha_bars = make_alpha_bars(betas) # n_steps=50 mirrors the training-time preview; cfg's ddim_steps is for FID only imgs = ddim_sample( model, 16, image_size, alpha_bars, n_steps=50, pred_type=cfg.get("pred_type", "eps"), device=str(device), batch_size=16, ) else: raise ValueError(f"Unknown model type: {kind}") return (imgs.clamp(-1, 1) + 1.0) / 2.0 def main(): p = argparse.ArgumentParser() p.add_argument("--samples", type=int, default=10, help="Number of 4x4 grids per model") p.add_argument("--output-dir", type=Path, default=ROOT / "outputs" / "samples" / "final_comparison") p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--truncation", type=float, default=None, help="Optional GAN latent truncation (lower=less diversity but sharper)") p.add_argument("--models", nargs="+", choices=list(PHASE5_RUNS.keys()), default=list(PHASE5_RUNS.keys())) args = p.parse_args() device = torch.device(args.device) args.output_dir.mkdir(parents=True, exist_ok=True) for name in args.models: print(f"[{name}] Loading checkpoint...") cfg = _load_cfg(name) model = _load_model(name, cfg, device) out_dir = args.output_dir / name out_dir.mkdir(exist_ok=True) print(f"[{name}] Generating {args.samples} grids -> {out_dir}") for i in range(1, args.samples + 1): grid = _generate_grid(name, model, cfg, device, truncation=args.truncation) save_image(grid, out_dir / f"grid_{i:04d}.png", nrow=4) print("Done.") if __name__ == "__main__": main()