Skip to content
Snippets Groups Projects
Commit 3ebbabba authored by lukas leufen's avatar lukas leufen
Browse files

include bug fix, /close #154

parents 95226859 3ebd31de
Branches
Tags
3 merge requests!146Develop,!145Resolve "new release v0.12.0",!138Resolve "Advanced Documentation"
......@@ -23,13 +23,13 @@ How to create a customised model?
class MyCustomisedModel(AbstractModelClass):
def __init__(self, window_history_size, window_lead_time, channels):
super.__init__()
def __init__(self, shape_inputs: list, shape_outputs: list):
super().__init__(shape_inputs[0], shape_outputs[0])
# settings
self.window_history_size = window_history_size
self.window_lead_time = window_lead_time
self.channels = channels
self.dropout_rate = 0.1
self.activation = keras.layers.PReLU
# apply to model
self.set_model()
......@@ -49,14 +49,14 @@ How to create a customised model?
class MyCustomisedModel(AbstractModelClass):
def set_model(self):
x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels))
x_input = keras.layers.Input(shape=self.shape_inputs)
x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input)
x_in = self.activation(name='{}_conv_act'.format("major"))(x_in)
x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in)
x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in)
x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in)
x_in = self.activation()(x_in)
x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in)
x_in = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("major"))(x_in)
out_main = self.activation()(x_in)
self.model = keras.Model(inputs=x_input, outputs=[out_main])
......@@ -153,6 +153,7 @@ class AbstractModelClass(ABC):
'target_tensors': None
}
self.__compile_options = self.__allowed_compile_options
self.__compile_options_is_set = False
self.shape_inputs = shape_inputs
self.shape_outputs = self.__extract_from_tuple(shape_outputs)
......@@ -204,7 +205,8 @@ class AbstractModelClass(ABC):
def compile_options(self) -> Callable:
"""
The compile options property allows the user to use all keras.compile() arguments. They can ether be passed as
dictionary (1), as attribute, with compile_options=None (2) or as mixture of both of them (3).
dictionary (1), as attribute, without setting compile_options (2) or as mixture (partly defined as instance
attributes and partly parsing a dictionary) of both of them (3).
The method will raise an Error when the same parameter is set differently.
Example (1) Recommended (includes check for valid keywords which are used as args in keras.compile)
......@@ -220,7 +222,6 @@ class AbstractModelClass(ABC):
self.optimizer = keras.optimizers.SGD()
self.loss = keras.losses.mean_squared_error
self.metrics = ["mse", "mae"]
self.compile_options = None # make sure to use this line
Example (3)
Correct:
......@@ -245,6 +246,8 @@ class AbstractModelClass(ABC):
:return:
"""
if self.__compile_options_is_set is False:
self.compile_options = None
return self.__compile_options
@compile_options.setter
......@@ -274,6 +277,7 @@ class AbstractModelClass(ABC):
else:
raise ValueError(
f"Got different values or arguments for same argument: self.{allow_k}={new_v_attr.__class__} and '{allow_k}': {new_v_dic.__class__}")
self.__compile_options_is_set = True
@staticmethod
def __extract_from_tuple(tup):
......@@ -282,6 +286,11 @@ class AbstractModelClass(ABC):
@staticmethod
def __compare_keras_optimizers(first, second):
"""
Compares if optimiser and all settings of the optimisers are exactly equal.
:return True if optimisers are interchangeable, or False if optimisers are distinguishable.
"""
if first.__class__ == second.__class__ and first.__module__ == 'keras.optimizers':
res = True
init = tf.global_variables_initializer()
......@@ -688,3 +697,8 @@ class MyPaperModel(AbstractModelClass):
self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
self.compile_options = {"loss": [keras.losses.mean_squared_error, keras.losses.mean_squared_error],
"metrics": ['mse', 'mea']}
if __name__ == "__main__":
model = MyLittleModel([(1, 3, 10)], [2])
print(model.compile_options)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment