Phase 3 classifier

This commit is contained in:
Johnny Fernandes
2026-05-05 00:36:37 +01:00
parent 799ec0c13a
commit 66913b2354
10 changed files with 111 additions and 29 deletions
+8
View File
@@ -0,0 +1,8 @@
{
"pretrained": true,
"epochs": 15,
"image_size": 224,
"subsample": 0.2,
"augment": false,
"data_dir": "cropped/classifier"
}
@@ -0,0 +1,5 @@
{
"extends": "_base.json",
"run_name": "p3_convnext_tiny",
"backbone": "convnext_tiny"
}
@@ -0,0 +1,5 @@
{
"extends": "_base.json",
"run_name": "p3_efficientnet_b0",
"backbone": "efficientnet_b0"
}
@@ -0,0 +1,5 @@
{
"extends": "_base.json",
"run_name": "p3_mobilenetv3_small",
"backbone": "mobilenet_v3_small"
}
@@ -0,0 +1,5 @@
{
"extends": "_base.json",
"run_name": "p3_resnet34",
"backbone": "resnet34"
}
@@ -0,0 +1,5 @@
{
"extends": "_base.json",
"run_name": "p3_resnet50",
"backbone": "resnet50"
}
+1 -1
View File
@@ -30,6 +30,6 @@ def load_checkpoint(model: nn.Module, path: Union[Path, str], device) -> nn.Modu
# Importing the modules triggers their register() calls
from src.models import simple_cnn, resnet, efficientnet # noqa: E402, F401
from src.models import simple_cnn, resnet, efficientnet, mobilenet, convnext # noqa: E402, F401
__all__ = ["get_model", "load_checkpoint", "register"]
+27
View File
@@ -0,0 +1,27 @@
import torch.nn as nn
from torchvision import models
from src.models import register
# ConvNeXt's classification head is a Sequential ending in (LayerNorm2d, Flatten, Linear); [-1] targets the Linear
def build(cfg: dict) -> nn.Module:
backbone = cfg.get("backbone", "convnext_tiny")
pretrained = cfg.get("pretrained", True)
if backbone == "convnext_tiny":
weights = models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None
model = models.convnext_tiny(weights=weights)
elif backbone == "convnext_small":
weights = models.ConvNeXt_Small_Weights.DEFAULT if pretrained else None
model = models.convnext_small(weights=weights)
else:
raise ValueError(f"Unsupported ConvNeXt backbone: {backbone!r}. Supported: convnext_tiny, convnext_small")
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
register("convnext_tiny", build)
register("convnext_small", build)
+27
View File
@@ -0,0 +1,27 @@
import torch.nn as nn
from torchvision import models
from src.models import register
# MobileNetV3's classification head is a Sequential; [-1] targets the final Linear
def build(cfg: dict) -> nn.Module:
backbone = cfg.get("backbone", "mobilenet_v3_small")
pretrained = cfg.get("pretrained", True)
if backbone == "mobilenet_v3_small":
weights = models.MobileNet_V3_Small_Weights.DEFAULT if pretrained else None
model = models.mobilenet_v3_small(weights=weights)
elif backbone == "mobilenet_v3_large":
weights = models.MobileNet_V3_Large_Weights.DEFAULT if pretrained else None
model = models.mobilenet_v3_large(weights=weights)
else:
raise ValueError(f"Unsupported MobileNet backbone: {backbone!r}. Supported: mobilenet_v3_small, mobilenet_v3_large")
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
register("mobilenet_v3_small", build)
register("mobilenet_v3_large", build)