91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
"""
|
|
End-to-end smoke test: synthetic DFF layout -> short CV train -> inference.
|
|
|
|
Run from the classifier package root:
|
|
|
|
cd classifier && conda activate drl && python -m unittest tests.smoke_test -v
|
|
"""
|
|
import io
|
|
import subprocess
|
|
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
from PIL import Image
|
|
|
|
_CLASSIFIER_ROOT = Path(__file__).resolve().parents[1]
|
|
_SMOKE_CFG = _CLASSIFIER_ROOT / "configs" / "smoke" / "smoke.json"
|
|
|
|
|
|
def _write_synthetic_dff(data_root: Path, *, n_identities: int = 16) -> None:
|
|
"""Minimal DeepFakeFace-style tree: wiki + three fake sources, shared basenames per identity."""
|
|
sources = ("wiki", "inpainting", "text2img", "insight")
|
|
for i in range(n_identities):
|
|
stem = f"id{i:03d}"
|
|
for src in sources:
|
|
d = data_root / src / f"person_{i:03d}"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
buf = io.BytesIO()
|
|
Image.new(
|
|
"RGB",
|
|
(96, 96),
|
|
color=(min(20 + i * 11, 255), 80, 120 if src == "wiki" else 40),
|
|
).save(buf, format="JPEG")
|
|
# Same filename stem across all sources → one CV group per identity (matches DFF).
|
|
(d / f"{stem}.jpg").write_bytes(buf.getvalue())
|
|
|
|
|
|
class SmokeTrainInferTests(unittest.TestCase):
|
|
def test_local_smoke_train_then_inference(self):
|
|
import tempfile
|
|
|
|
with tempfile.TemporaryDirectory() as td:
|
|
tmp = Path(td)
|
|
data_dir = tmp / "data"
|
|
out_root = tmp / "outputs"
|
|
_write_synthetic_dff(data_dir)
|
|
|
|
sys.path.insert(0, str(_CLASSIFIER_ROOT))
|
|
import run as train_run
|
|
|
|
train_run.main(
|
|
str(_SMOKE_CFG),
|
|
data_dir_override=str(data_dir),
|
|
output_root=str(out_root),
|
|
use_gpu=False,
|
|
)
|
|
|
|
models_dir = out_root / "models"
|
|
ck_fold0 = models_dir / "smoke_fold0_best.pt"
|
|
if not ck_fold0.is_file():
|
|
ck_fold0 = models_dir / "smoke_fold0_final.pt"
|
|
self.assertTrue(
|
|
ck_fold0.is_file(),
|
|
f"Expected fold-0 checkpoint under {models_dir}",
|
|
)
|
|
|
|
probe = tmp / "probe.jpg"
|
|
probe.write_bytes((data_dir / "wiki" / "person_000" / "id000.jpg").read_bytes())
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
str(_CLASSIFIER_ROOT / "tools" / "inference.py"),
|
|
str(probe),
|
|
str(_SMOKE_CFG),
|
|
"--checkpoint",
|
|
str(ck_fold0),
|
|
]
|
|
proc = subprocess.run(
|
|
cmd,
|
|
cwd=str(_CLASSIFIER_ROOT),
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=300,
|
|
)
|
|
self.assertEqual(proc.returncode, 0, proc.stderr + proc.stdout)
|
|
self.assertIn("P(fake)", proc.stdout)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|