Clean state
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Tests for preprocessing transforms: eval pipeline is deterministic and test-safe.
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from src.preprocessing.pipeline import get_transforms
|
||||
|
||||
|
||||
class TransformTests(unittest.TestCase):
|
||||
def test_eval_transform_is_deterministic(self):
|
||||
rng = np.random.RandomState(0)
|
||||
arr = (rng.rand(128, 128, 3) * 255).astype(np.uint8)
|
||||
img = Image.fromarray(arr, mode="RGB")
|
||||
tfm = get_transforms(train=False, image_size=64)
|
||||
a = tfm(img)
|
||||
b = tfm(img)
|
||||
self.assertEqual(tuple(a.shape), (3, 64, 64))
|
||||
self.assertTrue(np.allclose(a.numpy(), b.numpy()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user