Skip to content
Snippets Groups Projects
Commit edc5e931 authored by lukas leufen's avatar lukas leufen
Browse files

updated generator tests

parent 1754830c
No related branches found
No related tags found
2 merge requests!37include new development,!33Lukas issue036 feat local temp data storage
Pipeline #29111 passed
import pytest import pytest
import os import os
import shutil import shutil
import numpy as np
import pickle
from src.data_handling.data_generator import DataGenerator from src.data_handling.data_generator import DataGenerator
from src.data_handling.data_preparation import DataPrep
class TestDataGenerator: class TestDataGenerator:
...@@ -17,6 +20,12 @@ class TestDataGenerator: ...@@ -17,6 +20,12 @@ class TestDataGenerator:
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
'datetime', 'variables', 'o3') 'datetime', 'variables', 'o3')
class DummyDataPrep:
def __init__(self, data):
self.station = "DEBW107"
self.variables = ["o3", "temp"]
self.data = data
def test_init(self, gen): def test_init(self, gen):
assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data') assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data')
assert gen.network == 'AIRBASE' assert gen.network == 'AIRBASE'
...@@ -44,15 +53,6 @@ class TestDataGenerator: ...@@ -44,15 +53,6 @@ class TestDataGenerator:
gen.stations = ['station1', 'station2', 'station3'] gen.stations = ['station1', 'station2', 'station3']
assert len(gen) == 3 assert len(gen) == 3
def test_getitem(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
station = gen["DEBW107"]
assert len(station) == 2
assert station[0].Stations.data == "DEBW107"
assert station[0].data.shape[1:] == (8, 1, 2)
assert station[1].data.shape[-1] == gen.window_lead_time
assert station[0].data.shape[1] == gen.window_history_size + 1
def test_iter(self, gen): def test_iter(self, gen):
assert hasattr(gen, '_iterator') is False assert hasattr(gen, '_iterator') is False
iter(gen) iter(gen)
...@@ -64,6 +64,15 @@ class TestDataGenerator: ...@@ -64,6 +64,15 @@ class TestDataGenerator:
for i, d in enumerate(gen, start=1): for i, d in enumerate(gen, start=1):
assert i == gen._iterator assert i == gen._iterator
def test_getitem(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
station = gen["DEBW107"]
assert len(station) == 2
assert station[0].Stations.data == "DEBW107"
assert station[0].data.shape[1:] == (8, 1, 2)
assert station[1].data.shape[-1] == gen.window_lead_time
assert station[0].data.shape[1] == gen.window_history_size + 1
def test_get_station_key(self, gen): def test_get_station_key(self, gen):
gen.stations.append("DEBW108") gen.stations.append("DEBW108")
f = gen.get_station_key f = gen.get_station_key
...@@ -85,3 +94,38 @@ class TestDataGenerator: ...@@ -85,3 +94,38 @@ class TestDataGenerator:
with pytest.raises(KeyError) as e: with pytest.raises(KeyError) as e:
f(6.5) f(6.5)
assert "key has to be from Union[str, int]. Given was 6.5 (float)" assert "key has to be from Union[str, int]. Given was 6.5 (float)"
def test_get_data_generator(self, gen):
gen.kwargs = {"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}}
file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle")
if os.path.exists(file):
os.remove(file)
assert not os.path.exists(file)
assert isinstance(gen.get_data_generator("DEBW107", local_tmp_storage=False), DataPrep)
t = os.stat(file).st_ctime
assert os.path.exists(file)
assert isinstance(gen.get_data_generator("DEBW107"), DataPrep)
assert os.stat(file).st_mtime == t
os.remove(file)
assert isinstance(gen.get_data_generator("DEBW107"), DataPrep)
assert os.stat(file).st_ctime > t
def test_save_pickle_data(self, gen):
file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle")
if os.path.exists(file):
os.remove(file)
assert not os.path.exists(file)
data = self.DummyDataPrep(np.ones((10, 2)))
gen._save_pickle_data(data)
assert os.path.exists(file)
os.remove(file)
def test_load_pickle_data(self, gen):
file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle")
data = self.DummyDataPrep(np.ones((10, 2)))
with open(file, "wb") as f:
pickle.dump(data, f)
assert os.path.exists(file)
res = gen._load_pickle_data("DEBW107", ["o3", "temp"]).data
assert np.testing.assert_almost_equal(res, np.ones((10, 2))) is None
os.remove(file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment