diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 897d85c121fbe4b64b81bfc504209c7a3069f9f6..9394dedc68c54f81d3afba685b4170ac90a1a5fe 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -111,9 +111,15 @@ class DataGenerator(keras.utils.Sequence): elif mean == "estimate": mean, std = self.calculate_estimated_transformation(method) else: - mean = mean - transformation["mean"] = mean - transformation["std"] = std + raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or" + f"be an array with already calculated means. Given was: {mean}") + elif scope == "station": + raise NotImplementedError("This is currently not implemented. ") + else: + raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}") + transformation["method"] = method + transformation["mean"] = mean + transformation["std"] = std return transformation def calculate_accurate_transformation(self, method): @@ -138,8 +144,10 @@ class DataGenerator(keras.utils.Sequence): return mean, std def calculate_estimated_transformation(self, method): - mean = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"]) - std = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"]) + data = [[]]*len(self.variables) + coords = {"variables": self.variables, "Stations": range(0)} + mean = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) + std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) for station in self.stations: try: data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 75a98ffbd5ee89a555b82c24ca88340ad3afeb65..98b47a6df3825581564fa9aaef7be8698408760e 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -378,7 +378,7 @@ class DataPrep(object): :param coord: name of axis to slice :return: """ - return data.loc[{coord: slice(start, end)}] + return data.loc[{coord: slice(str(start), str(end))}] def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: """ diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 306dfa3079c306e46e05cc5b8fe2361acdcf281f..7f712952f5a5c0c8538984287c6cb37c63a6935a 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -1,12 +1,15 @@ import os +import operator as op import pytest import shutil import numpy as np +import xarray as xr import pickle from src.data_handling.data_generator import DataGenerator from src.data_handling.data_preparation import DataPrep +from src.join import EmptyQueryResult class TestDataGenerator: @@ -22,6 +25,21 @@ class TestDataGenerator: return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014) + @pytest.fixture + def gen_no_init(self): + generator = object.__new__(DataGenerator) + path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data')) + generator.data_path = path + if not os.path.exists(path): + os.makedirs(path) + generator.stations = ["DEBW107", "DEBW013", "DEBW001"] + generator.network = "AIRBASE" + generator.variables = ["temp", "o3"] + generator.station_type = "background" + generator.kwargs = {"start": 2010, "end": 2014, "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} + + return generator + class DummyDataPrep: def __init__(self, data): self.station = "DEBW107" @@ -41,7 +59,7 @@ class TestDataGenerator: assert gen.limit_nan_fill == 1 assert gen.window_history_size == 7 assert gen.window_lead_time == 4 - assert gen.transform_method == "standardise" + assert gen.transformation is None assert gen.kwargs == {"start": 2010, "end": 2014} def test_repr(self, gen): @@ -76,6 +94,71 @@ class TestDataGenerator: assert station[1].data.shape[-1] == gen.window_lead_time assert station[0].data.shape[1] == gen.window_history_size + 1 + def test_setup_transformation_no_transformation(self, gen_no_init): + assert gen_no_init.setup_transformation(None) is None + assert gen_no_init.setup_transformation({}) == {"method": "standardise", "mean": None, "std": None} + + def test_setup_transformation_calculate_statistics(self, gen_no_init): + transformation = {"scope": "data", "mean": "accurate"} + res_acc = gen_no_init.setup_transformation(transformation) + assert sorted(res_acc.keys()) == sorted(["scope", "mean", "std", "method"]) + assert isinstance(res_acc["mean"], xr.DataArray) + assert isinstance(res_acc["std"], xr.DataArray) + transformation["mean"] = "estimate" + res_est = gen_no_init.setup_transformation(transformation) + assert sorted(res_est.keys()) == sorted(["scope", "mean", "std", "method"]) + assert isinstance(res_est["mean"], xr.DataArray) + assert isinstance(res_est["std"], xr.DataArray) + assert np.testing.assert_array_compare(op.__ne__, res_est["std"].values, res_acc["std"].values) is None + + def test_setup_transformation_use_given_statistics(self, gen_no_init): + mean = xr.DataArray([30, 15], coords={"variables": ["o3", "temp"]}, dims=["variables"]) + transformation = {"scope": "data", "method": "centre", "mean": mean} + res = gen_no_init.setup_transformation(transformation) + assert np.testing.assert_equal(res["mean"].values, mean.values) is None + assert res["std"] is None + + def test_setup_transformation_errors(self, gen_no_init): + with pytest.raises(NotImplementedError): + gen_no_init.setup_transformation({"mean": "accurate"}) + transformation = {"scope": "random", "mean": "accurate"} + with pytest.raises(ValueError): + gen_no_init.setup_transformation(transformation) + transformation = {"scope": "data", "mean": "fit"} + with pytest.raises(ValueError): + gen_no_init.setup_transformation(transformation) + + def test_calculate_accurate_transformation(self, gen_no_init): + tmp = np.nan + for station in gen_no_init.stations: + try: + data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, + station_type=gen_no_init.station_type, **gen_no_init.kwargs) + tmp = data_prep.data.combine_first(tmp) + except EmptyQueryResult: + continue + mean_expected = tmp.mean(dim=["Stations", "datetime"]) + std_expected = tmp.std(dim=["Stations", "datetime"]) + mean, std = gen_no_init.calculate_accurate_transformation("standardise") + assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None + assert np.testing.assert_almost_equal(std_expected.values, std.values) is None + + def test_calculate_estimated_transformation(self, gen_no_init): + mean, std = None, None + for station in gen_no_init.stations: + try: + data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, + station_type=gen_no_init.station_type, **gen_no_init.kwargs) + mean = data_prep.data.mean(axis=1).combine_first(mean) + std = data_prep.data.std(axis=1).combine_first(std) + except EmptyQueryResult: + continue + mean_expected = mean.mean(axis=0) + std_expected = std.mean(axis=0) + mean, std = gen_no_init.calculate_estimated_transformation("standardise") + assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None + assert np.testing.assert_almost_equal(std_expected.values, std.values) is None + def test_get_station_key(self, gen): gen.stations.append("DEBW108") f = gen.get_station_key