Skip to content
Snippets Groups Projects
Commit a0d24483 authored by Falco Weichselbaum's avatar Falco Weichselbaum
Browse files

make_keras_pickable hotfix added from:...

make_keras_pickable hotfix added from: https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883. It is placed into helpers.py and runs in __init__() of AbstractModelClass to always be called before any model configuration.
parent 9d87b9c6
Branches
Tags
3 merge requests!413update release branch,!412Resolve "release v2.0.0",!335Resolve "upgrade code to TensorFlow V2"
Pipeline #81176 failed
......@@ -3,4 +3,4 @@
from .testing import PyTestRegex, PyTestAllEqual
from .time_tracking import TimeTracking, TimeTrackingWrapper
from .logger import Logger
from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict
from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, make_keras_pickable
......@@ -12,6 +12,43 @@ import dask.array as da
from typing import Dict, Callable, Union, List, Any, Tuple
from tensorflow.keras.models import Model
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils
"""
The following code is copied from: https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883
and is a hotfix to make keras.model.model models serializable/pickable
"""
def unpack(model, training_config, weights):
restored_model = deserialize(model)
if training_config is not None:
restored_model.compile(
**saving_utils.compile_args_from_training_config(
training_config
)
)
restored_model.set_weights(weights)
return restored_model
# Hotfix function
def make_keras_pickable():
def __reduce__(self):
model_metadata = saving_utils.model_metadata(self)
training_config = model_metadata.get("training_config", None)
model = serialize(self)
weights = self.get_weights()
return (unpack, (model, training_config, weights))
cls = Model
cls.__reduce__ = __reduce__
" end of hotfix "
def to_list(obj: Any) -> List:
"""
......
......@@ -5,7 +5,7 @@ from typing import Any, Dict, Callable
import tensorflow.keras as keras
import tensorflow as tf
from mlair.helpers import remove_items
from mlair.helpers import remove_items, make_keras_pickable
class AbstractModelClass(ABC):
......@@ -21,6 +21,7 @@ class AbstractModelClass(ABC):
def __init__(self, input_shape, output_shape) -> None:
"""Predefine internal attributes for model and loss."""
make_keras_pickable()
self.__model = None
self.model_name = self.__class__.__name__
self.__custom_objects = {}
......
......@@ -22,7 +22,7 @@ def load_stations():
def main(parser_args):
tf.compat.v1.disable_v2_behavior()
# tf.compat.v1.disable_v2_behavior()
plots = remove_items(DEFAULT_PLOT_LIST, ["PlotConditionalQuantiles", "PlotPeriodogram"])
workflow = DefaultWorkflow( # stations=load_stations(),
# stations=["DEBW087","DEBW013", "DEBW107", "DEBW076"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment