Final polish
This commit is contained in:
@@ -18,8 +18,9 @@ def parse_args():
|
||||
|
||||
|
||||
def iter_config_paths(config_root: Path):
|
||||
for sub in ("phase1", "phase2"):
|
||||
yield from sorted((config_root / sub).glob("*.json"))
|
||||
for sub in sorted(config_root.iterdir()):
|
||||
if sub.is_dir() and sub.name not in ("smoke",):
|
||||
yield from sorted(sub.glob("*.json"))
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
Download the DeepFakeFace dataset from HuggingFace and extract it.
|
||||
|
||||
Usage:
|
||||
python tools/download_data.py
|
||||
python tools/download_data.py --data-dir /mnt/data/DFF
|
||||
python tools/fetch_ds.py
|
||||
python tools/fetch_ds.py --data-dir /mnt/data/DFF
|
||||
"""
|
||||
import argparse
|
||||
import zipfile
|
||||
|
||||
@@ -19,9 +19,27 @@ from PIL import Image
|
||||
|
||||
from src.models import get_model, load_checkpoint
|
||||
from src.preprocessing import get_transforms
|
||||
from src.utils import load_config
|
||||
|
||||
|
||||
# Defaults checkpoint to outputs/models/{run_name}_best.pt when not supplied
|
||||
def _default_checkpoint(cfg: dict, checkpoint_path: Path | None) -> Path:
|
||||
"""Resolve checkpoint: explicit path, single-fold `*_best.pt`, or CV `*_fold{k}_best.pt` / `*_final.pt`."""
|
||||
run_name = cfg["run_name"]
|
||||
models_dir = ROOT / "outputs" / "models"
|
||||
if checkpoint_path is not None:
|
||||
return Path(checkpoint_path)
|
||||
candidates: list[Path] = [models_dir / f"{run_name}_best.pt"]
|
||||
for k in range(32):
|
||||
candidates.append(models_dir / f"{run_name}_fold{k}_best.pt")
|
||||
for k in range(32):
|
||||
candidates.append(models_dir / f"{run_name}_fold{k}_final.pt")
|
||||
for p in candidates:
|
||||
if p.is_file():
|
||||
return p
|
||||
return models_dir / f"{run_name}_best.pt"
|
||||
|
||||
|
||||
# Defaults checkpoint under outputs/models/ (single-run or CV best/final).
|
||||
def predict(image_path, config_path, checkpoint_path=None):
|
||||
image_path = Path(image_path)
|
||||
config_path = Path(config_path)
|
||||
@@ -35,10 +53,9 @@ def predict(image_path, config_path, checkpoint_path=None):
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error: Invalid JSON in config: {e}")
|
||||
cfg = load_config(str(config_path))
|
||||
except (json.JSONDecodeError, OSError, ValueError) as e:
|
||||
print(f"Error: Failed to load config: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -50,10 +67,7 @@ def predict(image_path, config_path, checkpoint_path=None):
|
||||
print(f"Error: Failed to build model: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if checkpoint_path is None:
|
||||
checkpoint_path = ROOT / "outputs" / "models" / f"{cfg['run_name']}_best.pt"
|
||||
else:
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
checkpoint_path = _default_checkpoint(cfg, Path(checkpoint_path) if checkpoint_path else None)
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
print(f"Error: Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
Reference in New Issue
Block a user