From 9d87b9c6697c7077858fcea5957b4378fb339763 Mon Sep 17 00:00:00 2001
From: Falco Weichselbaum <f.weichselbaum@fz-juelich.de>
Date: Wed, 20 Oct 2021 19:49:25 +0200
Subject: [PATCH] deactivated not working advanced_paddings.py import, changed
 some other keras imports to tensorflow.keras as keras,
 *.model.fit_generator() changed to *.model.fit() which is said to have the
 same functionality - Commit suffers from PicklingError with RLocks when
 trying to save Callback in Epoch 0002

---
 mlair/model_modules/flatten.py     |  2 +-
 mlair/model_modules/model_class.py |  6 +++---
 mlair/run_modules/training.py      | 34 +++++++++++++++---------------
 run.py                             |  3 ++-
 4 files changed, 23 insertions(+), 22 deletions(-)

diff --git a/mlair/model_modules/flatten.py b/mlair/model_modules/flatten.py
index dd1e8e21..98a55bfc 100644
--- a/mlair/model_modules/flatten.py
+++ b/mlair/model_modules/flatten.py
@@ -3,7 +3,7 @@ __date__ = '2019-12-02'
 
 from typing import Union, Callable
 
-import keras
+import tensorflow.keras as keras
 
 
 def get_activation(input_to_activate: keras.layers, activation: Union[Callable, str], **kwargs):
diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py
index 9a0e97db..be4f4b22 100644
--- a/mlair/model_modules/model_class.py
+++ b/mlair/model_modules/model_class.py
@@ -120,12 +120,12 @@ import mlair.model_modules.keras_extensions
 __author__ = "Lukas Leufen, Felix Kleinert"
 __date__ = '2020-05-12'
 
-import keras
+import tensorflow.keras as keras
 
 from mlair.model_modules import AbstractModelClass
-from mlair.model_modules.inception_model import InceptionModelBase
+#from mlair.model_modules.inception_model import InceptionModelBase
 from mlair.model_modules.flatten import flatten_tail
-from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D
+#from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D
 from mlair.model_modules.loss import l_p_loss
 
 
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 27dd4445..cb538abb 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -137,14 +137,14 @@ class Training(RunEnvironment):
 
         checkpoint = self.callbacks.get_checkpoint()
         if not os.path.exists(checkpoint.filepath) or self._create_new_model:
-            history = self.model.fit_generator(generator=self.train_set,
-                                               steps_per_epoch=len(self.train_set),
-                                               epochs=self.epochs,
-                                               verbose=2,
-                                               validation_data=self.val_set,
-                                               validation_steps=len(self.val_set),
-                                               callbacks=self.callbacks.get_callbacks(as_dict=False),
-                                               workers=psutil.cpu_count(logical=False))
+            history = self.model.fit(self.train_set,
+                                     steps_per_epoch=len(self.train_set),
+                                     epochs=self.epochs,
+                                     verbose=2,
+                                     validation_data=self.val_set,
+                                     validation_steps=len(self.val_set),
+                                     callbacks=self.callbacks.get_callbacks(as_dict=False),
+                                     workers=psutil.cpu_count(logical=False))
         else:
             logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
             self.callbacks.load_callbacks()
@@ -152,15 +152,15 @@ class Training(RunEnvironment):
             self.model = keras.models.load_model(checkpoint.filepath)
             hist: History = self.callbacks.get_callback_by_name("hist")
             initial_epoch = max(hist.epoch) + 1
-            _ = self.model.fit_generator(generator=self.train_set,
-                                         steps_per_epoch=len(self.train_set),
-                                         epochs=self.epochs,
-                                         verbose=2,
-                                         validation_data=self.val_set,
-                                         validation_steps=len(self.val_set),
-                                         callbacks=self.callbacks.get_callbacks(as_dict=False),
-                                         initial_epoch=initial_epoch,
-                                         workers=psutil.cpu_count(logical=False))
+            _ = self.model.fit(self.train_set,
+                               steps_per_epoch=len(self.train_set),
+                               epochs=self.epochs,
+                               verbose=2,
+                               validation_data=self.val_set,
+                               validation_steps=len(self.val_set),
+                               callbacks=self.callbacks.get_callbacks(as_dict=False),
+                               initial_epoch=initial_epoch,
+                               workers=psutil.cpu_count(logical=False))
             history = hist
         try:
             lr = self.callbacks.get_callback_by_name("lr")
diff --git a/run.py b/run.py
index fbe6aa26..954f8532 100644
--- a/run.py
+++ b/run.py
@@ -3,6 +3,7 @@ __date__ = '2020-06-29'
 
 import argparse
 from mlair.workflows import DefaultWorkflow
+from mlair.model_modules.model_class import MyLittleModelHourly as chosen_model
 from mlair.helpers import remove_items
 from mlair.configuration.defaults import DEFAULT_PLOT_LIST
 import os
@@ -28,7 +29,7 @@ def main(parser_args):
         stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"],
         train_model=False, create_new_model=True, network="UBA",
         evaluate_bootstraps=False,  # plot_list=["PlotCompetitiveSkillScore"],
-        competitors=["test_model", "test_model2"],
+        competitors=["test_model", "test_model2"], model=chosen_model,
         competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
         **parser_args.__dict__, start_script=__file__)
     workflow.run()
-- 
GitLab