28 lines
995 B
Python
28 lines
995 B
Python
import torch.nn as nn
|
|
from torchvision import models
|
|
|
|
from src.models import register
|
|
|
|
|
|
# ConvNeXt's classification head is a Sequential ending in (LayerNorm2d, Flatten, Linear); [-1] targets the Linear
|
|
def build(cfg: dict) -> nn.Module:
|
|
backbone = cfg.get("backbone", "convnext_tiny")
|
|
pretrained = cfg.get("pretrained", True)
|
|
|
|
if backbone == "convnext_tiny":
|
|
weights = models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None
|
|
model = models.convnext_tiny(weights=weights)
|
|
elif backbone == "convnext_small":
|
|
weights = models.ConvNeXt_Small_Weights.DEFAULT if pretrained else None
|
|
model = models.convnext_small(weights=weights)
|
|
else:
|
|
raise ValueError(f"Unsupported ConvNeXt backbone: {backbone!r}. Supported: convnext_tiny, convnext_small")
|
|
|
|
in_features = model.classifier[-1].in_features
|
|
model.classifier[-1] = nn.Linear(in_features, 1)
|
|
return model
|
|
|
|
|
|
register("convnext_tiny", build)
|
|
register("convnext_small", build)
|