diff --git a/src/helpers.py b/src/helpers.py index 40a3f9762cd649651631e45d94b78c19562b9749..172a8dd3cf04a15e9069347dac7f06c6d2d8ed60 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -5,16 +5,19 @@ __date__ = '2019-10-21' import logging -import keras -import keras.backend as K import math -from typing import Union -import numpy as np import os import time import socket import datetime as dt +import keras +import keras.backend as K +import numpy as np +import xarray as xr + +from typing import Union, Dict, Callable + def to_list(arg): if not isinstance(arg, list): @@ -197,3 +200,35 @@ class PyTestRegex: def __repr__(self) -> str: return self._regex.pattern + + +def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray: + """ + Convert a dictionary of 2D-xarrays to single 3D-xarray. The name of new coordinate axis follows <coordinate_name>. + :param d: dictionary with 2D-xarrays + :param coordinate_name: name of the new created axis (2D -> 3D) + :return: combined xarray + """ + xarray = None + for k, v in d.items(): + if xarray is None: + xarray = v + xarray.coords[coordinate_name] = k + else: + tmp_xarray = v + tmp_xarray.coords[coordinate_name] = k + xarray = xr.concat([xarray, tmp_xarray], coordinate_name) + return xarray + + +def float_round(number: float, decimals: int = 0, round_type: Callable = math.ceil) -> float: + """ + Perform given rounding operation on number with the precision of decimals. + :param number: the number to round + :param decimals: numbers of decimals of the rounding operations (default 0 -> round to next integer value) + :param round_type: the actual rounding operation. Can be any callable function like math.ceil, math.floor or python + built-in round operation. + :return: rounded number with desired precision + """ + multiplier = 10. ** decimals + return round_type(number * multiplier) / multiplier diff --git a/test/test_helpers.py b/test/test_helpers.py index ce5d28a63d63dc4a793e6e07c60f95cb411ae97e..e98a46fad6365a3a05ab28c9d118a119e35ff86a 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -189,3 +189,55 @@ class TestSetExperimentName: def test_set_experiment_from_sys(self): exp_name, _ = set_experiment_name(experiment_date="2019-11-14") assert exp_name == "2019-11-14_network" + + +class TestPytestRegex: + + @pytest.fixture + def regex(self): + return PyTestRegex("teststring") + + def test_pytest_regex_init(self, regex): + assert regex._regex.pattern == "teststring" + + def test_pytest_regex_eq(self, regex): + assert regex == "teststringabcd" + assert regex != "teststgabcd" + + def test_pytest_regex_repr(self, regex): + assert regex.__repr__() == "teststring" + + +class TestDictToXarray: + + def test_dict_to_xarray(self): + array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) + array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) + d = {"number1": array1, "number2": array2} + res = dict_to_xarray(d, "merge_dim") + assert type(res) == xr.DataArray + assert sorted(list(res.coords)) == ["merge_dim", "x"] + assert res.shape == (2, 2, 3) + + +class TestFloatRound: + + def test_float_round_ceil(self): + assert float_round(4.6) == 5 + assert float_round(239.3992) == 240 + + def test_float_round_decimals(self): + assert float_round(23.0091, 2) == 23.01 + assert float_round(23.1091, 3) == 23.11 + + def test_float_round_type(self): + assert float_round(34.9221, 2, math.floor) == 34.92 + assert float_round(34.9221, 0, math.floor) == 34. + assert float_round(34.9221, 2, round) == 34.92 + assert float_round(34.9221, 0, round) == 35. + + def test_float_round_negative(self): + assert float_round(-34.9221, 2, math.floor) == -34.93 + assert float_round(-34.9221, 0, math.floor) == -35. + assert float_round(-34.9221, 2) == -34.92 + assert float_round(-34.9221, 0) == -34.