Final polish
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"extends": "../shared.json",
|
||||
"run_name": "smoke",
|
||||
"backbone": "simple_cnn",
|
||||
"cnn_preset": "micro",
|
||||
"dropout": 0.0,
|
||||
"epochs": 1,
|
||||
"cv_folds": 2,
|
||||
"image_size": 64,
|
||||
"batch_size": 8,
|
||||
"num_workers": 0,
|
||||
"early_stopping_patience": 0,
|
||||
"subsample": 1.0,
|
||||
"augment": false,
|
||||
"lr": 0.001,
|
||||
"T_max": 1,
|
||||
"data_dir": "data"
|
||||
}
|
||||
+912
-924
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
End-to-end smoke test: synthetic DFF layout -> short CV train -> inference.
|
||||
|
||||
Run from the classifier package root:
|
||||
|
||||
cd classifier && conda activate drl && python -m unittest tests.smoke_test -v
|
||||
"""
|
||||
import io
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
_CLASSIFIER_ROOT = Path(__file__).resolve().parents[1]
|
||||
_SMOKE_CFG = _CLASSIFIER_ROOT / "configs" / "smoke" / "smoke.json"
|
||||
|
||||
|
||||
def _write_synthetic_dff(data_root: Path, *, n_identities: int = 16) -> None:
|
||||
"""Minimal DeepFakeFace-style tree: wiki + three fake sources, shared basenames per identity."""
|
||||
sources = ("wiki", "inpainting", "text2img", "insight")
|
||||
for i in range(n_identities):
|
||||
stem = f"id{i:03d}"
|
||||
for src in sources:
|
||||
d = data_root / src / f"person_{i:03d}"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
buf = io.BytesIO()
|
||||
Image.new(
|
||||
"RGB",
|
||||
(96, 96),
|
||||
color=(min(20 + i * 11, 255), 80, 120 if src == "wiki" else 40),
|
||||
).save(buf, format="JPEG")
|
||||
# Same filename stem across all sources → one CV group per identity (matches DFF).
|
||||
(d / f"{stem}.jpg").write_bytes(buf.getvalue())
|
||||
|
||||
|
||||
class SmokeTrainInferTests(unittest.TestCase):
|
||||
def test_local_smoke_train_then_inference(self):
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
tmp = Path(td)
|
||||
data_dir = tmp / "data"
|
||||
out_root = tmp / "outputs"
|
||||
_write_synthetic_dff(data_dir)
|
||||
|
||||
sys.path.insert(0, str(_CLASSIFIER_ROOT))
|
||||
import run as train_run
|
||||
|
||||
train_run.main(
|
||||
str(_SMOKE_CFG),
|
||||
data_dir_override=str(data_dir),
|
||||
output_root=str(out_root),
|
||||
use_gpu=False,
|
||||
)
|
||||
|
||||
models_dir = out_root / "models"
|
||||
ck_fold0 = models_dir / "smoke_fold0_best.pt"
|
||||
if not ck_fold0.is_file():
|
||||
ck_fold0 = models_dir / "smoke_fold0_final.pt"
|
||||
self.assertTrue(
|
||||
ck_fold0.is_file(),
|
||||
f"Expected fold-0 checkpoint under {models_dir}",
|
||||
)
|
||||
|
||||
probe = tmp / "probe.jpg"
|
||||
probe.write_bytes((data_dir / "wiki" / "person_000" / "id000.jpg").read_bytes())
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(_CLASSIFIER_ROOT / "tools" / "inference.py"),
|
||||
str(probe),
|
||||
str(_SMOKE_CFG),
|
||||
"--checkpoint",
|
||||
str(ck_fold0),
|
||||
]
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
cwd=str(_CLASSIFIER_ROOT),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
self.assertEqual(proc.returncode, 0, proc.stderr + proc.stdout)
|
||||
self.assertIn("P(fake)", proc.stdout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -18,8 +18,9 @@ def parse_args():
|
||||
|
||||
|
||||
def iter_config_paths(config_root: Path):
|
||||
for sub in ("phase1", "phase2"):
|
||||
yield from sorted((config_root / sub).glob("*.json"))
|
||||
for sub in sorted(config_root.iterdir()):
|
||||
if sub.is_dir() and sub.name not in ("smoke",):
|
||||
yield from sorted(sub.glob("*.json"))
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
Download the DeepFakeFace dataset from HuggingFace and extract it.
|
||||
|
||||
Usage:
|
||||
python tools/download_data.py
|
||||
python tools/download_data.py --data-dir /mnt/data/DFF
|
||||
python tools/fetch_ds.py
|
||||
python tools/fetch_ds.py --data-dir /mnt/data/DFF
|
||||
"""
|
||||
import argparse
|
||||
import zipfile
|
||||
|
||||
@@ -19,9 +19,27 @@ from PIL import Image
|
||||
|
||||
from src.models import get_model, load_checkpoint
|
||||
from src.preprocessing import get_transforms
|
||||
from src.utils import load_config
|
||||
|
||||
|
||||
# Defaults checkpoint to outputs/models/{run_name}_best.pt when not supplied
|
||||
def _default_checkpoint(cfg: dict, checkpoint_path: Path | None) -> Path:
|
||||
"""Resolve checkpoint: explicit path, single-fold `*_best.pt`, or CV `*_fold{k}_best.pt` / `*_final.pt`."""
|
||||
run_name = cfg["run_name"]
|
||||
models_dir = ROOT / "outputs" / "models"
|
||||
if checkpoint_path is not None:
|
||||
return Path(checkpoint_path)
|
||||
candidates: list[Path] = [models_dir / f"{run_name}_best.pt"]
|
||||
for k in range(32):
|
||||
candidates.append(models_dir / f"{run_name}_fold{k}_best.pt")
|
||||
for k in range(32):
|
||||
candidates.append(models_dir / f"{run_name}_fold{k}_final.pt")
|
||||
for p in candidates:
|
||||
if p.is_file():
|
||||
return p
|
||||
return models_dir / f"{run_name}_best.pt"
|
||||
|
||||
|
||||
# Defaults checkpoint under outputs/models/ (single-run or CV best/final).
|
||||
def predict(image_path, config_path, checkpoint_path=None):
|
||||
image_path = Path(image_path)
|
||||
config_path = Path(config_path)
|
||||
@@ -35,10 +53,9 @@ def predict(image_path, config_path, checkpoint_path=None):
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error: Invalid JSON in config: {e}")
|
||||
cfg = load_config(str(config_path))
|
||||
except (json.JSONDecodeError, OSError, ValueError) as e:
|
||||
print(f"Error: Failed to load config: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -50,10 +67,7 @@ def predict(image_path, config_path, checkpoint_path=None):
|
||||
print(f"Error: Failed to build model: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if checkpoint_path is None:
|
||||
checkpoint_path = ROOT / "outputs" / "models" / f"{cfg['run_name']}_best.pt"
|
||||
else:
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
checkpoint_path = _default_checkpoint(cfg, Path(checkpoint_path) if checkpoint_path else None)
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
print(f"Error: Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
Reference in New Issue
Block a user