import torch.nn as nn from torchvision import models from src.models import register # ConvNeXt's classification head is a Sequential ending in (LayerNorm2d, Flatten, Linear); [-1] targets the Linear def build(cfg: dict) -> nn.Module: backbone = cfg.get("backbone", "convnext_tiny") pretrained = cfg.get("pretrained", True) if backbone == "convnext_tiny": weights = models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None model = models.convnext_tiny(weights=weights) elif backbone == "convnext_small": weights = models.ConvNeXt_Small_Weights.DEFAULT if pretrained else None model = models.convnext_small(weights=weights) else: raise ValueError(f"Unsupported ConvNeXt backbone: {backbone!r}. Supported: convnext_tiny, convnext_small") in_features = model.classifier[-1].in_features model.classifier[-1] = nn.Linear(in_features, 1) return model register("convnext_tiny", build) register("convnext_small", build)