diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 879436afddb8da8d11d6cc585da7c703aa12ef8a..34cc60d7b6a9ccbd4d30463f185ce5cf6eff6f15 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -1,7 +1,10 @@ import pytest import os import shutil +import numpy as np +import pickle from src.data_handling.data_generator import DataGenerator +from src.data_handling.data_preparation import DataPrep class TestDataGenerator: @@ -17,6 +20,12 @@ class TestDataGenerator: return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3') + class DummyDataPrep: + def __init__(self, data): + self.station = "DEBW107" + self.variables = ["o3", "temp"] + self.data = data + def test_init(self, gen): assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data') assert gen.network == 'AIRBASE' @@ -44,15 +53,6 @@ class TestDataGenerator: gen.stations = ['station1', 'station2', 'station3'] assert len(gen) == 3 - def test_getitem(self, gen): - gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} - station = gen["DEBW107"] - assert len(station) == 2 - assert station[0].Stations.data == "DEBW107" - assert station[0].data.shape[1:] == (8, 1, 2) - assert station[1].data.shape[-1] == gen.window_lead_time - assert station[0].data.shape[1] == gen.window_history_size + 1 - def test_iter(self, gen): assert hasattr(gen, '_iterator') is False iter(gen) @@ -64,6 +64,15 @@ class TestDataGenerator: for i, d in enumerate(gen, start=1): assert i == gen._iterator + def test_getitem(self, gen): + gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} + station = gen["DEBW107"] + assert len(station) == 2 + assert station[0].Stations.data == "DEBW107" + assert station[0].data.shape[1:] == (8, 1, 2) + assert station[1].data.shape[-1] == gen.window_lead_time + assert station[0].data.shape[1] == gen.window_history_size + 1 + def test_get_station_key(self, gen): gen.stations.append("DEBW108") f = gen.get_station_key @@ -85,3 +94,38 @@ class TestDataGenerator: with pytest.raises(KeyError) as e: f(6.5) assert "key has to be from Union[str, int]. Given was 6.5 (float)" + + def test_get_data_generator(self, gen): + gen.kwargs = {"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} + file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle") + if os.path.exists(file): + os.remove(file) + assert not os.path.exists(file) + assert isinstance(gen.get_data_generator("DEBW107", local_tmp_storage=False), DataPrep) + t = os.stat(file).st_ctime + assert os.path.exists(file) + assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) + assert os.stat(file).st_mtime == t + os.remove(file) + assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) + assert os.stat(file).st_ctime > t + + def test_save_pickle_data(self, gen): + file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle") + if os.path.exists(file): + os.remove(file) + assert not os.path.exists(file) + data = self.DummyDataPrep(np.ones((10, 2))) + gen._save_pickle_data(data) + assert os.path.exists(file) + os.remove(file) + + def test_load_pickle_data(self, gen): + file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle") + data = self.DummyDataPrep(np.ones((10, 2))) + with open(file, "wb") as f: + pickle.dump(data, f) + assert os.path.exists(file) + res = gen._load_pickle_data("DEBW107", ["o3", "temp"]).data + assert np.testing.assert_almost_equal(res, np.ones((10, 2))) is None + os.remove(file)