Files
DRL_PROJ/generator/tests/smoke_test.py
T
Johnny Fernandes afd26f47d2 Final polish
2026-05-14 21:16:03 +01:00

79 lines
2.5 KiB
Python

"""
Smoke test: tiny wiki-only dataset -> 1 VAE epoch -> decode / sample (inference).
Run from the generator package root:
cd generator && conda activate drl && python -m unittest tests.smoke_test -v
"""
import io
import sys
import unittest
from pathlib import Path
from unittest.mock import patch
import torch
from PIL import Image
_GEN_ROOT = Path(__file__).resolve().parents[1]
_SMOKE_CFG = _GEN_ROOT / "configs" / "smoke" / "smoke.json"
def _write_wiki_only(data_root: Path, *, n_folders: int = 12) -> None:
"""Minimal layout for GeneratorDataset: data/wiki/<id>/*.jpg"""
for i in range(n_folders):
d = data_root / "wiki" / f"person_{i:03d}"
d.mkdir(parents=True, exist_ok=True)
buf = io.BytesIO()
Image.new("RGB", (64, 64), color=(30 + i * 15, 100, 180)).save(buf, format="JPEG")
(d / "face.jpg").write_bytes(buf.getvalue())
class SmokeGeneratorTrainSampleTests(unittest.TestCase):
def test_vae_one_epoch_then_sample(self):
import tempfile
if not _SMOKE_CFG.is_file():
self.skipTest("smoke config missing")
with tempfile.TemporaryDirectory() as td:
tmp = Path(td)
data_dir = tmp / "data"
out_root = tmp / "outputs"
_write_wiki_only(data_dir)
sys.path.insert(0, str(_GEN_ROOT))
import run as gen_run
# Avoid torch.compile in short smoke runs (simpler state_dict / fewer edge cases).
with patch("torch.compile", lambda m, **kwargs: m):
gen_run.main(
str(_SMOKE_CFG),
data_dir_override=str(data_dir),
output_root=str(out_root),
)
models_dir = out_root / "models"
ckpt = models_dir / "smoke_final_ema.pt"
self.assertTrue(ckpt.is_file(), f"Expected EMA checkpoint at {ckpt}")
from src.models import get_model
from src.utils import load_config
cfg = load_config(str(_SMOKE_CFG))
vae, kind = get_model(cfg)
self.assertEqual(kind, "vae")
device = torch.device("cpu")
state = torch.load(ckpt, map_location=device, weights_only=True)
vae.load_state_dict(state, strict=True)
vae.eval()
with torch.no_grad():
grid = vae.sample(4, device)
self.assertEqual(grid.shape, (4, 3, 64, 64))
self.assertTrue(torch.isfinite(grid).all())
if __name__ == "__main__":
unittest.main()