Skip to content
Snippets Groups Projects
Commit 11e155dd authored by leufen1's avatar leufen1
Browse files

first implementation

parent 3ce4920f
Branches
Tags
4 merge requests!468first implementation of toar-data-v2, can load data (but cannot process these...,!467Resolve "release v2.2.0",!461Merge Dev into issue400,!459Resolve "improve set keras generator speed"
Pipeline #106110 passed
......@@ -8,10 +8,12 @@ import numpy as np
import math
import os
import shutil
import pickle
import psutil
import multiprocessing
import logging
import dill
from typing import Tuple, List
from mlair.helpers import TimeTrackingWrapper
class StandardIterator(Iterator):
......@@ -75,7 +77,7 @@ class DataCollection(Iterable):
class KerasIterator(keras.utils.Sequence):
def __init__(self, collection: DataCollection, batch_size: int, batch_path: str, shuffle_batches: bool = False,
model=None, upsampling=False, name=None):
model=None, upsampling=False, name=None, use_multiprocessing=False, max_number_multiprocessing=1):
self._collection = collection
batch_path = os.path.join(batch_path, str(name if name is not None else id(self)))
self._path = os.path.join(batch_path, "%i.pickle")
......@@ -85,7 +87,7 @@ class KerasIterator(keras.utils.Sequence):
self.upsampling = upsampling
self.indexes: list = []
self._cleanup_path(batch_path)
self._prepare_batches()
self._prepare_batches(use_multiprocessing, max_number_multiprocessing)
def __len__(self) -> int:
return len(self.indexes)
......@@ -130,7 +132,8 @@ class KerasIterator(keras.utils.Sequence):
Y = list(map(lambda x: x[p], Y))
return X, Y
def _prepare_batches(self) -> None:
@TimeTrackingWrapper
def _prepare_batches(self, use_multiprocessing=False, max_process=1) -> None:
"""
Prepare all batches as locally stored files.
......@@ -142,6 +145,12 @@ class KerasIterator(keras.utils.Sequence):
index = 0
remaining = None
mod_rank = self._get_model_rank()
# max_process = 12
n_process = min([psutil.cpu_count(logical=False), len(self._collection), max_process]) # use only physical cpus
if n_process > 1 and use_multiprocessing is True: # parallel solution
pool = multiprocessing.Pool(n_process)
else:
pool = None
for data in self._collection:
logging.debug(f"prepare batches for {str(data)}")
X, _Y = data.get_data(upsampling=self.upsampling)
......@@ -152,18 +161,28 @@ class KerasIterator(keras.utils.Sequence):
X, Y = self._concatenate(X, remaining[0]), self._concatenate(Y, remaining[1])
length = X[0].shape[0]
batches = self._get_number_of_mini_batches(length)
output = []
for b in range(batches):
batch_X, batch_Y = self._get_batch(X, b), self._get_batch(Y, b)
self._save_to_pickle(X=batch_X, Y=batch_Y, index=index)
if pool is None:
output.append(f_proc_keras_gen(X, Y, b, self.batch_size, index, self._path))
else:
output.append(pool.apply_async(f_proc_keras_gen, args=(X, Y, b, self.batch_size, index, self._path)))
index += 1
if pool is not None:
[p.get() for p in output]
if (batches * self.batch_size) < length: # keep remaining to concatenate with next data element
remaining = (self._get_batch(X, batches), self._get_batch(Y, batches))
remaining = (get_batch(X, batches, self.batch_size), get_batch(Y, batches, self.batch_size))
# remaining = (self._get_batch(X, batches), self._get_batch(Y, batches))
else:
remaining = None
if remaining is not None: # add remaining as smaller batch
self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index)
# self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index)
save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index)
index += 1
self.indexes = np.arange(0, index).tolist()
if pool is not None:
pool.close()
pool.join()
def _save_to_pickle(self, X: List[np.ndarray], Y: List[np.ndarray], index: int) -> None:
"""Save data as pickle file with variables X and Y and given index as <index>.pickle ."""
......@@ -188,3 +207,21 @@ class KerasIterator(keras.utils.Sequence):
"""Randomly shuffle indexes if enabled."""
if self.shuffle is True:
np.random.shuffle(self.indexes)
def f_proc_keras_gen(X, Y, batch_number, batch_size, index, path):
batch_X, batch_Y = get_batch(X, batch_number, batch_size), get_batch(Y, batch_number, batch_size)
save_to_pickle(path, X=batch_X, Y=batch_Y, index=index)
def save_to_pickle(path, X: List[np.ndarray], Y: List[np.ndarray], index: int) -> None:
"""Save data as pickle file with variables X and Y and given index as <index>.pickle ."""
data = {"X": X, "Y": Y}
file = path % index
with open(file, "wb") as f:
dill.dump(data, f)
def get_batch(data_list: List[np.ndarray], b: int, batch_size: int) -> List[np.ndarray]:
"""Get batch according to batch size from data list."""
return list(map(lambda data: data[b * batch_size:(b + 1) * batch_size, ...], data_list))
......@@ -109,7 +109,8 @@ class Training(RunEnvironment):
:param mode: name of set, should be from ["train", "val", "test"]
"""
collection = self.data_store.get("data_collection", mode)
kwargs = self.data_store.create_args_dict(["upsampling", "shuffle_batches", "batch_path"], scope=mode)
kwargs = self.data_store.create_args_dict(["upsampling", "shuffle_batches", "batch_path", "use_multiprocessing", "max_number_multiprocessing"], scope=mode)
setattr(self, f"{mode}_set", KerasIterator(collection, self.batch_size, model=self.model, name=mode, **kwargs))
def set_generators(self) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment