Skip to content
Snippets Groups Projects
Commit 1c57e71b authored by leufen1's avatar leufen1
Browse files

fix for uncompiled model when resuming training from epoch differing from 0

parent 2e4b29b1
No related branches found
No related tags found
5 merge requests!413update release branch,!412Resolve "release v2.0.0",!361name of pdf starts now with feature_importance, there is now also another...,!350Resolve "upgrade code to TensorFlow V2",!335Resolve "upgrade code to TensorFlow V2"
Pipeline #83362 failed
...@@ -38,10 +38,12 @@ class AbstractModelClass(ABC): ...@@ -38,10 +38,12 @@ class AbstractModelClass(ABC):
self._input_shape = input_shape self._input_shape = input_shape
self._output_shape = self.__extract_from_tuple(output_shape) self._output_shape = self.__extract_from_tuple(output_shape)
def load_model(self, name: str): def load_model(self, name: str, compile: bool = False):
hist = self.model.history hist = self.model.history
self.model = keras.models.load_model(name) self.model = keras.models.load_model(name)
self.model.history = hist self.model.history = hist
if compile is True:
self.model.compile(**self.compile_options)
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
""" """
......
...@@ -45,8 +45,9 @@ class PlotModelHistory: ...@@ -45,8 +45,9 @@ class PlotModelHistory:
self._additional_columns = self._filter_columns(history) self._additional_columns = self._filter_columns(history)
self._plot(filename) self._plot(filename)
@staticmethod def _get_plot_metric(self, history, plot_metric, main_branch, correct_names=True):
def _get_plot_metric(history, plot_metric, main_branch): _plot_metric = plot_metric
if correct_names is True:
if plot_metric.lower() == "mse": if plot_metric.lower() == "mse":
plot_metric = "mean_squared_error" plot_metric = "mean_squared_error"
elif plot_metric.lower() == "mae": elif plot_metric.lower() == "mae":
...@@ -54,6 +55,8 @@ class PlotModelHistory: ...@@ -54,6 +55,8 @@ class PlotModelHistory:
available_keys = [k for k in history.keys() if available_keys = [k for k in history.keys() if
plot_metric in k and ("main" in k.lower() if main_branch else True)] plot_metric in k and ("main" in k.lower() if main_branch else True)]
available_keys.sort(key=len) available_keys.sort(key=len)
if len(available_keys) == 0 and correct_names is True:
return self._get_plot_metric(history, _plot_metric, main_branch, correct_names=False)
return available_keys[0] return available_keys[0]
def _filter_columns(self, history: Dict) -> List[str]: def _filter_columns(self, history: Dict) -> List[str]:
......
...@@ -149,7 +149,7 @@ class Training(RunEnvironment): ...@@ -149,7 +149,7 @@ class Training(RunEnvironment):
logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
self.callbacks.load_callbacks() self.callbacks.load_callbacks()
self.callbacks.update_checkpoint() self.callbacks.update_checkpoint()
self.model.load_model(checkpoint.filepath) self.model.load_model(checkpoint.filepath, compile=True)
hist: History = self.callbacks.get_callback_by_name("hist") hist: History = self.callbacks.get_callback_by_name("hist")
initial_epoch = max(hist.epoch) + 1 initial_epoch = max(hist.epoch) + 1
_ = self.model.fit(self.train_set, _ = self.model.fit(self.train_set,
...@@ -190,8 +190,8 @@ class Training(RunEnvironment): ...@@ -190,8 +190,8 @@ class Training(RunEnvironment):
""" """
logging.debug(f"load best model: {name}") logging.debug(f"load best model: {name}")
try: try:
self.model.load_model(name) self.model.load_model(name, compile=True)
logging.info('reload weights...') logging.info('reload model...')
except OSError: except OSError:
logging.info('no weights to reload...') logging.info('no weights to reload...')
...@@ -236,9 +236,11 @@ class Training(RunEnvironment): ...@@ -236,9 +236,11 @@ class Training(RunEnvironment):
if multiple_branches_used: if multiple_branches_used:
filename = os.path.join(path, f"{name}_history_main_loss.pdf") filename = os.path.join(path, f"{name}_history_main_loss.pdf")
PlotModelHistory(filename=filename, history=history, main_branch=True) PlotModelHistory(filename=filename, history=history, main_branch=True)
if len([e for e in history.model.metrics_names if "mean_squared_error" in e]) > 0: mse_indicator = list(set(history.model.metrics_names).intersection(["mean_squared_error", "mse"]))
if len(mse_indicator) > 0:
filename = os.path.join(path, f"{name}_history_main_mse.pdf") filename = os.path.join(path, f"{name}_history_main_mse.pdf")
PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used) PlotModelHistory(filename=filename, history=history, plot_metric=mse_indicator[0],
main_branch=multiple_branches_used)
# plot learning rate # plot learning rate
if lr_sc: if lr_sc:
......
...@@ -4,8 +4,8 @@ __date__ = '2019-11-14' ...@@ -4,8 +4,8 @@ __date__ = '2019-11-14'
import argparse import argparse
from mlair.workflows import DefaultWorkflow from mlair.workflows import DefaultWorkflow
from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \ from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling
DataHandlerSeparationOfScales
stats = {'o3': 'dma8eu', 'no': 'dma8eu', 'no2': 'dma8eu', stats = {'o3': 'dma8eu', 'no': 'dma8eu', 'no2': 'dma8eu',
'relhum': 'average_values', 'u': 'average_values', 'v': 'average_values', 'relhum': 'average_values', 'u': 'average_values', 'v': 'average_values',
...@@ -20,7 +20,7 @@ data_origin = {'o3': '', 'no': '', 'no2': '', ...@@ -20,7 +20,7 @@ data_origin = {'o3': '', 'no': '', 'no2': '',
def main(parser_args): def main(parser_args):
args = dict(stations=["DEBW107", "DEBW013"], args = dict(stations=["DEBW107", "DEBW013"],
network="UBA", network="UBA",
evaluate_feature_importance=False, plot_list=[], evaluate_feature_importance=True, # plot_list=[],
data_origin=data_origin, data_handler=DataHandlerMixedSampling, data_origin=data_origin, data_handler=DataHandlerMixedSampling,
interpolation_limit=(3, 1), overwrite_local_data=False, interpolation_limit=(3, 1), overwrite_local_data=False,
sampling=("hourly", "daily"), sampling=("hourly", "daily"),
...@@ -28,8 +28,6 @@ def main(parser_args): ...@@ -28,8 +28,6 @@ def main(parser_args):
create_new_model=True, train_model=False, epochs=1, create_new_model=True, train_model=False, epochs=1,
window_history_size=6 * 24 + 16, window_history_size=6 * 24 + 16,
window_history_offset=16, window_history_offset=16,
kz_filter_length=[100 * 24, 15 * 24],
kz_filter_iter=[4, 5],
start="2006-01-01", start="2006-01-01",
train_start="2006-01-01", train_start="2006-01-01",
end="2011-12-31", end="2011-12-31",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment