Skip to content
Snippets Groups Projects
Commit f9d943c4 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

update IntelliO3 model

parent 0a105e02
Branches
Tags
1 merge request!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
...@@ -6,6 +6,7 @@ import keras ...@@ -6,6 +6,7 @@ import keras
import tensorflow as tf import tensorflow as tf
from mlair.helpers import remove_items from mlair.helpers import remove_items
from keras import backend as K
class AbstractModelClass(ABC): class AbstractModelClass(ABC):
...@@ -36,6 +37,7 @@ class AbstractModelClass(ABC): ...@@ -36,6 +37,7 @@ class AbstractModelClass(ABC):
self.__compile_options_is_set = False self.__compile_options_is_set = False
self._input_shape = input_shape self._input_shape = input_shape
self._output_shape = self.__extract_from_tuple(output_shape) self._output_shape = self.__extract_from_tuple(output_shape)
self.avail_gpus = K.tensorflow_backend._get_available_gpus()
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
""" """
......
...@@ -350,7 +350,7 @@ class MyTowerModel(AbstractModelClass): ...@@ -350,7 +350,7 @@ class MyTowerModel(AbstractModelClass):
self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse"]} self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse"]}
class IntelliO3_ts_architecture(AbstractModelClass): class IntelliO3TsArchitecture(AbstractModelClass):
def __init__(self, input_shape: list, output_shape: list): def __init__(self, input_shape: list, output_shape: list):
""" """
...@@ -454,7 +454,11 @@ class IntelliO3_ts_architecture(AbstractModelClass): ...@@ -454,7 +454,11 @@ class IntelliO3_ts_architecture(AbstractModelClass):
kernel_regularizer=self.regularizer kernel_regularizer=self.regularizer
) )
self.model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main]) model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main])
if self.avail_gpus == 0:
self.model = model
else:
self.model = keras.utils.multi_gpu_model(model, self.avail_gpus)
def set_compile_options(self): def set_compile_options(self):
self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment