From 94bce9420a143cf0ee467609bf22c08adabcb90f Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 13 Dec 2019 12:19:30 +0100 Subject: [PATCH] renamed data_store.put -> set, model class must include now the settings (removed from model_setup.py) --- src/datastore.py | 12 +- src/model_modules/model_class.py | 62 +++++--- src/modules/experiment_setup.py | 2 +- src/modules/model_setup.py | 43 +----- src/modules/pre_processing.py | 6 +- test/test_datastore.py | 160 ++++++++++---------- test/test_model_modules/test_model_class.py | 10 +- test/test_modules/test_model_setup.py | 6 +- test/test_modules/test_pre_processing.py | 12 +- test/test_modules/test_training.py | 40 ++--- 10 files changed, 178 insertions(+), 175 deletions(-) diff --git a/src/datastore.py b/src/datastore.py index 5f0df675..d14ae07d 100644 --- a/src/datastore.py +++ b/src/datastore.py @@ -38,7 +38,7 @@ class AbstractDataStore(ABC): # empty initialise the data-store variables self._store: Dict = {} - def put(self, name: str, obj: Any, scope: str) -> None: + def set(self, name: str, obj: Any, scope: str) -> None: """ Abstract method to add an object to the data store :param name: Name of object to store @@ -89,7 +89,7 @@ class AbstractDataStore(ABC): def clear_data_store(self) -> None: self._store = {} - def create_args_dict(self, arg_list: List[str], scope: str = "general"): + def create_args_dict(self, arg_list: List[str], scope: str = "general") -> Dict: args = {} for arg in arg_list: try: @@ -98,6 +98,10 @@ class AbstractDataStore(ABC): pass return args + def set_args_from_dict(self, arg_dict: Dict, scope: str = "general") -> None: + for (k, v) in arg_dict.items(): + self.set(k, v, scope) + class DataStoreByVariable(AbstractDataStore): @@ -115,7 +119,7 @@ class DataStoreByVariable(AbstractDataStore): <scope3>: value """ - def put(self, name: str, obj: Any, scope: str) -> None: + def set(self, name: str, obj: Any, scope: str) -> None: """ Store an object `obj` with given `name` under `scope`. In the current implementation, existing entries are overwritten. @@ -239,7 +243,7 @@ class DataStoreByScope(AbstractDataStore): <variable3>: value """ - def put(self, name: str, obj: Any, scope: str) -> None: + def set(self, name: str, obj: Any, scope: str) -> None: """ Store an object `obj` with given `name` under `scope`. In the current implementation, existing entries are overwritten. diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 19ea629c..ecaef632 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -7,6 +7,8 @@ from typing import Any, Callable import keras +from src import helpers + class AbstractModelClass(ABC): @@ -23,8 +25,8 @@ class AbstractModelClass(ABC): Predefine internal attributes for model and loss. """ - self._model = None - self._loss = None + self.__model = None + self.__loss = None def __getattr__(self, name: str) -> Any: @@ -48,7 +50,11 @@ class AbstractModelClass(ABC): :return: the keras model """ - return self._model + return self.__model + + @model.setter + def model(self, value): + self.__model = value @property def loss(self) -> Callable: @@ -63,7 +69,14 @@ class AbstractModelClass(ABC): :return: the loss function """ - return self._loss + return self.__loss + + @loss.setter + def loss(self, value) -> None: + self.__loss = value + + def get_settings(self): + return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__")) class MyLittleModel(AbstractModelClass): @@ -74,7 +87,7 @@ class MyLittleModel(AbstractModelClass): Dense layer. """ - def __init__(self, activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time): + def __init__(self, input_x, window_history_size, window_lead_time): """ Sets model and loss depending on the given arguments. @@ -87,10 +100,25 @@ class MyLittleModel(AbstractModelClass): """ super().__init__() - self.set_model(activation, window_history_size, channels, dropout_rate, window_lead_time) + + # settings + self.window_history_size = window_history_size + self.window_lead_time = window_lead_time + self.channels = input_x.shape[-1] # input variables + self.dropout_rate = 0.1 + self.regularizer = keras.regularizers.l2(0.1) + self.initial_lr = 1e-2 + self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) + self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) + self.epochs = 2 + self.batch_size = int(256) + self.activation = keras.layers.PReLU + + # apply to model + self.set_model() self.set_loss() - def set_model(self, activation, window_history_size, channels, dropout_rate, window_lead_time): + def set_model(self): """ Build the model. @@ -103,20 +131,20 @@ class MyLittleModel(AbstractModelClass): """ # add 1 to window_size to include current time step t0 - x_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels)) + x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels)) x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input) - x_in = activation(name='{}_conv_act'.format("major"))(x_in) + x_in = self.activation(name='{}_conv_act'.format("major"))(x_in) x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in) - x_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_1'.format("major"))(x_in) + x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in) x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in) - x_in = activation()(x_in) + x_in = self.activation()(x_in) x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in) - x_in = activation()(x_in) + x_in = self.activation()(x_in) x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in) - x_in = activation()(x_in) - x_in = keras.layers.Dense(window_lead_time, name='{}_Dense'.format("major"))(x_in) - out_main = activation()(x_in) - self._model = keras.Model(inputs=x_input, outputs=[out_main]) + x_in = self.activation()(x_in) + x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in) + out_main = self.activation()(x_in) + self.model = keras.Model(inputs=x_input, outputs=[out_main]) def set_loss(self): @@ -125,4 +153,4 @@ class MyLittleModel(AbstractModelClass): :return: loss function """ - self._loss = keras.losses.mean_squared_error + self.loss = keras.losses.mean_squared_error diff --git a/src/modules/experiment_setup.py b/src/modules/experiment_setup.py index 5fdc1f1f..f726c66e 100644 --- a/src/modules/experiment_setup.py +++ b/src/modules/experiment_setup.py @@ -93,7 +93,7 @@ class ExperimentSetup(RunEnvironment): def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: if value is None and default is not None: value = default - self.data_store.put(param, value, scope) + self.data_store.set(param, value, scope) logging.debug(f"set experiment attribute: {param}({scope})={value}") @staticmethod diff --git a/src/modules/model_setup.py b/src/modules/model_setup.py index 947b9fa3..f7f6b4f4 100644 --- a/src/modules/model_setup.py +++ b/src/modules/model_setup.py @@ -36,9 +36,6 @@ class ModelSetup(RunEnvironment): # create checkpoint self._set_checkpoint() - # set all model settings - self.my_model_settings() - # build model graph using settings from my_model_settings() self.build_model() @@ -56,11 +53,11 @@ class ModelSetup(RunEnvironment): optimizer = self.data_store.get("optimizer", self.scope) loss = self.model.loss self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) - self.data_store.put("model", self.model, self.scope) + self.data_store.set("model", self.model, self.scope) def _set_checkpoint(self): checkpoint = ModelCheckpoint(self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') - self.data_store.put("checkpoint", checkpoint, self.scope) + self.data_store.set("checkpoint", checkpoint, self.scope) def load_weights(self): try: @@ -70,9 +67,12 @@ class ModelSetup(RunEnvironment): logging.info('no weights to reload...') def build_model(self): - args_list = ["activation", "window_history_size", "channels", "regularizer", "dropout_rate", "window_lead_time"] + args_list = ["window_history_size", "window_lead_time"] args = self.data_store.create_args_dict(args_list, self.scope) - self.model = MyLittleModel(**args) + input_x = self.data_store.get("generator", "general.train")[0][0] + self.model = MyLittleModel(input_x, **args) + model_settings = self.model.get_settings() + self.data_store.set_args_from_dict(model_settings, self.scope) def plot_model(self): # pragma: no cover with tf.device("/cpu:0"): @@ -81,35 +81,6 @@ class ModelSetup(RunEnvironment): file_name = os.path.join(path, name) keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) - def my_model_settings(self): - - # channels - X, _ = self.data_store.get("generator", "general.train")[0] - channels = X.shape[-1] # input variables - self.data_store.put("channels", channels, self.scope) - - # dropout - self.data_store.put("dropout_rate", 0.1, self.scope) - - # regularizer - self.data_store.put("regularizer", l2(0.1), self.scope) - - # learning rate - initial_lr = 1e-2 - self.data_store.put("initial_lr", initial_lr, self.scope) - optimizer = SGD(lr=initial_lr, momentum=0.9) - # optimizer=Adam(lr=initial_lr, amsgrad=True) - self.data_store.put("optimizer", optimizer, self.scope) - self.data_store.put("lr_decay", LearningRateDecay(base_lr=initial_lr, drop=.94, epochs_drop=10), self.scope) - - # learning settings - self.data_store.put("epochs", 2, self.scope) - self.data_store.put("batch_size", int(256), self.scope) - - # activation - activation = keras.layers.PReLU # ELU #LeakyReLU keras.activations.tanh # - self.data_store.put("activation", activation, self.scope) - def my_loss(): loss = l_p_loss(4) diff --git a/src/modules/pre_processing.py b/src/modules/pre_processing.py index 8fad9d1b..cce9ee58 100644 --- a/src/modules/pre_processing.py +++ b/src/modules/pre_processing.py @@ -36,7 +36,7 @@ class PreProcessing(RunEnvironment): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST) kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST) valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general")) - self.data_store.put("stations", valid_stations, "general") + self.data_store.set("stations", valid_stations, "general") self.split_train_val_test() def report_pre_processing(self): @@ -87,10 +87,10 @@ class PreProcessing(RunEnvironment): set_stations = stations[index_list] logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") set_stations = self.check_valid_stations(args, kwargs, set_stations) - self.data_store.put("stations", set_stations, scope) + self.data_store.set("stations", set_stations, scope) set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) data_set = DataGenerator(**set_args, **kwargs) - self.data_store.put("generator", data_set, scope) + self.data_store.set("generator", data_set, scope) @staticmethod def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]): diff --git a/test/test_datastore.py b/test/test_datastore.py index 5ba76bf8..3f61c227 100644 --- a/test/test_datastore.py +++ b/test/test_datastore.py @@ -24,95 +24,95 @@ class TestDataStoreByVariable: return DataStoreByVariable() def test_put(self, ds): - ds.put("number", 3, "general.subscope") + ds.set("number", 3, "general.subscope") assert ds._store["number"]["general.subscope"] == 3 def test_get(self, ds): - ds.put("number", 3, "general.subscope") + ds.set("number", 3, "general.subscope") assert ds.get("number", "general.subscope") == 3 def test_get_with_sub_scope(self, ds): - ds.put("number", 3, "general") - ds.put("number", 10, "general.subscope") + ds.set("number", 3, "general") + ds.set("number", 10, "general.subscope") assert ds.get("number", "general.subscope") == 10 assert ds.get("number", "general") == 3 def test_get_with_not_existing_sub_scope(self, ds): - ds.put("number", 3, "general") - ds.put("number2", 10, "general.subscope") - ds.put("number2", 1, "general") + ds.set("number", 3, "general") + ds.set("number2", 10, "general.subscope") + ds.set("number2", 1, "general") assert ds.get("number", "general.subscope") == 3 def test_raise_not_in_data_store(self, ds): - ds.put("number", 22, "general") + ds.set("number", 22, "general") with pytest.raises(NameNotFoundInDataStore) as e: ds.get("number3", "general") assert "Couldn't find number3 in data store" in e.value.args[0] def test_search(self, ds): - ds.put("number", 22, "general") - ds.put("number", 22, "general2") - ds.put("number", 22, "general.sub") + ds.set("number", 22, "general") + ds.set("number", 22, "general2") + ds.set("number", 22, "general.sub") assert ds.search_name("number") == ["general", "general.sub", "general2"] def test_raise_not_in_scope(self, ds): - ds.put("number", 11, "general.sub") + ds.set("number", 11, "general.sub") with pytest.raises(NameNotFoundInScope) as e: ds.get("number", "general.sub2") assert "Couldn't find number in scope general.sub2 . number is only defined in ['general.sub']" in e.value.args[0] def test_list_all_scopes(self, ds): - ds.put("number", 22, "general2") - ds.put("number", 11, "general.sub") - ds.put("number2", 2, "general.sub") - ds.put("number", 3, "general.sub3") - ds.put("number", 1, "general") + ds.set("number", 22, "general2") + ds.set("number", 11, "general.sub") + ds.set("number2", 2, "general.sub") + ds.set("number", 3, "general.sub3") + ds.set("number", 1, "general") assert ds.list_all_scopes() == ['general', 'general.sub', 'general.sub3', 'general2'] def test_search_scope(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub") == ["number", "number1"] def test_search_empty_scope(self, ds): - ds.put("number", 22, "general2") - ds.put("number", 11, "general.sub") + ds.set("number", 22, "general2") + ds.set("number", 11, "general.sub") with pytest.raises(EmptyScope) as e: ds.search_scope("general.sub2") assert "Given scope general.sub2 is not part of the data store." in e.value.args[0] assert "Available scopes are: ['general.sub', 'general2']" in e.value.args[0] def test_list_all_names(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.list_all_names() == ["number", "number1", "number2"] def test_search_scope_and_all_superiors(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub", current_scope_only=False) == ["number", "number1"] assert ds.search_scope("general.sub.sub", current_scope_only=False) == ["number", "number1", "number2"] def test_search_scope_return_all(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub", return_all=True) == [("number", "general.sub", 11), ("number1", "general.sub", 22)] def test_search_scope_and_all_superiors_return_all(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") - ds.put("number", "ABC", "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") + ds.set("number", "ABC", "general.sub.sub") assert ds.search_scope("general.sub", current_scope_only=False, return_all=True) == \ [("number", "general.sub", 11), ("number1", "general.sub", 22)] assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \ @@ -126,95 +126,95 @@ class TestDataStoreByScope: return DataStoreByScope() def test_put_with_scope(self, ds): - ds.put("number", 3, "general.subscope") + ds.set("number", 3, "general.subscope") assert ds._store["general.subscope"]["number"] == 3 def test_get(self, ds): - ds.put("number", 3, "general.subscope") + ds.set("number", 3, "general.subscope") assert ds.get("number", "general.subscope") == 3 def test_get_with_sub_scope(self, ds): - ds.put("number", 3, "general") - ds.put("number", 10, "general.subscope") + ds.set("number", 3, "general") + ds.set("number", 10, "general.subscope") assert ds.get("number", "general.subscope") == 10 assert ds.get("number", "general") == 3 def test_get_with_not_existing_sub_scope(self, ds): - ds.put("number", 3, "general") - ds.put("number2", 10, "general.subscope") - ds.put("number2", 1, "general") + ds.set("number", 3, "general") + ds.set("number2", 10, "general.subscope") + ds.set("number2", 1, "general") assert ds.get("number", "general.subscope") == 3 def test_raise_not_in_data_store(self, ds): - ds.put("number", 22, "general") + ds.set("number", 22, "general") with pytest.raises(NameNotFoundInDataStore) as e: ds.get("number3", "general") assert "Couldn't find number3 in data store" in e.value.args[0] def test_search(self, ds): - ds.put("number", 22, "general") - ds.put("number", 22, "general2") - ds.put("number", 22, "general.sub") + ds.set("number", 22, "general") + ds.set("number", 22, "general2") + ds.set("number", 22, "general.sub") assert ds.search_name("number") == ["general", "general.sub", "general2"] def test_raise_not_in_scope(self, ds): - ds.put("number", 11, "general.sub") + ds.set("number", 11, "general.sub") with pytest.raises(NameNotFoundInScope) as e: ds.get("number", "general.sub2") assert "Couldn't find number in scope general.sub2 . number is only defined in ['general.sub']" in e.value.args[0] def test_list_all_scopes(self, ds): - ds.put("number", 22, "general2") - ds.put("number", 11, "general.sub") - ds.put("number2", 2, "general.sub") - ds.put("number", 3, "general.sub3") - ds.put("number", 1, "general") + ds.set("number", 22, "general2") + ds.set("number", 11, "general.sub") + ds.set("number2", 2, "general.sub") + ds.set("number", 3, "general.sub3") + ds.set("number", 1, "general") assert ds.list_all_scopes() == ['general', 'general.sub', 'general.sub3', 'general2'] def test_search_scope(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub") == ["number", "number1"] def test_search_empty_scope(self, ds): - ds.put("number", 22, "general2") - ds.put("number", 11, "general.sub") + ds.set("number", 22, "general2") + ds.set("number", 11, "general.sub") with pytest.raises(EmptyScope) as e: ds.search_scope("general.sub2") assert "Given scope general.sub2 is not part of the data store." in e.value.args[0] assert "Available scopes are: ['general.sub', 'general2']" in e.value.args[0] def test_list_all_names(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.list_all_names() == ["number", "number1", "number2"] def test_search_scope_and_all_superiors(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub", current_scope_only=False) == ["number", "number1"] assert ds.search_scope("general.sub.sub", current_scope_only=False) == ["number", "number1", "number2"] def test_search_scope_return_all(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub", return_all=True) == [("number", "general.sub", 11), ("number1", "general.sub", 22)] def test_search_scope_and_all_superiors_return_all(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") - ds.put("number", "ABC", "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") + ds.set("number", "ABC", "general.sub.sub") assert ds.search_scope("general.sub", current_scope_only=False, return_all=True) == \ [("number", "general.sub", 11), ("number1", "general.sub", 22)] assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \ diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py index 0c05e8bf..d370dea5 100644 --- a/test/test_model_modules/test_model_class.py +++ b/test/test_model_modules/test_model_class.py @@ -11,19 +11,19 @@ class TestAbstractModelClass: return AbstractModelClass() def test_init(self, amc): - assert amc._model is None - assert amc._loss is None + assert amc.__model is None + assert amc.__loss is None def test_model_property(self, amc): - amc._model = keras.Model() + amc.__model = keras.Model() assert isinstance(amc.model, keras.Model) is True def test_loss_property(self, amc): - amc._loss = keras.losses.mean_absolute_error + amc.__loss = keras.losses.mean_absolute_error assert amc.loss == keras.losses.mean_absolute_error def test_getattr(self, amc): - amc._model = keras.Model() + amc.__model = keras.Model() assert hasattr(amc, "compile") is True assert hasattr(amc.model, "compile") is True assert amc.compile == amc.model.compile diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index 85cb24e3..95a242b0 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -25,9 +25,9 @@ class TestModelSetup: @pytest.fixture def setup_with_gen(self, setup, gen): - setup.data_store.put("generator", gen, "general.train") - setup.data_store.put("window_history_size", gen.window_history_size, "general") - setup.data_store.put("window_lead_time", gen.window_lead_time, "general") + setup.data_store.set("generator", gen, "general.train") + setup.data_store.set("window_history_size", gen.window_history_size, "general") + setup.data_store.set("window_lead_time", gen.window_lead_time, "general") yield setup RunEnvironment().__del__() diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index c884b146..34c27ff1 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -15,11 +15,11 @@ class TestPreProcessing: def obj_super_init(self): obj = object.__new__(PreProcessing) super(PreProcessing, obj).__init__() - obj.data_store.put("NAME1", 1, "general") - obj.data_store.put("NAME2", 2, "general") - obj.data_store.put("NAME3", 3, "general") - obj.data_store.put("NAME1", 10, "general.sub") - obj.data_store.put("NAME4", 4, "general.sub.sub") + obj.data_store.set("NAME1", 1, "general") + obj.data_store.set("NAME2", 2, "general") + obj.data_store.set("NAME3", 3, "general") + obj.data_store.set("NAME1", 10, "general.sub") + obj.data_store.set("NAME4", 4, "general.sub.sub") yield obj RunEnvironment().__del__() @@ -58,7 +58,7 @@ class TestPreProcessing: def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) - obj_with_exp_setup.data_store.put("use_all_stations_on_all_data_sets", False, "general.awesome") + obj_with_exp_setup.data_store.set("use_all_stations_on_all_data_sets", False, "general.awesome") obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") data_store = obj_with_exp_setup.data_store diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index ddb301c6..e6e3571d 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -49,15 +49,15 @@ class TestTraining: obj.checkpoint = checkpoint obj.lr_sc = LearningRateDecay() obj.experiment_name = "TestExperiment" - obj.data_store.put("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train") - obj.data_store.put("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val") - obj.data_store.put("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test") + obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train") + obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val") + obj.data_store.set("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test") os.makedirs(path) - obj.data_store.put("experiment_path", path, "general") - obj.data_store.put("experiment_name", "TestExperiment", "general") + obj.data_store.set("experiment_path", path, "general") + obj.data_store.set("experiment_name", "TestExperiment", "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) - obj.data_store.put("plot_path", path_plot, "general") + obj.data_store.set("plot_path", path_plot, "general") yield obj if os.path.exists(path): shutil.rmtree(path) @@ -112,9 +112,9 @@ class TestTraining: @pytest.fixture def ready_to_run(self, generator, init_without_run): obj = init_without_run - obj.data_store.put("generator", generator, "general.train") - obj.data_store.put("generator", generator, "general.val") - obj.data_store.put("generator", generator, "general.test") + obj.data_store.set("generator", generator, "general.train") + obj.data_store.set("generator", generator, "general.val") + obj.data_store.set("generator", generator, "general.test") obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) return obj @@ -122,20 +122,20 @@ class TestTraining: def ready_to_init(self, generator, model, checkpoint, path): os.makedirs(path) obj = RunEnvironment() - obj.data_store.put("generator", generator, "general.train") - obj.data_store.put("generator", generator, "general.val") - obj.data_store.put("generator", generator, "general.test") + obj.data_store.set("generator", generator, "general.train") + obj.data_store.set("generator", generator, "general.val") + obj.data_store.set("generator", generator, "general.test") model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) - obj.data_store.put("model", model, "general.model") - obj.data_store.put("batch_size", 256, "general.model") - obj.data_store.put("epochs", 2, "general.model") - obj.data_store.put("checkpoint", checkpoint, "general.model") - obj.data_store.put("lr_decay", LearningRateDecay(), "general.model") - obj.data_store.put("experiment_name", "TestExperiment", "general") - obj.data_store.put("experiment_path", path, "general") + obj.data_store.set("model", model, "general.model") + obj.data_store.set("batch_size", 256, "general.model") + obj.data_store.set("epochs", 2, "general.model") + obj.data_store.set("checkpoint", checkpoint, "general.model") + obj.data_store.set("lr_decay", LearningRateDecay(), "general.model") + obj.data_store.set("experiment_name", "TestExperiment", "general") + obj.data_store.set("experiment_path", path, "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) - obj.data_store.put("plot_path", path_plot, "general") + obj.data_store.set("plot_path", path_plot, "general") yield obj if os.path.exists(path): shutil.rmtree(path) -- GitLab