From 9e6497c27afdd30afe97d3bf5593cec502cf022e Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 13 Dec 2019 14:40:37 +0100
Subject: [PATCH] another test update

---
 test/test_model_modules/test_model_class.py | 10 ++++----
 test/test_modules/test_model_setup.py       | 26 +++++++++++++--------
 2 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py
index d370dea5..0af16012 100644
--- a/test/test_model_modules/test_model_class.py
+++ b/test/test_model_modules/test_model_class.py
@@ -11,19 +11,19 @@ class TestAbstractModelClass:
         return AbstractModelClass()
 
     def test_init(self, amc):
-        assert amc.__model is None
-        assert amc.__loss is None
+        assert amc.model is None
+        assert amc.loss is None
 
     def test_model_property(self, amc):
-        amc.__model = keras.Model()
+        amc.model = keras.Model()
         assert isinstance(amc.model, keras.Model) is True
 
     def test_loss_property(self, amc):
-        amc.__loss = keras.losses.mean_absolute_error
+        amc.loss = keras.losses.mean_absolute_error
         assert amc.loss == keras.losses.mean_absolute_error
 
     def test_getattr(self, amc):
-        amc.__model = keras.Model()
+        amc.model = keras.Model()
         assert hasattr(amc, "compile") is True
         assert hasattr(amc.model, "compile") is True
         assert amc.compile == amc.model.compile
diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py
index 5a5a7bbd..ca750304 100644
--- a/test/test_modules/test_model_setup.py
+++ b/test/test_modules/test_model_setup.py
@@ -31,16 +31,22 @@ class TestModelSetup:
         setup.data_store.set("generator", gen, "general.train")
         setup.data_store.set("window_history_size", gen.window_history_size, "general")
         setup.data_store.set("window_lead_time", gen.window_lead_time, "general")
+        setup.data_store.set("channels", 2, "general")
         yield setup
         RunEnvironment().__del__()
 
     @pytest.fixture
-    def setup_with_model(self, setup_with_gen):
-        setup_with_gen.data_store.set("channels", 2, "general")
-        setup_with_gen.model = AbstractModelClass()
-        setup_with_gen.model.epochs = 2
-        setup_with_gen.model.batch_size = int(256)
-        yield setup_with_gen
+    def setup_with_gen_tiny(self, setup, gen):
+        setup.data_store.set("generator", gen, "general.train")
+        yield setup
+        RunEnvironment().__del__()
+
+    @pytest.fixture
+    def setup_with_model(self, setup):
+        setup.model = AbstractModelClass()
+        setup.model.epochs = 2
+        setup.model.batch_size = int(256)
+        yield setup
         RunEnvironment().__del__()
 
     @staticmethod
@@ -67,10 +73,10 @@ class TestModelSetup:
                     "optimizer", "lr_decay", "epochs", "batch_size", "activation"}
         assert expected <= self.current_scope_as_set(setup_with_gen)
 
-    def test_set_channels(self, setup_with_gen):
-        assert len(setup_with_gen.data_store.search_name("channels")) == 0
-        setup_with_gen._set_channels()
-        assert setup_with_gen.data_store.get("channels", setup_with_gen.scope) == 2
+    def test_set_channels(self, setup_with_gen_tiny):
+        assert len(setup_with_gen_tiny.data_store.search_name("channels")) == 0
+        setup_with_gen_tiny._set_channels()
+        assert setup_with_gen_tiny.data_store.get("channels", setup_with_gen_tiny.scope) == 2
 
     def test_load_weights(self):
         pass
-- 
GitLab