99 lines
3.1 KiB
Python
99 lines
3.1 KiB
Python
"""
|
|
Run inference on a single image using a trained classifier.
|
|
|
|
Usage:
|
|
python tools/inference.py <image_path> <config.json>
|
|
python tools/inference.py <image_path> <config.json> --checkpoint <path>
|
|
"""
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from src.models import get_model, load_checkpoint
|
|
from src.preprocessing import get_transforms
|
|
|
|
|
|
# Defaults checkpoint to outputs/models/{run_name}_best.pt when not supplied
|
|
def predict(image_path, config_path, checkpoint_path=None):
|
|
image_path = Path(image_path)
|
|
config_path = Path(config_path)
|
|
|
|
if not image_path.exists():
|
|
print(f"Error: Image not found: {image_path}")
|
|
sys.exit(1)
|
|
|
|
if not config_path.exists():
|
|
print(f"Error: Config not found: {config_path}")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
with open(config_path) as f:
|
|
cfg = json.load(f)
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error: Invalid JSON in config: {e}")
|
|
sys.exit(1)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
try:
|
|
# pretrained=False — we're loading a saved checkpoint, not ImageNet weights
|
|
model = get_model({**cfg, "pretrained": False})
|
|
except Exception as e:
|
|
print(f"Error: Failed to build model: {e}")
|
|
sys.exit(1)
|
|
|
|
if checkpoint_path is None:
|
|
checkpoint_path = ROOT / "outputs" / "models" / f"{cfg['run_name']}_best.pt"
|
|
else:
|
|
checkpoint_path = Path(checkpoint_path)
|
|
|
|
if not checkpoint_path.exists():
|
|
print(f"Error: Checkpoint not found: {checkpoint_path}")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
load_checkpoint(model, checkpoint_path, device)
|
|
except Exception as e:
|
|
print(f"Error: Failed to load checkpoint: {e}")
|
|
sys.exit(1)
|
|
|
|
model.eval().to(device)
|
|
|
|
try:
|
|
transform = get_transforms(train=False, image_size=cfg["image_size"])
|
|
image = Image.open(image_path).convert("RGB")
|
|
tensor = transform(image).unsqueeze(0).to(device)
|
|
except Exception as e:
|
|
print(f"Error: Failed to load/preprocess image: {e}")
|
|
sys.exit(1)
|
|
|
|
with torch.no_grad():
|
|
logit = model(tensor).squeeze()
|
|
prob = torch.sigmoid(logit).item()
|
|
|
|
label = "FAKE" if prob >= 0.5 else "REAL"
|
|
confidence = prob if prob >= 0.5 else 1 - prob
|
|
|
|
print(f"Image : {image_path}")
|
|
print(f"Model : {cfg['run_name']} ({cfg['backbone']})")
|
|
print(f"Device: {device}")
|
|
print(f"Result: {label} (confidence: {confidence:.1%})")
|
|
print(f"P(fake): {prob:.4f} P(real): {1-prob:.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("image_path", help="Path to the input image")
|
|
parser.add_argument("config_path", help="Path to the model config JSON")
|
|
parser.add_argument("--checkpoint", help="Optional path to model checkpoint")
|
|
args = parser.parse_args()
|
|
predict(args.image_path, args.config_path, args.checkpoint)
|