From 9d87b9c6697c7077858fcea5957b4378fb339763 Mon Sep 17 00:00:00 2001 From: Falco Weichselbaum <f.weichselbaum@fz-juelich.de> Date: Wed, 20 Oct 2021 19:49:25 +0200 Subject: [PATCH] deactivated not working advanced_paddings.py import, changed some other keras imports to tensorflow.keras as keras, *.model.fit_generator() changed to *.model.fit() which is said to have the same functionality - Commit suffers from PicklingError with RLocks when trying to save Callback in Epoch 0002 --- mlair/model_modules/flatten.py | 2 +- mlair/model_modules/model_class.py | 6 +++--- mlair/run_modules/training.py | 34 +++++++++++++++--------------- run.py | 3 ++- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/mlair/model_modules/flatten.py b/mlair/model_modules/flatten.py index dd1e8e21..98a55bfc 100644 --- a/mlair/model_modules/flatten.py +++ b/mlair/model_modules/flatten.py @@ -3,7 +3,7 @@ __date__ = '2019-12-02' from typing import Union, Callable -import keras +import tensorflow.keras as keras def get_activation(input_to_activate: keras.layers, activation: Union[Callable, str], **kwargs): diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index 9a0e97db..be4f4b22 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -120,12 +120,12 @@ import mlair.model_modules.keras_extensions __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2020-05-12' -import keras +import tensorflow.keras as keras from mlair.model_modules import AbstractModelClass -from mlair.model_modules.inception_model import InceptionModelBase +#from mlair.model_modules.inception_model import InceptionModelBase from mlair.model_modules.flatten import flatten_tail -from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D +#from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D from mlair.model_modules.loss import l_p_loss diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 27dd4445..cb538abb 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -137,14 +137,14 @@ class Training(RunEnvironment): checkpoint = self.callbacks.get_checkpoint() if not os.path.exists(checkpoint.filepath) or self._create_new_model: - history = self.model.fit_generator(generator=self.train_set, - steps_per_epoch=len(self.train_set), - epochs=self.epochs, - verbose=2, - validation_data=self.val_set, - validation_steps=len(self.val_set), - callbacks=self.callbacks.get_callbacks(as_dict=False), - workers=psutil.cpu_count(logical=False)) + history = self.model.fit(self.train_set, + steps_per_epoch=len(self.train_set), + epochs=self.epochs, + verbose=2, + validation_data=self.val_set, + validation_steps=len(self.val_set), + callbacks=self.callbacks.get_callbacks(as_dict=False), + workers=psutil.cpu_count(logical=False)) else: logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") self.callbacks.load_callbacks() @@ -152,15 +152,15 @@ class Training(RunEnvironment): self.model = keras.models.load_model(checkpoint.filepath) hist: History = self.callbacks.get_callback_by_name("hist") initial_epoch = max(hist.epoch) + 1 - _ = self.model.fit_generator(generator=self.train_set, - steps_per_epoch=len(self.train_set), - epochs=self.epochs, - verbose=2, - validation_data=self.val_set, - validation_steps=len(self.val_set), - callbacks=self.callbacks.get_callbacks(as_dict=False), - initial_epoch=initial_epoch, - workers=psutil.cpu_count(logical=False)) + _ = self.model.fit(self.train_set, + steps_per_epoch=len(self.train_set), + epochs=self.epochs, + verbose=2, + validation_data=self.val_set, + validation_steps=len(self.val_set), + callbacks=self.callbacks.get_callbacks(as_dict=False), + initial_epoch=initial_epoch, + workers=psutil.cpu_count(logical=False)) history = hist try: lr = self.callbacks.get_callback_by_name("lr") diff --git a/run.py b/run.py index fbe6aa26..954f8532 100644 --- a/run.py +++ b/run.py @@ -3,6 +3,7 @@ __date__ = '2020-06-29' import argparse from mlair.workflows import DefaultWorkflow +from mlair.model_modules.model_class import MyLittleModelHourly as chosen_model from mlair.helpers import remove_items from mlair.configuration.defaults import DEFAULT_PLOT_LIST import os @@ -28,7 +29,7 @@ def main(parser_args): stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"], train_model=False, create_new_model=True, network="UBA", evaluate_bootstraps=False, # plot_list=["PlotCompetitiveSkillScore"], - competitors=["test_model", "test_model2"], + competitors=["test_model", "test_model2"], model=chosen_model, competitor_path=os.path.join(os.getcwd(), "data", "comp_test"), **parser_args.__dict__, start_script=__file__) workflow.run() -- GitLab