diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index 8af858923207acdc4c471fec0ae7d03f990291b8..5248b2634666a9405e37a09ac01f93daa739a228 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -37,7 +37,7 @@ class AbstractModelClass(ABC): self.__compile_options_is_set = False self._input_shape = input_shape self._output_shape = self.__extract_from_tuple(output_shape) - self.avail_gpus = K.tensorflow_backend._get_available_gpus() + self.avail_gpus = len(K.tensorflow_backend._get_available_gpus()) def __getattr__(self, name: str) -> Any: """ diff --git a/test/test_model_modules/test_abstract_model_class.py b/test/test_model_modules/test_abstract_model_class.py index dfef68d550b07f824ed38e5c7809c00e5386d115..ddc3e2e5b87c5eeb1d5fe07d171cea0ade796507 100644 --- a/test/test_model_modules/test_abstract_model_class.py +++ b/test/test_model_modules/test_abstract_model_class.py @@ -2,6 +2,8 @@ import keras import pytest from mlair import AbstractModelClass +from keras import backend as K + class Paddings: @@ -25,6 +27,10 @@ class TestAbstractModelClass: def amsc(self): return AbstractModelSubClass() + @pytest.fixture + def num_avail_gpus(self): + return len(K.tensorflow_backend._get_available_gpus()) + def test_init(self, amc): assert amc.model is None # assert amc.loss is None @@ -179,11 +185,11 @@ class TestAbstractModelClass: assert hasattr(amc.model, "compile") is True assert amc.compile == amc.model.compile - def test_get_settings(self, amc, amsc): + def test_get_settings(self, amc, amsc, num_avail_gpus): assert amc.get_settings() == {"model_name": "AbstractModelClass", "_input_shape": (14, 1, 2), - "_output_shape": 3} + "_output_shape": 3, "avail_gpus": num_avail_gpus} assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass", - "_input_shape": (12, 1, 2), "_output_shape": 3} + "_input_shape": (12, 1, 2), "_output_shape": 3, "avail_gpus": num_avail_gpus} def test_custom_objects(self, amc): amc.custom_objects = {"Test": 123}