diff --git a/requirements.txt b/requirements.txt index 9ccd09ef9234bafff4559aa9fc325bef7d8bf3ea..227da2b61976d6147902e05a54a3e414dc3f40cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ pyshp six pyproj shapely -cartopy==0.16.0 +Cartopy==0.16.0 matplotlib pillow scipy \ No newline at end of file diff --git a/src/datastore.py b/src/datastore.py index 92623baf7c01f8199f653b7220db77a931986708..69e486796cc4320d88b3c3f5fd4a970d4ee814e9 100644 --- a/src/datastore.py +++ b/src/datastore.py @@ -144,6 +144,22 @@ class DataStoreByVariable(AbstractDataStore): """ return self._stride_through_scopes(name, scope)[2] + def get_default(self, name: str, scope: str, default: Any) -> Any: + """ + Same functionality like the standard get method. But this method adds a default argument that is returned if no + data was stored in the data store. Use this function with care, because it will not report any errors and just + return the given default value. Currently, there is no statement that reports, if the returned value comes from + the data store or the default value. + :param name: Name to look for + :param scope: scope to search the name for + :param default: default value that is return, if no data was found for given name and scope + :return: the stored object or the default value + """ + try: + return self._stride_through_scopes(name, scope)[2] + except (NameNotFoundInDataStore, NameNotFoundInScope): + return default + def _stride_through_scopes(self, name, scope, depth=0): if depth <= scope.count("."): local_scope = scope.rsplit(".", maxsplit=depth)[0] @@ -267,6 +283,22 @@ class DataStoreByScope(AbstractDataStore): """ return self._stride_through_scopes(name, scope)[2] + def get_default(self, name: str, scope: str, default: Any) -> Any: + """ + Same functionality like the standard get method. But this method adds a default argument that is returned if no + data was stored in the data store. Use this function with care, because it will not report any errors and just + return the given default value. Currently, there is no statement that reports, if the returned value comes from + the data store or the default value. + :param name: Name to look for + :param scope: scope to search the name for + :param default: default value that is return, if no data was found for given name and scope + :return: the stored object or the default value + """ + try: + return self._stride_through_scopes(name, scope)[2] + except (NameNotFoundInDataStore, NameNotFoundInScope): + return default + def _stride_through_scopes(self, name, scope, depth=0): if depth <= scope.count("."): local_scope = scope.rsplit(".", maxsplit=depth)[0] diff --git a/src/helpers.py b/src/helpers.py index a4ce625c8ae9bbf3c03425116a6bc10abf328bc9..5646eb94dbd43941b5673e64f6b70a7ed0e51c26 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -88,6 +88,13 @@ class TimeTracking(object): def duration(self): return self._duration() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + logging.info(f"undefined job finished after {self}") + def prepare_host(create_new=True): hostname = socket.gethostname() @@ -99,7 +106,7 @@ def prepare_host(create_new=True): path = f'/home/{user}/Data/toar_daily/' elif hostname == 'zam347': path = f'/home/{user}/Data/toar_daily/' - elif hostname == 'linux-gzsx': + elif hostname == 'linux-aa9b': path = f'/home/{user}/machinelearningtools/data/toar_daily/' elif (len(hostname) > 2) and (hostname[:2] == 'jr'): path = f'/p/project/cjjsc42/{user}/DATA/toar_daily/' diff --git a/src/run_modules/modules.py b/src/run_modules/modules.py deleted file mode 100644 index 5f0f12c19a87de7ba4ad1e0508906e75d1605563..0000000000000000000000000000000000000000 --- a/src/run_modules/modules.py +++ /dev/null @@ -1,33 +0,0 @@ -import logging -import argparse - -from src.run_modules.run_environment import RunEnvironment -from src.run_modules.experiment_setup import ExperimentSetup -from src.run_modules.pre_processing import PreProcessing - - -class Training(RunEnvironment): - - def __init__(self): - super().__init__() - - -class PostProcessing(RunEnvironment): - - def __init__(self): - super().__init__() - - -if __name__ == "__main__": - - formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' - logging.basicConfig(format=formatter, level=logging.DEBUG) - - parser = argparse.ArgumentParser() - parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None, - help="set experiment date as string") - parser_args = parser.parse_args() - with RunEnvironment(): - ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], - station_type='background') - PreProcessing() diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index a9695064e1d2864d4a367a297ba94cc404d46538..e6f271ce3cc6cf2548ff5b06ba40e2fd509f8c8d 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -18,6 +18,7 @@ from src import statistics from src.plotting.postprocessing_plotting import plot_conditional_quantiles from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, PlotCompetitiveSkillScore from src.datastore import NameNotFoundInDataStore +from src.helpers import TimeTracking class PostProcessing(RunEnvironment): @@ -26,7 +27,7 @@ class PostProcessing(RunEnvironment): super().__init__() self.model: keras.Model = self._load_model() self.ols_model = None - self.batch_size: int = self.data_store.get("batch_size", "general.model") + self.batch_size: int = self.data_store.get_default("batch_size", "general.model", 64) self.test_data: DataGenerator = self.data_store.get("generator", "general.test") self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size) self.train_data: DataGenerator = self.data_store.get("generator", "general.train") @@ -36,8 +37,14 @@ class PostProcessing(RunEnvironment): self._run() def _run(self): - self.train_ols_model() - preds_for_all_stations = self.make_prediction() + with TimeTracking(): + self.train_ols_model() + logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " + "skip make_prediction() whenever it is possible to save time.") + with TimeTracking(): + preds_for_all_stations = self.make_prediction() + logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " + "skip make_prediction() whenever it is possible to save time.") self.skill_scores = self.calculate_skill_scores() self.plot() diff --git a/test/test_datastore.py b/test/test_datastore.py index 3f61c227be3a05d78f825eca77b0d6cbbc617ce1..e7510cffacafb4a6db6006097f48992fb6a10e55 100644 --- a/test/test_datastore.py +++ b/test/test_datastore.py @@ -16,6 +16,12 @@ class TestAbstractDataStore: def test_init(self, ds): assert ds._store == {} + def test_clear_data_store(self, ds): + ds._store["test"] = "test2" + assert len(ds._store.keys()) == 1 + ds.clear_data_store() + assert len(ds._store.keys()) == 0 + class TestDataStoreByVariable: @@ -49,6 +55,12 @@ class TestDataStoreByVariable: ds.get("number3", "general") assert "Couldn't find number3 in data store" in e.value.args[0] + def test_get_default(self, ds): + ds.set("number", 3, "general") + assert ds.get_default("number", "general", 45) == 3 + assert ds.get_default("number", "general.sub", 45) == 3 + assert ds.get_default("number", "other", 45) == 45 + def test_search(self, ds): ds.set("number", 22, "general") ds.set("number", 22, "general2") @@ -118,6 +130,25 @@ class TestDataStoreByVariable: assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \ [("number", "general.sub.sub", "ABC"), ("number1", "general.sub", 22), ("number2", "general.sub.sub", 3)] + def test_create_args_dict(self, ds): + ds.set("tester1", 1, "general") + ds.set("tester2", 11, "general") + ds.set("tester2", 10, "general.sub") + ds.set("tester3", 21, "general") + args = ["tester1", "tester2", "tester3", "tester4"] + assert ds.create_args_dict(args) == {"tester1": 1, "tester2": 11, "tester3": 21} + assert ds.create_args_dict(args, "general.sub") == {"tester1": 1, "tester2": 10, "tester3": 21} + assert ds.create_args_dict(["notAvail", "alsonot"]) == {} + + def test_set_args_from_dict(self, ds): + ds.set_args_from_dict({"tester1": 1, "tester2": 10, "tester3": 21}) + assert ds.get("tester1", "general") == 1 + assert ds.get("tester2", "general") == 10 + assert ds.get("tester3", "general") == 21 + ds.set_args_from_dict({"tester1": 111}, "general.sub") + assert ds.get("tester1", "general.sub") == 111 + assert ds.get("tester3", "general.sub") == 21 + class TestDataStoreByScope: @@ -151,6 +182,12 @@ class TestDataStoreByScope: ds.get("number3", "general") assert "Couldn't find number3 in data store" in e.value.args[0] + def test_get_default(self, ds): + ds.set("number", 3, "general") + assert ds.get_default("number", "general", 45) == 3 + assert ds.get_default("number", "general.sub", 45) == 3 + assert ds.get_default("number", "other", 45) == 45 + def test_search(self, ds): ds.set("number", 22, "general") ds.set("number", 22, "general2") @@ -220,3 +257,21 @@ class TestDataStoreByScope: assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \ [("number", "general.sub.sub", "ABC"), ("number1", "general.sub", 22), ("number2", "general.sub.sub", 3)] + def test_create_args_dict(self, ds): + ds.set("tester1", 1, "general") + ds.set("tester2", 11, "general") + ds.set("tester2", 10, "general.sub") + ds.set("tester3", 21, "general") + args = ["tester1", "tester2", "tester3", "tester4"] + assert ds.create_args_dict(args) == {"tester1": 1, "tester2": 11, "tester3": 21} + assert ds.create_args_dict(args, "general.sub") == {"tester1": 1, "tester2": 10, "tester3": 21} + assert ds.create_args_dict(["notAvail", "alsonot"]) == {} + + def test_set_args_from_dict(self, ds): + ds.set_args_from_dict({"tester1": 1, "tester2": 10, "tester3": 21}) + assert ds.get("tester1", "general") == 1 + assert ds.get("tester2", "general") == 10 + assert ds.get("tester3", "general") == 21 + ds.set_args_from_dict({"tester1": 111}, "general.sub") + assert ds.get("tester1", "general.sub") == 111 + assert ds.get("tester3", "general.sub") == 21 \ No newline at end of file diff --git a/test/test_helpers.py b/test/test_helpers.py index c909960b4e5e053b9291c12e64e3649e957886bc..463007dc361d21df934bb239b3ecac2fc86882ad 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -109,10 +109,18 @@ class TestTimeTracking: duration = t.stop(get_duration=True) assert duration == t.duration() + def test_enter_exit(self, caplog): + caplog.set_level(logging.INFO) + with TimeTracking() as t: + assert t.start is not None + assert t.end is None + expression = PyTestRegex(r"undefined job finished after \d+:\d+:\d+ \(hh:mm:ss\)") + assert caplog.record_tuples[-1] == ('root', 20, expression) + class TestPrepareHost: - @mock.patch("socket.gethostname", side_effect=["linux-gzsx", "ZAM144", "zam347", "jrtest", "jwtest"]) + @mock.patch("socket.gethostname", side_effect=["linux-aa9b", "ZAM144", "zam347", "jrtest", "jwtest"]) @mock.patch("os.getlogin", return_value="testUser") @mock.patch("os.path.exists", return_value=True) def test_prepare_host(self, mock_host, mock_user, mock_path): @@ -134,10 +142,10 @@ class TestPrepareHost: prepare_host() assert "unknown host 'NotExistingHostName'" in e.value.args[0] if "runner-6HmDp9Qd-project-2411-concurrent" not in platform.node(): - mock_host.return_value = "linux-gzsx" + mock_host.return_value = "linux-aa9b" with pytest.raises(NotADirectoryError) as e: prepare_host() - assert "does not exist for host 'linux-gzsx'" in e.value.args[0] + assert "does not exist for host 'linux-aa9b'" in e.value.args[0] class TestSetExperimentName: