Phase 3 classifier
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
from src.models import register
|
||||
|
||||
|
||||
# ConvNeXt's classification head is a Sequential ending in (LayerNorm2d, Flatten, Linear); [-1] targets the Linear
|
||||
def build(cfg: dict) -> nn.Module:
|
||||
backbone = cfg.get("backbone", "convnext_tiny")
|
||||
pretrained = cfg.get("pretrained", True)
|
||||
|
||||
if backbone == "convnext_tiny":
|
||||
weights = models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None
|
||||
model = models.convnext_tiny(weights=weights)
|
||||
elif backbone == "convnext_small":
|
||||
weights = models.ConvNeXt_Small_Weights.DEFAULT if pretrained else None
|
||||
model = models.convnext_small(weights=weights)
|
||||
else:
|
||||
raise ValueError(f"Unsupported ConvNeXt backbone: {backbone!r}. Supported: convnext_tiny, convnext_small")
|
||||
|
||||
in_features = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(in_features, 1)
|
||||
return model
|
||||
|
||||
|
||||
register("convnext_tiny", build)
|
||||
register("convnext_small", build)
|
||||
Reference in New Issue
Block a user