From 778eec7bd2b8669c0e07d874efecf015085e6e03 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 25 Nov 2019 16:42:43 +0100
Subject: [PATCH] updated tests

---
 src/modules/experiment_setup.py            |  17 +--
 test/test_modules/test_experiment_setup.py | 117 ++++++++++++++++++++-
 2 files changed, 127 insertions(+), 7 deletions(-)

diff --git a/src/modules/experiment_setup.py b/src/modules/experiment_setup.py
index 1e20871e..0c9d9ed1 100644
--- a/src/modules/experiment_setup.py
+++ b/src/modules/experiment_setup.py
@@ -27,10 +27,10 @@ class ExperimentSetup(RunEnvironment):
     trainable: Train new model if true, otherwise try to load existing model
     """
 
-    def __init__(self, parser_args, var_all_dict=None, stations=None, network=None, variables=None, target_var="o3",
+    def __init__(self, parser_args=None, var_all_dict=None, stations=None, network=None, variables=None, target_var="o3",
                  target_dim=None, dimensions=None, interpolate_dim=None, train_start=None, train_end=None,
                  val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True,
-                 trainable=False, fraction_of_train=None):
+                 trainable=False, fraction_of_train=None, experiment_path=None):
 
         # create run framework
         super().__init__()
@@ -42,7 +42,7 @@ class ExperimentSetup(RunEnvironment):
 
         # set experiment name
         exp_date = self._get_parser_args(parser_args).get("experiment_date")
-        exp_name, exp_path = helpers.set_experiment_name(experiment_date=exp_date)
+        exp_name, exp_path = helpers.set_experiment_name(experiment_date=exp_date, experiment_path=experiment_path)
         self._set_param("experiment_name", exp_name)
         self._set_param("experiment_path", exp_path)
         helpers.check_path_and_create(self.data_store.get("experiment_path", "general"))
@@ -83,15 +83,20 @@ class ExperimentSetup(RunEnvironment):
         logging.debug(f"set experiment attribute: {param}({scope})={value}")
 
     @staticmethod
-    def _get_parser_args(args: Union[Dict, argparse.ArgumentParser]) -> Dict:
+    def _get_parser_args(args: Union[Dict, argparse.Namespace, argparse.ArgumentParser]) -> Dict:
         """
         Transform args to dict if given as argparse.Namespace
         :param args: either a dictionary or an argument parser instance
         :return: dictionary with all arguments
         """
-        if isinstance(args, argparse.Namespace):
+        if isinstance(args, argparse.ArgumentParser):
+            return args.parse_args().__dict__
+        elif isinstance(args, argparse.Namespace):
             return args.__dict__
-        return args
+        elif isinstance(args, dict):
+            return args
+        else:
+            return {}
 
 
 if __name__ == "__main__":
diff --git a/test/test_modules/test_experiment_setup.py b/test/test_modules/test_experiment_setup.py
index 65ec052b..cf634b2c 100644
--- a/test/test_modules/test_experiment_setup.py
+++ b/test/test_modules/test_experiment_setup.py
@@ -1,4 +1,119 @@
-
 import pytest
+import logging
+import argparse
+import os
+
+from src.modules.experiment_setup import ExperimentSetup
+from src.helpers import TimeTracking, prepare_host
+from src.datastore import NameNotFoundInScope, NameNotFoundInDataStore
+
+
+class TestExperimentSetup:
+
+    @pytest.fixture
+    def empty_obj(self, caplog):
+        caplog.set_level(logging.DEBUG)
+        obj = object.__new__(ExperimentSetup)
+        obj.time = TimeTracking()
+        return obj
+
+    def test_set_param_by_value(self, caplog, empty_obj):
+        empty_obj._set_param("23tester", 23)
+        assert caplog.record_tuples[-1] == ('root', 10, 'set experiment attribute: 23tester(general)=23')
+        assert empty_obj.data_store.get("23tester", "general") == 23
+
+    def test_set_param_by_value_and_scope(self, caplog, empty_obj):
+        empty_obj._set_param("109tester", 109, "general.testing")
+        assert empty_obj.data_store.get("109tester", "general.tester") == 109
+
+    def test_set_param_with_default(self, caplog, empty_obj):
+        empty_obj._set_param("NoneTester", None, "notNone", "general.testing")
+        assert empty_obj.data_store.get("NoneTester", "general.testing") == "notNone"
+        empty_obj._set_param("AnotherNoneTester", None)
+        assert empty_obj.data_store.get("AnotherNoneTester", "general") is None
+
+    def test_get_parser_args_from_dict(self, empty_obj):
+        res = empty_obj._get_parser_args({'test2': 2, 'test10str': "10"})
+        assert res == {'test2': 2, 'test10str': "10"}
+
+    def test_get_parser_args_from_argparser(self, empty_obj):
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--experiment_date', type=str, nargs=1, default="TODAY")
+        assert empty_obj._get_parser_args(parser) == {"experiment_date": "TODAY"}
+
+    def test_get_parser_args_from_parse_args(self, empty_obj):
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--experiment_date', type=str, nargs=1, default="TOMORROW")
+        parser_args = parser.parse_args()
+        assert empty_obj._get_parser_args(parser_args) == {"experiment_date": "TOMORROW"}
 
+    def test_init_default(self):
+        exp_setup = ExperimentSetup()
+        data_store = exp_setup.data_store
+        assert data_store.get("data_path", "general") == prepare_host()
+        assert data_store.get("trainable", "general") is False
+        assert data_store.get("fraction_of_train", "general") == 0.8
+        assert data_store.get("experiment_name", "general") == "TestExperiment"
+        path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment"))
+        assert data_store.get("experiment_path", "general") == path
+        default_var_all_dict = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
+                                'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
+                                'pblheight': 'maximum'}
+        assert data_store.get("var_all_dict", "general") == default_var_all_dict
+        default_stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022',
+                            'DEBY004', 'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039',
+                            'DEBW038', 'DEBW081', 'DEBY075', 'DEBW040', 'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072',
+                            'DEBW042', 'DEBW039', 'DEBY001', 'DEBY113', 'DEBY089', 'DEBW024', 'DEBW004', 'DEBY037',
+                            'DEBW056', 'DEBW029', 'DEBY068', 'DEBW010', 'DEBW026', 'DEBY002', 'DEBY079', 'DEBW084',
+                            'DEBY049', 'DEBY031', 'DEBW019', 'DEBW001', 'DEBY063', 'DEBY005', 'DEBW046', 'DEBW103',
+                            'DEBW052', 'DEBW034', 'DEBY088', ]
+        assert data_store.get("stations", "general") == default_stations
+        assert data_store.get("network", "general") == "AIRBASE"
+        assert data_store.get("variables", "general") == list(default_var_all_dict.keys())
+        assert data_store.get("target_var", "general") == "o3"
+        assert data_store.get("target_dim", "general") == "variables"
+        assert data_store.get("dimensions", "general") == {'new_index': ['datetime', 'Stations']}
+        assert data_store.get("interpolate_dim", "general") == "datetime"
+        with pytest.raises(NameNotFoundInScope):
+            data_store.get("start", "general")
+        with pytest.raises(NameNotFoundInScope):
+            data_store.get("end", "general")
+        assert data_store.get("start", "general.train") == "1997-01-01"
+        assert data_store.get("end", "general.train") == "2007-12-31"
+        assert data_store.get("start", "general.val") == "2008-01-01"
+        assert data_store.get("end", "general.val") == "2009-12-31"
+        assert data_store.get("start", "general.test") == "2010-01-01"
+        assert data_store.get("end", "general.test") == "2017-12-31"
 
+    def test_init_no_default(self):
+        experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
+        kwargs = dict(parser_args={"experiment_date": "TODAY"},
+                      var_all_dict={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'},
+                      stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", variables=["o3", "temp"],
+                      target_var="temp", target_dim="target", dimensions="dim1", interpolate_dim="int_dim",
+                      train_start="2000-01-01", train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04",
+                      test_start="2000-01-05", test_end="2000-01-06", use_all_stations_on_all_data_sets=False,
+                      trainable=True, fraction_of_train=0.5, experiment_path=experiment_path)
+        exp_setup = ExperimentSetup(**kwargs)
+        data_store = exp_setup.data_store
+        assert data_store.get("data_path", "general") == prepare_host()
+        assert data_store.get("trainable", "general") is True
+        assert data_store.get("fraction_of_train", "general") == 0.5
+        assert data_store.get("experiment_name", "general") == "TODAY_network/"
+        path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
+        assert data_store.get("experiment_path", "general") == path
+        assert data_store.get("var_all_dict", "general") == {'o3': 'dma8eu', 'relhum': 'average_values',
+                                                             'temp': 'maximum'}
+        assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027']
+        assert data_store.get("network", "general") == "INTERNET"
+        assert data_store.get("variables", "general") == ["o3", "temp"]
+        assert data_store.get("target_var", "general") == "temp"
+        assert data_store.get("target_dim", "general") == "target"
+        assert data_store.get("dimensions", "general") == "dim1"
+        assert data_store.get("interpolate_dim", "general") == "int_dim"
+        assert data_store.get("start", "general.train") == "2000-01-01"
+        assert data_store.get("end", "general.train") == "2000-01-02"
+        assert data_store.get("start", "general.val") == "2000-01-03"
+        assert data_store.get("end", "general.val") == "2000-01-04"
+        assert data_store.get("start", "general.test") == "2000-01-05"
+        assert data_store.get("end", "general.test") == "2000-01-06"
-- 
GitLab