diff --git a/conftest.py b/conftest.py index 70f1b1243f2796d46540bc5ce9b51bf48a15a4b0..d8205ae4349bdbcb6a249c453a9a773c78939756 100644 --- a/conftest.py +++ b/conftest.py @@ -5,15 +5,34 @@ import py import pytest -@pytest.fixture(autouse=True, scope='class') +@pytest.fixture(scope='class') def teardown_module(pytestconfig): yield # dirname can be found at pytestconfig._assertstate.hook.session._initialparts[0][0].dirname but it is not clear for # me, if there can be more than 1 entry in each of the lists. Therefore just loop over all elements will definitely # catch all dirnames. d = pytestconfig._assertstate.hook.session._initialparts + f = open("testfile.log", "a+") + print(f"pytestconfig._assertstate.hook.session._initialparts {pytestconfig._assertstate.hook.session._initialparts}") + print(f"pytestconfig._assertstate.hook.session._initialpaths {pytestconfig._assertstate.hook.session._initialpaths}") + f.write(f"pytestconfig._assertstate.hook.session._initialparts {pytestconfig._assertstate.hook.session._initialparts}\n") + f.write(f"pytestconfig._assertstate.hook.session._initialpaths {pytestconfig._assertstate.hook.session._initialpaths}\n") + f.write(f"pytestconfig.invocation_params.dir {pytestconfig.invocation_params.dir}\n") + f.write(f"pytestconfig.invocation_dir {pytestconfig.invocation_dir}\n") + f.write(f"list(pytestconfig._assertstate.hook.session._bestrelpathcache.keys())[0].dirname {list(pytestconfig._assertstate.hook.session._bestrelpathcache.keys())[0].dirname}\n") + f.write(f"list(pytestconfig._assertstate.hook.session._bestrelpathcache.keys())[0].strpath {list(pytestconfig._assertstate.hook.session._bestrelpathcache.keys())[0].strpath}\n") for di in d: + print(f"di: {type(di)}: {di}") for dii in di: + print(f"dii: {type(dii)}: {dii}") + f.write(f"di: {type(di)}: {di}\n") + f.write(f"\tdii: {type(dii)}: {dii}\n") if isinstance(dii, py._path.local.LocalPath): - if "data" in os.listdir(dii.dirname): - shutil.rmtree(os.path.join(dii.dirname, "data"), ignore_errors=True) + print("is localpath") + f.write(f"\t\t{os.listdir(dii.strpath)}\n") + if "data" in os.listdir(dii.strpath): + print("found data") + f.write(f"\t\tfound data\n") + shutil.rmtree(os.path.join(dii.strpath, "data"), ignore_errors=True) + f.write("#############\n\n") + f.close() diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index 0ca04bdefb2c1fe6085b0471a8df5c58cbe5ac19..77f83536db5eaed3545d609e1d33a042c7ad23dd 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -1,3 +1,4 @@ +from __future__ import generator_stop import math import keras @@ -43,7 +44,7 @@ class Distributor(keras.utils.Sequence): if x is not None: yield (x, y) if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call: - raise StopIteration + return def __len__(self): num_batch = 0 diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index fc6ce9e9ff9da41eb8caf64f059226903e9d020c..cb51f20c8771ec49116731f02c7b462a62405394 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -1,5 +1,6 @@ import math import os +import shutil import keras import numpy as np @@ -72,4 +73,4 @@ class TestDistributor: gen = generator_two_stations d = Distributor(gen, model) expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256) - assert len(d) == expected \ No newline at end of file + assert len(d) == expected diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index b08145e987278a7ab19a3b4e3567541e91f432c5..879436afddb8da8d11d6cc585da7c703aa12ef8a 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -1,10 +1,17 @@ import pytest import os +import shutil from src.data_handling.data_generator import DataGenerator class TestDataGenerator: + # @pytest.fixture(autouse=True, scope='module') + # def teardown_module(self): + # yield + # if "data" in os.listdir(os.path.dirname(__file__)): + # shutil.rmtree(os.path.join(os.path.dirname(__file__), "data"), ignore_errors=True) + @pytest.fixture def gen(self): return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 32d1937ae0fe217b2adaa5a07bf0051b6098af6d..12b619d9e31990f6cc24216ff84ad9d030265e36 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -44,7 +44,7 @@ class TestDataPrep: def test_set_file_name_and_meta(self): d = object.__new__(DataPrep) - d.path = os.path.abspath('test/data/') + d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "data") d.station = 'TESTSTATION' d.variables = ['a', 'bc'] assert d._set_file_name() == os.path.join(os.path.abspath(os.path.dirname(__file__)), diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index 7f4ed517f83ce07b2d5e82a05f63e9f4c60375fd..c333322a911732470fc25f413c10f2db14514515 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -11,11 +11,6 @@ from src.modules.run_environment import RunEnvironment class TestPreProcessing: - @pytest.fixture - def obj_no_init(self): - yield object.__new__(PreProcessing) - RunEnvironment().__del__() - @pytest.fixture def obj_super_init(self): obj = object.__new__(PreProcessing) @@ -96,9 +91,9 @@ class TestPreProcessing: assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 6 station\(s\). Found ' r'5/6 valid stations.')) - def test_split_set_indices(self, obj_no_init): + def test_split_set_indices(self, obj_super_init): dummy_list = list(range(0, 15)) - train, val, test = obj_no_init.split_set_indices(len(dummy_list), 0.9) + train, val, test = obj_super_init.split_set_indices(len(dummy_list), 0.9) assert dummy_list[train] == list(range(0, 10)) assert dummy_list[val] == list(range(10, 13)) assert dummy_list[test] == list(range(13, 15))