Skip to content
Snippets Groups Projects
Commit ef0ddf4e authored by lukas leufen's avatar lukas leufen
Browse files

tests for transformation

parent 47c9e0d3
Branches
Tags
2 merge requests!50release for v0.7.0,!49Lukas issue054 feat transformation on entire dataset
Pipeline #30972 passed
...@@ -111,7 +111,13 @@ class DataGenerator(keras.utils.Sequence): ...@@ -111,7 +111,13 @@ class DataGenerator(keras.utils.Sequence):
elif mean == "estimate": elif mean == "estimate":
mean, std = self.calculate_estimated_transformation(method) mean, std = self.calculate_estimated_transformation(method)
else: else:
mean = mean raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or"
f"be an array with already calculated means. Given was: {mean}")
elif scope == "station":
raise NotImplementedError("This is currently not implemented. ")
else:
raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}")
transformation["method"] = method
transformation["mean"] = mean transformation["mean"] = mean
transformation["std"] = std transformation["std"] = std
return transformation return transformation
...@@ -138,8 +144,10 @@ class DataGenerator(keras.utils.Sequence): ...@@ -138,8 +144,10 @@ class DataGenerator(keras.utils.Sequence):
return mean, std return mean, std
def calculate_estimated_transformation(self, method): def calculate_estimated_transformation(self, method):
mean = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"]) data = [[]]*len(self.variables)
std = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"]) coords = {"variables": self.variables, "Stations": range(0)}
mean = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
for station in self.stations: for station in self.stations:
try: try:
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
......
...@@ -378,7 +378,7 @@ class DataPrep(object): ...@@ -378,7 +378,7 @@ class DataPrep(object):
:param coord: name of axis to slice :param coord: name of axis to slice
:return: :return:
""" """
return data.loc[{coord: slice(start, end)}] return data.loc[{coord: slice(str(start), str(end))}]
def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
""" """
......
import os import os
import operator as op
import pytest import pytest
import shutil import shutil
import numpy as np import numpy as np
import xarray as xr
import pickle import pickle
from src.data_handling.data_generator import DataGenerator from src.data_handling.data_generator import DataGenerator
from src.data_handling.data_preparation import DataPrep from src.data_handling.data_preparation import DataPrep
from src.join import EmptyQueryResult
class TestDataGenerator: class TestDataGenerator:
...@@ -22,6 +25,21 @@ class TestDataGenerator: ...@@ -22,6 +25,21 @@ class TestDataGenerator:
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
'datetime', 'variables', 'o3', start=2010, end=2014) 'datetime', 'variables', 'o3', start=2010, end=2014)
@pytest.fixture
def gen_no_init(self):
generator = object.__new__(DataGenerator)
path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data'))
generator.data_path = path
if not os.path.exists(path):
os.makedirs(path)
generator.stations = ["DEBW107", "DEBW013", "DEBW001"]
generator.network = "AIRBASE"
generator.variables = ["temp", "o3"]
generator.station_type = "background"
generator.kwargs = {"start": 2010, "end": 2014, "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}}
return generator
class DummyDataPrep: class DummyDataPrep:
def __init__(self, data): def __init__(self, data):
self.station = "DEBW107" self.station = "DEBW107"
...@@ -41,7 +59,7 @@ class TestDataGenerator: ...@@ -41,7 +59,7 @@ class TestDataGenerator:
assert gen.limit_nan_fill == 1 assert gen.limit_nan_fill == 1
assert gen.window_history_size == 7 assert gen.window_history_size == 7
assert gen.window_lead_time == 4 assert gen.window_lead_time == 4
assert gen.transform_method == "standardise" assert gen.transformation is None
assert gen.kwargs == {"start": 2010, "end": 2014} assert gen.kwargs == {"start": 2010, "end": 2014}
def test_repr(self, gen): def test_repr(self, gen):
...@@ -76,6 +94,71 @@ class TestDataGenerator: ...@@ -76,6 +94,71 @@ class TestDataGenerator:
assert station[1].data.shape[-1] == gen.window_lead_time assert station[1].data.shape[-1] == gen.window_lead_time
assert station[0].data.shape[1] == gen.window_history_size + 1 assert station[0].data.shape[1] == gen.window_history_size + 1
def test_setup_transformation_no_transformation(self, gen_no_init):
assert gen_no_init.setup_transformation(None) is None
assert gen_no_init.setup_transformation({}) == {"method": "standardise", "mean": None, "std": None}
def test_setup_transformation_calculate_statistics(self, gen_no_init):
transformation = {"scope": "data", "mean": "accurate"}
res_acc = gen_no_init.setup_transformation(transformation)
assert sorted(res_acc.keys()) == sorted(["scope", "mean", "std", "method"])
assert isinstance(res_acc["mean"], xr.DataArray)
assert isinstance(res_acc["std"], xr.DataArray)
transformation["mean"] = "estimate"
res_est = gen_no_init.setup_transformation(transformation)
assert sorted(res_est.keys()) == sorted(["scope", "mean", "std", "method"])
assert isinstance(res_est["mean"], xr.DataArray)
assert isinstance(res_est["std"], xr.DataArray)
assert np.testing.assert_array_compare(op.__ne__, res_est["std"].values, res_acc["std"].values) is None
def test_setup_transformation_use_given_statistics(self, gen_no_init):
mean = xr.DataArray([30, 15], coords={"variables": ["o3", "temp"]}, dims=["variables"])
transformation = {"scope": "data", "method": "centre", "mean": mean}
res = gen_no_init.setup_transformation(transformation)
assert np.testing.assert_equal(res["mean"].values, mean.values) is None
assert res["std"] is None
def test_setup_transformation_errors(self, gen_no_init):
with pytest.raises(NotImplementedError):
gen_no_init.setup_transformation({"mean": "accurate"})
transformation = {"scope": "random", "mean": "accurate"}
with pytest.raises(ValueError):
gen_no_init.setup_transformation(transformation)
transformation = {"scope": "data", "mean": "fit"}
with pytest.raises(ValueError):
gen_no_init.setup_transformation(transformation)
def test_calculate_accurate_transformation(self, gen_no_init):
tmp = np.nan
for station in gen_no_init.stations:
try:
data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables,
station_type=gen_no_init.station_type, **gen_no_init.kwargs)
tmp = data_prep.data.combine_first(tmp)
except EmptyQueryResult:
continue
mean_expected = tmp.mean(dim=["Stations", "datetime"])
std_expected = tmp.std(dim=["Stations", "datetime"])
mean, std = gen_no_init.calculate_accurate_transformation("standardise")
assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None
assert np.testing.assert_almost_equal(std_expected.values, std.values) is None
def test_calculate_estimated_transformation(self, gen_no_init):
mean, std = None, None
for station in gen_no_init.stations:
try:
data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables,
station_type=gen_no_init.station_type, **gen_no_init.kwargs)
mean = data_prep.data.mean(axis=1).combine_first(mean)
std = data_prep.data.std(axis=1).combine_first(std)
except EmptyQueryResult:
continue
mean_expected = mean.mean(axis=0)
std_expected = std.mean(axis=0)
mean, std = gen_no_init.calculate_estimated_transformation("standardise")
assert np.testing.assert_almost_equal(mean_expected.values, mean.values) is None
assert np.testing.assert_almost_equal(std_expected.values, std.values) is None
def test_get_station_key(self, gen): def test_get_station_key(self, gen):
gen.stations.append("DEBW108") gen.stations.append("DEBW108")
f = gen.get_station_key f = gen.get_station_key
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment