Final polish
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Smoke test: tiny wiki-only dataset -> 1 VAE epoch -> decode / sample (inference).
|
||||
|
||||
Run from the generator package root:
|
||||
|
||||
cd generator && conda activate drl && python -m unittest tests.smoke_test -v
|
||||
"""
|
||||
import io
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
_GEN_ROOT = Path(__file__).resolve().parents[1]
|
||||
_SMOKE_CFG = _GEN_ROOT / "configs" / "smoke" / "smoke.json"
|
||||
|
||||
|
||||
def _write_wiki_only(data_root: Path, *, n_folders: int = 12) -> None:
|
||||
"""Minimal layout for GeneratorDataset: data/wiki/<id>/*.jpg"""
|
||||
for i in range(n_folders):
|
||||
d = data_root / "wiki" / f"person_{i:03d}"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
buf = io.BytesIO()
|
||||
Image.new("RGB", (64, 64), color=(30 + i * 15, 100, 180)).save(buf, format="JPEG")
|
||||
(d / "face.jpg").write_bytes(buf.getvalue())
|
||||
|
||||
|
||||
class SmokeGeneratorTrainSampleTests(unittest.TestCase):
|
||||
def test_vae_one_epoch_then_sample(self):
|
||||
import tempfile
|
||||
|
||||
if not _SMOKE_CFG.is_file():
|
||||
self.skipTest("smoke config missing")
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
tmp = Path(td)
|
||||
data_dir = tmp / "data"
|
||||
out_root = tmp / "outputs"
|
||||
_write_wiki_only(data_dir)
|
||||
|
||||
sys.path.insert(0, str(_GEN_ROOT))
|
||||
import run as gen_run
|
||||
|
||||
# Avoid torch.compile in short smoke runs (simpler state_dict / fewer edge cases).
|
||||
with patch("torch.compile", lambda m, **kwargs: m):
|
||||
gen_run.main(
|
||||
str(_SMOKE_CFG),
|
||||
data_dir_override=str(data_dir),
|
||||
output_root=str(out_root),
|
||||
)
|
||||
|
||||
models_dir = out_root / "models"
|
||||
ckpt = models_dir / "smoke_final_ema.pt"
|
||||
self.assertTrue(ckpt.is_file(), f"Expected EMA checkpoint at {ckpt}")
|
||||
|
||||
from src.models import get_model
|
||||
from src.utils import load_config
|
||||
|
||||
cfg = load_config(str(_SMOKE_CFG))
|
||||
vae, kind = get_model(cfg)
|
||||
self.assertEqual(kind, "vae")
|
||||
|
||||
device = torch.device("cpu")
|
||||
state = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
vae.load_state_dict(state, strict=True)
|
||||
vae.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
grid = vae.sample(4, device)
|
||||
self.assertEqual(grid.shape, (4, 3, 64, 64))
|
||||
self.assertTrue(torch.isfinite(grid).all())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user