From 3b62aa5779fb5b75068a0df4166205d6a2ea9e9e Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Fri, 2 Jul 2021 08:46:10 +0200
Subject: [PATCH] update gpu count and tests

---
 mlair/model_modules/abstract_model_class.py          |  2 +-
 test/test_model_modules/test_abstract_model_class.py | 12 +++++++++---
 2 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py
index 8af85892..5248b263 100644
--- a/mlair/model_modules/abstract_model_class.py
+++ b/mlair/model_modules/abstract_model_class.py
@@ -37,7 +37,7 @@ class AbstractModelClass(ABC):
         self.__compile_options_is_set = False
         self._input_shape = input_shape
         self._output_shape = self.__extract_from_tuple(output_shape)
-        self.avail_gpus = K.tensorflow_backend._get_available_gpus()
+        self.avail_gpus = len(K.tensorflow_backend._get_available_gpus())
 
     def __getattr__(self, name: str) -> Any:
         """
diff --git a/test/test_model_modules/test_abstract_model_class.py b/test/test_model_modules/test_abstract_model_class.py
index dfef68d5..ddc3e2e5 100644
--- a/test/test_model_modules/test_abstract_model_class.py
+++ b/test/test_model_modules/test_abstract_model_class.py
@@ -2,6 +2,8 @@ import keras
 import pytest
 
 from mlair import AbstractModelClass
+from keras import backend as K
+
 
 
 class Paddings:
@@ -25,6 +27,10 @@ class TestAbstractModelClass:
     def amsc(self):
         return AbstractModelSubClass()
 
+    @pytest.fixture
+    def num_avail_gpus(self):
+        return len(K.tensorflow_backend._get_available_gpus())
+
     def test_init(self, amc):
         assert amc.model is None
         # assert amc.loss is None
@@ -179,11 +185,11 @@ class TestAbstractModelClass:
         assert hasattr(amc.model, "compile") is True
         assert amc.compile == amc.model.compile
 
-    def test_get_settings(self, amc, amsc):
+    def test_get_settings(self, amc, amsc, num_avail_gpus):
         assert amc.get_settings() == {"model_name": "AbstractModelClass", "_input_shape": (14, 1, 2),
-                                      "_output_shape": 3}
+                                      "_output_shape": 3, "avail_gpus": num_avail_gpus}
         assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass",
-                                       "_input_shape": (12, 1, 2), "_output_shape": 3}
+                                       "_input_shape": (12, 1, 2), "_output_shape": 3, "avail_gpus": num_avail_gpus}
 
     def test_custom_objects(self, amc):
         amc.custom_objects = {"Test": 123}
-- 
GitLab