diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py index 83090d6f5f1a79b48e37cf8c38fd4bbd804031a4..1a3268206b33c4cae583b12102f500f854b82467 100644 --- a/mlair/data_handler/iterator.py +++ b/mlair/data_handler/iterator.py @@ -87,7 +87,8 @@ class KerasIterator(keras.utils.Sequence): self.upsampling = upsampling self.indexes: list = [] self._cleanup_path(batch_path) - self._prepare_batches(use_multiprocessing, max_number_multiprocessing) + self._prepare_batches_parallel(use_multiprocessing, max_number_multiprocessing) + self._prepare_batches(False, max_number_multiprocessing) def __len__(self) -> int: return len(self.indexes) @@ -121,6 +122,11 @@ class KerasIterator(keras.utils.Sequence): """Concatenate two lists of data along axis=0.""" return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new)) + @staticmethod + def _concatenate_multi(*args: List[np.ndarray]) -> List[np.ndarray]: + """Concatenate two lists of data along axis=0.""" + return list(map(lambda *_args: np.concatenate(_args, axis=0), *args)) + def _get_batch(self, data_list: List[np.ndarray], b: int) -> List[np.ndarray]: """Get batch according to batch size from data list.""" return list(map(lambda data: data[b * self.batch_size:(b + 1) * self.batch_size, ...], data_list)) @@ -132,6 +138,51 @@ class KerasIterator(keras.utils.Sequence): Y = list(map(lambda x: x[p], Y)) return X, Y + @TimeTrackingWrapper + def _prepare_batches_parallel(self, use_multiprocessing=False, max_process=1) -> None: + index = 0 + remaining = [] + mod_rank = self._get_model_rank() + # max_process = 12 + n_process = min([psutil.cpu_count(logical=False), len(self._collection), max_process]) # use only physical cpus + if n_process > 1 and use_multiprocessing is True: # parallel solution + pool = multiprocessing.Pool(n_process) + output = [] + else: + pool = None + output = None + for data in self._collection: + X, _Y = data.get_data(upsampling=self.upsampling) + length = X[0].shape[0] + batches = _get_number_of_mini_batches(length, self.batch_size) + if pool is None: + res = f_proc(X, _Y, self.upsampling, mod_rank, self.batch_size, self._path, index) + if res is not None: + remaining.append(res) + else: + output.append(pool.apply_async(f_proc, args=(X, _Y, self.upsampling, mod_rank, self.batch_size, self._path, index))) + index += batches + if output is not None: + for p in output: + res = p.get() + if res is not None: + remaining.append(res) + pool.close() + if len(remaining) > 0: + X = self._concatenate_multi(*[e[0] for e in remaining]) + Y = self._concatenate_multi(*[e[1] for e in remaining]) + length = X[0].shape[0] + batches = _get_number_of_mini_batches(length, self.batch_size) + remaining = f_proc(X, Y, self.upsampling, mod_rank, self.batch_size, self._path, index) + index += batches + if remaining is not None: + save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index) + index += 1 + self.indexes = np.arange(0, index).tolist() + logging.warning(f"hightst index is {index}") + if pool is not None: + pool.join() + @TimeTrackingWrapper def _prepare_batches(self, use_multiprocessing=False, max_process=1) -> None: """ @@ -180,6 +231,7 @@ class KerasIterator(keras.utils.Sequence): save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index) index += 1 self.indexes = np.arange(0, index).tolist() + logging.warning(f"hightst index is {index}") if pool is not None: pool.close() pool.join() @@ -225,3 +277,30 @@ def save_to_pickle(path, X: List[np.ndarray], Y: List[np.ndarray], index: int) - def get_batch(data_list: List[np.ndarray], b: int, batch_size: int) -> List[np.ndarray]: """Get batch according to batch size from data list.""" return list(map(lambda data: data[b * batch_size:(b + 1) * batch_size, ...], data_list)) + + +def _permute_data(X, Y): + p = np.random.permutation(len(X[0])) # equiv to .shape[0] + X = list(map(lambda x: x[p], X)) + Y = list(map(lambda x: x[p], Y)) + return X, Y + +def _get_number_of_mini_batches(number_of_samples: int, batch_size: int) -> int: + """Return number of mini batches as the floored ration of number of samples to batch size.""" + return math.floor(number_of_samples / batch_size) + + +def f_proc(X, _Y, upsampling, mod_rank, batch_size, _path, index): + Y = [_Y[0] for _ in range(mod_rank)] + if upsampling: + X, Y = _permute_data(X, Y) + length = X[0].shape[0] + batches = _get_number_of_mini_batches(length, batch_size) + for b in range(batches): + f_proc_keras_gen(X, Y, b, batch_size, index, _path) + index += 1 + if (batches * batch_size) < length: # keep remaining to concatenate with next data element + remaining = (get_batch(X, batches, batch_size), get_batch(Y, batches, batch_size)) + else: + remaining = None + return remaining