Skip to content
Snippets Groups Projects
Commit a5d3ec28 authored by lukas leufen's avatar lukas leufen
Browse files

add failing test to check new pytest-sugar package

parent 157af210
Branches
Tags
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #32935 failed
...@@ -43,6 +43,7 @@ pytest-cov==2.8.1 ...@@ -43,6 +43,7 @@ pytest-cov==2.8.1
pytest-html==2.0.1 pytest-html==2.0.1
pytest-lazy-fixture==0.6.3 pytest-lazy-fixture==0.6.3
pytest-metadata==1.8.0 pytest-metadata==1.8.0
pytest-sugar
python-dateutil==2.8.1 python-dateutil==2.8.1
pytz==2019.3 pytz==2019.3
PyYAML==5.3 PyYAML==5.3
......
...@@ -43,6 +43,7 @@ pytest-cov==2.8.1 ...@@ -43,6 +43,7 @@ pytest-cov==2.8.1
pytest-html==2.0.1 pytest-html==2.0.1
pytest-lazy-fixture==0.6.3 pytest-lazy-fixture==0.6.3
pytest-metadata==1.8.0 pytest-metadata==1.8.0
pytest-sugar
python-dateutil==2.8.1 python-dateutil==2.8.1
pytz==2019.3 pytz==2019.3
PyYAML==5.3 PyYAML==5.3
......
...@@ -166,6 +166,28 @@ class PyTestRegex: ...@@ -166,6 +166,28 @@ class PyTestRegex:
return self._regex.pattern return self._regex.pattern
class PyTestAllEqual:
def __init__(self, check_list):
self._list = check_list
def _check_all_equal(self):
equal = True
for b in self._list:
equal *= xr.testing.assert_equal(self._list[0], b) is None
return equal == 1
def is_true(self):
return self._check_all_equal()
def xr_all_equal(check_list):
equal = True
for b in check_list:
equal *= xr.testing.assert_equal(check_list[0], b) is None
return equal == 1
def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray: def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray:
""" """
Convert a dictionary of 2D-xarrays to single 3D-xarray. The name of new coordinate axis follows <coordinate_name>. Convert a dictionary of 2D-xarrays to single 3D-xarray. The name of new coordinate axis follows <coordinate_name>.
......
from src.data_handling.bootstraps import BootStraps, BootStrapGenerator from src.data_handling.bootstraps import BootStraps, BootStrapGenerator
from src.data_handling.data_generator import DataGenerator from src.data_handling.data_generator import DataGenerator
from src.helpers import PyTestAllEqual, xr_all_equal
import os import os
import pytest import pytest
...@@ -81,6 +82,8 @@ class TestBootstrapGenerator: ...@@ -81,6 +82,8 @@ class TestBootstrapGenerator:
dummy_content = xr.DataArray([1, 2, 3], dims="dummy") dummy_content = xr.DataArray([1, 2, 3], dims="dummy")
dummy_content.to_netcdf(os.path.join(path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc")) dummy_content.to_netcdf(os.path.join(path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc"))
dummy_content.to_netcdf(os.path.join(path, "DEBW013_o3_temp_hist7_nboots20_shuffled.nc")) dummy_content.to_netcdf(os.path.join(path, "DEBW013_o3_temp_hist7_nboots20_shuffled.nc"))
dummy_content = dummy_content.expand_dims({"type": ["CNN"]})
dummy_content.to_netcdf(os.path.join(path, "forecasts_norm_DEBW107_test.nc"))
return BootStrapGenerator(orig_generator, 20, path) return BootStrapGenerator(orig_generator, 20, path)
def test_init(self, orig_generator): def test_init(self, orig_generator):
...@@ -114,17 +117,15 @@ class TestBootstrapGenerator: ...@@ -114,17 +117,15 @@ class TestBootstrapGenerator:
res.append(label) res.append(label)
assert len(res) == boot_gen.number_of_boots assert len(res) == boot_gen.number_of_boots
assert xr.testing.assert_equal(res[0], res[-1]) is None assert xr.testing.assert_equal(res[0], res[-1]) is None
assert PyTestAllEqual(res).is_true()
def all_equal(check_list):
equal = True
for b in check_list:
equal *= xr.testing.assert_equal(check_list[0], b) is None
return equal
assert all_equal(res)
def test_get_orig_prediction(self, boot_gen): def test_get_orig_prediction(self, boot_gen):
pass path = boot_gen.orig_generator.data_path
res = []
for pred in boot_gen.get_orig_prediction(path, "forecasts_norm_DEBW107_test.nc"):
res.append(pred)
assert len(res) == boot_gen.number_of_boots+1
assert PyTestAllEqual(res).is_true()
def test_load_shuffled_data(self, boot_gen): def test_load_shuffled_data(self, boot_gen):
shuffled_data = boot_gen.load_shuffled_data("DEBW107", ["o3", "temp"]) shuffled_data = boot_gen.load_shuffled_data("DEBW107", ["o3", "temp"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment