Skip to content
Snippets Groups Projects
Commit f9d943c4 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

update IntelliO3 model

parent 0a105e02
No related branches found
No related tags found
1 merge request!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
......@@ -6,6 +6,7 @@ import keras
import tensorflow as tf
from mlair.helpers import remove_items
from keras import backend as K
class AbstractModelClass(ABC):
......@@ -36,6 +37,7 @@ class AbstractModelClass(ABC):
self.__compile_options_is_set = False
self._input_shape = input_shape
self._output_shape = self.__extract_from_tuple(output_shape)
self.avail_gpus = K.tensorflow_backend._get_available_gpus()
def __getattr__(self, name: str) -> Any:
"""
......
......@@ -350,7 +350,7 @@ class MyTowerModel(AbstractModelClass):
self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse"]}
class IntelliO3_ts_architecture(AbstractModelClass):
class IntelliO3TsArchitecture(AbstractModelClass):
def __init__(self, input_shape: list, output_shape: list):
"""
......@@ -454,7 +454,11 @@ class IntelliO3_ts_architecture(AbstractModelClass):
kernel_regularizer=self.regularizer
)
self.model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main])
model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main])
if self.avail_gpus == 0:
self.model = model
else:
self.model = keras.utils.multi_gpu_model(model, self.avail_gpus)
def set_compile_options(self):
self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment