diff --git a/mlair/configuration/path_config.py b/mlair/configuration/path_config.py index 67c6bce4a3478443323b4ef49b5dc36258271ccd..e7418b984dab74b0527b8dca05a9f6c3636ac18f 100644 --- a/mlair/configuration/path_config.py +++ b/mlair/configuration/path_config.py @@ -4,13 +4,13 @@ import logging import os import re import socket -from typing import Tuple +from typing import Union # ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) ROOT_PATH = os.getcwd() -def prepare_host(create_new=True, data_path=None, sampling="daily") -> str: +def prepare_host(create_new=True, data_path=None) -> str: """ Set up host path. @@ -20,7 +20,6 @@ def prepare_host(create_new=True, data_path=None, sampling="daily") -> str: :param create_new: Create new path if enabled :param data_path: Parse your custom path (and therefore ignore preset paths fitting to known hosts) - :param sampling: sampling rate to separate data physically by temporal resolution (deprecated) :return: full path to data """ @@ -73,7 +72,7 @@ def set_experiment_path(name: str, path: str = None) -> str: return experiment_path -def set_experiment_name(name: str = None, sampling: str = None) -> str: +def set_experiment_name(name: str = None, sampling: Union[str, tuple] = None) -> str: """ Set name of experiment and its path. @@ -90,6 +89,8 @@ def set_experiment_name(name: str = None, sampling: str = None) -> str: else: experiment_name = f"{name}_network" if sampling is not None: + if not isinstance(sampling, str): + sampling = sampling[-1] experiment_name += f"_{sampling}" return experiment_name diff --git a/test/test_configuration/test_path_config.py b/test/test_configuration/test_path_config.py index fb8a2b1950cd07909543fbe564230ab73661c126..996550bf98d736b468edfe99b648e91885907dfb 100644 --- a/test/test_configuration/test_path_config.py +++ b/test/test_configuration/test_path_config.py @@ -68,6 +68,14 @@ class TestSetExperimentName: exp_name = set_experiment_name(sampling="daily") assert exp_name == "TestExperiment_daily" + def test_set_experiment_name_tuple_sampling(self): + exp_name = set_experiment_name(sampling=("hourly")) + assert exp_name == "TestExperiment_hourly" + exp_name = set_experiment_name(sampling=("hourly", "daily")) + assert exp_name == "TestExperiment_daily" + exp_name = set_experiment_name(sampling=("hourly", "dummy", "daily")) + assert exp_name == "TestExperiment_daily" + def test_set_experiment_path(self): exp_path = set_experiment_path("TestExperiment") assert exp_path == os.path.abspath(os.path.join(ROOT_PATH, "TestExperiment"))