Final polish

This commit is contained in:
Johnny Fernandes
2026-05-14 21:16:03 +01:00
parent 3bff7eefb0
commit afd26f47d2
732 changed files with 4149 additions and 79134 deletions
+90
View File
@@ -0,0 +1,90 @@
"""
End-to-end smoke test: synthetic DFF layout -> short CV train -> inference.
Run from the classifier package root:
cd classifier && conda activate drl && python -m unittest tests.smoke_test -v
"""
import io
import subprocess
import sys
import unittest
from pathlib import Path
from PIL import Image
_CLASSIFIER_ROOT = Path(__file__).resolve().parents[1]
_SMOKE_CFG = _CLASSIFIER_ROOT / "configs" / "smoke" / "smoke.json"
def _write_synthetic_dff(data_root: Path, *, n_identities: int = 16) -> None:
"""Minimal DeepFakeFace-style tree: wiki + three fake sources, shared basenames per identity."""
sources = ("wiki", "inpainting", "text2img", "insight")
for i in range(n_identities):
stem = f"id{i:03d}"
for src in sources:
d = data_root / src / f"person_{i:03d}"
d.mkdir(parents=True, exist_ok=True)
buf = io.BytesIO()
Image.new(
"RGB",
(96, 96),
color=(min(20 + i * 11, 255), 80, 120 if src == "wiki" else 40),
).save(buf, format="JPEG")
# Same filename stem across all sources → one CV group per identity (matches DFF).
(d / f"{stem}.jpg").write_bytes(buf.getvalue())
class SmokeTrainInferTests(unittest.TestCase):
def test_local_smoke_train_then_inference(self):
import tempfile
with tempfile.TemporaryDirectory() as td:
tmp = Path(td)
data_dir = tmp / "data"
out_root = tmp / "outputs"
_write_synthetic_dff(data_dir)
sys.path.insert(0, str(_CLASSIFIER_ROOT))
import run as train_run
train_run.main(
str(_SMOKE_CFG),
data_dir_override=str(data_dir),
output_root=str(out_root),
use_gpu=False,
)
models_dir = out_root / "models"
ck_fold0 = models_dir / "smoke_fold0_best.pt"
if not ck_fold0.is_file():
ck_fold0 = models_dir / "smoke_fold0_final.pt"
self.assertTrue(
ck_fold0.is_file(),
f"Expected fold-0 checkpoint under {models_dir}",
)
probe = tmp / "probe.jpg"
probe.write_bytes((data_dir / "wiki" / "person_000" / "id000.jpg").read_bytes())
cmd = [
sys.executable,
str(_CLASSIFIER_ROOT / "tools" / "inference.py"),
str(probe),
str(_SMOKE_CFG),
"--checkpoint",
str(ck_fold0),
]
proc = subprocess.run(
cmd,
cwd=str(_CLASSIFIER_ROOT),
capture_output=True,
text=True,
timeout=300,
)
self.assertEqual(proc.returncode, 0, proc.stderr + proc.stdout)
self.assertIn("P(fake)", proc.stdout)
if __name__ == "__main__":
unittest.main()