Skip to content
Snippets Groups Projects
Select Git revision
  • Issue_61
  • develop default
  • 102-Utest
  • 96-method_i
  • 89-str_to_array
  • 91-output
  • master
  • 87-method-d-swallowed-obstacles
  • feature_trajectory_correction
  • Issue_63
  • refactor_SteadyState
  • v0.8.3
  • v0.8.2
  • v0.8.1
  • v0.8
  • v0.7
  • v0.6
17 results

PedData.h

Blame
  • test_iterator.py 11.24 KiB
    from mlair.data_handler.iterator import DataCollection, StandardIterator, KerasIterator
    from mlair.helpers.testing import PyTestAllEqual
    from mlair.model_modules.model_class import MyBranchedModel
    from mlair.model_modules.fully_connected_networks import FCN_64_32_16
    
    import numpy as np
    import pytest
    import mock
    import os
    import shutil
    
    
    class TestStandardIterator:
    
        @pytest.fixture
        def collection(self):
            return list(range(10))
    
        def test_blank(self):
            std_iterator = object.__new__(StandardIterator)
            assert std_iterator._position is None
    
        def test_init(self, collection):
            std_iterator = StandardIterator(collection)
            assert std_iterator._collection == list(range(10))
            assert std_iterator._position == 0
    
        def test_next(self, collection):
            std_iterator = StandardIterator(collection)
            for i in range(10):
                assert i == next(std_iterator)
            with pytest.raises(StopIteration):
                next(std_iterator)
            std_iterator = StandardIterator(collection)
            for e, i in enumerate(iter(std_iterator)):
                assert i == e
    
    
    class TestDataCollection:
    
        @pytest.fixture
        def collection(self):
            return list(range(10))
    
        def test_init(self, collection):
            data_collection = DataCollection(collection)
            assert data_collection._collection == collection
    
        def test_iter(self, collection):
            data_collection = DataCollection(collection)
            assert isinstance(iter(data_collection), StandardIterator)
            for e, i in enumerate(data_collection):
                assert i == e
    
        def test_add(self):
            data_collection = DataCollection()
            data_collection.add("first_element")
            assert len(data_collection) == 1
            assert data_collection["first_element"] == "first_element"
            assert data_collection[0] == "first_element"
    
        def test_name(self):
            data_collection = DataCollection(name="testcase")
            assert data_collection._name == "testcase"
            assert data_collection.name == "testcase"
    
        def test_set_mapping(self):
            data_collection = object.__new__(DataCollection)
            data_collection._collection = ["a", "b", "c"]
            data_collection._mapping = {}
            data_collection._set_mapping()
            assert data_collection._mapping == {"a": 0, "b": 1, "c": 2}
    
        def test_getitem(self):
            data_collection = DataCollection(["a", "b", "c"])
            assert data_collection["a"] == "a"
            assert data_collection[1] == "b"
    
        def test_keys(self):
            collection = ["a", "b", "c"]
            data_collection = DataCollection(collection)
            assert data_collection.keys() == collection
            data_collection.add("another")
            assert data_collection.keys() == collection + ["another"]
    
    
    class DummyData:
    
        def __init__(self, number_of_samples=np.random.randint(100, 150)):
            np.random.seed(45)
            self.number_of_samples = number_of_samples
    
        def get_X(self, upsampling=False, as_numpy=True):
            np.random.seed(45)
            X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5))  # samples, window, variables
            np.random.seed(45)
            X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2))  # samples, window, variables
            np.random.seed(45)
            X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2))  # samples, window, variables
            return [X1, X2, X3]
    
        def get_Y(self, upsampling=False, as_numpy=True):
            np.random.seed(45)
            Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1))  # samples, window, variables
            np.random.seed(45)
            Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1))  # samples, window, variables
            return [Y1, Y2]
    
    
    class TestKerasIterator:
    
        @pytest.fixture
        def collection(self):
            coll = []
            for i in range(3):
                coll.append(DummyData(50 + i))
            data_coll = DataCollection(collection=coll)
            return data_coll
    
        @pytest.fixture
        def collection_small(self):
            coll = []
            for i in range(3):
                coll.append(DummyData(5 + i))
            data_coll = DataCollection(collection=coll)
            return data_coll
    
        @pytest.fixture
        def path(self):
            p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
            shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None
            yield p
            shutil.rmtree(p, ignore_errors=True)
    
        def test_init(self, collection, path):
            iterator = KerasIterator(collection, 25, path)
            assert isinstance(iterator._collection, DataCollection)
            assert iterator._path == os.path.join(path, str(id(iterator)), "%i.pickle")
            assert iterator.batch_size == 25
            assert iterator.shuffle is False
    
        def test_cleanup_path(self, path):
            assert os.path.exists(path) is False
            iterator = object.__new__(KerasIterator)
            iterator._cleanup_path(path, create_new=False)
            assert os.path.exists(path) is False
            iterator._cleanup_path(path)
            assert os.path.exists(path) is True
            iterator._cleanup_path(path, create_new=False)
            assert os.path.exists(path) is False
    
        def test_get_number_of_mini_batches(self):
            iterator = object.__new__(KerasIterator)
            iterator.batch_size = 36
            assert iterator._get_number_of_mini_batches(30) == 0
            assert iterator._get_number_of_mini_batches(40) == 1
            assert iterator._get_number_of_mini_batches(72) == 2
    
        def test_len(self):
            iterator = object.__new__(KerasIterator)
            iterator.indexes = [0, 1, 2, 3, 4, 5]
            assert len(iterator) == 6
    
        def test_concatenate(self):
            arr1 = DummyData(10).get_X()
            arr2 = DummyData(50).get_X()
            iterator = object.__new__(KerasIterator)
            new_arr = iterator._concatenate(arr2, arr1)
            test_arr = [np.concatenate((arr1[0], arr2[0]), axis=0),
                        np.concatenate((arr1[1], arr2[1]), axis=0),
                        np.concatenate((arr1[2], arr2[2]), axis=0)]
            for i in range(3):
                assert PyTestAllEqual([new_arr[i], test_arr[i]])
    
        def test_get_batch(self):
            arr = DummyData(20).get_X()
            iterator = object.__new__(KerasIterator)
            iterator.batch_size = 19
            batch1 = iterator._get_batch(arr, 0)
            assert batch1[0].shape[0] == 19
            batch2 = iterator._get_batch(arr, 1)
            assert batch2[0].shape[0] == 1
    
        def test_save_to_pickle(self, path):
            os.makedirs(path)
            d = DummyData(20)
            X, Y = d.get_X(), d.get_Y()
            iterator = object.__new__(KerasIterator)
            iterator._path = os.path.join(path, "%i.pickle")
            assert os.path.exists(iterator._path % 2) is False
            iterator._save_to_pickle(X=X, Y=Y, index=2)
            assert os.path.exists(iterator._path % 2) is True
    
        def test_prepare_batches(self, collection, path):
            iterator = object.__new__(KerasIterator)
            iterator._collection = collection
            iterator.batch_size = 50
            iterator.indexes = []
            iterator.model = None
            iterator.upsampling = False
            iterator._path = os.path.join(path, "%i.pickle")
            os.makedirs(path)
            iterator._prepare_batches()
            assert len(os.listdir(path)) == 4
            assert len(iterator.indexes) == 4
            assert len(iterator) == 4
            assert iterator.indexes == [0, 1, 2, 3]
    
        def test_prepare_batches_upsampling(self, collection_small, path):
            iterator = object.__new__(KerasIterator)
            iterator._collection = collection_small
            iterator.batch_size = 100
            iterator.indexes = []
            iterator.model = None
            iterator.upsampling = False
            iterator._path = os.path.join(path, "%i.pickle")
            os.makedirs(path)
            iterator._prepare_batches()
            X1, Y1 = iterator[0]
            iterator.upsampling = True
            iterator._prepare_batches()
            X1p, Y1p = iterator[0]
            assert X1[0].shape == X1p[0].shape
            assert Y1[0].shape == Y1p[0].shape
            assert np.testing.assert_almost_equal(X1[0].sum(), X1p[0].sum(), 2) is None
            assert np.testing.assert_almost_equal(Y1[0].sum(), Y1p[0].sum(), 2) is None
            f = np.testing.assert_array_almost_equal
            assert np.testing.assert_raises(AssertionError, f, X1[0], X1p[0]) is None
    
        def test_prepare_batches_no_remaining(self, path):
            iterator = object.__new__(KerasIterator)
            iterator._collection = DataCollection([DummyData(50)])
            iterator.batch_size = 50
            iterator.indexes = []
            iterator.model = None
            iterator.upsampling = False
            iterator._path = os.path.join(path, "%i.pickle")
            os.makedirs(path)
            iterator._prepare_batches()
            assert len(os.listdir(path)) == 1
            assert len(iterator.indexes) == 1
            assert len(iterator) == 1
            assert iterator.indexes == [0]
    
        def test_data_generation(self, collection, path):
            iterator = KerasIterator(collection, 50, path)
            X, Y = iterator._KerasIterator__data_generation(0)
            expected = next(iter(collection))
            assert PyTestAllEqual([X, expected.get_X()])
            assert PyTestAllEqual([Y, expected.get_Y()])
    
        def test_getitem(self, collection, path):
            iterator = KerasIterator(collection, 50, path)
            X, Y = iterator[0]
            expected = next(iter(collection))
            assert PyTestAllEqual([X, expected.get_X()])
            assert PyTestAllEqual([Y, expected.get_Y()])
    
        def test_on_epoch_end(self):
            iterator = object.__new__(KerasIterator)
            iterator.indexes = [0, 1, 2, 3, 4]
            iterator.shuffle = False
            iterator.on_epoch_end()
            assert iterator.indexes == [0, 1, 2, 3, 4]
            iterator.shuffle = True
            while iterator.indexes == sorted(iterator.indexes):
                iterator.on_epoch_end()
            assert iterator.indexes != [0, 1, 2, 3, 4]
            assert sorted(iterator.indexes) == [0, 1, 2, 3, 4]
    
        def test_get_model_rank_no_model(self):
            iterator = object.__new__(KerasIterator)
            iterator.model = None
            assert iterator._get_model_rank() == 1
    
        def test_get_model_rank_single_output_branch(self):
            iterator = object.__new__(KerasIterator)
            iterator.model = FCN_64_32_16(input_shape=[(14, 1, 2)], output_shape=[(3,)])
            assert iterator._get_model_rank() == 1
    
        def test_get_model_rank_multiple_output_branch(self):
            iterator = object.__new__(KerasIterator)
            iterator.model = MyBranchedModel(input_shape=[(14, 1, 2)], output_shape=[(3,)])
            assert iterator._get_model_rank() == 3
    
        def test_get_model_rank_error(self):
            iterator = object.__new__(KerasIterator)
            iterator.model = mock.MagicMock(return_value=1)
            with pytest.raises(TypeError):
                iterator._get_model_rank()
    
        def test_permute(self):
            iterator = object.__new__(KerasIterator)
            X = [np.array([[1, 2, 3, 4],
                           [1.1, 2.1, 3.1, 4.1],
                           [1.2, 2.2, 3.2, 4.2]], dtype="f2")]
            Y = [np.array([1, 2, 3])]
            X_p, Y_p = iterator._permute_data(X, Y)
            assert X_p[0].shape == X[0].shape
            assert Y_p[0].shape == Y[0].shape
            assert np.testing.assert_almost_equal(X_p[0].sum(), X[0].sum(), 2) is None
            assert np.testing.assert_almost_equal(Y_p[0].sum(), Y[0].sum(), 2) is None