""" Run inference on a single image using a trained classifier. Usage: python tools/inference.py python tools/inference.py --checkpoint """ import argparse import json import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import torch from PIL import Image from src.models import get_model, load_checkpoint from src.preprocessing import get_transforms # Defaults checkpoint to outputs/models/{run_name}_best.pt when not supplied def predict(image_path, config_path, checkpoint_path=None): image_path = Path(image_path) config_path = Path(config_path) if not image_path.exists(): print(f"Error: Image not found: {image_path}") sys.exit(1) if not config_path.exists(): print(f"Error: Config not found: {config_path}") sys.exit(1) try: with open(config_path) as f: cfg = json.load(f) except json.JSONDecodeError as e: print(f"Error: Invalid JSON in config: {e}") sys.exit(1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: # pretrained=False — we're loading a saved checkpoint, not ImageNet weights model = get_model({**cfg, "pretrained": False}) except Exception as e: print(f"Error: Failed to build model: {e}") sys.exit(1) if checkpoint_path is None: checkpoint_path = ROOT / "outputs" / "models" / f"{cfg['run_name']}_best.pt" else: checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): print(f"Error: Checkpoint not found: {checkpoint_path}") sys.exit(1) try: load_checkpoint(model, checkpoint_path, device) except Exception as e: print(f"Error: Failed to load checkpoint: {e}") sys.exit(1) model.eval().to(device) try: transform = get_transforms(train=False, image_size=cfg["image_size"]) image = Image.open(image_path).convert("RGB") tensor = transform(image).unsqueeze(0).to(device) except Exception as e: print(f"Error: Failed to load/preprocess image: {e}") sys.exit(1) with torch.no_grad(): logit = model(tensor).squeeze() prob = torch.sigmoid(logit).item() label = "FAKE" if prob >= 0.5 else "REAL" confidence = prob if prob >= 0.5 else 1 - prob print(f"Image : {image_path}") print(f"Model : {cfg['run_name']} ({cfg['backbone']})") print(f"Device: {device}") print(f"Result: {label} (confidence: {confidence:.1%})") print(f"P(fake): {prob:.4f} P(real): {1-prob:.4f}") if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("image_path", help="Path to the input image") parser.add_argument("config_path", help="Path to the model config JSON") parser.add_argument("--checkpoint", help="Optional path to model checkpoint") args = parser.parse_args() predict(args.image_path, args.config_path, args.checkpoint)