Files
DRL_PROJ/classifier/tools/inference.py
T
Johnny Fernandes bb3dfb92d5 Clean state
2026-04-30 01:25:39 +01:00

99 lines
3.1 KiB
Python

"""
Run inference on a single image using a trained classifier.
Usage:
python tools/inference.py <image_path> <config.json>
python tools/inference.py <image_path> <config.json> --checkpoint <path>
"""
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from PIL import Image
from src.models import get_model, load_checkpoint
from src.preprocessing import get_transforms
# Defaults checkpoint to outputs/models/{run_name}_best.pt when not supplied
def predict(image_path, config_path, checkpoint_path=None):
image_path = Path(image_path)
config_path = Path(config_path)
if not image_path.exists():
print(f"Error: Image not found: {image_path}")
sys.exit(1)
if not config_path.exists():
print(f"Error: Config not found: {config_path}")
sys.exit(1)
try:
with open(config_path) as f:
cfg = json.load(f)
except json.JSONDecodeError as e:
print(f"Error: Invalid JSON in config: {e}")
sys.exit(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# pretrained=False — we're loading a saved checkpoint, not ImageNet weights
model = get_model({**cfg, "pretrained": False})
except Exception as e:
print(f"Error: Failed to build model: {e}")
sys.exit(1)
if checkpoint_path is None:
checkpoint_path = ROOT / "outputs" / "models" / f"{cfg['run_name']}_best.pt"
else:
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
print(f"Error: Checkpoint not found: {checkpoint_path}")
sys.exit(1)
try:
load_checkpoint(model, checkpoint_path, device)
except Exception as e:
print(f"Error: Failed to load checkpoint: {e}")
sys.exit(1)
model.eval().to(device)
try:
transform = get_transforms(train=False, image_size=cfg["image_size"])
image = Image.open(image_path).convert("RGB")
tensor = transform(image).unsqueeze(0).to(device)
except Exception as e:
print(f"Error: Failed to load/preprocess image: {e}")
sys.exit(1)
with torch.no_grad():
logit = model(tensor).squeeze()
prob = torch.sigmoid(logit).item()
label = "FAKE" if prob >= 0.5 else "REAL"
confidence = prob if prob >= 0.5 else 1 - prob
print(f"Image : {image_path}")
print(f"Model : {cfg['run_name']} ({cfg['backbone']})")
print(f"Device: {device}")
print(f"Result: {label} (confidence: {confidence:.1%})")
print(f"P(fake): {prob:.4f} P(real): {1-prob:.4f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("image_path", help="Path to the input image")
parser.add_argument("config_path", help="Path to the model config JSON")
parser.add_argument("--checkpoint", help="Optional path to model checkpoint")
args = parser.parse_args()
predict(args.image_path, args.config_path, args.checkpoint)