From 7c8918ca922bfb4572436120ce1bd9e7b6adcd79 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Mon, 9 Dec 2019 10:36:12 +0100 Subject: [PATCH] current status, not running! --- conftest.py | 25 ++++++++++++++++--- src/data_handling/data_distributor.py | 3 ++- .../test_data_distributor.py | 3 ++- .../test_data_handling/test_data_generator.py | 7 ++++++ .../test_data_preparation.py | 2 +- test/test_modules/test_pre_processing.py | 9 ++----- 6 files changed, 36 insertions(+), 13 deletions(-) diff --git a/conftest.py b/conftest.py index 70f1b124..d8205ae4 100644 --- a/conftest.py +++ b/conftest.py @@ -5,15 +5,34 @@ import py import pytest -@pytest.fixture(autouse=True, scope='class') +@pytest.fixture(scope='class') def teardown_module(pytestconfig): yield # dirname can be found at pytestconfig._assertstate.hook.session._initialparts[0][0].dirname but it is not clear for # me, if there can be more than 1 entry in each of the lists. Therefore just loop over all elements will definitely # catch all dirnames. d = pytestconfig._assertstate.hook.session._initialparts + f = open("testfile.log", "a+") + print(f"pytestconfig._assertstate.hook.session._initialparts {pytestconfig._assertstate.hook.session._initialparts}") + print(f"pytestconfig._assertstate.hook.session._initialpaths {pytestconfig._assertstate.hook.session._initialpaths}") + f.write(f"pytestconfig._assertstate.hook.session._initialparts {pytestconfig._assertstate.hook.session._initialparts}\n") + f.write(f"pytestconfig._assertstate.hook.session._initialpaths {pytestconfig._assertstate.hook.session._initialpaths}\n") + f.write(f"pytestconfig.invocation_params.dir {pytestconfig.invocation_params.dir}\n") + f.write(f"pytestconfig.invocation_dir {pytestconfig.invocation_dir}\n") + f.write(f"list(pytestconfig._assertstate.hook.session._bestrelpathcache.keys())[0].dirname {list(pytestconfig._assertstate.hook.session._bestrelpathcache.keys())[0].dirname}\n") + f.write(f"list(pytestconfig._assertstate.hook.session._bestrelpathcache.keys())[0].strpath {list(pytestconfig._assertstate.hook.session._bestrelpathcache.keys())[0].strpath}\n") for di in d: + print(f"di: {type(di)}: {di}") for dii in di: + print(f"dii: {type(dii)}: {dii}") + f.write(f"di: {type(di)}: {di}\n") + f.write(f"\tdii: {type(dii)}: {dii}\n") if isinstance(dii, py._path.local.LocalPath): - if "data" in os.listdir(dii.dirname): - shutil.rmtree(os.path.join(dii.dirname, "data"), ignore_errors=True) + print("is localpath") + f.write(f"\t\t{os.listdir(dii.strpath)}\n") + if "data" in os.listdir(dii.strpath): + print("found data") + f.write(f"\t\tfound data\n") + shutil.rmtree(os.path.join(dii.strpath, "data"), ignore_errors=True) + f.write("#############\n\n") + f.close() diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index 0ca04bde..77f83536 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -1,3 +1,4 @@ +from __future__ import generator_stop import math import keras @@ -43,7 +44,7 @@ class Distributor(keras.utils.Sequence): if x is not None: yield (x, y) if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call: - raise StopIteration + return def __len__(self): num_batch = 0 diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index fc6ce9e9..cb51f20c 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -1,5 +1,6 @@ import math import os +import shutil import keras import numpy as np @@ -72,4 +73,4 @@ class TestDistributor: gen = generator_two_stations d = Distributor(gen, model) expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256) - assert len(d) == expected \ No newline at end of file + assert len(d) == expected diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index b08145e9..879436af 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -1,10 +1,17 @@ import pytest import os +import shutil from src.data_handling.data_generator import DataGenerator class TestDataGenerator: + # @pytest.fixture(autouse=True, scope='module') + # def teardown_module(self): + # yield + # if "data" in os.listdir(os.path.dirname(__file__)): + # shutil.rmtree(os.path.join(os.path.dirname(__file__), "data"), ignore_errors=True) + @pytest.fixture def gen(self): return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 32d1937a..12b619d9 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -44,7 +44,7 @@ class TestDataPrep: def test_set_file_name_and_meta(self): d = object.__new__(DataPrep) - d.path = os.path.abspath('test/data/') + d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "data") d.station = 'TESTSTATION' d.variables = ['a', 'bc'] assert d._set_file_name() == os.path.join(os.path.abspath(os.path.dirname(__file__)), diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index 7f4ed517..c333322a 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -11,11 +11,6 @@ from src.modules.run_environment import RunEnvironment class TestPreProcessing: - @pytest.fixture - def obj_no_init(self): - yield object.__new__(PreProcessing) - RunEnvironment().__del__() - @pytest.fixture def obj_super_init(self): obj = object.__new__(PreProcessing) @@ -96,9 +91,9 @@ class TestPreProcessing: assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 6 station\(s\). Found ' r'5/6 valid stations.')) - def test_split_set_indices(self, obj_no_init): + def test_split_set_indices(self, obj_super_init): dummy_list = list(range(0, 15)) - train, val, test = obj_no_init.split_set_indices(len(dummy_list), 0.9) + train, val, test = obj_super_init.split_set_indices(len(dummy_list), 0.9) assert dummy_list[train] == list(range(0, 10)) assert dummy_list[val] == list(range(10, 13)) assert dummy_list[test] == list(range(13, 15)) -- GitLab