Final polish

This commit is contained in:
Johnny Fernandes
2026-05-14 21:16:03 +01:00
parent 3bff7eefb0
commit afd26f47d2
732 changed files with 4149 additions and 79134 deletions
+23 -9
View File
@@ -19,9 +19,27 @@ from PIL import Image
from src.models import get_model, load_checkpoint
from src.preprocessing import get_transforms
from src.utils import load_config
# Defaults checkpoint to outputs/models/{run_name}_best.pt when not supplied
def _default_checkpoint(cfg: dict, checkpoint_path: Path | None) -> Path:
"""Resolve checkpoint: explicit path, single-fold `*_best.pt`, or CV `*_fold{k}_best.pt` / `*_final.pt`."""
run_name = cfg["run_name"]
models_dir = ROOT / "outputs" / "models"
if checkpoint_path is not None:
return Path(checkpoint_path)
candidates: list[Path] = [models_dir / f"{run_name}_best.pt"]
for k in range(32):
candidates.append(models_dir / f"{run_name}_fold{k}_best.pt")
for k in range(32):
candidates.append(models_dir / f"{run_name}_fold{k}_final.pt")
for p in candidates:
if p.is_file():
return p
return models_dir / f"{run_name}_best.pt"
# Defaults checkpoint under outputs/models/ (single-run or CV best/final).
def predict(image_path, config_path, checkpoint_path=None):
image_path = Path(image_path)
config_path = Path(config_path)
@@ -35,10 +53,9 @@ def predict(image_path, config_path, checkpoint_path=None):
sys.exit(1)
try:
with open(config_path) as f:
cfg = json.load(f)
except json.JSONDecodeError as e:
print(f"Error: Invalid JSON in config: {e}")
cfg = load_config(str(config_path))
except (json.JSONDecodeError, OSError, ValueError) as e:
print(f"Error: Failed to load config: {e}")
sys.exit(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -50,10 +67,7 @@ def predict(image_path, config_path, checkpoint_path=None):
print(f"Error: Failed to build model: {e}")
sys.exit(1)
if checkpoint_path is None:
checkpoint_path = ROOT / "outputs" / "models" / f"{cfg['run_name']}_best.pt"
else:
checkpoint_path = Path(checkpoint_path)
checkpoint_path = _default_checkpoint(cfg, Path(checkpoint_path) if checkpoint_path else None)
if not checkpoint_path.exists():
print(f"Error: Checkpoint not found: {checkpoint_path}")