diff --git a/mlair/model_modules/flatten.py b/mlair/model_modules/flatten.py index dd1e8e21eeb96f75372add0208b03dc06f5dc25c..98a55bfcfbe51ff0757479704f8e30738f7db705 100644 --- a/mlair/model_modules/flatten.py +++ b/mlair/model_modules/flatten.py @@ -3,7 +3,7 @@ __date__ = '2019-12-02' from typing import Union, Callable -import keras +import tensorflow.keras as keras def get_activation(input_to_activate: keras.layers, activation: Union[Callable, str], **kwargs): diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index 9a0e97dbd1f3a3a52f5717c88d09702e5d0d7928..be4f4b22715d8a8e75cd52b9f819cb391c15b354 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -120,12 +120,12 @@ import mlair.model_modules.keras_extensions __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2020-05-12' -import keras +import tensorflow.keras as keras from mlair.model_modules import AbstractModelClass -from mlair.model_modules.inception_model import InceptionModelBase +#from mlair.model_modules.inception_model import InceptionModelBase from mlair.model_modules.flatten import flatten_tail -from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D +#from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D from mlair.model_modules.loss import l_p_loss diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 27dd444531ba253c7bf7ae996bbea7d15318d32e..cb538abbbcae2f1c4afdad70c8f621746fc26fbb 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -137,14 +137,14 @@ class Training(RunEnvironment): checkpoint = self.callbacks.get_checkpoint() if not os.path.exists(checkpoint.filepath) or self._create_new_model: - history = self.model.fit_generator(generator=self.train_set, - steps_per_epoch=len(self.train_set), - epochs=self.epochs, - verbose=2, - validation_data=self.val_set, - validation_steps=len(self.val_set), - callbacks=self.callbacks.get_callbacks(as_dict=False), - workers=psutil.cpu_count(logical=False)) + history = self.model.fit(self.train_set, + steps_per_epoch=len(self.train_set), + epochs=self.epochs, + verbose=2, + validation_data=self.val_set, + validation_steps=len(self.val_set), + callbacks=self.callbacks.get_callbacks(as_dict=False), + workers=psutil.cpu_count(logical=False)) else: logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") self.callbacks.load_callbacks() @@ -152,15 +152,15 @@ class Training(RunEnvironment): self.model = keras.models.load_model(checkpoint.filepath) hist: History = self.callbacks.get_callback_by_name("hist") initial_epoch = max(hist.epoch) + 1 - _ = self.model.fit_generator(generator=self.train_set, - steps_per_epoch=len(self.train_set), - epochs=self.epochs, - verbose=2, - validation_data=self.val_set, - validation_steps=len(self.val_set), - callbacks=self.callbacks.get_callbacks(as_dict=False), - initial_epoch=initial_epoch, - workers=psutil.cpu_count(logical=False)) + _ = self.model.fit(self.train_set, + steps_per_epoch=len(self.train_set), + epochs=self.epochs, + verbose=2, + validation_data=self.val_set, + validation_steps=len(self.val_set), + callbacks=self.callbacks.get_callbacks(as_dict=False), + initial_epoch=initial_epoch, + workers=psutil.cpu_count(logical=False)) history = hist try: lr = self.callbacks.get_callback_by_name("lr") diff --git a/run.py b/run.py index fbe6aa262a31d8902f5722699e787a93f8488c12..954f8532f9f1260921133ebe7f588a523181b780 100644 --- a/run.py +++ b/run.py @@ -3,6 +3,7 @@ __date__ = '2020-06-29' import argparse from mlair.workflows import DefaultWorkflow +from mlair.model_modules.model_class import MyLittleModelHourly as chosen_model from mlair.helpers import remove_items from mlair.configuration.defaults import DEFAULT_PLOT_LIST import os @@ -28,7 +29,7 @@ def main(parser_args): stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"], train_model=False, create_new_model=True, network="UBA", evaluate_bootstraps=False, # plot_list=["PlotCompetitiveSkillScore"], - competitors=["test_model", "test_model2"], + competitors=["test_model", "test_model2"], model=chosen_model, competitor_path=os.path.join(os.getcwd(), "data", "comp_test"), **parser_args.__dict__, start_script=__file__) workflow.run()