Phase 3 classifier
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"pretrained": true,
|
||||
"epochs": 15,
|
||||
"image_size": 224,
|
||||
"subsample": 0.2,
|
||||
"augment": false,
|
||||
"data_dir": "cropped/classifier"
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "p3_base.json",
|
||||
"run_name": "p3_convnext_tiny",
|
||||
"backbone": "convnext_tiny"
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "p3_base.json",
|
||||
"run_name": "p3_efficientnet_b0",
|
||||
"backbone": "efficientnet_b0"
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "p3_base.json",
|
||||
"run_name": "p3_mobilenetv3_small",
|
||||
"backbone": "mobilenet_v3_small"
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "p3_base.json",
|
||||
"run_name": "p3_resnet34",
|
||||
"backbone": "resnet34"
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"extends": "p3_base.json",
|
||||
"run_name": "p3_resnet50",
|
||||
"backbone": "resnet50"
|
||||
}
|
||||
@@ -30,6 +30,6 @@ def load_checkpoint(model: nn.Module, path: Union[Path, str], device) -> nn.Modu
|
||||
|
||||
|
||||
# Importing the modules triggers their register() calls
|
||||
from src.models import simple_cnn, resnet, efficientnet # noqa: E402, F401
|
||||
from src.models import simple_cnn, resnet, efficientnet, mobilenet, convnext # noqa: E402, F401
|
||||
|
||||
__all__ = ["get_model", "load_checkpoint", "register"]
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
from src.models import register
|
||||
|
||||
|
||||
# ConvNeXt's classification head is a Sequential ending in (LayerNorm2d, Flatten, Linear); [-1] targets the Linear
|
||||
def build(cfg: dict) -> nn.Module:
|
||||
backbone = cfg.get("backbone", "convnext_tiny")
|
||||
pretrained = cfg.get("pretrained", True)
|
||||
|
||||
if backbone == "convnext_tiny":
|
||||
weights = models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None
|
||||
model = models.convnext_tiny(weights=weights)
|
||||
elif backbone == "convnext_small":
|
||||
weights = models.ConvNeXt_Small_Weights.DEFAULT if pretrained else None
|
||||
model = models.convnext_small(weights=weights)
|
||||
else:
|
||||
raise ValueError(f"Unsupported ConvNeXt backbone: {backbone!r}. Supported: convnext_tiny, convnext_small")
|
||||
|
||||
in_features = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(in_features, 1)
|
||||
return model
|
||||
|
||||
|
||||
register("convnext_tiny", build)
|
||||
register("convnext_small", build)
|
||||
@@ -0,0 +1,27 @@
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
from src.models import register
|
||||
|
||||
|
||||
# MobileNetV3's classification head is a Sequential; [-1] targets the final Linear
|
||||
def build(cfg: dict) -> nn.Module:
|
||||
backbone = cfg.get("backbone", "mobilenet_v3_small")
|
||||
pretrained = cfg.get("pretrained", True)
|
||||
|
||||
if backbone == "mobilenet_v3_small":
|
||||
weights = models.MobileNet_V3_Small_Weights.DEFAULT if pretrained else None
|
||||
model = models.mobilenet_v3_small(weights=weights)
|
||||
elif backbone == "mobilenet_v3_large":
|
||||
weights = models.MobileNet_V3_Large_Weights.DEFAULT if pretrained else None
|
||||
model = models.mobilenet_v3_large(weights=weights)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MobileNet backbone: {backbone!r}. Supported: mobilenet_v3_small, mobilenet_v3_large")
|
||||
|
||||
in_features = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(in_features, 1)
|
||||
return model
|
||||
|
||||
|
||||
register("mobilenet_v3_small", build)
|
||||
register("mobilenet_v3_large", build)
|
||||
Reference in New Issue
Block a user