""" Smoke test: tiny wiki-only dataset -> 1 VAE epoch -> decode / sample (inference). Run from the generator package root: cd generator && conda activate drl && python -m unittest tests.smoke_test -v """ import io import sys import unittest from pathlib import Path from unittest.mock import patch import torch from PIL import Image _GEN_ROOT = Path(__file__).resolve().parents[1] _SMOKE_CFG = _GEN_ROOT / "configs" / "smoke" / "smoke.json" def _write_wiki_only(data_root: Path, *, n_folders: int = 12) -> None: """Minimal layout for GeneratorDataset: data/wiki//*.jpg""" for i in range(n_folders): d = data_root / "wiki" / f"person_{i:03d}" d.mkdir(parents=True, exist_ok=True) buf = io.BytesIO() Image.new("RGB", (64, 64), color=(30 + i * 15, 100, 180)).save(buf, format="JPEG") (d / "face.jpg").write_bytes(buf.getvalue()) class SmokeGeneratorTrainSampleTests(unittest.TestCase): def test_vae_one_epoch_then_sample(self): import tempfile if not _SMOKE_CFG.is_file(): self.skipTest("smoke config missing") with tempfile.TemporaryDirectory() as td: tmp = Path(td) data_dir = tmp / "data" out_root = tmp / "outputs" _write_wiki_only(data_dir) sys.path.insert(0, str(_GEN_ROOT)) import run as gen_run # Avoid torch.compile in short smoke runs (simpler state_dict / fewer edge cases). with patch("torch.compile", lambda m, **kwargs: m): gen_run.main( str(_SMOKE_CFG), data_dir_override=str(data_dir), output_root=str(out_root), ) models_dir = out_root / "models" ckpt = models_dir / "smoke_final_ema.pt" self.assertTrue(ckpt.is_file(), f"Expected EMA checkpoint at {ckpt}") from src.models import get_model from src.utils import load_config cfg = load_config(str(_SMOKE_CFG)) vae, kind = get_model(cfg) self.assertEqual(kind, "vae") device = torch.device("cpu") state = torch.load(ckpt, map_location=device, weights_only=True) vae.load_state_dict(state, strict=True) vae.eval() with torch.no_grad(): grid = vae.sample(4, device) self.assertEqual(grid.shape, (4, 3, 64, 64)) self.assertTrue(torch.isfinite(grid).all()) if __name__ == "__main__": unittest.main()