diff --git a/.gitignore b/.gitignore index e109cec7c7622cee9d9a635cc458b9c662fc4761..366a3e3907a4b0bed1bd400cc2e377b7cdbe92bc 100644 --- a/.gitignore +++ b/.gitignore @@ -60,7 +60,7 @@ Thumbs.db htmlcov/ .pytest_cache /test/data/ -/test/test_modules/data/ +/test/test_run_modules/data/ report.html /TestExperiment/ /testrun_network*/ diff --git a/requirements.txt b/requirements.txt index 71bb1338effff38092510982d4a2c1f37f7b026a..7da29a05b748531fd4ec327ff17f432ff1ecaabb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,9 +38,9 @@ pydot==1.4.1 pyparsing==2.4.6 pyproj==2.5.0 pyshp==2.1.0 -pytest==5.3.5 -pytest-cov==2.8.1 -pytest-html==2.0.1 +pytest==6.0.0 +pytest-cov==2.10.0 +pytest-html==2.1.1 pytest-lazy-fixture==0.6.3 pytest-metadata==1.8.0 pytest-sugar diff --git a/src/data_handler/__init__.py b/src/data_handler/__init__.py index 9ce7307d87fea03c11066068d8eccd78a02ed0bf..451868b838ab7a0d165942e36b5ec6aa03e42721 100644 --- a/src/data_handler/__init__.py +++ b/src/data_handler/__init__.py @@ -11,5 +11,5 @@ __date__ = '2020-04-17' from .bootstraps import BootStraps from .iterator import KerasIterator, DataCollection -from .advanced_data_handling import DefaultDataPreparation, AbstractDataPreparation +from .advanced_data_handler import DefaultDataPreparation, AbstractDataPreparation from .data_preparation_neighbors import DataPreparationNeighbors diff --git a/src/data_handler/advanced_data_handling.py b/src/data_handler/advanced_data_handler.py similarity index 98% rename from src/data_handler/advanced_data_handling.py rename to src/data_handler/advanced_data_handler.py index 26cd3ca82d03d4f95703a6ed0ad0a4b1d28e09e4..c9c25ca7a6ce765db2eb67d9b6b7d9144e54987a 100644 --- a/src/data_handler/advanced_data_handling.py +++ b/src/data_handler/advanced_data_handler.py @@ -12,6 +12,7 @@ import pandas as pd import datetime as dt import shutil import inspect +import copy from typing import Union, List, Tuple, Dict import logging @@ -109,9 +110,9 @@ class DefaultDataPreparation(AbstractDataPreparation): @classmethod def build(cls, station, **kwargs): - sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs} + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} sp = StationPrep(station, **sp_keys) - dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs} + dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} return cls(sp, **dp_args) def _create_collection(self): @@ -274,7 +275,7 @@ class DefaultDataPreparation(AbstractDataPreparation): @classmethod def transformation(cls, set_stations, **kwargs): - sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs} + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} transformation_dict = sp_keys.pop("transformation") if transformation_dict is None: return diff --git a/src/data_handler/bootstraps.py b/src/data_handler/bootstraps.py index e3f7ff91f34ea4e6352a9809749973ca10bbc00f..4ccc1350ee36df49dba0683eea896fc4ed398b60 100644 --- a/src/data_handler/bootstraps.py +++ b/src/data_handler/bootstraps.py @@ -12,7 +12,6 @@ __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2020-02-07' -import logging import os from collections import Iterator, Iterable from itertools import chain @@ -20,7 +19,7 @@ from itertools import chain import numpy as np import xarray as xr -from src.data_handler.advanced_data_handling import AbstractDataPreparation +from src.data_handler.advanced_data_handler import AbstractDataPreparation class BootstrapIterator(Iterator): diff --git a/src/data_handler/data_preparation_neighbors.py b/src/data_handler/data_preparation_neighbors.py index 6d2099da9c0e6cac6b0dec6a0ee7fdd090d9df33..508716b14d085ab1bb2aaaeb02471480608b6a27 100644 --- a/src/data_handler/data_preparation_neighbors.py +++ b/src/data_handler/data_preparation_neighbors.py @@ -5,7 +5,7 @@ __date__ = '2020-07-17' from src.helpers import to_list from src.data_handler.station_preparation import StationPrep -from src.data_handler.advanced_data_handling import DefaultDataPreparation +from src.data_handler.advanced_data_handler import DefaultDataPreparation import os from typing import Union, List diff --git a/src/run.py b/src/run.py index 8a4ade33c0e5b260fafab58e76cf753455077d50..1244c25d1b67d1f80b7da2b1e18210186ac3a9f0 100644 --- a/src/run.py +++ b/src/run.py @@ -39,4 +39,6 @@ def run(stations=None, if __name__ == "__main__": - run(stations=["DEBW013","DEBW025"], statistics_per_var={'o3': 'dma8eu', "temp": "maximum"}, trainable=True, create_new_model=True) + from src.model_modules.model_class import MyBranchedModel + run(stations=["DEBW013","DEBW025"], statistics_per_var={'o3': 'dma8eu', "temp": "maximum"}, trainable=True, + create_new_model=True, model=MyBranchedModel) diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 5d0e3743ac81262a74f28fd291193a6f10b8eeac..6fec871327d218cb2f42b84c518807c558c9c53d 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -18,7 +18,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST -from src.data_handler.advanced_data_handling import DefaultDataPreparation +from src.data_handler.advanced_data_handler import DefaultDataPreparation from src.run_modules.run_environment import RunEnvironment from src.model_modules.model_class import MyLittleModel as VanillaModel diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index c5cebcf39093ff035536e9ce7838999769dc0cbd..4b6de8253a58ce0b65184ae506f198a8a6b17aad 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -157,7 +157,7 @@ class PreProcessing(RunEnvironment): raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset" f"order was: {subset_names}.") for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names): - self.create_set_split_new(ind, scope) + self.create_set_split(ind, scope) @staticmethod def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]: @@ -181,7 +181,7 @@ class PreProcessing(RunEnvironment): train_val_index = slice(0, pos_test_split) return train_index, val_index, test_index, train_val_index - def create_set_split_new(self, index_list: slice, set_name: str) -> None: + def create_set_split(self, index_list: slice, set_name: str) -> None: # get set stations stations = self.data_store.get("stations", scope=set_name) if self.data_store.get("use_all_stations_on_all_data_sets"): @@ -212,7 +212,7 @@ class PreProcessing(RunEnvironment): :return: Corrected list containing only valid station IDs. """ t_outer = TimeTracking() - logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}") + logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}") # calculate transformation using train data if set_name == "train": self.transformation(data_preparation, set_stations) diff --git a/test/test_data_handling/old_t_bootstraps.py b/test/test_data_handler/old_t_bootstraps.py similarity index 100% rename from test/test_data_handling/old_t_bootstraps.py rename to test/test_data_handler/old_t_bootstraps.py diff --git a/test/test_data_handling/old_t_data_generator.py b/test/test_data_handler/old_t_data_generator.py similarity index 100% rename from test/test_data_handling/old_t_data_generator.py rename to test/test_data_handler/old_t_data_generator.py diff --git a/test/test_data_handling/old_t_data_preparation.py b/test/test_data_handler/old_t_data_preparation.py similarity index 100% rename from test/test_data_handling/old_t_data_preparation.py rename to test/test_data_handler/old_t_data_preparation.py diff --git a/test/test_data_handling/test_iterator.py b/test/test_data_handler/test_iterator.py similarity index 100% rename from test/test_data_handling/test_iterator.py rename to test/test_data_handler/test_iterator.py diff --git a/test/test_modules/test_experiment_setup.py b/test/test_run_modules/test_experiment_setup.py similarity index 100% rename from test/test_modules/test_experiment_setup.py rename to test/test_run_modules/test_experiment_setup.py diff --git a/test/test_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py similarity index 100% rename from test/test_modules/test_model_setup.py rename to test/test_run_modules/test_model_setup.py diff --git a/test/test_modules/test_partition_check.py b/test/test_run_modules/test_partition_check.py similarity index 100% rename from test/test_modules/test_partition_check.py rename to test/test_run_modules/test_partition_check.py diff --git a/test/test_modules/test_post_processing.py b/test/test_run_modules/test_post_processing.py similarity index 100% rename from test/test_modules/test_post_processing.py rename to test/test_run_modules/test_post_processing.py diff --git a/test/test_modules/old_t_pre_processing.py b/test/test_run_modules/test_pre_processing.py similarity index 70% rename from test/test_modules/old_t_pre_processing.py rename to test/test_run_modules/test_pre_processing.py index 63035d3858adb7c64265b27f87f8042c7f15a997..d08e3302fd55b8708e964bc5873209cc7d2dbbde 100644 --- a/test/test_modules/old_t_pre_processing.py +++ b/test/test_run_modules/test_pre_processing.py @@ -2,8 +2,7 @@ import logging import pytest -from src.data_handler import DataPrepJoin -from src.data_handler.data_generator import DataGenerator +from src.data_handler import DefaultDataPreparation, DataCollection, AbstractDataPreparation from src.helpers.datastore import NameNotFoundInScope from src.helpers import PyTestRegex from src.run_modules.experiment_setup import ExperimentSetup @@ -29,7 +28,7 @@ class TestPreProcessing: def obj_with_exp_setup(self): ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background", - data_preparation=DataPrepJoin) + data_preparation=DefaultDataPreparation) pre = object.__new__(PreProcessing) super(PreProcessing, pre).__init__() yield pre @@ -48,19 +47,20 @@ class TestPreProcessing: RunEnvironment().__del__() def test_run(self, obj_with_exp_setup): - assert obj_with_exp_setup.data_store.search_name("generator") == [] + assert obj_with_exp_setup.data_store.search_name("data_collection") == [] assert obj_with_exp_setup._run() is None - assert obj_with_exp_setup.data_store.search_name("generator") == sorted(["general.train", "general.val", - "general.train_val", "general.test"]) + assert obj_with_exp_setup.data_store.search_name("data_collection") == sorted(["general.train", "general.val", + "general.train_val", + "general.test"]) def test_split_train_val_test(self, obj_with_exp_setup): - assert obj_with_exp_setup.data_store.search_name("generator") == [] + assert obj_with_exp_setup.data_store.search_name("data_collection") == [] obj_with_exp_setup.split_train_val_test() data_store = obj_with_exp_setup.data_store - expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length", "extreme_values", - "extremes_on_right_tail_only", "upsampling"] + expected_params = ["data_collection", "start", "end", "stations", "permute_data", "min_length", + "extreme_values", "extremes_on_right_tail_only", "upsampling"] assert data_store.search_scope("general.train") == sorted(expected_params) - assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test", + assert data_store.search_name("data_collection") == sorted(["general.train", "general.val", "general.test", "general.train_val"]) def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup): @@ -69,9 +69,9 @@ class TestPreProcessing: obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") assert ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") in caplog.record_tuples data_store = obj_with_exp_setup.data_store - assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) + assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection) with pytest.raises(NameNotFoundInScope): - data_store.get("generator", "general") + data_store.get("data_collection", "general") assert data_store.get("stations", "general.awesome") == ["DEBW107", "DEBY081"] def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup): @@ -80,22 +80,22 @@ class TestPreProcessing: message = "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']" assert ('root', 10, message) in caplog.record_tuples data_store = obj_with_exp_setup.data_store - assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) + assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection) with pytest.raises(NameNotFoundInScope): - data_store.get("generator", "general") + data_store.get("data_collection", "general") assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] @pytest.mark.parametrize("name", (None, "tester")) - def test_check_valid_stations(self, caplog, obj_with_exp_setup, name): + def test_validate_station(self, caplog, obj_with_exp_setup, name): pre = obj_with_exp_setup caplog.set_level(logging.INFO) - args = pre.data_store.create_args_dict(DEFAULT_ARGS_LIST) - kwargs = pre.data_store.create_args_dict(DEFAULT_KWARGS_LIST) stations = pre.data_store.get("stations", "general") - valid_stations = pre.check_valid_stations(args, kwargs, stations, name=name) + data_preparation = pre.data_store.get("data_preparation") + collection, valid_stations = pre.validate_station(data_preparation, stations, set_name=name) + assert isinstance(collection, DataCollection) assert len(valid_stations) < len(stations) assert valid_stations == stations[:-1] - expected = 'check valid stations started (tester)' if name else 'check valid stations started' + expected = "check valid stations started" + ' (%s)' % (name if name else 'all') assert caplog.record_tuples[0] == ('root', 20, expected) assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 ' r'station\(s\). Found 5/6 valid stations.')) @@ -107,3 +107,11 @@ class TestPreProcessing: assert dummy_list[val] == list(range(10, 13)) assert dummy_list[test] == list(range(13, 15)) assert dummy_list[train_val] == list(range(0, 13)) + + def test_transformation(self): + pre = object.__new__(PreProcessing) + data_preparation = AbstractDataPreparation + stations = ['DEBW107', 'DEBY081'] + assert pre.transformation(data_preparation, stations) is None + class data_preparation_no_trans: pass + assert pre.transformation(data_preparation_no_trans, stations) is None diff --git a/test/test_modules/test_run_environment.py b/test/test_run_modules/test_run_environment.py similarity index 100% rename from test/test_modules/test_run_environment.py rename to test/test_run_modules/test_run_environment.py diff --git a/test/test_modules/old_t_training.py b/test/test_run_modules/test_training.py similarity index 72% rename from test/test_modules/old_t_training.py rename to test/test_run_modules/test_training.py index 0998ff67cb6d52ad24efaea3d5e404632331735a..5885accc87e9cd1b95cdfbd5c2a4dff65b3a2c18 100644 --- a/test/test_modules/old_t_training.py +++ b/test/test_run_modules/test_training.py @@ -9,7 +9,7 @@ import mock import pytest from keras.callbacks import History -from src.data_handler import DataPrepJoin +from src.data_handler import DataCollection, KerasIterator, DefaultDataPreparation from src.helpers import PyTestRegex from src.model_modules.flatten import flatten_tail from src.model_modules.inception_model import InceptionModelBase @@ -18,7 +18,7 @@ from src.run_modules.run_environment import RunEnvironment from src.run_modules.training import Training -def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False): +def my_test_model(activation, window_history_size, channels, output_size, dropout_rate, add_minor_branch=False): inception_model = InceptionModelBase() conv_settings_dict1 = { 'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation}, @@ -27,7 +27,6 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels)) X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1) if add_minor_branch: - # out = [flatten_tail(X_in, 'Minor_1', activation=activation)] out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4, output_activation='linear', reduction_filter=64, name='Minor_1', dropout_rate=dropout_rate, @@ -35,8 +34,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m else: out = [] X_in = keras.layers.Dropout(dropout_rate)(X_in) - # out.append(flatten_tail(X_in, 'Main', activation=activation)) - out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4, + out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=output_size, output_activation='linear', reduction_filter=64, name='Main', dropout_rate=dropout_rate, )) @@ -46,7 +44,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m class TestTraining: @pytest.fixture - def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path): + def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path, batch_path): obj = object.__new__(Training) super(Training, obj).__init__() obj.model = model @@ -60,15 +58,18 @@ class TestTraining: obj.lr_sc = lr obj.hist = hist obj.experiment_name = "TestExperiment" - obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train") - obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val") - obj.data_store.set("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test") + obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_train_gen"), "general.train") + obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_val_gen"), "general.val") + obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_test_gen"), "general.test") os.makedirs(path) obj.data_store.set("experiment_path", path, "general") + os.makedirs(batch_path) + obj.data_store.set("batch_path", batch_path, "general") os.makedirs(model_path) obj.data_store.set("model_path", model_path, "general") obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") obj.data_store.set("experiment_name", "TestExperiment", "general") + path_plot = os.path.join(path, "plots") os.makedirs(path_plot) obj.data_store.set("plot_path", path_plot, "general") @@ -106,14 +107,35 @@ class TestTraining: return os.path.join(path, "model") @pytest.fixture - def generator(self, path): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), ['DEBW107'], ['o3', 'temp'], 'datetime', - 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, - data_preparation=DataPrepJoin) + def batch_path(self, path): + return os.path.join(path, "batch") + + @pytest.fixture + def window_history_size(self): + return 7 + + @pytest.fixture + def window_lead_time(self): + return 2 @pytest.fixture - def model(self): - return my_test_model(keras.layers.PReLU, 7, 2, 0.1, False) + def statistics_per_var(self): + return {'o3': 'dma8eu', 'temp': 'maximum'} + + @pytest.fixture + def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var): + data_prep = DefaultDataPreparation.build(['DEBW107'], data_path=os.path.join(os.path.dirname(__file__), 'data'), + statistics_per_var=statistics_per_var, station_type="background", + network="AIRBASE", sampling="daily", target_dim="variables", + target_var="o3", interpolate_dim="datetime", + window_history_size=window_history_size, + window_lead_time=window_lead_time, name_affix="train") + return DataCollection([data_prep]) + + @pytest.fixture + def model(self, window_history_size, window_lead_time, statistics_per_var): + channels = len(list(statistics_per_var.keys())) + return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False) @pytest.fixture def callbacks(self, path): @@ -127,29 +149,31 @@ class TestTraining: return clbk, hist, lr @pytest.fixture - def ready_to_train(self, generator: DataGenerator, init_without_run: Training): - init_without_run.train_set = Distributor(generator, init_without_run.model, init_without_run.batch_size) - init_without_run.val_set = Distributor(generator, init_without_run.model, init_without_run.batch_size) + def ready_to_train(self, data_collection: DataCollection, init_without_run: Training, batch_path: str): + batch_size = init_without_run.batch_size + model = init_without_run.model + init_without_run.train_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="train") + init_without_run.val_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="val") init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) return init_without_run @pytest.fixture - def ready_to_run(self, generator, init_without_run): + def ready_to_run(self, data_collection, init_without_run): obj = init_without_run - obj.data_store.set("generator", generator, "general.train") - obj.data_store.set("generator", generator, "general.val") - obj.data_store.set("generator", generator, "general.test") + obj.data_store.set("data_collection", data_collection, "general.train") + obj.data_store.set("data_collection", data_collection, "general.val") + obj.data_store.set("data_collection", data_collection, "general.test") obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) return obj @pytest.fixture - def ready_to_init(self, generator, model, callbacks, path, model_path): + def ready_to_init(self, data_collection, model, callbacks, path, model_path, batch_path): os.makedirs(path) os.makedirs(model_path) obj = RunEnvironment() - obj.data_store.set("generator", generator, "general.train") - obj.data_store.set("generator", generator, "general.val") - obj.data_store.set("generator", generator, "general.test") + obj.data_store.set("data_collection", data_collection, "general.train") + obj.data_store.set("data_collection", data_collection, "general.val") + obj.data_store.set("data_collection", data_collection, "general.test") model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) obj.data_store.set("model", model, "general.model") obj.data_store.set("model_path", model_path, "general") @@ -164,6 +188,8 @@ class TestTraining: obj.data_store.set("experiment_path", path, "general") obj.data_store.set("trainable", True, "general") obj.data_store.set("create_new_model", True, "general") + os.makedirs(batch_path) + obj.data_store.set("batch_path", batch_path, "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) obj.data_store.set("plot_path", path_plot, "general") @@ -174,6 +200,13 @@ class TestTraining: def test_init(self, ready_to_init): assert isinstance(Training(), Training) # just test, if nothing fails + def test_no_training(self, ready_to_init, caplog): + caplog.set_level(logging.INFO) + ready_to_init.data_store.set("trainable", False) + Training() + message = "No training has started, because trainable parameter was false." + assert caplog.record_tuples[-2] == ("root", 20, message) + def test_run(self, ready_to_run): assert ready_to_run._run() is None # just test, if nothing fails @@ -185,8 +218,8 @@ class TestTraining: def test_set_gen(self, init_without_run): assert init_without_run.train_set is None init_without_run._set_gen("train") - assert isinstance(init_without_run.train_set, Distributor) - assert init_without_run.train_set.generator.return_value == "mock_train_gen" + assert isinstance(init_without_run.train_set, KerasIterator) + assert init_without_run.train_set._collection.return_value == "mock_train_gen" def test_set_generators(self, init_without_run): sets = ["train", "val", "test"] @@ -194,7 +227,7 @@ class TestTraining: init_without_run.set_generators() assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) assert all( - [getattr(init_without_run, f"{obj}_set").generator.return_value == f"mock_{obj}_gen" for obj in sets]) + [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets]) def test_train(self, ready_to_train, path): assert not hasattr(ready_to_train.model, "history")