diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index 692c7d8c3fcc2afce967c1f1ee380c769bbdff02..6cd79abe2212294095caea60f551d0288d74f431 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -139,6 +139,8 @@ class AbstractModelClass(ABC): for allow_k in self.__allowed_compile_options.keys(): if hasattr(self, allow_k): new_v_attr = getattr(self, allow_k) + if new_v_attr == list(): + new_v_attr = None else: new_v_attr = None if isinstance(value, dict): @@ -147,8 +149,10 @@ class AbstractModelClass(ABC): new_v_dic = None else: raise TypeError(f"`compile_options' must be `dict' or `None', but is {type(value)}.") - if (new_v_attr == new_v_dic or self.__compare_keras_optimizers(new_v_attr, new_v_dic)) or ( - (new_v_attr is None) ^ (new_v_dic is None)): + ## self.__compare_keras_optimizers() foremost disabled, because it does not work as expected + #if (new_v_attr == new_v_dic or self.__compare_keras_optimizers(new_v_attr, new_v_dic)) or ( + # (new_v_attr is None) ^ (new_v_dic is None)): + if (new_v_attr == new_v_dic) or ((new_v_attr is None) ^ (new_v_dic is None)): if new_v_attr is not None: self.__compile_options[allow_k] = new_v_attr else: @@ -171,18 +175,22 @@ class AbstractModelClass(ABC): :return True if optimisers are interchangeable, or False if optimisers are distinguishable. """ - if first.__class__ == second.__class__ and first.__module__ == 'keras.optimizers': - res = True - init = tf.compat.v1.global_variables_initializer() - with tf.compat.v1.Session() as sess: - sess.run(init) - for k, v in first.__dict__.items(): - try: - res *= sess.run(v) == sess.run(second.__dict__[k]) - except TypeError: - res *= v == second.__dict__[k] - else: + if isinstance(list, type(second)): res = False + else: + if first.__class__ == second.__class__ and '.'.join( + first.__module__.split('.')[0:4]) == 'tensorflow.python.keras.optimizer_v2': + res = True + init = tf.compat.v1.global_variables_initializer() + with tf.compat.v1.Session() as sess: + sess.run(init) + for k, v in first.__dict__.items(): + try: + res *= sess.run(v) == sess.run(second.__dict__[k]) + except TypeError: + res *= v == second.__dict__[k] + else: + res = False return bool(res) def get_settings(self) -> Dict: diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 62c8d19dc0b1560da6452f5e365a5fdb588ed6e0..27dd444531ba253c7bf7ae996bbea7d15318d32e 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -99,7 +99,7 @@ class Training(RunEnvironment): workers. To prevent this, the function is pre-compiled. See discussion @ https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252 """ - self.model._make_predict_function() + self.model.make_predict_function() def _set_gen(self, mode: str) -> None: """