Skip to content
Snippets Groups Projects
Commit 0a7bcdb7 authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue345_store-transformation-locally-as-file' into 'develop'

Resolve "store transformation locally as file"

See merge request !370
parents 8578459c 399a383e
Branches
Tags
3 merge requests!413update release branch,!412Resolve "release v2.0.0",!370Resolve "store transformation locally as file"
Pipeline #85768 passed
sphinx==3.0.3 sphinx==3.0.3
sphinx-autoapi==1.3.0 sphinx-autoapi==1.8.4
sphinx-autodoc-typehints==1.10.3 sphinx-autodoc-typehints==1.12.0
sphinx-rtd-theme==0.4.3 sphinx-rtd-theme==0.4.3
#recommonmark==0.6.0 #recommonmark==0.6.0
m2r2==0.2.5 m2r2==0.3.1
docutils<0.18 docutils<0.18
mistune==0.8.4
setuptools>=59.5.0
\ No newline at end of file
...@@ -293,6 +293,7 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -293,6 +293,7 @@ class DefaultDataHandler(AbstractDataHandler):
transformation_dict = ({}, {}) transformation_dict = ({}, {})
max_process = kwargs.get("max_number_multiprocessing", 16) max_process = kwargs.get("max_number_multiprocessing", 16)
set_stations = to_list(set_stations)
n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus
if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution
logging.info("use parallel transformation approach") logging.info("use parallel transformation approach")
......
...@@ -86,3 +86,35 @@ def PyTestAllEqual(check_list: List): ...@@ -86,3 +86,35 @@ def PyTestAllEqual(check_list: List):
return self._check_all_equal() return self._check_all_equal()
return PyTestAllEqualClass(check_list).is_true() return PyTestAllEqualClass(check_list).is_true()
def test_nested_equality(obj1, obj2):
try:
print(f"check type {type(obj1)} and {type(obj2)}")
assert type(obj1) == type(obj2)
if isinstance(obj1, (tuple, list)):
print(f"check length {len(obj1)} and {len(obj2)}")
assert len(obj1) == len(obj2)
for pos in range(len(obj1)):
print(f"check pos {obj1[pos]} and {obj2[pos]}")
assert test_nested_equality(obj1[pos], obj2[pos]) is True
elif isinstance(obj1, dict):
print(f"check keys {obj1.keys()} and {obj2.keys()}")
assert sorted(obj1.keys()) == sorted(obj2.keys())
for k in obj1.keys():
print(f"check pos {obj1[k]} and {obj2[k]}")
assert test_nested_equality(obj1[k], obj2[k]) is True
elif isinstance(obj1, xr.DataArray):
print(f"check xr {obj1} and {obj2}")
assert xr.testing.assert_equal(obj1, obj2) is None
elif isinstance(obj1, np.ndarray):
print(f"check np {obj1} and {obj2}")
assert np.testing.assert_array_equal(obj1, obj2) is None
else:
print(f"check equal {obj1} and {obj2}")
assert obj1 == obj2
except AssertionError:
return False
return True
...@@ -187,6 +187,9 @@ class ExperimentSetup(RunEnvironment): ...@@ -187,6 +187,9 @@ class ExperimentSetup(RunEnvironment):
:param use_multiprocessing: Enable parallel preprocessing (postprocessing not implemented yet) by setting this :param use_multiprocessing: Enable parallel preprocessing (postprocessing not implemented yet) by setting this
parameter to `True` (default). If set to `False` the computation is performed in an serial approach. parameter to `True` (default). If set to `False` the computation is performed in an serial approach.
Multiprocessing is disabled when running in debug mode and cannot be switched on. Multiprocessing is disabled when running in debug mode and cannot be switched on.
:param transformation_file: Use transformation options from this file for transformation
:param calculate_fresh_transformation: can either be True or False, indicates if new transformation options should
be calculated in any case (transformation_file is not used in this case!).
""" """
...@@ -224,7 +227,8 @@ class ExperimentSetup(RunEnvironment): ...@@ -224,7 +227,8 @@ class ExperimentSetup(RunEnvironment):
max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None, max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None,
overwrite_lazy_data: bool = None, uncertainty_estimate_block_length: str = None, overwrite_lazy_data: bool = None, uncertainty_estimate_block_length: str = None,
uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int = None, uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int = None,
do_uncertainty_estimate: bool = None, model_display_name: str = None, **kwargs): do_uncertainty_estimate: bool = None, model_display_name: str = None, transformation_file: str = None,
calculate_fresh_transformation: bool = None, **kwargs):
# create run framework # create run framework
super().__init__() super().__init__()
...@@ -311,6 +315,9 @@ class ExperimentSetup(RunEnvironment): ...@@ -311,6 +315,9 @@ class ExperimentSetup(RunEnvironment):
scope="preprocessing") scope="preprocessing")
self._set_param("transformation", transformation, default={}) self._set_param("transformation", transformation, default={})
self._set_param("transformation", None, scope="preprocessing") self._set_param("transformation", None, scope="preprocessing")
self._set_param("transformation_file", transformation_file, default=None)
if calculate_fresh_transformation is not None:
self._set_param("calculate_fresh_transformation", calculate_fresh_transformation)
self._set_param("data_handler", data_handler, default=DefaultDataHandler) self._set_param("data_handler", data_handler, default=DefaultDataHandler)
# iter and window dimension # iter and window dimension
......
...@@ -295,12 +295,43 @@ class PreProcessing(RunEnvironment): ...@@ -295,12 +295,43 @@ class PreProcessing(RunEnvironment):
self.data_store.set(k, v) self.data_store.set(k, v)
def transformation(self, data_handler: AbstractDataHandler, stations): def transformation(self, data_handler: AbstractDataHandler, stations):
calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True)
if hasattr(data_handler, "transformation"): if hasattr(data_handler, "transformation"):
transformation_opts = None if calculate_fresh_transformation is True else self._load_transformation()
if transformation_opts is None:
logging.info(f"start to calculate transformation parameters.")
kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train") kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train")
tmp_path = self.data_store.get_default("tmp_path", default=None) tmp_path = self.data_store.get_default("tmp_path", default=None)
transformation_dict = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) transformation_opts = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs)
if transformation_dict is not None: else:
self.data_store.set("transformation", transformation_dict) logging.info("In case no valid train data could be found due to problems with transformation, please "
"check your provided transformation file for compability with your data.")
self.data_store.set("transformation", transformation_opts)
if transformation_opts is not None:
self._store_transformation(transformation_opts)
def _load_transformation(self):
"""Try to load transformation options from file if transformation_file is provided."""
transformation_file = self.data_store.get_default("transformation_file", None)
if transformation_file is not None:
if os.path.exists(transformation_file):
logging.info(f"use transformation from given transformation file: {transformation_file}")
with open(transformation_file, "rb") as pickle_file:
return dill.load(pickle_file)
else:
logging.info(f"cannot load transformation file: {transformation_file}. Use fresh calculation of "
f"transformation from train data.")
def _store_transformation(self, transformation_opts):
"""Store transformation options locally inside experiment_path if not exists already."""
experiment_path = self.data_store.get("experiment_path")
transformation_path = os.path.join(experiment_path, "data", "transformation")
transformation_file = os.path.join(transformation_path, "transformation.pickle")
if not os.path.exists(transformation_file):
path_config.check_path_and_create(transformation_path)
with open(transformation_file, "wb") as f:
dill.dump(transformation_opts, f, protocol=4)
logging.info(f"Store transformation options locally for later use at: {transformation_file}")
def prepare_competitors(self): def prepare_competitors(self):
""" """
......
from mlair.helpers.testing import PyTestRegex, PyTestAllEqual from mlair.helpers.testing import PyTestRegex, PyTestAllEqual, test_nested_equality
import re import re
import xarray as xr import xarray as xr
...@@ -46,3 +46,35 @@ class TestPyTestAllEqual: ...@@ -46,3 +46,35 @@ class TestPyTestAllEqual:
[xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])]]) [xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])]])
assert PyTestAllEqual([["test", "test2"], assert PyTestAllEqual([["test", "test2"],
["test", "test2"]]) ["test", "test2"]])
class TestNestedEquality:
def test_nested_equality_single_entries(self):
assert test_nested_equality(3, 3) is True
assert test_nested_equality(3.9, 3.9) is True
assert test_nested_equality(3.91, 3.9) is False
assert test_nested_equality("3", 3) is False
assert test_nested_equality("3", "3") is True
assert test_nested_equality(None, None) is True
def test_nested_equality_xarray(self):
obj1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
obj2 = xr.ones_like(obj1) * obj1
assert test_nested_equality(obj1, obj2) is True
def test_nested_equality_numpy(self):
obj1 = np.random.randn(2, 3)
obj2 = obj1 * 1
assert test_nested_equality(obj1, obj2) is True
def test_nested_equality_list_tuple(self):
assert test_nested_equality([3, 3], [3, 3]) is True
assert test_nested_equality((2, 6), (2, 6)) is True
assert test_nested_equality([3, 3.5], [3.5, 3]) is False
assert test_nested_equality([3, 3.5, 10], [3, 3.5]) is False
def test_nested_equality_dict(self):
assert test_nested_equality({"a": 3, "b": 10}, {"b": 10, "a": 3}) is True
assert test_nested_equality({"a": 3, "b": [10, 100]}, {"b": [10, 100], "a": 3}) is True
assert test_nested_equality({"a": 3, "b": 10, "c": "c"}, {"b": 10, "a": 3}) is False
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment