Skip to content
Snippets Groups Projects
Commit a0d24483 authored by Falco Weichselbaum's avatar Falco Weichselbaum
Browse files

make_keras_pickable hotfix added from:...

make_keras_pickable hotfix added from: https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883. It is placed into helpers.py and runs in __init__() of AbstractModelClass to always be called before any model configuration.
parent 9d87b9c6
No related branches found
No related tags found
3 merge requests!413update release branch,!412Resolve "release v2.0.0",!335Resolve "upgrade code to TensorFlow V2"
Pipeline #81176 failed
...@@ -3,4 +3,4 @@ ...@@ -3,4 +3,4 @@
from .testing import PyTestRegex, PyTestAllEqual from .testing import PyTestRegex, PyTestAllEqual
from .time_tracking import TimeTracking, TimeTrackingWrapper from .time_tracking import TimeTracking, TimeTrackingWrapper
from .logger import Logger from .logger import Logger
from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, make_keras_pickable
...@@ -12,6 +12,43 @@ import dask.array as da ...@@ -12,6 +12,43 @@ import dask.array as da
from typing import Dict, Callable, Union, List, Any, Tuple from typing import Dict, Callable, Union, List, Any, Tuple
from tensorflow.keras.models import Model
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils
"""
The following code is copied from: https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883
and is a hotfix to make keras.model.model models serializable/pickable
"""
def unpack(model, training_config, weights):
restored_model = deserialize(model)
if training_config is not None:
restored_model.compile(
**saving_utils.compile_args_from_training_config(
training_config
)
)
restored_model.set_weights(weights)
return restored_model
# Hotfix function
def make_keras_pickable():
def __reduce__(self):
model_metadata = saving_utils.model_metadata(self)
training_config = model_metadata.get("training_config", None)
model = serialize(self)
weights = self.get_weights()
return (unpack, (model, training_config, weights))
cls = Model
cls.__reduce__ = __reduce__
" end of hotfix "
def to_list(obj: Any) -> List: def to_list(obj: Any) -> List:
""" """
......
...@@ -5,7 +5,7 @@ from typing import Any, Dict, Callable ...@@ -5,7 +5,7 @@ from typing import Any, Dict, Callable
import tensorflow.keras as keras import tensorflow.keras as keras
import tensorflow as tf import tensorflow as tf
from mlair.helpers import remove_items from mlair.helpers import remove_items, make_keras_pickable
class AbstractModelClass(ABC): class AbstractModelClass(ABC):
...@@ -21,6 +21,7 @@ class AbstractModelClass(ABC): ...@@ -21,6 +21,7 @@ class AbstractModelClass(ABC):
def __init__(self, input_shape, output_shape) -> None: def __init__(self, input_shape, output_shape) -> None:
"""Predefine internal attributes for model and loss.""" """Predefine internal attributes for model and loss."""
make_keras_pickable()
self.__model = None self.__model = None
self.model_name = self.__class__.__name__ self.model_name = self.__class__.__name__
self.__custom_objects = {} self.__custom_objects = {}
......
...@@ -22,7 +22,7 @@ def load_stations(): ...@@ -22,7 +22,7 @@ def load_stations():
def main(parser_args): def main(parser_args):
tf.compat.v1.disable_v2_behavior() # tf.compat.v1.disable_v2_behavior()
plots = remove_items(DEFAULT_PLOT_LIST, ["PlotConditionalQuantiles", "PlotPeriodogram"]) plots = remove_items(DEFAULT_PLOT_LIST, ["PlotConditionalQuantiles", "PlotPeriodogram"])
workflow = DefaultWorkflow( # stations=load_stations(), workflow = DefaultWorkflow( # stations=load_stations(),
# stations=["DEBW087","DEBW013", "DEBW107", "DEBW076"], # stations=["DEBW087","DEBW013", "DEBW107", "DEBW076"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment