diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index b93e23d52f54b6357832ee8d33862dfd56368eca..5de7ef5f788ddcee591c20f6f9125813cec5205a 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -4,7 +4,7 @@ __date__ = '2019-11-15' import argparse import logging import os -from typing import Union, Dict, Any, List +from typing import Union, Dict, Any, List, Callable from mlair.configuration import path_config from mlair import helpers @@ -279,7 +279,7 @@ class ExperimentSetup(RunEnvironment): path_config.check_path_and_create(self.data_store.get("logging_path")) # setup for data - self._set_param("stations", helpers.to_list(stations), default=DEFAULT_STATIONS) + self._set_param("stations", stations, default=DEFAULT_STATIONS, apply=helpers.to_list) self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys())) self._set_param("start", start, default=DEFAULT_START) @@ -355,10 +355,14 @@ class ExperimentSetup(RunEnvironment): raise KeyError(f"Given argument {k} with value {v} cannot be set for this experiment due to a " f"conflict with an existing entry with same naming: {k}={self.data_store.get(k)}") - def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: - """Set given parameter and log in debug.""" + def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general", + apply: Callable = None) -> None: + """Set given parameter and log in debug. Use apply parameter to adjust the stored value (e.g. to transform value + to a list use apply=helpers.to_list).""" if value is None and default is not None: value = default + if apply is not None: + value = apply(value) self.data_store.set(param, value, scope) logging.debug(f"set experiment attribute: {param}({scope})={value}") diff --git a/test/test_run_modules/test_experiment_setup.py b/test/test_run_modules/test_experiment_setup.py index ff35508542b694eb1def0ba791d9a5f70043f19c..7c63d3d101176a40749ce903f569263b9c884d5e 100644 --- a/test/test_run_modules/test_experiment_setup.py +++ b/test/test_run_modules/test_experiment_setup.py @@ -4,7 +4,7 @@ import os import pytest -from mlair.helpers import TimeTracking +from mlair.helpers import TimeTracking, to_list from mlair.configuration.path_config import prepare_host from mlair.run_modules.experiment_setup import ExperimentSetup @@ -33,6 +33,16 @@ class TestExperimentSetup: empty_obj._set_param("AnotherNoneTester", None) assert empty_obj.data_store.get("AnotherNoneTester", "general") is None + def test_set_param_with_apply(self, caplog, empty_obj): + empty_obj._set_param("NoneTester", None, default="notNone", apply=None) + assert empty_obj.data_store.get("NoneTester") == "notNone" + empty_obj._set_param("NoneTester", None, default="notNone", apply=to_list) + assert empty_obj.data_store.get("NoneTester") == ["notNone"] + empty_obj._set_param("NoneTester", None, apply=to_list) + assert empty_obj.data_store.get("NoneTester") == [None] + empty_obj._set_param("NoneTester", 2.3, apply=int) + assert empty_obj.data_store.get("NoneTester") == 2 + def test_init_default(self): exp_setup = ExperimentSetup() data_store = exp_setup.data_store