Skip to content
Snippets Groups Projects
Commit 6c160572 authored by Falco Weichselbaum's avatar Falco Weichselbaum
Browse files

make_predict_function() changed to non-private version in training:102,...

make_predict_function() changed to non-private version in training:102, abstract_model_class:134+ disabled __compare_keras_optimizers and rewrite empty lists with None
parent 9b27fd76
No related branches found
No related tags found
3 merge requests!413update release branch,!412Resolve "release v2.0.0",!335Resolve "upgrade code to TensorFlow V2"
Pipeline #80688 failed
......@@ -139,6 +139,8 @@ class AbstractModelClass(ABC):
for allow_k in self.__allowed_compile_options.keys():
if hasattr(self, allow_k):
new_v_attr = getattr(self, allow_k)
if new_v_attr == list():
new_v_attr = None
else:
new_v_attr = None
if isinstance(value, dict):
......@@ -147,8 +149,10 @@ class AbstractModelClass(ABC):
new_v_dic = None
else:
raise TypeError(f"`compile_options' must be `dict' or `None', but is {type(value)}.")
if (new_v_attr == new_v_dic or self.__compare_keras_optimizers(new_v_attr, new_v_dic)) or (
(new_v_attr is None) ^ (new_v_dic is None)):
## self.__compare_keras_optimizers() foremost disabled, because it does not work as expected
#if (new_v_attr == new_v_dic or self.__compare_keras_optimizers(new_v_attr, new_v_dic)) or (
# (new_v_attr is None) ^ (new_v_dic is None)):
if (new_v_attr == new_v_dic) or ((new_v_attr is None) ^ (new_v_dic is None)):
if new_v_attr is not None:
self.__compile_options[allow_k] = new_v_attr
else:
......@@ -171,7 +175,11 @@ class AbstractModelClass(ABC):
:return True if optimisers are interchangeable, or False if optimisers are distinguishable.
"""
if first.__class__ == second.__class__ and first.__module__ == 'keras.optimizers':
if isinstance(list, type(second)):
res = False
else:
if first.__class__ == second.__class__ and '.'.join(
first.__module__.split('.')[0:4]) == 'tensorflow.python.keras.optimizer_v2':
res = True
init = tf.compat.v1.global_variables_initializer()
with tf.compat.v1.Session() as sess:
......
......@@ -99,7 +99,7 @@ class Training(RunEnvironment):
workers. To prevent this, the function is pre-compiled. See discussion @
https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252
"""
self.model._make_predict_function()
self.model.make_predict_function()
def _set_gen(self, mode: str) -> None:
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment