diff --git a/src/datastore.py b/src/datastore.py index 5f0df67573dd510fdc4f04d0cc632b36c5082959..d14ae07d70f44b6c62987ee8c3fcbce8eee0ce46 100644 --- a/src/datastore.py +++ b/src/datastore.py @@ -38,7 +38,7 @@ class AbstractDataStore(ABC): # empty initialise the data-store variables self._store: Dict = {} - def put(self, name: str, obj: Any, scope: str) -> None: + def set(self, name: str, obj: Any, scope: str) -> None: """ Abstract method to add an object to the data store :param name: Name of object to store @@ -89,7 +89,7 @@ class AbstractDataStore(ABC): def clear_data_store(self) -> None: self._store = {} - def create_args_dict(self, arg_list: List[str], scope: str = "general"): + def create_args_dict(self, arg_list: List[str], scope: str = "general") -> Dict: args = {} for arg in arg_list: try: @@ -98,6 +98,10 @@ class AbstractDataStore(ABC): pass return args + def set_args_from_dict(self, arg_dict: Dict, scope: str = "general") -> None: + for (k, v) in arg_dict.items(): + self.set(k, v, scope) + class DataStoreByVariable(AbstractDataStore): @@ -115,7 +119,7 @@ class DataStoreByVariable(AbstractDataStore): <scope3>: value """ - def put(self, name: str, obj: Any, scope: str) -> None: + def set(self, name: str, obj: Any, scope: str) -> None: """ Store an object `obj` with given `name` under `scope`. In the current implementation, existing entries are overwritten. @@ -239,7 +243,7 @@ class DataStoreByScope(AbstractDataStore): <variable3>: value """ - def put(self, name: str, obj: Any, scope: str) -> None: + def set(self, name: str, obj: Any, scope: str) -> None: """ Store an object `obj` with given `name` under `scope`. In the current implementation, existing entries are overwritten. diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8f7c4c400eaf75bdd1dc6af2e0993f662eac49 --- /dev/null +++ b/src/model_modules/model_class.py @@ -0,0 +1,156 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-12-12' + + +from abc import ABC +from typing import Any, Callable + +import keras + +from src import helpers + + +class AbstractModelClass(ABC): + + """ + The AbstractModelClass provides a unified skeleton for any model provided to the machine learning workflow. The + model can always be accessed by calling ModelClass.model or directly by an model method without parsing the model + attribute name (e.g. ModelClass.model.compile -> ModelClass.compile). Beside the model, this class provides the + corresponding loss function. + """ + + def __init__(self) -> None: + + """ + Predefine internal attributes for model and loss. + """ + + self.__model = None + self.__loss = None + + def __getattr__(self, name: str) -> Any: + + """ + Is called if __getattribute__ is not able to find requested attribute. Normally, the model class is saved into + a variable like `model = ModelClass()`. To bypass a call like `model.model` to access the _model attribute, + this method tries to search for the named attribute in the self.model namespace and returns this attribute if + available. Therefore, following expression is true: `ModelClass().compile == ModelClass().model.compile` as long + the called attribute/method is not part if the ModelClass itself. + :param name: name of the attribute or method to call + :return: attribute or method from self.model namespace + """ + + return self.model.__getattribute__(name) + + @property + def model(self) -> keras.Model: + + """ + The model property containing a keras.Model instance. + :return: the keras model + """ + + return self.__model + + @model.setter + def model(self, value): + self.__model = value + + @property + def loss(self) -> Callable: + + """ + The loss property containing a callable loss function. The loss function can be any keras loss or a customised + function. If the loss is a customised function, it must contain the internal loss(y_true, y_pred) function: + def customised_loss(args): + def loss(y_true, y_pred): + return actual_function(y_true, y_pred, args) + return loss + :return: the loss function + """ + + return self.__loss + + @loss.setter + def loss(self, value) -> None: + self.__loss = value + + def get_settings(self): + return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__")) + + +class MyLittleModel(AbstractModelClass): + + """ + A customised model with a 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the + output layer depending on the window_lead_time parameter. Dropout is used between the Convolution and the first + Dense layer. + """ + + def __init__(self, window_history_size, window_lead_time, channels): + + """ + Sets model and loss depending on the given arguments. + :param activation: activation function + :param window_history_size: number of historical time steps included in the input data + :param channels: number of variables used in input data + :param regularizer: <not used here> + :param dropout_rate: dropout rate used in the model [0, 1) + :param window_lead_time: number of time steps to forecast in the output layer + """ + + super().__init__() + + # settings + self.window_history_size = window_history_size + self.window_lead_time = window_lead_time + self.channels = channels + self.dropout_rate = 0.1 + self.regularizer = keras.regularizers.l2(0.1) + self.initial_lr = 1e-2 + self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) + self.lr_decay = helpers.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) + self.epochs = 2 + self.batch_size = int(256) + self.activation = keras.layers.PReLU + + # apply to model + self.set_model() + self.set_loss() + + def set_model(self): + + """ + Build the model. + :param activation: activation function + :param window_history_size: number of historical time steps included in the input data + :param channels: number of variables used in input data + :param dropout_rate: dropout rate used in the model [0, 1) + :param window_lead_time: number of time steps to forecast in the output layer + :return: built keras model + """ + + # add 1 to window_size to include current time step t0 + x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels)) + x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input) + x_in = self.activation(name='{}_conv_act'.format("major"))(x_in) + x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in) + x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in) + x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in) + x_in = self.activation()(x_in) + x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in) + x_in = self.activation()(x_in) + x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in) + x_in = self.activation()(x_in) + x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in) + out_main = self.activation()(x_in) + self.model = keras.Model(inputs=x_input, outputs=[out_main]) + + def set_loss(self): + + """ + Set the loss + :return: loss function + """ + + self.loss = keras.losses.mean_squared_error diff --git a/src/modules/experiment_setup.py b/src/modules/experiment_setup.py index 5fdc1f1f478f15619d7e21726e80699b0497d695..f726c66e116d3bf978281805915f571a50f0cf2f 100644 --- a/src/modules/experiment_setup.py +++ b/src/modules/experiment_setup.py @@ -93,7 +93,7 @@ class ExperimentSetup(RunEnvironment): def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: if value is None and default is not None: value = default - self.data_store.put(param, value, scope) + self.data_store.set(param, value, scope) logging.debug(f"set experiment attribute: {param}({scope})={value}") @staticmethod diff --git a/src/modules/model_setup.py b/src/modules/model_setup.py index f6c25aff51372f84ede5cda0884fd7c603ffaa6b..a62b53b86651109c4c1dd10d4a7dfccbaf3cf9c2 100644 --- a/src/modules/model_setup.py +++ b/src/modules/model_setup.py @@ -15,6 +15,7 @@ from src.modules.run_environment import RunEnvironment from src.helpers import l_p_loss, LearningRateDecay from src.inception_model import InceptionModelBase from src.flatten import flatten_tail +from src.model_modules.model_class import MyLittleModel class ModelSetup(RunEnvironment): @@ -35,8 +36,8 @@ class ModelSetup(RunEnvironment): # create checkpoint self._set_checkpoint() - # set all model settings - self.my_model_settings() + # set channels depending on inputs + self._set_channels() # build model graph using settings from my_model_settings() self.build_model() @@ -51,15 +52,19 @@ class ModelSetup(RunEnvironment): # compile model self.compile_model() + def _set_channels(self): + channels = self.data_store.get("generator", "general.train")[0][0].shape[-1] + self.data_store.set("channels", channels, self.scope) + def compile_model(self): optimizer = self.data_store.get("optimizer", self.scope) - loss = self.data_store.get("loss", self.scope) + loss = self.model.loss self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) - self.data_store.put("model", self.model, self.scope) + self.data_store.set("model", self.model, self.scope) def _set_checkpoint(self): checkpoint = ModelCheckpoint(self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') - self.data_store.put("checkpoint", checkpoint, self.scope) + self.data_store.set("checkpoint", checkpoint, self.scope) def load_weights(self): try: @@ -69,9 +74,14 @@ class ModelSetup(RunEnvironment): logging.info('no weights to reload...') def build_model(self): - args_list = ["activation", "window_history_size", "channels", "regularizer", "dropout_rate", "window_lead_time"] + args_list = ["window_history_size", "window_lead_time", "channels"] args = self.data_store.create_args_dict(args_list, self.scope) - self.model = my_little_model(**args) + self.model = MyLittleModel(**args) + self.get_model_settings() + + def get_model_settings(self): + model_settings = self.model.get_settings() + self.data_store.set_args_from_dict(model_settings, self.scope) def plot_model(self): # pragma: no cover with tf.device("/cpu:0"): @@ -80,39 +90,6 @@ class ModelSetup(RunEnvironment): file_name = os.path.join(path, name) keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) - def my_model_settings(self): - - # channels - X, _ = self.data_store.get("generator", "general.train")[0] - channels = X.shape[-1] # input variables - self.data_store.put("channels", channels, self.scope) - - # dropout - self.data_store.put("dropout_rate", 0.1, self.scope) - - # regularizer - self.data_store.put("regularizer", l2(0.1), self.scope) - - # learning rate - initial_lr = 1e-2 - self.data_store.put("initial_lr", initial_lr, self.scope) - optimizer = SGD(lr=initial_lr, momentum=0.9) - # optimizer=Adam(lr=initial_lr, amsgrad=True) - self.data_store.put("optimizer", optimizer, self.scope) - self.data_store.put("lr_decay", LearningRateDecay(base_lr=initial_lr, drop=.94, epochs_drop=10), self.scope) - - # learning settings - self.data_store.put("epochs", 2, self.scope) - self.data_store.put("batch_size", int(256), self.scope) - - # activation - activation = keras.layers.PReLU # ELU #LeakyReLU keras.activations.tanh # - self.data_store.put("activation", activation, self.scope) - - # set los - loss_all = my_little_loss() - self.data_store.put("loss", loss_all, self.scope) - def my_loss(): loss = l_p_loss(4) diff --git a/src/modules/pre_processing.py b/src/modules/pre_processing.py index 8fad9d1bf756baf830c236f16102878ba83515c2..cce9ee587c7a9b70b9cce8064cb4b77aa1bf3386 100644 --- a/src/modules/pre_processing.py +++ b/src/modules/pre_processing.py @@ -36,7 +36,7 @@ class PreProcessing(RunEnvironment): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST) kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST) valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general")) - self.data_store.put("stations", valid_stations, "general") + self.data_store.set("stations", valid_stations, "general") self.split_train_val_test() def report_pre_processing(self): @@ -87,10 +87,10 @@ class PreProcessing(RunEnvironment): set_stations = stations[index_list] logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") set_stations = self.check_valid_stations(args, kwargs, set_stations) - self.data_store.put("stations", set_stations, scope) + self.data_store.set("stations", set_stations, scope) set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) data_set = DataGenerator(**set_args, **kwargs) - self.data_store.put("generator", data_set, scope) + self.data_store.set("generator", data_set, scope) @staticmethod def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]): diff --git a/test/test_datastore.py b/test/test_datastore.py index 5ba76bf8fc9c21553723cba3b2125be2d758e23b..3f61c227be3a05d78f825eca77b0d6cbbc617ce1 100644 --- a/test/test_datastore.py +++ b/test/test_datastore.py @@ -24,95 +24,95 @@ class TestDataStoreByVariable: return DataStoreByVariable() def test_put(self, ds): - ds.put("number", 3, "general.subscope") + ds.set("number", 3, "general.subscope") assert ds._store["number"]["general.subscope"] == 3 def test_get(self, ds): - ds.put("number", 3, "general.subscope") + ds.set("number", 3, "general.subscope") assert ds.get("number", "general.subscope") == 3 def test_get_with_sub_scope(self, ds): - ds.put("number", 3, "general") - ds.put("number", 10, "general.subscope") + ds.set("number", 3, "general") + ds.set("number", 10, "general.subscope") assert ds.get("number", "general.subscope") == 10 assert ds.get("number", "general") == 3 def test_get_with_not_existing_sub_scope(self, ds): - ds.put("number", 3, "general") - ds.put("number2", 10, "general.subscope") - ds.put("number2", 1, "general") + ds.set("number", 3, "general") + ds.set("number2", 10, "general.subscope") + ds.set("number2", 1, "general") assert ds.get("number", "general.subscope") == 3 def test_raise_not_in_data_store(self, ds): - ds.put("number", 22, "general") + ds.set("number", 22, "general") with pytest.raises(NameNotFoundInDataStore) as e: ds.get("number3", "general") assert "Couldn't find number3 in data store" in e.value.args[0] def test_search(self, ds): - ds.put("number", 22, "general") - ds.put("number", 22, "general2") - ds.put("number", 22, "general.sub") + ds.set("number", 22, "general") + ds.set("number", 22, "general2") + ds.set("number", 22, "general.sub") assert ds.search_name("number") == ["general", "general.sub", "general2"] def test_raise_not_in_scope(self, ds): - ds.put("number", 11, "general.sub") + ds.set("number", 11, "general.sub") with pytest.raises(NameNotFoundInScope) as e: ds.get("number", "general.sub2") assert "Couldn't find number in scope general.sub2 . number is only defined in ['general.sub']" in e.value.args[0] def test_list_all_scopes(self, ds): - ds.put("number", 22, "general2") - ds.put("number", 11, "general.sub") - ds.put("number2", 2, "general.sub") - ds.put("number", 3, "general.sub3") - ds.put("number", 1, "general") + ds.set("number", 22, "general2") + ds.set("number", 11, "general.sub") + ds.set("number2", 2, "general.sub") + ds.set("number", 3, "general.sub3") + ds.set("number", 1, "general") assert ds.list_all_scopes() == ['general', 'general.sub', 'general.sub3', 'general2'] def test_search_scope(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub") == ["number", "number1"] def test_search_empty_scope(self, ds): - ds.put("number", 22, "general2") - ds.put("number", 11, "general.sub") + ds.set("number", 22, "general2") + ds.set("number", 11, "general.sub") with pytest.raises(EmptyScope) as e: ds.search_scope("general.sub2") assert "Given scope general.sub2 is not part of the data store." in e.value.args[0] assert "Available scopes are: ['general.sub', 'general2']" in e.value.args[0] def test_list_all_names(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.list_all_names() == ["number", "number1", "number2"] def test_search_scope_and_all_superiors(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub", current_scope_only=False) == ["number", "number1"] assert ds.search_scope("general.sub.sub", current_scope_only=False) == ["number", "number1", "number2"] def test_search_scope_return_all(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub", return_all=True) == [("number", "general.sub", 11), ("number1", "general.sub", 22)] def test_search_scope_and_all_superiors_return_all(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") - ds.put("number", "ABC", "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") + ds.set("number", "ABC", "general.sub.sub") assert ds.search_scope("general.sub", current_scope_only=False, return_all=True) == \ [("number", "general.sub", 11), ("number1", "general.sub", 22)] assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \ @@ -126,95 +126,95 @@ class TestDataStoreByScope: return DataStoreByScope() def test_put_with_scope(self, ds): - ds.put("number", 3, "general.subscope") + ds.set("number", 3, "general.subscope") assert ds._store["general.subscope"]["number"] == 3 def test_get(self, ds): - ds.put("number", 3, "general.subscope") + ds.set("number", 3, "general.subscope") assert ds.get("number", "general.subscope") == 3 def test_get_with_sub_scope(self, ds): - ds.put("number", 3, "general") - ds.put("number", 10, "general.subscope") + ds.set("number", 3, "general") + ds.set("number", 10, "general.subscope") assert ds.get("number", "general.subscope") == 10 assert ds.get("number", "general") == 3 def test_get_with_not_existing_sub_scope(self, ds): - ds.put("number", 3, "general") - ds.put("number2", 10, "general.subscope") - ds.put("number2", 1, "general") + ds.set("number", 3, "general") + ds.set("number2", 10, "general.subscope") + ds.set("number2", 1, "general") assert ds.get("number", "general.subscope") == 3 def test_raise_not_in_data_store(self, ds): - ds.put("number", 22, "general") + ds.set("number", 22, "general") with pytest.raises(NameNotFoundInDataStore) as e: ds.get("number3", "general") assert "Couldn't find number3 in data store" in e.value.args[0] def test_search(self, ds): - ds.put("number", 22, "general") - ds.put("number", 22, "general2") - ds.put("number", 22, "general.sub") + ds.set("number", 22, "general") + ds.set("number", 22, "general2") + ds.set("number", 22, "general.sub") assert ds.search_name("number") == ["general", "general.sub", "general2"] def test_raise_not_in_scope(self, ds): - ds.put("number", 11, "general.sub") + ds.set("number", 11, "general.sub") with pytest.raises(NameNotFoundInScope) as e: ds.get("number", "general.sub2") assert "Couldn't find number in scope general.sub2 . number is only defined in ['general.sub']" in e.value.args[0] def test_list_all_scopes(self, ds): - ds.put("number", 22, "general2") - ds.put("number", 11, "general.sub") - ds.put("number2", 2, "general.sub") - ds.put("number", 3, "general.sub3") - ds.put("number", 1, "general") + ds.set("number", 22, "general2") + ds.set("number", 11, "general.sub") + ds.set("number2", 2, "general.sub") + ds.set("number", 3, "general.sub3") + ds.set("number", 1, "general") assert ds.list_all_scopes() == ['general', 'general.sub', 'general.sub3', 'general2'] def test_search_scope(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub") == ["number", "number1"] def test_search_empty_scope(self, ds): - ds.put("number", 22, "general2") - ds.put("number", 11, "general.sub") + ds.set("number", 22, "general2") + ds.set("number", 11, "general.sub") with pytest.raises(EmptyScope) as e: ds.search_scope("general.sub2") assert "Given scope general.sub2 is not part of the data store." in e.value.args[0] assert "Available scopes are: ['general.sub', 'general2']" in e.value.args[0] def test_list_all_names(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.list_all_names() == ["number", "number1", "number2"] def test_search_scope_and_all_superiors(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub", current_scope_only=False) == ["number", "number1"] assert ds.search_scope("general.sub.sub", current_scope_only=False) == ["number", "number1", "number2"] def test_search_scope_return_all(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") assert ds.search_scope("general.sub", return_all=True) == [("number", "general.sub", 11), ("number1", "general.sub", 22)] def test_search_scope_and_all_superiors_return_all(self, ds): - ds.put("number", 22, "general") - ds.put("number", 11, "general.sub") - ds.put("number1", 22, "general.sub") - ds.put("number2", 3, "general.sub.sub") - ds.put("number", "ABC", "general.sub.sub") + ds.set("number", 22, "general") + ds.set("number", 11, "general.sub") + ds.set("number1", 22, "general.sub") + ds.set("number2", 3, "general.sub.sub") + ds.set("number", "ABC", "general.sub.sub") assert ds.search_scope("general.sub", current_scope_only=False, return_all=True) == \ [("number", "general.sub", 11), ("number1", "general.sub", 22)] assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \ diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py new file mode 100644 index 0000000000000000000000000000000000000000..0af16012336c9dbcc3133d7dac1f365e276d11bc --- /dev/null +++ b/test/test_model_modules/test_model_class.py @@ -0,0 +1,29 @@ +import pytest +import keras + +from src.model_modules.model_class import AbstractModelClass + + +class TestAbstractModelClass: + + @pytest.fixture + def amc(self): + return AbstractModelClass() + + def test_init(self, amc): + assert amc.model is None + assert amc.loss is None + + def test_model_property(self, amc): + amc.model = keras.Model() + assert isinstance(amc.model, keras.Model) is True + + def test_loss_property(self, amc): + amc.loss = keras.losses.mean_absolute_error + assert amc.loss == keras.losses.mean_absolute_error + + def test_getattr(self, amc): + amc.model = keras.Model() + assert hasattr(amc, "compile") is True + assert hasattr(amc.model, "compile") is True + assert amc.compile == amc.model.compile diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index 85cb24e3a31af69c8ae9735b4c85173318866339..ca7503040ecf8e45636fddebccdcebd7242dbaec 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -1,10 +1,13 @@ import pytest import os import keras +import mock from src.modules.model_setup import ModelSetup from src.modules.run_environment import RunEnvironment from src.data_handling.data_generator import DataGenerator +from src.model_modules.model_class import AbstractModelClass +from src.datastore import EmptyScope class TestModelSetup: @@ -25,29 +28,55 @@ class TestModelSetup: @pytest.fixture def setup_with_gen(self, setup, gen): - setup.data_store.put("generator", gen, "general.train") - setup.data_store.put("window_history_size", gen.window_history_size, "general") - setup.data_store.put("window_lead_time", gen.window_lead_time, "general") + setup.data_store.set("generator", gen, "general.train") + setup.data_store.set("window_history_size", gen.window_history_size, "general") + setup.data_store.set("window_lead_time", gen.window_lead_time, "general") + setup.data_store.set("channels", 2, "general") yield setup RunEnvironment().__del__() + @pytest.fixture + def setup_with_gen_tiny(self, setup, gen): + setup.data_store.set("generator", gen, "general.train") + yield setup + RunEnvironment().__del__() + + @pytest.fixture + def setup_with_model(self, setup): + setup.model = AbstractModelClass() + setup.model.epochs = 2 + setup.model.batch_size = int(256) + yield setup + RunEnvironment().__del__() + + @staticmethod + def current_scope_as_set(model_cls): + return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True)) + def test_set_checkpoint(self, setup): assert "general.modeltest" not in setup.data_store.search_name("checkpoint") setup.checkpoint_name = "TestName" setup._set_checkpoint() assert "general.modeltest" in setup.data_store.search_name("checkpoint") - def test_my_model_settings(self, setup_with_gen): - setup_with_gen.my_model_settings() - expected = {"channels", "dropout_rate", "regularizer", "initial_lr", "optimizer", "lr_decay", "epochs", - "batch_size", "activation", "loss"} - assert expected <= set(setup_with_gen.data_store.search_scope(setup_with_gen.scope, current_scope_only=True)) + def test_get_model_settings(self, setup_with_model): + with pytest.raises(EmptyScope): + self.current_scope_as_set(setup_with_model) # will fail because scope is not created + setup_with_model.get_model_settings() # this saves now the parameters epochs and batch_size into scope + assert {"epochs", "batch_size"} <= self.current_scope_as_set(setup_with_model) def test_build_model(self, setup_with_gen): - setup_with_gen.my_model_settings() assert setup_with_gen.model is None setup_with_gen.build_model() - assert isinstance(setup_with_gen.model, keras.Model) + assert isinstance(setup_with_gen.model, AbstractModelClass) + expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr", + "optimizer", "lr_decay", "epochs", "batch_size", "activation"} + assert expected <= self.current_scope_as_set(setup_with_gen) + + def test_set_channels(self, setup_with_gen_tiny): + assert len(setup_with_gen_tiny.data_store.search_name("channels")) == 0 + setup_with_gen_tiny._set_channels() + assert setup_with_gen_tiny.data_store.get("channels", setup_with_gen_tiny.scope) == 2 def test_load_weights(self): pass @@ -55,3 +84,9 @@ class TestModelSetup: def test_compile_model(self): pass + def test_run(self): + pass + + def test_init(self): + pass + diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index c884b14657447a50377dc38ec2dea10ba300f4d7..34c27ff1f08eaa3b223dab5d3bcc6e3cb9a09a97 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -15,11 +15,11 @@ class TestPreProcessing: def obj_super_init(self): obj = object.__new__(PreProcessing) super(PreProcessing, obj).__init__() - obj.data_store.put("NAME1", 1, "general") - obj.data_store.put("NAME2", 2, "general") - obj.data_store.put("NAME3", 3, "general") - obj.data_store.put("NAME1", 10, "general.sub") - obj.data_store.put("NAME4", 4, "general.sub.sub") + obj.data_store.set("NAME1", 1, "general") + obj.data_store.set("NAME2", 2, "general") + obj.data_store.set("NAME3", 3, "general") + obj.data_store.set("NAME1", 10, "general.sub") + obj.data_store.set("NAME4", 4, "general.sub.sub") yield obj RunEnvironment().__del__() @@ -58,7 +58,7 @@ class TestPreProcessing: def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) - obj_with_exp_setup.data_store.put("use_all_stations_on_all_data_sets", False, "general.awesome") + obj_with_exp_setup.data_store.set("use_all_stations_on_all_data_sets", False, "general.awesome") obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") data_store = obj_with_exp_setup.data_store diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index ddb301c685fa1a5dc46ba37251bc5068fe9c4a37..e6e3571d9beb03671a8e49f7f3988501b5eaa674 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -49,15 +49,15 @@ class TestTraining: obj.checkpoint = checkpoint obj.lr_sc = LearningRateDecay() obj.experiment_name = "TestExperiment" - obj.data_store.put("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train") - obj.data_store.put("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val") - obj.data_store.put("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test") + obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train") + obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val") + obj.data_store.set("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test") os.makedirs(path) - obj.data_store.put("experiment_path", path, "general") - obj.data_store.put("experiment_name", "TestExperiment", "general") + obj.data_store.set("experiment_path", path, "general") + obj.data_store.set("experiment_name", "TestExperiment", "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) - obj.data_store.put("plot_path", path_plot, "general") + obj.data_store.set("plot_path", path_plot, "general") yield obj if os.path.exists(path): shutil.rmtree(path) @@ -112,9 +112,9 @@ class TestTraining: @pytest.fixture def ready_to_run(self, generator, init_without_run): obj = init_without_run - obj.data_store.put("generator", generator, "general.train") - obj.data_store.put("generator", generator, "general.val") - obj.data_store.put("generator", generator, "general.test") + obj.data_store.set("generator", generator, "general.train") + obj.data_store.set("generator", generator, "general.val") + obj.data_store.set("generator", generator, "general.test") obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) return obj @@ -122,20 +122,20 @@ class TestTraining: def ready_to_init(self, generator, model, checkpoint, path): os.makedirs(path) obj = RunEnvironment() - obj.data_store.put("generator", generator, "general.train") - obj.data_store.put("generator", generator, "general.val") - obj.data_store.put("generator", generator, "general.test") + obj.data_store.set("generator", generator, "general.train") + obj.data_store.set("generator", generator, "general.val") + obj.data_store.set("generator", generator, "general.test") model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) - obj.data_store.put("model", model, "general.model") - obj.data_store.put("batch_size", 256, "general.model") - obj.data_store.put("epochs", 2, "general.model") - obj.data_store.put("checkpoint", checkpoint, "general.model") - obj.data_store.put("lr_decay", LearningRateDecay(), "general.model") - obj.data_store.put("experiment_name", "TestExperiment", "general") - obj.data_store.put("experiment_path", path, "general") + obj.data_store.set("model", model, "general.model") + obj.data_store.set("batch_size", 256, "general.model") + obj.data_store.set("epochs", 2, "general.model") + obj.data_store.set("checkpoint", checkpoint, "general.model") + obj.data_store.set("lr_decay", LearningRateDecay(), "general.model") + obj.data_store.set("experiment_name", "TestExperiment", "general") + obj.data_store.set("experiment_path", path, "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) - obj.data_store.put("plot_path", path_plot, "general") + obj.data_store.set("plot_path", path_plot, "general") yield obj if os.path.exists(path): shutil.rmtree(path)