Clean state
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Run inference on a single image using a trained classifier.
|
||||
|
||||
Usage:
|
||||
python tools/inference.py <image_path> <config.json>
|
||||
python tools/inference.py <image_path> <config.json> --checkpoint <path>
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from src.models import get_model, load_checkpoint
|
||||
from src.preprocessing import get_transforms
|
||||
|
||||
|
||||
# Defaults checkpoint to outputs/models/{run_name}_best.pt when not supplied
|
||||
def predict(image_path, config_path, checkpoint_path=None):
|
||||
image_path = Path(image_path)
|
||||
config_path = Path(config_path)
|
||||
|
||||
if not image_path.exists():
|
||||
print(f"Error: Image not found: {image_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not config_path.exists():
|
||||
print(f"Error: Config not found: {config_path}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error: Invalid JSON in config: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
try:
|
||||
# pretrained=False — we're loading a saved checkpoint, not ImageNet weights
|
||||
model = get_model({**cfg, "pretrained": False})
|
||||
except Exception as e:
|
||||
print(f"Error: Failed to build model: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if checkpoint_path is None:
|
||||
checkpoint_path = ROOT / "outputs" / "models" / f"{cfg['run_name']}_best.pt"
|
||||
else:
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
print(f"Error: Checkpoint not found: {checkpoint_path}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
load_checkpoint(model, checkpoint_path, device)
|
||||
except Exception as e:
|
||||
print(f"Error: Failed to load checkpoint: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
model.eval().to(device)
|
||||
|
||||
try:
|
||||
transform = get_transforms(train=False, image_size=cfg["image_size"])
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
tensor = transform(image).unsqueeze(0).to(device)
|
||||
except Exception as e:
|
||||
print(f"Error: Failed to load/preprocess image: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
with torch.no_grad():
|
||||
logit = model(tensor).squeeze()
|
||||
prob = torch.sigmoid(logit).item()
|
||||
|
||||
label = "FAKE" if prob >= 0.5 else "REAL"
|
||||
confidence = prob if prob >= 0.5 else 1 - prob
|
||||
|
||||
print(f"Image : {image_path}")
|
||||
print(f"Model : {cfg['run_name']} ({cfg['backbone']})")
|
||||
print(f"Device: {device}")
|
||||
print(f"Result: {label} (confidence: {confidence:.1%})")
|
||||
print(f"P(fake): {prob:.4f} P(real): {1-prob:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("image_path", help="Path to the input image")
|
||||
parser.add_argument("config_path", help="Path to the model config JSON")
|
||||
parser.add_argument("--checkpoint", help="Optional path to model checkpoint")
|
||||
args = parser.parse_args()
|
||||
predict(args.image_path, args.config_path, args.checkpoint)
|
||||
Reference in New Issue
Block a user