diff --git a/mlair/helpers/__init__.py b/mlair/helpers/__init__.py index 4671334c16267be819ab8ee0ad96b7135ee01531..bb30a594fca5b5b161571d2b3485b48467018900 100644 --- a/mlair/helpers/__init__.py +++ b/mlair/helpers/__init__.py @@ -3,4 +3,4 @@ from .testing import PyTestRegex, PyTestAllEqual from .time_tracking import TimeTracking, TimeTrackingWrapper from .logger import Logger -from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict +from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, make_keras_pickable diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 4cc7310db32c2ef3bbdb9f70896a2f8455a974fc..16b36921773a4af131065063de23963a56cb4c65 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -12,6 +12,43 @@ import dask.array as da from typing import Dict, Callable, Union, List, Any, Tuple +from tensorflow.keras.models import Model +from tensorflow.python.keras.layers import deserialize, serialize +from tensorflow.python.keras.saving import saving_utils + +""" +The following code is copied from: https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883 +and is a hotfix to make keras.model.model models serializable/pickable +""" + + +def unpack(model, training_config, weights): + restored_model = deserialize(model) + if training_config is not None: + restored_model.compile( + **saving_utils.compile_args_from_training_config( + training_config + ) + ) + restored_model.set_weights(weights) + return restored_model + +# Hotfix function +def make_keras_pickable(): + + def __reduce__(self): + model_metadata = saving_utils.model_metadata(self) + training_config = model_metadata.get("training_config", None) + model = serialize(self) + weights = self.get_weights() + return (unpack, (model, training_config, weights)) + + cls = Model + cls.__reduce__ = __reduce__ + + +" end of hotfix " + def to_list(obj: Any) -> List: """ diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index 6cd79abe2212294095caea60f551d0288d74f431..8898a6b2d0591328f2bb7010ccbfe144a48ca40b 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Callable import tensorflow.keras as keras import tensorflow as tf -from mlair.helpers import remove_items +from mlair.helpers import remove_items, make_keras_pickable class AbstractModelClass(ABC): @@ -21,6 +21,7 @@ class AbstractModelClass(ABC): def __init__(self, input_shape, output_shape) -> None: """Predefine internal attributes for model and loss.""" + make_keras_pickable() self.__model = None self.model_name = self.__class__.__name__ self.__custom_objects = {} diff --git a/run.py b/run.py index 954f8532f9f1260921133ebe7f588a523181b780..49537c8fe595984114cf80bf250bda7d71de4f67 100644 --- a/run.py +++ b/run.py @@ -22,7 +22,7 @@ def load_stations(): def main(parser_args): - tf.compat.v1.disable_v2_behavior() + # tf.compat.v1.disable_v2_behavior() plots = remove_items(DEFAULT_PLOT_LIST, ["PlotConditionalQuantiles", "PlotPeriodogram"]) workflow = DefaultWorkflow( # stations=load_stations(), # stations=["DEBW087","DEBW013", "DEBW107", "DEBW076"],