Phase 3 classifier

This commit is contained in:
Johnny Fernandes
2026-05-05 00:36:37 +01:00
parent 799ec0c13a
commit b1e0e61431
10 changed files with 111 additions and 29 deletions
+27
View File
@@ -0,0 +1,27 @@
import torch.nn as nn
from torchvision import models
from src.models import register
# ConvNeXt's classification head is a Sequential ending in (LayerNorm2d, Flatten, Linear); [-1] targets the Linear
def build(cfg: dict) -> nn.Module:
backbone = cfg.get("backbone", "convnext_tiny")
pretrained = cfg.get("pretrained", True)
if backbone == "convnext_tiny":
weights = models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None
model = models.convnext_tiny(weights=weights)
elif backbone == "convnext_small":
weights = models.ConvNeXt_Small_Weights.DEFAULT if pretrained else None
model = models.convnext_small(weights=weights)
else:
raise ValueError(f"Unsupported ConvNeXt backbone: {backbone!r}. Supported: convnext_tiny, convnext_small")
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
register("convnext_tiny", build)
register("convnext_small", build)