79 lines
2.5 KiB
Python
79 lines
2.5 KiB
Python
"""
|
|
Smoke test: tiny wiki-only dataset -> 1 VAE epoch -> decode / sample (inference).
|
|
|
|
Run from the generator package root:
|
|
|
|
cd generator && conda activate drl && python -m unittest tests.smoke_test -v
|
|
"""
|
|
import io
|
|
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
_GEN_ROOT = Path(__file__).resolve().parents[1]
|
|
_SMOKE_CFG = _GEN_ROOT / "configs" / "smoke" / "smoke.json"
|
|
|
|
|
|
def _write_wiki_only(data_root: Path, *, n_folders: int = 12) -> None:
|
|
"""Minimal layout for GeneratorDataset: data/wiki/<id>/*.jpg"""
|
|
for i in range(n_folders):
|
|
d = data_root / "wiki" / f"person_{i:03d}"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
buf = io.BytesIO()
|
|
Image.new("RGB", (64, 64), color=(30 + i * 15, 100, 180)).save(buf, format="JPEG")
|
|
(d / "face.jpg").write_bytes(buf.getvalue())
|
|
|
|
|
|
class SmokeGeneratorTrainSampleTests(unittest.TestCase):
|
|
def test_vae_one_epoch_then_sample(self):
|
|
import tempfile
|
|
|
|
if not _SMOKE_CFG.is_file():
|
|
self.skipTest("smoke config missing")
|
|
|
|
with tempfile.TemporaryDirectory() as td:
|
|
tmp = Path(td)
|
|
data_dir = tmp / "data"
|
|
out_root = tmp / "outputs"
|
|
_write_wiki_only(data_dir)
|
|
|
|
sys.path.insert(0, str(_GEN_ROOT))
|
|
import run as gen_run
|
|
|
|
# Avoid torch.compile in short smoke runs (simpler state_dict / fewer edge cases).
|
|
with patch("torch.compile", lambda m, **kwargs: m):
|
|
gen_run.main(
|
|
str(_SMOKE_CFG),
|
|
data_dir_override=str(data_dir),
|
|
output_root=str(out_root),
|
|
)
|
|
|
|
models_dir = out_root / "models"
|
|
ckpt = models_dir / "smoke_final_ema.pt"
|
|
self.assertTrue(ckpt.is_file(), f"Expected EMA checkpoint at {ckpt}")
|
|
|
|
from src.models import get_model
|
|
from src.utils import load_config
|
|
|
|
cfg = load_config(str(_SMOKE_CFG))
|
|
vae, kind = get_model(cfg)
|
|
self.assertEqual(kind, "vae")
|
|
|
|
device = torch.device("cpu")
|
|
state = torch.load(ckpt, map_location=device, weights_only=True)
|
|
vae.load_state_dict(state, strict=True)
|
|
vae.eval()
|
|
|
|
with torch.no_grad():
|
|
grid = vae.sample(4, device)
|
|
self.assertEqual(grid.shape, (4, 3, 64, 64))
|
|
self.assertTrue(torch.isfinite(grid).all())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|