diff --git a/requirements.txt b/requirements.txt index 9cd9ea44c3cd0068c985c52b07a7cfaa746d9b7c..e7c2f439966f6b085348af3078c814c7f0511024 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,6 +43,7 @@ pytest-cov==2.8.1 pytest-html==2.0.1 pytest-lazy-fixture==0.6.3 pytest-metadata==1.8.0 +pytest-sugar python-dateutil==2.8.1 pytz==2019.3 PyYAML==5.3 diff --git a/requirements_gpu.txt b/requirements_gpu.txt index 8e5a31e476e47b17d3f271199bbc151fc0dc0b50..9d1c2d62da0864d2626c7ada1aac4dcf6f633630 100644 --- a/requirements_gpu.txt +++ b/requirements_gpu.txt @@ -43,6 +43,7 @@ pytest-cov==2.8.1 pytest-html==2.0.1 pytest-lazy-fixture==0.6.3 pytest-metadata==1.8.0 +pytest-sugar python-dateutil==2.8.1 pytz==2019.3 PyYAML==5.3 diff --git a/src/helpers.py b/src/helpers.py index d108f3c30bbbe55965d3302d94571f740378503d..07e7e5dde3e20bcd016651cbd47d24970d38303d 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -166,6 +166,28 @@ class PyTestRegex: return self._regex.pattern +class PyTestAllEqual: + + def __init__(self, check_list): + self._list = check_list + + def _check_all_equal(self): + equal = True + for b in self._list: + equal *= xr.testing.assert_equal(self._list[0], b) is None + return equal == 1 + + def is_true(self): + return self._check_all_equal() + + +def xr_all_equal(check_list): + equal = True + for b in check_list: + equal *= xr.testing.assert_equal(check_list[0], b) is None + return equal == 1 + + def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray: """ Convert a dictionary of 2D-xarrays to single 3D-xarray. The name of new coordinate axis follows <coordinate_name>. diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py index 27494ee9918e1509787a2259cd07976627cb2b18..56b2a8bf9d4ce0d54c271e634b1bb8e171c80a6b 100644 --- a/test/test_data_handling/test_bootstraps.py +++ b/test/test_data_handling/test_bootstraps.py @@ -1,6 +1,7 @@ from src.data_handling.bootstraps import BootStraps, BootStrapGenerator from src.data_handling.data_generator import DataGenerator +from src.helpers import PyTestAllEqual, xr_all_equal import os import pytest @@ -81,6 +82,8 @@ class TestBootstrapGenerator: dummy_content = xr.DataArray([1, 2, 3], dims="dummy") dummy_content.to_netcdf(os.path.join(path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc")) dummy_content.to_netcdf(os.path.join(path, "DEBW013_o3_temp_hist7_nboots20_shuffled.nc")) + dummy_content = dummy_content.expand_dims({"type": ["CNN"]}) + dummy_content.to_netcdf(os.path.join(path, "forecasts_norm_DEBW107_test.nc")) return BootStrapGenerator(orig_generator, 20, path) def test_init(self, orig_generator): @@ -114,17 +117,15 @@ class TestBootstrapGenerator: res.append(label) assert len(res) == boot_gen.number_of_boots assert xr.testing.assert_equal(res[0], res[-1]) is None - - def all_equal(check_list): - equal = True - for b in check_list: - equal *= xr.testing.assert_equal(check_list[0], b) is None - return equal - assert all_equal(res) - + assert PyTestAllEqual(res).is_true() def test_get_orig_prediction(self, boot_gen): - pass + path = boot_gen.orig_generator.data_path + res = [] + for pred in boot_gen.get_orig_prediction(path, "forecasts_norm_DEBW107_test.nc"): + res.append(pred) + assert len(res) == boot_gen.number_of_boots+1 + assert PyTestAllEqual(res).is_true() def test_load_shuffled_data(self, boot_gen): shuffled_data = boot_gen.load_shuffled_data("DEBW107", ["o3", "temp"])