23 lines
637 B
Python
23 lines
637 B
Python
"""
|
|
Tests for binary_metrics edge cases: single-class inputs return null AUC/F1.
|
|
"""
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from src.evaluation.metrics import binary_metrics
|
|
|
|
|
|
class OneClassMetricTests(unittest.TestCase):
|
|
def test_one_class_returns_none_for_auc_and_f1(self):
|
|
logits = torch.tensor([0.1, -0.2, 0.3], dtype=torch.float32)
|
|
labels = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32)
|
|
metrics = binary_metrics(logits, labels)
|
|
self.assertIsNone(metrics["auc_roc"])
|
|
self.assertIsNone(metrics["f1"])
|
|
self.assertIn("accuracy", metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|