Skip to content
Snippets Groups Projects
Commit f6ad4736 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

update model load fkts

parent ce8616af
No related branches found
No related tags found
No related merge requests found
Pipeline #86987 failed
......@@ -37,7 +37,13 @@ class AbstractModelClass(ABC):
self.__compile_options_is_set = False
self._input_shape = input_shape
self._output_shape = self.__extract_from_tuple(output_shape)
# self.avail_gpus = len(K.tensorflow_backend._get_available_gpus())
def load_model(self, name: str, compile: bool = False) -> None:
hist = self.model.history
self.model.load_weights(name)
self.model.history = hist
if compile is True:
self.model.compile(**self.compile_options)
def __getattr__(self, name: str) -> Any:
"""
......
......@@ -454,12 +454,12 @@ class IntelliO3TsArchitecture(AbstractModelClass):
kernel_regularizer=self.regularizer
)
model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main])
if self.avail_gpus <= 1:
self.model = model
else:
self.model = keras.utils.multi_gpu_model(model, self.avail_gpus)
print(f"Set multi_gpu model with {self.avail_gpus} GPUs")
self.model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main])
# if self.avail_gpus <= 1:
# self.model = model
# else:
# self.model = keras.utils.multi_gpu_model(model, self.avail_gpus)
# print(f"Set multi_gpu model with {self.avail_gpus} GPUs")
def set_compile_options(self):
self.compile_options = {"optimizer": keras.optimizers.Adam(lr=self.initial_lr, amsgrad=True),
......@@ -762,6 +762,11 @@ class MyUnet(AbstractModelClass):
self.compile_options = {"metrics": ["mse", "mae"]}
class NN3s(MyUnet):
def __init__(self, input_shape: list, output_shape: list):
super().__init__(input_shape, output_shape)
class MySimpleConv2D(AbstractModelClass):
"""
Example adopted from https://www.kaggle.com/dimitreoliveira/deep-learning-for-time-series-forecasting
......
......@@ -84,7 +84,7 @@ class ModelSetup(RunEnvironment):
# load weights if no training shall be performed
if not self._train_model and not self._create_new_model:
self.load_weights()
self.load_model()
# create checkpoint
self._set_callbacks()
......@@ -131,13 +131,13 @@ class ModelSetup(RunEnvironment):
save_best_only=True, mode='auto')
self.data_store.set("callbacks", callbacks, self.scope)
def load_weights(self):
"""Try to load weights from existing model or skip if not possible."""
def load_model(self):
"""Try to load model from disk or skip if not possible."""
try:
self.model.load_weights(self.model_name)
logging.info(f"reload weights from model {self.model_name} ...")
self.model.load_model(self.model_name)
logging.info(f"reload model {self.model_name} from disk ...")
except OSError:
logging.info('no weights to reload...')
logging.info('no local model to load...')
def build_model(self):
"""Build model using input and output shapes from data store."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment