Files
DRL_PROJ/classifier/tools/inference.py
T
Johnny Fernandes afd26f47d2 Final polish
2026-05-14 21:16:03 +01:00

113 lines
3.7 KiB
Python

"""
Run inference on a single image using a trained classifier.
Usage:
python tools/inference.py <image_path> <config.json>
python tools/inference.py <image_path> <config.json> --checkpoint <path>
"""
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from PIL import Image
from src.models import get_model, load_checkpoint
from src.preprocessing import get_transforms
from src.utils import load_config
def _default_checkpoint(cfg: dict, checkpoint_path: Path | None) -> Path:
"""Resolve checkpoint: explicit path, single-fold `*_best.pt`, or CV `*_fold{k}_best.pt` / `*_final.pt`."""
run_name = cfg["run_name"]
models_dir = ROOT / "outputs" / "models"
if checkpoint_path is not None:
return Path(checkpoint_path)
candidates: list[Path] = [models_dir / f"{run_name}_best.pt"]
for k in range(32):
candidates.append(models_dir / f"{run_name}_fold{k}_best.pt")
for k in range(32):
candidates.append(models_dir / f"{run_name}_fold{k}_final.pt")
for p in candidates:
if p.is_file():
return p
return models_dir / f"{run_name}_best.pt"
# Defaults checkpoint under outputs/models/ (single-run or CV best/final).
def predict(image_path, config_path, checkpoint_path=None):
image_path = Path(image_path)
config_path = Path(config_path)
if not image_path.exists():
print(f"Error: Image not found: {image_path}")
sys.exit(1)
if not config_path.exists():
print(f"Error: Config not found: {config_path}")
sys.exit(1)
try:
cfg = load_config(str(config_path))
except (json.JSONDecodeError, OSError, ValueError) as e:
print(f"Error: Failed to load config: {e}")
sys.exit(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# pretrained=False — we're loading a saved checkpoint, not ImageNet weights
model = get_model({**cfg, "pretrained": False})
except Exception as e:
print(f"Error: Failed to build model: {e}")
sys.exit(1)
checkpoint_path = _default_checkpoint(cfg, Path(checkpoint_path) if checkpoint_path else None)
if not checkpoint_path.exists():
print(f"Error: Checkpoint not found: {checkpoint_path}")
sys.exit(1)
try:
load_checkpoint(model, checkpoint_path, device)
except Exception as e:
print(f"Error: Failed to load checkpoint: {e}")
sys.exit(1)
model.eval().to(device)
try:
transform = get_transforms(train=False, image_size=cfg["image_size"])
image = Image.open(image_path).convert("RGB")
tensor = transform(image).unsqueeze(0).to(device)
except Exception as e:
print(f"Error: Failed to load/preprocess image: {e}")
sys.exit(1)
with torch.no_grad():
logit = model(tensor).squeeze()
prob = torch.sigmoid(logit).item()
label = "FAKE" if prob >= 0.5 else "REAL"
confidence = prob if prob >= 0.5 else 1 - prob
print(f"Image : {image_path}")
print(f"Model : {cfg['run_name']} ({cfg['backbone']})")
print(f"Device: {device}")
print(f"Result: {label} (confidence: {confidence:.1%})")
print(f"P(fake): {prob:.4f} P(real): {1-prob:.4f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("image_path", help="Path to the input image")
parser.add_argument("config_path", help="Path to the model config JSON")
parser.add_argument("--checkpoint", help="Optional path to model checkpoint")
args = parser.parse_args()
predict(args.image_path, args.config_path, args.checkpoint)