Skip to content
Snippets Groups Projects
Commit a7ac1749 authored by leufen1's avatar leufen1
Browse files

updated tests

parent 1320ef1f
Branches
Tags
4 merge requests!468first implementation of toar-data-v2, can load data (but cannot process these...,!467Resolve "release v2.2.0",!461Merge Dev into issue400,!459Resolve "improve set keras generator speed"
Pipeline #106816 canceled
...@@ -22,6 +22,9 @@ class AbstractDataHandler(object): ...@@ -22,6 +22,9 @@ class AbstractDataHandler(object):
"""Return initialised class.""" """Return initialised class."""
return cls(*args, **kwargs) return cls(*args, **kwargs)
def __len__(self, upsampling=False):
raise NotImplementedError
@classmethod @classmethod
def requirements(cls, skip_args=None): def requirements(cls, skip_args=None):
"""Return requirements and own arguments without duplicates.""" """Return requirements and own arguments without duplicates."""
......
...@@ -55,6 +55,8 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -55,6 +55,8 @@ class DefaultDataHandler(AbstractDataHandler):
self._X_extreme = None self._X_extreme = None
self._Y_extreme = None self._Y_extreme = None
self._data_intersection = None self._data_intersection = None
self._len = None
self._len_upsampling = None
self._use_multiprocessing = use_multiprocessing self._use_multiprocessing = use_multiprocessing
self._max_number_multiprocessing = max_number_multiprocessing self._max_number_multiprocessing = max_number_multiprocessing
_name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self)) _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self))
...@@ -134,9 +136,11 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -134,9 +136,11 @@ class DefaultDataHandler(AbstractDataHandler):
def __repr__(self): def __repr__(self):
return str(self._collection[0]) return str(self._collection[0])
def __len__(self): def __len__(self, upsampling=False):
if self._data_intersection is not None: if upsampling is False:
return len(self._data_intersection) return self._len
else:
return self._len_upsampling
def get_X_original(self): def get_X_original(self):
X = [] X = []
...@@ -178,6 +182,7 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -178,6 +182,7 @@ class DefaultDataHandler(AbstractDataHandler):
Y = Y_original.sel({dim: intersect}) Y = Y_original.sel({dim: intersect})
self._data_intersection = intersect self._data_intersection = intersect
self._X, self._Y = X, Y self._X, self._Y = X, Y
self._len = len(self._data_intersection)
def get_observation(self): def get_observation(self):
dim = self.time_dim dim = self.time_dim
...@@ -212,6 +217,7 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -212,6 +217,7 @@ class DefaultDataHandler(AbstractDataHandler):
if extreme_values is None: if extreme_values is None:
logging.debug(f"No extreme values given, skip multiply extremes") logging.debug(f"No extreme values given, skip multiply extremes")
self._X_extreme, self._Y_extreme = self._X, self._Y self._X_extreme, self._Y_extreme = self._X, self._Y
self._len_upsampling = self._len
return return
# check type if inputs # check type if inputs
...@@ -247,6 +253,7 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -247,6 +253,7 @@ class DefaultDataHandler(AbstractDataHandler):
self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim) self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X)) self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
self._len_upsampling = len(self._X_extreme[0].coords[dim])
@staticmethod @staticmethod
def _add_timedelta(data, dim, timedelta): def _add_timedelta(data, dim, timedelta):
......
...@@ -86,7 +86,7 @@ class KerasIterator(keras.utils.Sequence): ...@@ -86,7 +86,7 @@ class KerasIterator(keras.utils.Sequence):
self.upsampling = upsampling self.upsampling = upsampling
self.indexes: list = [] self.indexes: list = []
self._cleanup_path(batch_path) self._cleanup_path(batch_path)
self._prepare_batches_parallel(use_multiprocessing, max_number_multiprocessing) self._prepare_batches(use_multiprocessing, max_number_multiprocessing)
def __len__(self) -> int: def __len__(self) -> int:
return len(self.indexes) return len(self.indexes)
...@@ -160,7 +160,7 @@ class KerasIterator(keras.utils.Sequence): ...@@ -160,7 +160,7 @@ class KerasIterator(keras.utils.Sequence):
index += 1 index += 1
self.indexes = np.arange(0, index).tolist() self.indexes = np.arange(0, index).tolist()
def _prepare_batches_parallel(self, use_multiprocessing=False, max_process=1) -> None: def _prepare_batches(self, use_multiprocessing=False, max_process=1) -> None:
""" """
Prepare all batches as locally stored files. Prepare all batches as locally stored files.
...@@ -183,7 +183,7 @@ class KerasIterator(keras.utils.Sequence): ...@@ -183,7 +183,7 @@ class KerasIterator(keras.utils.Sequence):
pool = None pool = None
output = None output = None
for data in self._collection: for data in self._collection:
length = len(data) length = data.__len__(self.upsampling)
batches = _get_number_of_mini_batches(length, self.batch_size) batches = _get_number_of_mini_batches(length, self.batch_size)
if pool is None: if pool is None:
res = f_proc(data, self.upsampling, mod_rank, self.batch_size, self._path, index) res = f_proc(data, self.upsampling, mod_rank, self.batch_size, self._path, index)
...@@ -209,7 +209,6 @@ class KerasIterator(keras.utils.Sequence): ...@@ -209,7 +209,6 @@ class KerasIterator(keras.utils.Sequence):
_save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index) _save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index)
index += 1 index += 1
self.indexes = np.arange(0, index).tolist() self.indexes = np.arange(0, index).tolist()
logging.warning(f"hightst index is {index}")
if pool is not None: if pool is not None:
pool.join() pool.join()
......
...@@ -102,6 +102,7 @@ class Training(RunEnvironment): ...@@ -102,6 +102,7 @@ class Training(RunEnvironment):
""" """
self.model.make_predict_function() self.model.make_predict_function()
@TimeTrackingWrapper
def _set_gen(self, mode: str) -> None: def _set_gen(self, mode: str) -> None:
""" """
Set and distribute the generators for given mode regarding batch size. Set and distribute the generators for given mode regarding batch size.
......
from mlair.data_handler.iterator import DataCollection, StandardIterator, KerasIterator from mlair.data_handler.iterator import DataCollection, StandardIterator, KerasIterator
from mlair.data_handler.iterator import _get_number_of_mini_batches, _get_batch, _permute_data, _save_to_pickle, f_proc
from mlair.helpers.testing import PyTestAllEqual from mlair.helpers.testing import PyTestAllEqual
from mlair.model_modules.model_class import MyBranchedModel from mlair.model_modules.model_class import MyBranchedModel
from mlair.model_modules.fully_connected_networks import FCN_64_32_16 from mlair.model_modules.fully_connected_networks import FCN_64_32_16
...@@ -89,6 +90,14 @@ class DummyData: ...@@ -89,6 +90,14 @@ class DummyData:
def __init__(self, number_of_samples=np.random.randint(100, 150)): def __init__(self, number_of_samples=np.random.randint(100, 150)):
np.random.seed(45) np.random.seed(45)
self.number_of_samples = number_of_samples self.number_of_samples = number_of_samples
self._len = self.number_of_samples
self._len_upsampling = self.number_of_samples
def __len__(self, upsampling=False):
if upsampling is False:
return self._len
else:
return self._len_upsampling
def get_X(self, upsampling=False, as_numpy=True): def get_X(self, upsampling=False, as_numpy=True):
np.random.seed(45) np.random.seed(45)
...@@ -152,13 +161,6 @@ class TestKerasIterator: ...@@ -152,13 +161,6 @@ class TestKerasIterator:
iterator._cleanup_path(path, create_new=False) iterator._cleanup_path(path, create_new=False)
assert os.path.exists(path) is False assert os.path.exists(path) is False
def test_get_number_of_mini_batches(self):
iterator = object.__new__(KerasIterator)
iterator.batch_size = 36
assert iterator._get_number_of_mini_batches(30) == 0
assert iterator._get_number_of_mini_batches(40) == 1
assert iterator._get_number_of_mini_batches(72) == 2
def test_len(self): def test_len(self):
iterator = object.__new__(KerasIterator) iterator = object.__new__(KerasIterator)
iterator.indexes = [0, 1, 2, 3, 4, 5] iterator.indexes = [0, 1, 2, 3, 4, 5]
...@@ -175,25 +177,6 @@ class TestKerasIterator: ...@@ -175,25 +177,6 @@ class TestKerasIterator:
for i in range(3): for i in range(3):
assert PyTestAllEqual([new_arr[i], test_arr[i]]) assert PyTestAllEqual([new_arr[i], test_arr[i]])
def test_get_batch(self):
arr = DummyData(20).get_X()
iterator = object.__new__(KerasIterator)
iterator.batch_size = 19
batch1 = iterator._get_batch(arr, 0)
assert batch1[0].shape[0] == 19
batch2 = iterator._get_batch(arr, 1)
assert batch2[0].shape[0] == 1
def test_save_to_pickle(self, path):
os.makedirs(path)
d = DummyData(20)
X, Y = d.get_X(), d.get_Y()
iterator = object.__new__(KerasIterator)
iterator._path = os.path.join(path, "%i.pickle")
assert os.path.exists(iterator._path % 2) is False
iterator._save_to_pickle(X=X, Y=Y, index=2)
assert os.path.exists(iterator._path % 2) is True
def test_prepare_batches(self, collection, path): def test_prepare_batches(self, collection, path):
iterator = object.__new__(KerasIterator) iterator = object.__new__(KerasIterator)
iterator._collection = collection iterator._collection = collection
...@@ -292,14 +275,112 @@ class TestKerasIterator: ...@@ -292,14 +275,112 @@ class TestKerasIterator:
with pytest.raises(TypeError): with pytest.raises(TypeError):
iterator._get_model_rank() iterator._get_model_rank()
def test_permute(self):
iterator = object.__new__(KerasIterator) class TestGetNumberOfMiniBatches:
def test_get_number_of_mini_batches(self):
batch_size = 36
assert _get_number_of_mini_batches(30, batch_size) == 0
assert _get_number_of_mini_batches(40, batch_size) == 1
assert _get_number_of_mini_batches(72, batch_size) == 2
class TestGetBatch:
def test_get_batch(self):
arr = DummyData(20).get_X()
batch_size = 19
batch1 = _get_batch(arr, 0, batch_size)
assert batch1[0].shape[0] == 19
batch2 = _get_batch(arr, 1, batch_size)
assert batch2[0].shape[0] == 1
class TestSaveToPickle:
@pytest.fixture
def path(self):
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None
yield p
shutil.rmtree(p, ignore_errors=True)
def test_save_to_pickle(self, path):
os.makedirs(path)
d = DummyData(20)
X, Y = d.get_X(), d.get_Y()
_path = os.path.join(path, "%i.pickle")
assert os.path.exists(_path % 2) is False
_save_to_pickle(_path, X=X, Y=Y, index=2)
assert os.path.exists(_path % 2) is True
class TestPermuteData:
def test_permute_data(self):
X = [np.array([[1, 2, 3, 4], X = [np.array([[1, 2, 3, 4],
[1.1, 2.1, 3.1, 4.1], [1.1, 2.1, 3.1, 4.1],
[1.2, 2.2, 3.2, 4.2]], dtype="f2")] [1.2, 2.2, 3.2, 4.2]], dtype="f2")]
Y = [np.array([1, 2, 3])] Y = [np.array([1, 2, 3])]
X_p, Y_p = iterator._permute_data(X, Y) X_p, Y_p = _permute_data(X, Y)
assert X_p[0].shape == X[0].shape assert X_p[0].shape == X[0].shape
assert Y_p[0].shape == Y[0].shape assert Y_p[0].shape == Y[0].shape
assert np.testing.assert_almost_equal(X_p[0].sum(), X[0].sum(), 2) is None assert np.testing.assert_almost_equal(X_p[0].sum(), X[0].sum(), 2) is None
assert np.testing.assert_almost_equal(Y_p[0].sum(), Y[0].sum(), 2) is None assert np.testing.assert_almost_equal(Y_p[0].sum(), Y[0].sum(), 2) is None
class TestFProc:
@pytest.fixture
def collection(self):
coll = []
for i in range(3):
coll.append(DummyData(50 + i))
data_coll = DataCollection(collection=coll)
return data_coll
@pytest.fixture
def collection_small(self):
coll = []
for i in range(3):
coll.append(DummyData(5 + i))
data_coll = DataCollection(collection=coll)
return data_coll
@pytest.fixture
def path(self):
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None
os.makedirs(p)
yield p
shutil.rmtree(p, ignore_errors=True)
def test_f_proc(self, collection, path):
data = collection[0]
upsampling = False
mod_rank = 2
batch_size = 32
remaining = f_proc(data, upsampling, mod_rank, batch_size, os.path.join(path, "%i.pickle"), 0)
assert isinstance(remaining, tuple)
assert len(remaining) == 2
assert isinstance(remaining[0], list)
assert len(remaining[0]) == 3
assert remaining[0][0].shape == (18, 14, 5)
def test_f_proc_no_remaining(self, collection, path):
data = collection[0]
upsampling = False
mod_rank = 2
batch_size = 50
remaining = f_proc(data, upsampling, mod_rank, batch_size, os.path.join(path, "%i.pickle"), 0)
assert remaining is None
def test_f_proc_X_Y(self, collection, path):
data = collection[0]
X, Y = data.get_data()
upsamling = False
mod_rank = 2
batch_size = 40
remaining = f_proc((X, Y), upsamling, mod_rank, batch_size, os.path.join(path, "%i.pickle"), 0)
assert remaining[0][0].shape == (10, 14, 5)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment