Final polish
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"extends": "../phase0/_base_phase0.json",
|
||||
"run_name": "smoke",
|
||||
"model": "vae",
|
||||
"latent_dim": 32,
|
||||
"ngf": 16,
|
||||
"epochs": 1,
|
||||
"batch_size": 8,
|
||||
"lr": 0.001,
|
||||
"beta_kl": 0.1,
|
||||
"lambda_perceptual": 0,
|
||||
"lambda_adversarial": 0,
|
||||
"subsample": 1.0,
|
||||
"sample_interval": 1,
|
||||
"fid_interval": 999,
|
||||
"fid_n_real": 64,
|
||||
"num_workers": 0,
|
||||
"ema_decay": 0.9
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Smoke test: tiny wiki-only dataset -> 1 VAE epoch -> decode / sample (inference).
|
||||
|
||||
Run from the generator package root:
|
||||
|
||||
cd generator && conda activate drl && python -m unittest tests.smoke_test -v
|
||||
"""
|
||||
import io
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
_GEN_ROOT = Path(__file__).resolve().parents[1]
|
||||
_SMOKE_CFG = _GEN_ROOT / "configs" / "smoke" / "smoke.json"
|
||||
|
||||
|
||||
def _write_wiki_only(data_root: Path, *, n_folders: int = 12) -> None:
|
||||
"""Minimal layout for GeneratorDataset: data/wiki/<id>/*.jpg"""
|
||||
for i in range(n_folders):
|
||||
d = data_root / "wiki" / f"person_{i:03d}"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
buf = io.BytesIO()
|
||||
Image.new("RGB", (64, 64), color=(30 + i * 15, 100, 180)).save(buf, format="JPEG")
|
||||
(d / "face.jpg").write_bytes(buf.getvalue())
|
||||
|
||||
|
||||
class SmokeGeneratorTrainSampleTests(unittest.TestCase):
|
||||
def test_vae_one_epoch_then_sample(self):
|
||||
import tempfile
|
||||
|
||||
if not _SMOKE_CFG.is_file():
|
||||
self.skipTest("smoke config missing")
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
tmp = Path(td)
|
||||
data_dir = tmp / "data"
|
||||
out_root = tmp / "outputs"
|
||||
_write_wiki_only(data_dir)
|
||||
|
||||
sys.path.insert(0, str(_GEN_ROOT))
|
||||
import run as gen_run
|
||||
|
||||
# Avoid torch.compile in short smoke runs (simpler state_dict / fewer edge cases).
|
||||
with patch("torch.compile", lambda m, **kwargs: m):
|
||||
gen_run.main(
|
||||
str(_SMOKE_CFG),
|
||||
data_dir_override=str(data_dir),
|
||||
output_root=str(out_root),
|
||||
)
|
||||
|
||||
models_dir = out_root / "models"
|
||||
ckpt = models_dir / "smoke_final_ema.pt"
|
||||
self.assertTrue(ckpt.is_file(), f"Expected EMA checkpoint at {ckpt}")
|
||||
|
||||
from src.models import get_model
|
||||
from src.utils import load_config
|
||||
|
||||
cfg = load_config(str(_SMOKE_CFG))
|
||||
vae, kind = get_model(cfg)
|
||||
self.assertEqual(kind, "vae")
|
||||
|
||||
device = torch.device("cpu")
|
||||
state = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
vae.load_state_dict(state, strict=True)
|
||||
vae.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
grid = vae.sample(4, device)
|
||||
self.assertEqual(grid.shape, (4, 3, 64, 64))
|
||||
self.assertTrue(torch.isfinite(grid).all())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -25,8 +25,8 @@ PHASE5_RUNS = {
|
||||
|
||||
|
||||
def _load_cfg(name: str) -> dict:
|
||||
with open(CFG_DIR / PHASE5_RUNS[name]["config"]) as f:
|
||||
return json.load(f)
|
||||
from src.utils import load_config
|
||||
return load_config(str(CFG_DIR / PHASE5_RUNS[name]["config"]))
|
||||
|
||||
|
||||
def _load_model(name: str, cfg: dict, device: torch.device):
|
||||
|
||||
Reference in New Issue
Block a user