26 lines
706 B
Python
26 lines
706 B
Python
"""
|
|
Tests for preprocessing transforms: eval pipeline is deterministic and test-safe.
|
|
"""
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from src.preprocessing.pipeline import get_transforms
|
|
|
|
|
|
class TransformTests(unittest.TestCase):
|
|
def test_eval_transform_is_deterministic(self):
|
|
rng = np.random.RandomState(0)
|
|
arr = (rng.rand(128, 128, 3) * 255).astype(np.uint8)
|
|
img = Image.fromarray(arr, mode="RGB")
|
|
tfm = get_transforms(train=False, image_size=64)
|
|
a = tfm(img)
|
|
b = tfm(img)
|
|
self.assertEqual(tuple(a.shape), (3, 64, 64))
|
|
self.assertTrue(np.allclose(a.numpy(), b.numpy()))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|