diff --git a/requirements.txt b/requirements.txt index cdf035784475dac51d17173e7863dbf483e20101..d2a7200b492fc3dc61f84f4432a97b05051ca184 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ pytest-lazy-fixture==0.6.1 pytest-cov pytest-html pydot +mock diff --git a/run.py b/run.py index 9304c42f365f1d9119e50beb700780171c9d905d..2781c14f5b642e57efc47cc0c7e303eb89bda07e 100644 --- a/run.py +++ b/run.py @@ -3,7 +3,25 @@ __date__ = '2019-11-14' import logging +from src.helpers import TimeTracking +from src import helpers +import argparse + if __name__ == "__main__": - logging.info("start run script") \ No newline at end of file + parser = argparse.ArgumentParser() + parser.add_argument('experiment date', metavar='exp_date', type=str, nargs=1, help='set experiment date as string', + default=None) + + logging.info("start run script") + total_time = TimeTracking() + + # set data path of this experiment + data_path = helpers.prepare_host() + + # set experiment name + experiment_date = parser.parse_args(["experiment_date"]) + experiment_name, experiment_path = helpers.set_experiment_name(experiment_date=experiment_date) + + # set if model shall be trained or not \ No newline at end of file diff --git a/src/helpers.py b/src/helpers.py index 9c7ab255ef0a5170197f9c0daed76ac3bc08476e..b087e73035497b7befdb04f7f2c17b643216ad8e 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -10,6 +10,8 @@ from typing import Union import numpy as np import os import time +import socket +import sys def to_list(arg): @@ -123,3 +125,40 @@ class TimeTracking(object): def duration(self): return self._duration() + + +def prepare_host(): + hostname = socket.gethostname() + user = os.getlogin() + if hostname == 'ZAM144': # pragma: no branch + path = f'/home/{user}/Data/toar_daily/' + elif hostname == 'zam347': + path = f'/home/{user}/Data/toar_daily/' + elif hostname == 'linux-gzsx': + path = f'/home/{user}/machinelearningtools' + elif (len(hostname) > 2) and (hostname[:2] == 'jr'): + path = f'/p/project/cjjsc42/{user}/DATA/toar_daily/' + elif (len(hostname) > 2) and (hostname[:2] == 'jw'): + path = f'/p/home/jusers/{user}/juwels/intelliaq/DATA/toar_daily/' + else: + logging.error(f"unknown host '{hostname}'") + raise OSError(f"unknown host '{hostname}'") + if not os.path.exists(path): + logging.error(f"path '{path}' does not exist for host '{hostname}'.") + raise NotADirectoryError(f"path '{path}' does not exist for host '{hostname}'.") + else: + logging.info(f"set path to: {path}") + return path + + +def set_experiment_name(experiment_date=None, experiment_path=None): + + if experiment_date is None: + experiment_name = "" + else: + experiment_name = f"{experiment_date}_network/" + if experiment_path is None: + experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", experiment_name)) + else: + experiment_path = os.path.abspath(experiment_path) + return experiment_name, experiment_path diff --git a/test/test_helpers.py b/test/test_helpers.py index aa80ec3841c59ab4cb86a70dd2074d9dab6b4d33..78e1a34f193ac3a84da7e887bbe48badbd056536 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -4,6 +4,7 @@ import logging import os import keras import numpy as np +import mock class TestToList: @@ -135,3 +136,39 @@ class TestTimeTracking: assert duration is not None duration = t.stop() assert duration == t.duration() + + +class TestPrepareHost: + + @mock.patch("socket.gethostname", return_value="linux-gzsx") + @mock.patch("os.getlogin", return_value="testUser") + @mock.patch("os.path.exists", return_value=True) + def test_prepare_host(self, mock_host, mock_user, mock_path): + path = prepare_host() + assert path == "/home/testUser/machinelearningtools" + + @mock.patch("socket.gethostname", return_value="NotExistingHostName") + @mock.patch("os.getlogin", return_value="zombie21") + def test_error_handling(self, mock_user, mock_host): + with pytest.raises(OSError) as e: + prepare_host() + assert "unknown host 'NotExistingHostName'" in e.value.args[0] + mock_host.return_value = "linux-gzsx" + with pytest.raises(NotADirectoryError) as e: + prepare_host() + assert "path '/home/zombie21/machinelearningtools' does not exist for host 'linux-gzsx'" in e.value.args[0] + + +class TestSetExperimentName: + + def test_set_experiment(self): + exp_name, exp_path = set_experiment_name() + assert exp_name == "" + assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "")) + exp_name, exp_path = set_experiment_name(experiment_date="2019-11-14", experiment_path="./test2") + assert exp_name == "2019-11-14_network/" + assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "test2")) + + def test_set_experiment_from_sys(self): + exp_name, _ = set_experiment_name(experiment_date="2019-11-14") + assert exp_name == "2019-11-14_network/"