import numpy as np
import xarray as xr
import datetime as dt
import logging
import math
import time
import mock
import pytest
from src.helpers import to_list, dict_to_xarray, float_round, remove_items
from src.helpers import PyTestRegex
from src.helpers import Logger, TimeTracking
class TestToList:
def test_to_list(self):
assert to_list('a') == ['a']
assert to_list('abcd') == ['abcd']
assert to_list([1, 2, 3]) == [1, 2, 3]
assert to_list([45]) == [45]
class TestTimeTracking:
def test_init(self):
t = TimeTracking()
assert t.start is not None
assert t.start < time.time()
assert t.end is None
t2 = TimeTracking(start=False)
assert t2.start is None
def test__start(self):
t = TimeTracking(start=False)
t._start()
assert t.start < time.time()
def test__end(self):
t = TimeTracking()
t._end()
assert t.end > t.start
def test__duration(self):
t = TimeTracking()
d1 = t._duration()
assert d1 > 0
d2 = t._duration()
assert d2 > d1
t._end()
d3 = t._duration()
assert d3 > d2
assert d3 == t._duration()
def test_repr(self):
t = TimeTracking()
t._end()
duration = t._duration()
assert t.__repr__().rstrip() == f"{dt.timedelta(seconds=math.ceil(duration))} (hh:mm:ss)".rstrip()
def test_run(self):
t = TimeTracking(start=False)
assert t.start is None
t.run()
assert t.start is not None
def test_stop(self):
t = TimeTracking()
assert t.end is None
duration = t.stop(get_duration=True)
assert duration == t._duration()
with pytest.raises(AssertionError) as e:
t.stop()
assert "Time was already stopped" in e.value.args[0]
t.run()
assert t.end is None
assert t.stop() is None
assert t.end is not None
def test_duration(self):
t = TimeTracking()
duration = t
assert duration is not None
duration = t.stop(get_duration=True)
assert duration == t.duration()
def test_enter_exit(self, caplog):
caplog.set_level(logging.INFO)
with TimeTracking() as t:
assert t.start is not None
assert t.end is None
expression = PyTestRegex(r"undefined job finished after \d+:\d+:\d+ \(hh:mm:ss\)")
assert caplog.record_tuples[-1] == ('root', 20, expression)
def test_name_enter_exit(self, caplog):
caplog.set_level(logging.INFO)
with TimeTracking(name="my job") as t:
assert t.start is not None
assert t.end is None
expression = PyTestRegex(r"my job finished after \d+:\d+:\d+ \(hh:mm:ss\)")
assert caplog.record_tuples[-1] == ('root', 20, expression)
class TestPytestRegex:
@pytest.fixture
def regex(self):
return PyTestRegex("teststring")
def test_pytest_regex_init(self, regex):
assert regex._regex.pattern == "teststring"
def test_pytest_regex_eq(self, regex):
assert regex == "teststringabcd"
assert regex != "teststgabcd"
def test_pytest_regex_repr(self, regex):
assert regex.__repr__() == "teststring"
class TestDictToXarray:
def test_dict_to_xarray(self):
array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]})
array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]})
d = {"number1": array1, "number2": array2}
res = dict_to_xarray(d, "merge_dim")
assert type(res) == xr.DataArray
assert sorted(list(res.coords)) == ["merge_dim", "x"]
assert res.shape == (2, 2, 3)
class TestFloatRound:
def test_float_round_ceil(self):
assert float_round(4.6) == 5
assert float_round(239.3992) == 240
def test_float_round_decimals(self):
assert float_round(23.0091, 2) == 23.01
assert float_round(23.1091, 3) == 23.11
def test_float_round_type(self):
assert float_round(34.9221, 2, math.floor) == 34.92
assert float_round(34.9221, 0, math.floor) == 34.
assert float_round(34.9221, 2, round) == 34.92
assert float_round(34.9221, 0, round) == 35.
def test_float_round_negative(self):
assert float_round(-34.9221, 2, math.floor) == -34.93
assert float_round(-34.9221, 0, math.floor) == -35.
assert float_round(-34.9221, 2) == -34.92
assert float_round(-34.9221, 0) == -34.
class TestRemoveItems:
@pytest.fixture
def custom_list(self):
return [1, 2, 3, 'a', 'bc']
@pytest.fixture
def custom_dict(self):
return {'a': 1, 'b': 2, 2: 'ab'}
def test_dict_remove_single(self, custom_dict):
# one out as list
d_pop = remove_items(custom_dict, [4])
assert d_pop == custom_dict
# one out as str
d_pop = remove_items(custom_dict, '4')
assert d_pop == custom_dict
# one in as str
d_pop = remove_items(custom_dict, 'b')
assert d_pop == {'a': 1, 2: 'ab'}
# one in as list
d_pop = remove_items(custom_dict, ['b'])
assert d_pop == {'a': 1, 2: 'ab'}
def test_dict_remove_multiple(self, custom_dict):
# all out (list)
d_pop = remove_items(custom_dict, [4, 'mykey'])
assert d_pop == custom_dict
# all in (list)
d_pop = remove_items(custom_dict, ['a', 2])
assert d_pop == {'b': 2}
# one in one out (list)
d_pop = remove_items(custom_dict, [2, '10'])
assert d_pop == {'a': 1, 'b': 2}
def test_list_remove_single(self, custom_list):
l_pop = remove_items(custom_list, 1)
assert l_pop == [2, 3, 'a', 'bc']
l_pop = remove_items(custom_list, 'bc')
assert l_pop == [1, 2, 3, 'a']
l_pop = remove_items(custom_list, 5)
assert l_pop == custom_list
def test_list_remove_multiple(self, custom_list):
# all in list
l_pop = remove_items(custom_list, [2, 'a'])
assert l_pop == [1, 3, 'bc']
# one in one out
l_pop = remove_items(custom_list, ['bc', 10])
assert l_pop == [1, 2, 3, 'a']
# all out
l_pop = remove_items(custom_list, [10, 'aa'])
assert l_pop == custom_list
def test_remove_missing_argument(self, custom_dict, custom_list):
with pytest.raises(TypeError) as e:
remove_items()
assert "remove_items() missing 2 required positional arguments: 'obj' and 'items'" in e.value.args[0]
with pytest.raises(TypeError) as e:
remove_items(custom_dict)
assert "remove_items() missing 1 required positional argument: 'items'" in e.value.args[0]
with pytest.raises(TypeError) as e:
remove_items(custom_list)
assert "remove_items() missing 1 required positional argument: 'items'" in e.value.args[0]
class TestLogger:
@pytest.fixture
def logger(self):
return Logger()
def test_init_default(self):
log = Logger()
assert log.formatter == "%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]"
assert log.log_file == Logger.setup_logging_path()
# assert PyTestRegex(
# ".*machinelearningtools/src/\.{2}/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log.log_file
def test_setup_logging_path_none(self):
log_file = Logger.setup_logging_path(None)
assert PyTestRegex(
".*mlair/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_file
@mock.patch("os.makedirs", side_effect=None)
def test_setup_logging_path_given(self, mock_makedirs):
path = "my/test/path"
log_path = Logger.setup_logging_path(path)
assert PyTestRegex("my/test/path/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_path
def test_logger_console_level0(self, logger):
consol = logger.logger_console(0)
assert isinstance(consol, logging.StreamHandler)
assert consol.level == 0
formatter = logging.Formatter(logger.formatter)
assert isinstance(formatter, logging.Formatter)
def test_logger_console_level1(self, logger):
consol = logger.logger_console(1)
assert isinstance(consol, logging.StreamHandler)
assert consol.level == 1
formatter = logging.Formatter(logger.formatter)
assert isinstance(formatter, logging.Formatter)
def test_logger_console_level_wrong_type(self, logger):
with pytest.raises(TypeError) as e:
logger.logger_console(1.5)
assert "Level not an integer or a valid string: 1.5" == e.value.args[0]