diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 9394dedc68c54f81d3afba685b4170ac90a1a5fe..842846ae0753c52e254138edceae7ac0c0bc5e5a 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -114,7 +114,7 @@ class DataGenerator(keras.utils.Sequence): raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or" f"be an array with already calculated means. Given was: {mean}") elif scope == "station": - raise NotImplementedError("This is currently not implemented. ") + mean, std = None, None else: raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}") transformation["method"] = method @@ -141,6 +141,8 @@ class DataGenerator(keras.utils.Sequence): if method == "standardise": std = da.nanstd(tmp, axis=1).compute() std = xr.DataArray(std.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"]) + else: + raise NotImplementedError return mean, std def calculate_estimated_transformation(self, method): diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 7f712952f5a5c0c8538984287c6cb37c63a6935a..9bf11154609afa9ada2b488455f7a341a41d21ae 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -25,6 +25,13 @@ class TestDataGenerator: return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014) + @pytest.fixture + def gen_with_transformation(self): + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], + 'datetime', 'variables', 'o3', start=2010, end=2014, + transformation={"scope": "data", "mean": "estimate"}, + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + @pytest.fixture def gen_no_init(self): generator = object.__new__(DataGenerator) @@ -37,9 +44,37 @@ class TestDataGenerator: generator.variables = ["temp", "o3"] generator.station_type = "background" generator.kwargs = {"start": 2010, "end": 2014, "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} - return generator + @pytest.fixture + def accurate_transformation(self, gen_no_init): + tmp = np.nan + for station in gen_no_init.stations: + try: + data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, + station_type=gen_no_init.station_type, **gen_no_init.kwargs) + tmp = data_prep.data.combine_first(tmp) + except EmptyQueryResult: + continue + mean_expected = tmp.mean(dim=["Stations", "datetime"]) + std_expected = tmp.std(dim=["Stations", "datetime"]) + return mean_expected, std_expected + + @pytest.fixture + def estimated_transformation(self, gen_no_init): + mean, std = None, None + for station in gen_no_init.stations: + try: + data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, + station_type=gen_no_init.station_type, **gen_no_init.kwargs) + mean = data_prep.data.mean(axis=1).combine_first(mean) + std = data_prep.data.std(axis=1).combine_first(std) + except EmptyQueryResult: + continue + mean_expected = mean.mean(axis=0) + std_expected = std.mean(axis=0) + return mean_expected, std_expected + class DummyDataPrep: def __init__(self, data): self.station = "DEBW107" @@ -97,6 +132,8 @@ class TestDataGenerator: def test_setup_transformation_no_transformation(self, gen_no_init): assert gen_no_init.setup_transformation(None) is None assert gen_no_init.setup_transformation({}) == {"method": "standardise", "mean": None, "std": None} + assert gen_no_init.setup_transformation({"scope": "station", "mean": "accurate"}) == \ + {"scope": "station", "method": "standardise", "mean": None, "std": None} def test_setup_transformation_calculate_statistics(self, gen_no_init): transformation = {"scope": "data", "mean": "accurate"} @@ -119,8 +156,6 @@ class TestDataGenerator: assert res["std"] is None def test_setup_transformation_errors(self, gen_no_init): - with pytest.raises(NotImplementedError): - gen_no_init.setup_transformation({"mean": "accurate"}) transformation = {"scope": "random", "mean": "accurate"} with pytest.raises(ValueError): gen_no_init.setup_transformation(transformation) @@ -128,37 +163,38 @@ class TestDataGenerator: with pytest.raises(ValueError): gen_no_init.setup_transformation(transformation) - def test_calculate_accurate_transformation(self, gen_no_init): - tmp = np.nan - for station in gen_no_init.stations: - try: - data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, - station_type=gen_no_init.station_type, **gen_no_init.kwargs) - tmp = data_prep.data.combine_first(tmp) - except EmptyQueryResult: - continue - mean_expected = tmp.mean(dim=["Stations", "datetime"]) - std_expected = tmp.std(dim=["Stations", "datetime"]) + def test_calculate_accurate_transformation_standardise(self, gen_no_init, accurate_transformation): + mean_expected, std_expected = accurate_transformation mean, std = gen_no_init.calculate_accurate_transformation("standardise") assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None assert np.testing.assert_almost_equal(std_expected.values, std.values) is None - def test_calculate_estimated_transformation(self, gen_no_init): - mean, std = None, None - for station in gen_no_init.stations: - try: - data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, - station_type=gen_no_init.station_type, **gen_no_init.kwargs) - mean = data_prep.data.mean(axis=1).combine_first(mean) - std = data_prep.data.std(axis=1).combine_first(std) - except EmptyQueryResult: - continue - mean_expected = mean.mean(axis=0) - std_expected = std.mean(axis=0) + def test_calculate_accurate_transformation_centre(self, gen_no_init, accurate_transformation): + mean_expected, _ = accurate_transformation + mean, std = gen_no_init.calculate_accurate_transformation("centre") + assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None + assert std is None + + def test_calculate_accurate_transformation_all_others(self, gen_no_init): + with pytest.raises(NotImplementedError): + gen_no_init.calculate_accurate_transformation("normalise") + + def test_calculate_estimated_transformation_standardise(self, gen_no_init, estimated_transformation): + mean_expected, std_expected = estimated_transformation mean, std = gen_no_init.calculate_estimated_transformation("standardise") assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None assert np.testing.assert_almost_equal(std_expected.values, std.values) is None + def test_calculate_estimated_transformation_centre(self, gen_no_init, estimated_transformation): + mean_expected, _ = estimated_transformation + mean, std = gen_no_init.calculate_estimated_transformation("centre") + assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None + assert std is None + + def test_calculate_estimated_transformation_all_others(self, gen_no_init): + with pytest.raises(NotImplementedError): + gen_no_init.calculate_estimated_transformation("normalise") + def test_get_station_key(self, gen): gen.stations.append("DEBW108") f = gen.get_station_key @@ -196,6 +232,12 @@ class TestDataGenerator: assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) assert os.stat(file).st_ctime > t + def test_get_data_generator_transform(self, gen_with_transformation): + gen = gen_with_transformation + data = gen.get_data_generator("DEBW107", load_local_tmp_storage=False, save_local_tmp_storage=False) + assert data._transform_method == "standardise" + assert data.mean is not None + def test_save_pickle_data(self, gen): file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}_2010_2014_.pickle") if os.path.exists(file):