Skip to content
Snippets Groups Projects
Commit 3b62aa57 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

update gpu count and tests

parent 612f6eaf
No related branches found
No related tags found
1 merge request!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #71966 passed
...@@ -37,7 +37,7 @@ class AbstractModelClass(ABC): ...@@ -37,7 +37,7 @@ class AbstractModelClass(ABC):
self.__compile_options_is_set = False self.__compile_options_is_set = False
self._input_shape = input_shape self._input_shape = input_shape
self._output_shape = self.__extract_from_tuple(output_shape) self._output_shape = self.__extract_from_tuple(output_shape)
self.avail_gpus = K.tensorflow_backend._get_available_gpus() self.avail_gpus = len(K.tensorflow_backend._get_available_gpus())
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
""" """
......
...@@ -2,6 +2,8 @@ import keras ...@@ -2,6 +2,8 @@ import keras
import pytest import pytest
from mlair import AbstractModelClass from mlair import AbstractModelClass
from keras import backend as K
class Paddings: class Paddings:
...@@ -25,6 +27,10 @@ class TestAbstractModelClass: ...@@ -25,6 +27,10 @@ class TestAbstractModelClass:
def amsc(self): def amsc(self):
return AbstractModelSubClass() return AbstractModelSubClass()
@pytest.fixture
def num_avail_gpus(self):
return len(K.tensorflow_backend._get_available_gpus())
def test_init(self, amc): def test_init(self, amc):
assert amc.model is None assert amc.model is None
# assert amc.loss is None # assert amc.loss is None
...@@ -179,11 +185,11 @@ class TestAbstractModelClass: ...@@ -179,11 +185,11 @@ class TestAbstractModelClass:
assert hasattr(amc.model, "compile") is True assert hasattr(amc.model, "compile") is True
assert amc.compile == amc.model.compile assert amc.compile == amc.model.compile
def test_get_settings(self, amc, amsc): def test_get_settings(self, amc, amsc, num_avail_gpus):
assert amc.get_settings() == {"model_name": "AbstractModelClass", "_input_shape": (14, 1, 2), assert amc.get_settings() == {"model_name": "AbstractModelClass", "_input_shape": (14, 1, 2),
"_output_shape": 3} "_output_shape": 3, "avail_gpus": num_avail_gpus}
assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass", assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass",
"_input_shape": (12, 1, 2), "_output_shape": 3} "_input_shape": (12, 1, 2), "_output_shape": 3, "avail_gpus": num_avail_gpus}
def test_custom_objects(self, amc): def test_custom_objects(self, amc):
amc.custom_objects = {"Test": 123} amc.custom_objects = {"Test": 123}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment