From 3b62aa5779fb5b75068a0df4166205d6a2ea9e9e Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Fri, 2 Jul 2021 08:46:10 +0200 Subject: [PATCH] update gpu count and tests --- mlair/model_modules/abstract_model_class.py | 2 +- test/test_model_modules/test_abstract_model_class.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index 8af85892..5248b263 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -37,7 +37,7 @@ class AbstractModelClass(ABC): self.__compile_options_is_set = False self._input_shape = input_shape self._output_shape = self.__extract_from_tuple(output_shape) - self.avail_gpus = K.tensorflow_backend._get_available_gpus() + self.avail_gpus = len(K.tensorflow_backend._get_available_gpus()) def __getattr__(self, name: str) -> Any: """ diff --git a/test/test_model_modules/test_abstract_model_class.py b/test/test_model_modules/test_abstract_model_class.py index dfef68d5..ddc3e2e5 100644 --- a/test/test_model_modules/test_abstract_model_class.py +++ b/test/test_model_modules/test_abstract_model_class.py @@ -2,6 +2,8 @@ import keras import pytest from mlair import AbstractModelClass +from keras import backend as K + class Paddings: @@ -25,6 +27,10 @@ class TestAbstractModelClass: def amsc(self): return AbstractModelSubClass() + @pytest.fixture + def num_avail_gpus(self): + return len(K.tensorflow_backend._get_available_gpus()) + def test_init(self, amc): assert amc.model is None # assert amc.loss is None @@ -179,11 +185,11 @@ class TestAbstractModelClass: assert hasattr(amc.model, "compile") is True assert amc.compile == amc.model.compile - def test_get_settings(self, amc, amsc): + def test_get_settings(self, amc, amsc, num_avail_gpus): assert amc.get_settings() == {"model_name": "AbstractModelClass", "_input_shape": (14, 1, 2), - "_output_shape": 3} + "_output_shape": 3, "avail_gpus": num_avail_gpus} assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass", - "_input_shape": (12, 1, 2), "_output_shape": 3} + "_input_shape": (12, 1, 2), "_output_shape": 3, "avail_gpus": num_avail_gpus} def test_custom_objects(self, amc): amc.custom_objects = {"Test": 123} -- GitLab