Final polish

This commit is contained in:
Johnny Fernandes
2026-05-14 21:16:03 +01:00
parent 3bff7eefb0
commit afd26f47d2
732 changed files with 4149 additions and 79134 deletions
+19
View File
@@ -0,0 +1,19 @@
{
"extends": "../phase0/_base_phase0.json",
"run_name": "smoke",
"model": "vae",
"latent_dim": 32,
"ngf": 16,
"epochs": 1,
"batch_size": 8,
"lr": 0.001,
"beta_kl": 0.1,
"lambda_perceptual": 0,
"lambda_adversarial": 0,
"subsample": 1.0,
"sample_interval": 1,
"fid_interval": 999,
"fid_n_real": 64,
"num_workers": 0,
"ema_decay": 0.9
}
View File
+78
View File
@@ -0,0 +1,78 @@
"""
Smoke test: tiny wiki-only dataset -> 1 VAE epoch -> decode / sample (inference).
Run from the generator package root:
cd generator && conda activate drl && python -m unittest tests.smoke_test -v
"""
import io
import sys
import unittest
from pathlib import Path
from unittest.mock import patch
import torch
from PIL import Image
_GEN_ROOT = Path(__file__).resolve().parents[1]
_SMOKE_CFG = _GEN_ROOT / "configs" / "smoke" / "smoke.json"
def _write_wiki_only(data_root: Path, *, n_folders: int = 12) -> None:
"""Minimal layout for GeneratorDataset: data/wiki/<id>/*.jpg"""
for i in range(n_folders):
d = data_root / "wiki" / f"person_{i:03d}"
d.mkdir(parents=True, exist_ok=True)
buf = io.BytesIO()
Image.new("RGB", (64, 64), color=(30 + i * 15, 100, 180)).save(buf, format="JPEG")
(d / "face.jpg").write_bytes(buf.getvalue())
class SmokeGeneratorTrainSampleTests(unittest.TestCase):
def test_vae_one_epoch_then_sample(self):
import tempfile
if not _SMOKE_CFG.is_file():
self.skipTest("smoke config missing")
with tempfile.TemporaryDirectory() as td:
tmp = Path(td)
data_dir = tmp / "data"
out_root = tmp / "outputs"
_write_wiki_only(data_dir)
sys.path.insert(0, str(_GEN_ROOT))
import run as gen_run
# Avoid torch.compile in short smoke runs (simpler state_dict / fewer edge cases).
with patch("torch.compile", lambda m, **kwargs: m):
gen_run.main(
str(_SMOKE_CFG),
data_dir_override=str(data_dir),
output_root=str(out_root),
)
models_dir = out_root / "models"
ckpt = models_dir / "smoke_final_ema.pt"
self.assertTrue(ckpt.is_file(), f"Expected EMA checkpoint at {ckpt}")
from src.models import get_model
from src.utils import load_config
cfg = load_config(str(_SMOKE_CFG))
vae, kind = get_model(cfg)
self.assertEqual(kind, "vae")
device = torch.device("cpu")
state = torch.load(ckpt, map_location=device, weights_only=True)
vae.load_state_dict(state, strict=True)
vae.eval()
with torch.no_grad():
grid = vae.sample(4, device)
self.assertEqual(grid.shape, (4, 3, 64, 64))
self.assertTrue(torch.isfinite(grid).all())
if __name__ == "__main__":
unittest.main()
+2 -2
View File
@@ -25,8 +25,8 @@ PHASE5_RUNS = {
def _load_cfg(name: str) -> dict:
with open(CFG_DIR / PHASE5_RUNS[name]["config"]) as f:
return json.load(f)
from src.utils import load_config
return load_config(str(CFG_DIR / PHASE5_RUNS[name]["config"]))
def _load_model(name: str, cfg: dict, device: torch.device):