Files
DRL_PROJ/classifier/src/models/convnext.py
T
Johnny Fernandes b1e0e61431 Phase 3 classifier
2026-05-05 00:36:37 +01:00

28 lines
995 B
Python

import torch.nn as nn
from torchvision import models
from src.models import register
# ConvNeXt's classification head is a Sequential ending in (LayerNorm2d, Flatten, Linear); [-1] targets the Linear
def build(cfg: dict) -> nn.Module:
backbone = cfg.get("backbone", "convnext_tiny")
pretrained = cfg.get("pretrained", True)
if backbone == "convnext_tiny":
weights = models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None
model = models.convnext_tiny(weights=weights)
elif backbone == "convnext_small":
weights = models.ConvNeXt_Small_Weights.DEFAULT if pretrained else None
model = models.convnext_small(weights=weights)
else:
raise ValueError(f"Unsupported ConvNeXt backbone: {backbone!r}. Supported: convnext_tiny, convnext_small")
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 1)
return model
register("convnext_tiny", build)
register("convnext_small", build)