Final polish
This commit is contained in:
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
End-to-end smoke test: synthetic DFF layout -> short CV train -> inference.
|
||||
|
||||
Run from the classifier package root:
|
||||
|
||||
cd classifier && conda activate drl && python -m unittest tests.smoke_test -v
|
||||
"""
|
||||
import io
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
_CLASSIFIER_ROOT = Path(__file__).resolve().parents[1]
|
||||
_SMOKE_CFG = _CLASSIFIER_ROOT / "configs" / "smoke" / "smoke.json"
|
||||
|
||||
|
||||
def _write_synthetic_dff(data_root: Path, *, n_identities: int = 16) -> None:
|
||||
"""Minimal DeepFakeFace-style tree: wiki + three fake sources, shared basenames per identity."""
|
||||
sources = ("wiki", "inpainting", "text2img", "insight")
|
||||
for i in range(n_identities):
|
||||
stem = f"id{i:03d}"
|
||||
for src in sources:
|
||||
d = data_root / src / f"person_{i:03d}"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
buf = io.BytesIO()
|
||||
Image.new(
|
||||
"RGB",
|
||||
(96, 96),
|
||||
color=(min(20 + i * 11, 255), 80, 120 if src == "wiki" else 40),
|
||||
).save(buf, format="JPEG")
|
||||
# Same filename stem across all sources → one CV group per identity (matches DFF).
|
||||
(d / f"{stem}.jpg").write_bytes(buf.getvalue())
|
||||
|
||||
|
||||
class SmokeTrainInferTests(unittest.TestCase):
|
||||
def test_local_smoke_train_then_inference(self):
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
tmp = Path(td)
|
||||
data_dir = tmp / "data"
|
||||
out_root = tmp / "outputs"
|
||||
_write_synthetic_dff(data_dir)
|
||||
|
||||
sys.path.insert(0, str(_CLASSIFIER_ROOT))
|
||||
import run as train_run
|
||||
|
||||
train_run.main(
|
||||
str(_SMOKE_CFG),
|
||||
data_dir_override=str(data_dir),
|
||||
output_root=str(out_root),
|
||||
use_gpu=False,
|
||||
)
|
||||
|
||||
models_dir = out_root / "models"
|
||||
ck_fold0 = models_dir / "smoke_fold0_best.pt"
|
||||
if not ck_fold0.is_file():
|
||||
ck_fold0 = models_dir / "smoke_fold0_final.pt"
|
||||
self.assertTrue(
|
||||
ck_fold0.is_file(),
|
||||
f"Expected fold-0 checkpoint under {models_dir}",
|
||||
)
|
||||
|
||||
probe = tmp / "probe.jpg"
|
||||
probe.write_bytes((data_dir / "wiki" / "person_000" / "id000.jpg").read_bytes())
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(_CLASSIFIER_ROOT / "tools" / "inference.py"),
|
||||
str(probe),
|
||||
str(_SMOKE_CFG),
|
||||
"--checkpoint",
|
||||
str(ck_fold0),
|
||||
]
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
cwd=str(_CLASSIFIER_ROOT),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
self.assertEqual(proc.returncode, 0, proc.stderr + proc.stdout)
|
||||
self.assertIn("P(fake)", proc.stdout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user