diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py index 18466fd9ede1e666f52b49fa461585a7e38410dd..30c45417a64e949b0c0535a96a20c933641fdcbb 100644 --- a/mlair/data_handler/iterator.py +++ b/mlair/data_handler/iterator.py @@ -37,7 +37,7 @@ class DataCollection(Iterable): if collection is None: collection = [] assert isinstance(collection, list) - self._collection = collection + self._collection = collection.copy() self._mapping = {} self._set_mapping() self._name = name @@ -119,9 +119,10 @@ class KerasIterator(keras.utils.Sequence): def _get_batch(self, data_list: List[np.ndarray], b: int) -> List[np.ndarray]: """Get batch according to batch size from data list.""" - return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list)) + return list(map(lambda data: data[b * self.batch_size:(b + 1) * self.batch_size, ...], data_list)) - def _permute_data(self, X, Y): + @staticmethod + def _permute_data(X, Y): p = np.random.permutation(len(X[0])) # equiv to .shape[0] X = list(map(lambda x: x[p], X)) Y = list(map(lambda x: x[p], Y)) @@ -184,35 +185,3 @@ class KerasIterator(keras.utils.Sequence): """Randomly shuffle indexes if enabled.""" if self.shuffle is True: np.random.shuffle(self.indexes) - - -class DummyData: # pragma: no cover - - def __init__(self, number_of_samples=np.random.randint(100, 150)): - self.number_of_samples = number_of_samples - - def get_X(self): - X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables - X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables - X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables - return [X1, X2, X3] - - def get_Y(self): - Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables - Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables - return [Y1, Y2] - - -if __name__ == "__main__": - - collection = [] - for _ in range(3): - collection.append(DummyData(50)) - - data_collection = DataCollection(collection=collection) - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") - iterator = KerasIterator(data_collection, 25, path, shuffle=True) - - for data in data_collection: - print(data) \ No newline at end of file diff --git a/test/test_data_handler/test_iterator.py b/test/test_data_handler/test_iterator.py index 678f3d369d4b6424f94557d7d739fc65a995aacc..2bd33cc3aeea6bc631323e3d75d0011baacabad3 100644 --- a/test/test_data_handler/test_iterator.py +++ b/test/test_data_handler/test_iterator.py @@ -59,20 +59,50 @@ class TestDataCollection: assert data_collection["first_element"] == "first_element" assert data_collection[0] == "first_element" + def test_name(self): + data_collection = DataCollection(name="testcase") + assert data_collection._name == "testcase" + assert data_collection.name == "testcase" + + def test_set_mapping(self): + data_collection = object.__new__(DataCollection) + data_collection._collection = ["a", "b", "c"] + data_collection._mapping = {} + data_collection._set_mapping() + assert data_collection._mapping == {"a": 0, "b": 1, "c": 2} + + def test_getitem(self): + data_collection = DataCollection(["a", "b", "c"]) + assert data_collection["a"] == "a" + assert data_collection[1] == "b" + + def test_keys(self): + collection = ["a", "b", "c"] + data_collection = DataCollection(collection) + assert data_collection.keys() == collection + data_collection.add("another") + assert data_collection.keys() == collection + ["another"] + class DummyData: def __init__(self, number_of_samples=np.random.randint(100, 150)): + np.random.seed(45) self.number_of_samples = number_of_samples def get_X(self, upsampling=False, as_numpy=True): + np.random.seed(45) X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables + np.random.seed(45) X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables + np.random.seed(45) X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables return [X1, X2, X3] def get_Y(self, upsampling=False, as_numpy=True): + np.random.seed(45) Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables + np.random.seed(45) Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables return [Y1, Y2] @@ -87,6 +117,14 @@ class TestKerasIterator: data_coll = DataCollection(collection=coll) return data_coll + @pytest.fixture + def collection_small(self): + coll = [] + for i in range(3): + coll.append(DummyData(5 + i)) + data_coll = DataCollection(collection=coll) + return data_coll + @pytest.fixture def path(self): p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") @@ -168,6 +206,27 @@ class TestKerasIterator: assert len(iterator) == 4 assert iterator.indexes == [0, 1, 2, 3] + def test_prepare_batches_upsampling(self, collection_small, path): + iterator = object.__new__(KerasIterator) + iterator._collection = collection_small + iterator.batch_size = 100 + iterator.indexes = [] + iterator.model = None + iterator.upsampling = False + iterator._path = os.path.join(path, "%i.pickle") + os.makedirs(path) + iterator._prepare_batches() + X1, Y1 = iterator[0] + iterator.upsampling = True + iterator._prepare_batches() + X1p, Y1p = iterator[0] + assert X1[0].shape == X1p[0].shape + assert Y1[0].shape == Y1p[0].shape + assert np.testing.assert_almost_equal(X1[0].sum(), X1p[0].sum(), 2) is None + assert np.testing.assert_almost_equal(Y1[0].sum(), Y1p[0].sum(), 2) is None + f = np.testing.assert_array_almost_equal + assert np.testing.assert_raises(AssertionError, f, X1[0], X1p[0]) is None + def test_prepare_batches_no_remaining(self, path): iterator = object.__new__(KerasIterator) iterator._collection = DataCollection([DummyData(50)]) @@ -233,3 +292,15 @@ class TestKerasIterator: iterator.model = mock.MagicMock(return_value=1) with pytest.raises(TypeError): iterator._get_model_rank() + + def test_permute(self): + iterator = object.__new__(KerasIterator) + X = [np.array([[1, 2, 3, 4], + [1.1, 2.1, 3.1, 4.1], + [1.2, 2.2, 3.2, 4.2]], dtype="f2")] + Y = [np.array([1, 2, 3])] + X_p, Y_p = iterator._permute_data(X, Y) + assert X_p[0].shape == X[0].shape + assert Y_p[0].shape == Y[0].shape + assert np.testing.assert_almost_equal(X_p[0].sum(), X[0].sum(), 2) is None + assert np.testing.assert_almost_equal(Y_p[0].sum(), Y[0].sum(), 2) is None