diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py index 9ea163fcad2890580e9c44e4bda0627d6419dc9f..a82e5005e8b30f9e3978ae61859e6b80746d95f1 100644 --- a/mlair/data_handler/abstract_data_handler.py +++ b/mlair/data_handler/abstract_data_handler.py @@ -22,6 +22,9 @@ class AbstractDataHandler(object): """Return initialised class.""" return cls(*args, **kwargs) + def __len__(self, upsampling=False): + raise NotImplementedError + @classmethod def requirements(cls, skip_args=None): """Return requirements and own arguments without duplicates.""" diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 06fd5772db2e18ff70792e8126da83ccfac46f82..69c9537b10ca583adf84480636680a99ab265a67 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -55,6 +55,8 @@ class DefaultDataHandler(AbstractDataHandler): self._X_extreme = None self._Y_extreme = None self._data_intersection = None + self._len = None + self._len_upsampling = None self._use_multiprocessing = use_multiprocessing self._max_number_multiprocessing = max_number_multiprocessing _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self)) @@ -134,9 +136,11 @@ class DefaultDataHandler(AbstractDataHandler): def __repr__(self): return str(self._collection[0]) - def __len__(self): - if self._data_intersection is not None: - return len(self._data_intersection) + def __len__(self, upsampling=False): + if upsampling is False: + return self._len + else: + return self._len_upsampling def get_X_original(self): X = [] @@ -178,6 +182,7 @@ class DefaultDataHandler(AbstractDataHandler): Y = Y_original.sel({dim: intersect}) self._data_intersection = intersect self._X, self._Y = X, Y + self._len = len(self._data_intersection) def get_observation(self): dim = self.time_dim @@ -212,6 +217,7 @@ class DefaultDataHandler(AbstractDataHandler): if extreme_values is None: logging.debug(f"No extreme values given, skip multiply extremes") self._X_extreme, self._Y_extreme = self._X, self._Y + self._len_upsampling = self._len return # check type if inputs @@ -247,6 +253,7 @@ class DefaultDataHandler(AbstractDataHandler): self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim) self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X)) + self._len_upsampling = len(self._X_extreme[0].coords[dim]) @staticmethod def _add_timedelta(data, dim, timedelta): diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py index b838626221b18c3f3b55ba15e43ae152d83ce5c6..cedb06409a99bf66f4a3a1e695de6059bcd9e143 100644 --- a/mlair/data_handler/iterator.py +++ b/mlair/data_handler/iterator.py @@ -86,7 +86,7 @@ class KerasIterator(keras.utils.Sequence): self.upsampling = upsampling self.indexes: list = [] self._cleanup_path(batch_path) - self._prepare_batches_parallel(use_multiprocessing, max_number_multiprocessing) + self._prepare_batches(use_multiprocessing, max_number_multiprocessing) def __len__(self) -> int: return len(self.indexes) @@ -160,7 +160,7 @@ class KerasIterator(keras.utils.Sequence): index += 1 self.indexes = np.arange(0, index).tolist() - def _prepare_batches_parallel(self, use_multiprocessing=False, max_process=1) -> None: + def _prepare_batches(self, use_multiprocessing=False, max_process=1) -> None: """ Prepare all batches as locally stored files. @@ -183,7 +183,7 @@ class KerasIterator(keras.utils.Sequence): pool = None output = None for data in self._collection: - length = len(data) + length = data.__len__(self.upsampling) batches = _get_number_of_mini_batches(length, self.batch_size) if pool is None: res = f_proc(data, self.upsampling, mod_rank, self.batch_size, self._path, index) @@ -209,7 +209,6 @@ class KerasIterator(keras.utils.Sequence): _save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index) index += 1 self.indexes = np.arange(0, index).tolist() - logging.warning(f"hightst index is {index}") if pool is not None: pool.join() diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index d23d68a823c03a82f88591df5c66e762889f8c93..0d7bb98f109b612cf3cffc3dc31541bb1733c541 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -102,6 +102,7 @@ class Training(RunEnvironment): """ self.model.make_predict_function() + @TimeTrackingWrapper def _set_gen(self, mode: str) -> None: """ Set and distribute the generators for given mode regarding batch size. diff --git a/test/test_data_handler/test_iterator.py b/test/test_data_handler/test_iterator.py index bb8ecb5d216519b3662a5baa4d463780b4c29d8c..fe740c094b41df7da45bd0f76d678830f95e1902 100644 --- a/test/test_data_handler/test_iterator.py +++ b/test/test_data_handler/test_iterator.py @@ -1,4 +1,5 @@ from mlair.data_handler.iterator import DataCollection, StandardIterator, KerasIterator +from mlair.data_handler.iterator import _get_number_of_mini_batches, _get_batch, _permute_data, _save_to_pickle, f_proc from mlair.helpers.testing import PyTestAllEqual from mlair.model_modules.model_class import MyBranchedModel from mlair.model_modules.fully_connected_networks import FCN_64_32_16 @@ -89,6 +90,14 @@ class DummyData: def __init__(self, number_of_samples=np.random.randint(100, 150)): np.random.seed(45) self.number_of_samples = number_of_samples + self._len = self.number_of_samples + self._len_upsampling = self.number_of_samples + + def __len__(self, upsampling=False): + if upsampling is False: + return self._len + else: + return self._len_upsampling def get_X(self, upsampling=False, as_numpy=True): np.random.seed(45) @@ -152,13 +161,6 @@ class TestKerasIterator: iterator._cleanup_path(path, create_new=False) assert os.path.exists(path) is False - def test_get_number_of_mini_batches(self): - iterator = object.__new__(KerasIterator) - iterator.batch_size = 36 - assert iterator._get_number_of_mini_batches(30) == 0 - assert iterator._get_number_of_mini_batches(40) == 1 - assert iterator._get_number_of_mini_batches(72) == 2 - def test_len(self): iterator = object.__new__(KerasIterator) iterator.indexes = [0, 1, 2, 3, 4, 5] @@ -175,25 +177,6 @@ class TestKerasIterator: for i in range(3): assert PyTestAllEqual([new_arr[i], test_arr[i]]) - def test_get_batch(self): - arr = DummyData(20).get_X() - iterator = object.__new__(KerasIterator) - iterator.batch_size = 19 - batch1 = iterator._get_batch(arr, 0) - assert batch1[0].shape[0] == 19 - batch2 = iterator._get_batch(arr, 1) - assert batch2[0].shape[0] == 1 - - def test_save_to_pickle(self, path): - os.makedirs(path) - d = DummyData(20) - X, Y = d.get_X(), d.get_Y() - iterator = object.__new__(KerasIterator) - iterator._path = os.path.join(path, "%i.pickle") - assert os.path.exists(iterator._path % 2) is False - iterator._save_to_pickle(X=X, Y=Y, index=2) - assert os.path.exists(iterator._path % 2) is True - def test_prepare_batches(self, collection, path): iterator = object.__new__(KerasIterator) iterator._collection = collection @@ -292,14 +275,112 @@ class TestKerasIterator: with pytest.raises(TypeError): iterator._get_model_rank() - def test_permute(self): - iterator = object.__new__(KerasIterator) + +class TestGetNumberOfMiniBatches: + + def test_get_number_of_mini_batches(self): + batch_size = 36 + assert _get_number_of_mini_batches(30, batch_size) == 0 + assert _get_number_of_mini_batches(40, batch_size) == 1 + assert _get_number_of_mini_batches(72, batch_size) == 2 + + +class TestGetBatch: + + def test_get_batch(self): + arr = DummyData(20).get_X() + batch_size = 19 + batch1 = _get_batch(arr, 0, batch_size) + assert batch1[0].shape[0] == 19 + batch2 = _get_batch(arr, 1, batch_size) + assert batch2[0].shape[0] == 1 + + +class TestSaveToPickle: + + @pytest.fixture + def path(self): + p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") + shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None + yield p + shutil.rmtree(p, ignore_errors=True) + + def test_save_to_pickle(self, path): + os.makedirs(path) + d = DummyData(20) + X, Y = d.get_X(), d.get_Y() + _path = os.path.join(path, "%i.pickle") + assert os.path.exists(_path % 2) is False + _save_to_pickle(_path, X=X, Y=Y, index=2) + assert os.path.exists(_path % 2) is True + + +class TestPermuteData: + + def test_permute_data(self): X = [np.array([[1, 2, 3, 4], [1.1, 2.1, 3.1, 4.1], [1.2, 2.2, 3.2, 4.2]], dtype="f2")] Y = [np.array([1, 2, 3])] - X_p, Y_p = iterator._permute_data(X, Y) + X_p, Y_p = _permute_data(X, Y) assert X_p[0].shape == X[0].shape assert Y_p[0].shape == Y[0].shape assert np.testing.assert_almost_equal(X_p[0].sum(), X[0].sum(), 2) is None assert np.testing.assert_almost_equal(Y_p[0].sum(), Y[0].sum(), 2) is None + + +class TestFProc: + + + @pytest.fixture + def collection(self): + coll = [] + for i in range(3): + coll.append(DummyData(50 + i)) + data_coll = DataCollection(collection=coll) + return data_coll + + @pytest.fixture + def collection_small(self): + coll = [] + for i in range(3): + coll.append(DummyData(5 + i)) + data_coll = DataCollection(collection=coll) + return data_coll + + @pytest.fixture + def path(self): + p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") + shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None + os.makedirs(p) + yield p + shutil.rmtree(p, ignore_errors=True) + + def test_f_proc(self, collection, path): + data = collection[0] + upsampling = False + mod_rank = 2 + batch_size = 32 + remaining = f_proc(data, upsampling, mod_rank, batch_size, os.path.join(path, "%i.pickle"), 0) + assert isinstance(remaining, tuple) + assert len(remaining) == 2 + assert isinstance(remaining[0], list) + assert len(remaining[0]) == 3 + assert remaining[0][0].shape == (18, 14, 5) + + def test_f_proc_no_remaining(self, collection, path): + data = collection[0] + upsampling = False + mod_rank = 2 + batch_size = 50 + remaining = f_proc(data, upsampling, mod_rank, batch_size, os.path.join(path, "%i.pickle"), 0) + assert remaining is None + + def test_f_proc_X_Y(self, collection, path): + data = collection[0] + X, Y = data.get_data() + upsamling = False + mod_rank = 2 + batch_size = 40 + remaining = f_proc((X, Y), upsamling, mod_rank, batch_size, os.path.join(path, "%i.pickle"), 0) + assert remaining[0][0].shape == (10, 14, 5)