diff --git a/.gitignore b/.gitignore index 305a5d1b9420eb62da24772fc1f4b263c1f3efe1..f5e425f752a1de0de0c68036a54e0d19450320bb 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,15 @@ ehthumbs.db Thumbs.db .idea/ /venv/ +/venv*/ +/build/ + +# ignore HPC related skripts # +############################## +run_*_develgpus.bash +run_*_gpus.bash +run_*_batch.bash +activate_env.sh # don't check data and plot folder # #################################### @@ -77,4 +86,4 @@ report.html # ignore locally build documentation # ###################################### -/docs/_build \ No newline at end of file +/docs/_build diff --git a/.gitlab/issue_templates/bug.md b/.gitlab/issue_templates/bug.md new file mode 100644 index 0000000000000000000000000000000000000000..60cf04b086eac998f78fce03a20dd7757221f57f --- /dev/null +++ b/.gitlab/issue_templates/bug.md @@ -0,0 +1,19 @@ +<!-- Use this template for a bug in MLAir. --> + +# Bug + +## Error description +<!-- Provide a context when the bug / error arises --> + +## Error message +<!-- Provide the error log if available --> + +## First guess on error origin +<!-- Add first ideas where the error could come from --> + +## Error origin +<!-- Fill this up when the bug / error origin has been found --> + +## Solution +<!-- Short description how to solve the error --> + diff --git a/.gitlab/issue_templates/release.md b/.gitlab/issue_templates/release.md new file mode 100644 index 0000000000000000000000000000000000000000..618738d3184c68514fe32602af32188e001d228b --- /dev/null +++ b/.gitlab/issue_templates/release.md @@ -0,0 +1,40 @@ +<!-- Use this template for a new release of MLAir. --> + +# Release +<!-- add your release version here --> + +vX.Y.Z + +## checklist + +* [ ] Create Release Issue +* [ ] Create merge request: branch `release_vX.Y.Z` into `master` +* [ ] Merge `develop` into `release_vX.Y.Z` +* [ ] Checkout `release_vX.Y.Z` +* [ ] Adjust `changelog.md` (see template for changelog) +* [ ] Update version number in `mlair/__ init__.py` +* [ ] Create new dist file: `python3 setup.py sdist bdist_wheel` +* [ ] Update file link `distribution file (current version)` in `README.md` +* [ ] Update file link in `docs/_source/get-started.rst` +* [ ] Commit + push +* [ ] Merge `release_vX.Y.Z` into `master` +* [ ] Create new tag with + * [ ] distribution file (.whl) + * [ ] link to Documentation + * [ ] Example Jupyter Notebook + * [ ] changelog + + +## template for changelog +<!-- use this structure for the changelog. Link all issue to at least one item. --> + +``` +## vX.Y.Z - yyyy-mm-dd - <release description> + +### general: +* text +### new features: +* words (issue) +### technical: +* +``` \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index bf0c2b6b3ab672522630b28a1865e020b64ac86b..4f59375d8ee3c245e7d8008e7e8c6d6ff13b3d96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,38 @@ # Changelog All notable changes to this project will be documented in this file. +## v1.2.0 - 2020-12-18 - parallel preprocessing and improved data handlers + +### general: + + * new plots + * parallelism for faster preprocessing + * improved data handler with mixed sampling types + * enhanced test coverage + +### new features: + +* station map plot highlights now subsets on the map and displays number of stations for each subset (#227, #231) +* two new data availability plots `PlotAvailabilityHistogram` (#191, #192, #223) +* introduced parallel code in preprocessing if system supports parallelism (#164, #224, #225) +* data handler `DataHandlerMixedSampling` (and inheritances) supports an offset parameter to end inputs at a different time than 00 hours (#220) +* args for data handler `DataHandlerMixedSampling` (and inheritances) that differ for input and target can now be parsed as tuple (#229) + +### technical: + +* added templates for release and bug issues (#189) +* improved test coverage (#236, #238, #239, #240, #241, #242, #243, #244, #245) +* station map plot includes now number of stations for each subset (#231) +* postprocessing plots are encapsulated in try except statements (#107) +* updated git settings (#213) +* bug fix for data handler (#235) +* reordering and bug fix for preprocessing reporting (#207, #232) +* bug fix for outdated system path style (#226) +* new plots are included in default plot list (#211) +* `helpers/join` connection to ToarDB (e.g. used by DefaultDataHandler) reports now which variable could not be loaded (#222) +* plot `PlotBootstrapSkillScore` can now additionally highlight specific variables, but not included in postprocessing up to now (#201) +* data handler `DataHandlerMixedSampling` has now a reduced data loading (#221) + ## v1.1.0 - 2020-11-18 - hourly resolution support and new data handlers ### general: diff --git a/README.md b/README.md index 2e7b0cff48ba92143263c65c7a3fa82c139b86c8..c48b7cdb44b6f98a6a1f12a81c0a4717cc1e0d41 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ HPC systems, see [here](#special-instructions-for-installation-on-jülich-hpc-sy * Installation of **MLAir**: * Either clone MLAir from the [gitlab repository](https://gitlab.version.fz-juelich.de/toar/mlair.git) and use it without installation (beside the requirements) - * or download the distribution file ([current version](https://gitlab.version.fz-juelich.de/toar/mlair/-/blob/master/dist/mlair-1.1.0-py3-none-any.whl)) + * or download the distribution file ([current version](https://gitlab.version.fz-juelich.de/toar/mlair/-/blob/master/dist/mlair-1.2.0-py3-none-any.whl)) and install it via `pip install <dist_file>.whl`. In this case, you can simply import MLAir in any python script inside your virtual environment using `import mlair`. * (tf) Currently, TensorFlow-1.13 is mentioned in the requirements. We already tested the TensorFlow-1.15 version and couldn't diff --git a/conftest.py b/conftest.py index 207606e6ec111459302360f5f2c4f917771bf80d..08641ff36543dbfba7109f84616ead8d2b472891 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,8 @@ import os import re import shutil +import pytest +import mock def pytest_runtest_teardown(item, nextitem): @@ -48,3 +50,18 @@ def remove_files_from_regex(list_dir, path, regex): shutil.rmtree(os.path.join(path, e), ignore_errors=True) except: pass + + +@pytest.fixture(scope="session", autouse=True) +def default_session_fixture(request): + """ + :type request: _pytest.python.SubRequest + :return: + """ + patched = mock.patch("multiprocessing.cpu_count", return_value=1) + patched.__enter__() + + def unpatch(): + patched.__exit__() + + request.addfinalizer(unpatch) diff --git a/dist/mlair-1.2.0-py3-none-any.whl b/dist/mlair-1.2.0-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..7b4c3eff904e45a20591e80eccd3f3720d3d339a Binary files /dev/null and b/dist/mlair-1.2.0-py3-none-any.whl differ diff --git a/docs/_source/_plots/data_availability_histogram_hist.png b/docs/_source/_plots/data_availability_histogram_hist.png new file mode 100644 index 0000000000000000000000000000000000000000..30f940145b5db8a6af1b8ce55c6737d9022da419 Binary files /dev/null and b/docs/_source/_plots/data_availability_histogram_hist.png differ diff --git a/docs/_source/_plots/data_availability_histogram_hist_cum.png b/docs/_source/_plots/data_availability_histogram_hist_cum.png new file mode 100644 index 0000000000000000000000000000000000000000..afb9796161739d060eebcd980258bc1bb1b9a14f Binary files /dev/null and b/docs/_source/_plots/data_availability_histogram_hist_cum.png differ diff --git a/docs/_source/_plots/skill_score_bootstrap_separated.png b/docs/_source/_plots/skill_score_bootstrap_separated.png new file mode 100644 index 0000000000000000000000000000000000000000..cddb0f005286e08617bf43fd55fed9b5862856d5 Binary files /dev/null and b/docs/_source/_plots/skill_score_bootstrap_separated.png differ diff --git a/docs/_source/get-started.rst b/docs/_source/get-started.rst index 477b4b89e5d56d1ec7a94301a4f9378dc1dce7dd..ede3cebfb7e1d9f673da3751c0cc2ab4dfba12ea 100644 --- a/docs/_source/get-started.rst +++ b/docs/_source/get-started.rst @@ -31,7 +31,7 @@ Installation of MLAir * Install all requirements from `requirements.txt <https://gitlab.version.fz-juelich.de/toar/machinelearningtools/-/blob/master/requirements.txt>`_ preferably in a virtual environment * Either clone MLAir from the `gitlab repository <https://gitlab.version.fz-juelich.de/toar/machinelearningtools.git>`_ -* or download the distribution file (`current version <https://gitlab.version.fz-juelich.de/toar/mlair/-/blob/master/dist/mlair-1.1.0-py3-none-any.whl>`_) +* or download the distribution file (`current version <https://gitlab.version.fz-juelich.de/toar/mlair/-/blob/master/dist/mlair-1.2.0-py3-none-any.whl>`_) and install it via :py:`pip install <dist_file>.whl`. In this case, you can simply import MLAir in any python script inside your virtual environment using :py:`import mlair`. * (tf) Currently, TensorFlow-1.13 is mentioned in the requirements. We already tested the TensorFlow-1.15 version and couldn't diff --git a/mlair/__init__.py b/mlair/__init__.py index 41b258eb7a0ef445718cb7c45cc01bbc3092cadc..e9a157ca5bba11b22e80df0f3f18092fb0f32db6 100644 --- a/mlair/__init__.py +++ b/mlair/__init__.py @@ -1,6 +1,6 @@ __version_info__ = { 'major': 1, - 'minor': 1, + 'minor': 2, 'micro': 0, } @@ -13,7 +13,7 @@ from mlair.model_modules import AbstractModelClass def get_version(): assert set(__version_info__.keys()) >= {"major", "minor"} vers = [f"{__version_info__['major']}.{__version_info__['minor']}"] - if "micro" in __version_info__: + if "micro" in __version_info__: # pragma: no branch vers.append(f".{__version_info__['micro']}") return "".join(vers) diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index ce42fc0eed6e891bc0a0625666da3dccfcc8a3ee..1862d6734430d42a2d0cda0b199acef97b58bebb 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -48,7 +48,7 @@ DEFAULT_CREATE_NEW_BOOTSTRAPS = False DEFAULT_NUMBER_OF_BOOTSTRAPS = 20 DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", - "PlotAvailability", "PlotSeparationOfScales"] + "PlotAvailability", "PlotAvailabilityHistogram", "PlotSeparationOfScales"] DEFAULT_SAMPLING = "daily" DEFAULT_DATA_ORIGIN = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA", "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", "no": "", "no2": "", "o3": "", diff --git a/mlair/configuration/path_config.py b/mlair/configuration/path_config.py index bf40c361e121c409efec08b85fdf4e19848049ee..67c6bce4a3478443323b4ef49b5dc36258271ccd 100644 --- a/mlair/configuration/path_config.py +++ b/mlair/configuration/path_config.py @@ -29,11 +29,11 @@ def prepare_host(create_new=True, data_path=None, sampling="daily") -> str: user = getpass.getuser() runner_regex = re.compile(r"runner-.*-project-2411-concurrent-\d+") if hostname == "ZAM144": - data_path = f"/home/{user}/Data/toar_{sampling}/" + data_path = f"/home/{user}/Data/toar/" elif hostname == "zam347": - data_path = f"/home/{user}/Data/toar_{sampling}/" + data_path = f"/home/{user}/Data/toar/" elif (len(hostname) > 2) and (hostname[:2] == "jr"): - data_path = f"/p/project/cjjsc42/{user}/DATA/toar_{sampling}/" + data_path = f"/p/project/cjjsc42/{user}/DATA/toar/" elif (len(hostname) > 2) and (hostname[:2] in ['jw', 'ju'] or hostname[:5] in ['hdfml']): data_path = f"/p/project/deepacf/intelliaq/{user}/DATA/MLAIR/" elif runner_regex.match(hostname) is not None: diff --git a/mlair/data_handler/advanced_data_handler.py b/mlair/data_handler/advanced_data_handler.py deleted file mode 100644 index f04748e82f11116b265796afba7f401c1cad9342..0000000000000000000000000000000000000000 --- a/mlair/data_handler/advanced_data_handler.py +++ /dev/null @@ -1,112 +0,0 @@ - -__author__ = 'Lukas Leufen' -__date__ = '2020-07-08' - -import numpy as np -import xarray as xr -import os -import pandas as pd -import datetime as dt - -from mlair.data_handler import AbstractDataHandler - -from typing import Union, List, Tuple, Dict -import logging -from functools import reduce -from mlair.helpers.join import EmptyQueryResult -from mlair.helpers import TimeTracking - -number = Union[float, int] -num_or_list = Union[number, List[number]] - - -def run_data_prep(): - from .data_handler_neighbors import DataHandlerNeighbors - data = DummyDataHandler("main_class") - data.get_X() - data.get_Y() - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") - data_prep = DataHandlerNeighbors(DummyDataHandler("main_class"), - path, - neighbors=[DummyDataHandler("neighbor1"), - DummyDataHandler("neighbor2")], - extreme_values=[1., 1.2]) - data_prep.get_data(upsampling=False) - - -def create_data_prep(): - from .data_handler_neighbors import DataHandlerNeighbors - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") - station_type = None - network = 'UBA' - sampling = 'daily' - target_dim = 'variables' - target_var = 'o3' - interpolation_dim = 'datetime' - window_history_size = 7 - window_lead_time = 3 - central_station = DataHandlerSingleStation("DEBW011", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim, - target_var, interpolation_dim, window_history_size, window_lead_time) - neighbor1 = DataHandlerSingleStation("DEBW013", path, {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {}, station_type, network, sampling, target_dim, - target_var, interpolation_dim, window_history_size, window_lead_time) - neighbor2 = DataHandlerSingleStation("DEBW034", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim, - target_var, interpolation_dim, window_history_size, window_lead_time) - - data_prep = [] - data_prep.append(DataHandlerNeighbors(central_station, path, neighbors=[neighbor1, neighbor2])) - data_prep.append(DataHandlerNeighbors(neighbor1, path, neighbors=[central_station, neighbor2])) - data_prep.append(DataHandlerNeighbors(neighbor2, path, neighbors=[neighbor1, central_station])) - return data_prep - - -class DummyDataHandler(AbstractDataHandler): - - def __init__(self, name, number_of_samples=None): - """This data handler takes a name argument and the number of samples to generate. If not provided, a random - number between 100 and 150 is set.""" - super().__init__() - self.name = name - self.number_of_samples = number_of_samples if number_of_samples is not None else np.random.randint(100, 150) - self._X = self.create_X() - self._Y = self.create_Y() - - def create_X(self): - """Inputs are random numbers between 0 and 10 with shape (no_samples, window=14, variables=5).""" - X = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables - datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist() - return xr.DataArray(X, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist, - "window": range(14), - "variables": range(5)}) - - def create_Y(self): - """Targets are normal distributed random numbers with shape (no_samples, window=5, variables=1).""" - Y = np.round(0.5 * np.random.randn(self.number_of_samples, 5, 1), 1) # samples, window, variables - datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist() - return xr.DataArray(Y, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist, - "window": range(5), - "variables": range(1)}) - - def get_X(self, upsampling=False, as_numpy=False): - """Upsampling parameter is not used for X.""" - return np.copy(self._X) if as_numpy is True else self._X - - def get_Y(self, upsampling=False, as_numpy=False): - """Upsampling parameter is not used for Y.""" - return np.copy(self._Y) if as_numpy is True else self._Y - - def __str__(self): - return self.name - - -if __name__ == "__main__": - from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation - from mlair.data_handler.iterator import KerasIterator, DataCollection - data_prep = create_data_prep() - data_collection = DataCollection(data_prep) - for data in data_collection: - print(data) - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata", "keras") - keras_it = KerasIterator(data_collection, 100, path, upsampling=True) - keras_it[2] - diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 80890b6f45dcde80aa75e9203a4a44ba25c7db01..19fc26fe78f4aaec034d6593e3b4628b85fc5644 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -6,11 +6,12 @@ from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleS from mlair.data_handler import DefaultDataHandler from mlair import helpers from mlair.helpers import remove_items -from mlair.configuration.defaults import DEFAULT_SAMPLING +from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD import inspect from typing import Callable import datetime as dt +from typing import Any import numpy as np import pandas as pd @@ -20,11 +21,39 @@ import xarray as xr class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - def __init__(self, *args, sampling_inputs, **kwargs): - sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING)) - kwargs.update({"sampling": sampling}) + def __init__(self, *args, **kwargs): + """ + This data handler requires the kwargs sampling, interpolation_limit, and interpolation_method to be a 2D tuple + for input and target data. If one of these kwargs is only a single argument, it will be applied to inputs and + targets with this value. If one of these kwargs is a 2-dim tuple, the first element is applied to inputs and the + second to targets respectively. If one of these kwargs is not provided, it is filled up with the same default + value for inputs and targets. + """ + self.update_kwargs("sampling", DEFAULT_SAMPLING, kwargs) + self.update_kwargs("interpolation_limit", DEFAULT_INTERPOLATION_LIMIT, kwargs) + self.update_kwargs("interpolation_method", DEFAULT_INTERPOLATION_METHOD, kwargs) super().__init__(*args, **kwargs) + @staticmethod + def update_kwargs(parameter_name: str, default: Any, kwargs: dict): + """ + Update a single element of kwargs inplace to be usable for inputs and targets. + + The updated value in the kwargs dictionary is a tuple consisting on the value applicable to the inputs as first + element and the target's value as second element: (<value_input>, <value_target>). If the value for the given + parameter_name is already a tuple, it is checked to have exact two entries. If the paramter_name is not + included in kwargs, the given default value is used and applied to both elements of the update tuple. + + :param parameter_name: name of the parameter that should be transformed to 2-dim + :param default: the default value to fill if parameter is not in kwargs + :param kwargs: the kwargs dictionary containing parameters + """ + parameter = kwargs.get(parameter_name, default) + if not isinstance(parameter, tuple): + parameter = (parameter, parameter) + assert len(parameter) == 2 # (inputs, targets) + kwargs.update({parameter_name: parameter}) + def setup_samples(self): """ Setup samples. This method prepares and creates samples X, and labels Y. @@ -36,11 +65,13 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): self.make_samples() def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: - data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], + vars = [self.variables, self.target_var] + stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) + data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], self.station_type, self.network, self.store_data_locally, self.data_origin, self.start, self.end) - data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, - limit=self.interpolation_limit) + data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], + limit=self.interpolation_limit[ind]) return data def set_inputs_and_targets(self): @@ -110,11 +141,14 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi else: # target start, end = self.start, self.end - data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], + vars = [self.variables, self.target_var] + stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) + + data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], self.station_type, self.network, self.store_data_locally, self.data_origin, start, end) - data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, - limit=self.interpolation_limit) + data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], + limit=self.interpolation_limit[ind]) return data diff --git a/mlair/data_handler/data_handler_neighbors.py b/mlair/data_handler/data_handler_neighbors.py index a004e659969232a080d49eb6905007d353bbe99c..6c87946eaad5568e1ff59c3988bf8fe469442641 100644 --- a/mlair/data_handler/data_handler_neighbors.py +++ b/mlair/data_handler/data_handler_neighbors.py @@ -1,10 +1,15 @@ - __author__ = 'Lukas Leufen' __date__ = '2020-07-17' +import datetime as dt + +import numpy as np +import pandas as pd +import xarray as xr +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation from mlair.helpers import to_list -from mlair.data_handler import DefaultDataHandler +from mlair.data_handler import DefaultDataHandler, AbstractDataHandler import os import copy @@ -43,8 +48,90 @@ class DataHandlerNeighbors(DefaultDataHandler): return [super(DataHandlerNeighbors, self).get_coordinates()].append(neighbors) -if __name__ == "__main__": +def run_data_prep(): + """Comment: methods just to start write meaningful test routines.""" + data = DummyDataHandler("main_class") + data.get_X() + data.get_Y() + + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") + data_prep = DataHandlerNeighbors(DummyDataHandler("main_class"), + path, + neighbors=[DummyDataHandler("neighbor1"), + DummyDataHandler("neighbor2")], + extreme_values=[1., 1.2]) + data_prep.get_data(upsampling=False) + + +def create_data_prep(): + """Comment: methods just to start write meaningful test routines.""" + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") + station_type = None + network = 'UBA' + sampling = 'daily' + target_dim = 'variables' + target_var = 'o3' + interpolation_dim = 'datetime' + window_history_size = 7 + window_lead_time = 3 + central_station = DataHandlerSingleStation("DEBW011", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, + network, sampling, target_dim, + target_var, interpolation_dim, window_history_size, window_lead_time) + neighbor1 = DataHandlerSingleStation("DEBW013", path, {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {}, + station_type, network, sampling, target_dim, + target_var, interpolation_dim, window_history_size, window_lead_time) + neighbor2 = DataHandlerSingleStation("DEBW034", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, + network, sampling, target_dim, + target_var, interpolation_dim, window_history_size, window_lead_time) + + data_prep = [] + data_prep.append(DataHandlerNeighbors(central_station, path, neighbors=[neighbor1, neighbor2])) + data_prep.append(DataHandlerNeighbors(neighbor1, path, neighbors=[central_station, neighbor2])) + data_prep.append(DataHandlerNeighbors(neighbor2, path, neighbors=[neighbor1, central_station])) + return data_prep + + +class DummyDataHandler(AbstractDataHandler): + + def __init__(self, name, number_of_samples=None): + """This data handler takes a name argument and the number of samples to generate. If not provided, a random + number between 100 and 150 is set.""" + super().__init__() + self.name = name + self.number_of_samples = number_of_samples if number_of_samples is not None else np.random.randint(100, 150) + self._X = self.create_X() + self._Y = self.create_Y() + def create_X(self): + """Inputs are random numbers between 0 and 10 with shape (no_samples, window=14, variables=5).""" + X = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables + datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist() + return xr.DataArray(X, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist, + "window": range(14), + "variables": range(5)}) + + def create_Y(self): + """Targets are normal distributed random numbers with shape (no_samples, window=5, variables=1).""" + Y = np.round(0.5 * np.random.randn(self.number_of_samples, 5, 1), 1) # samples, window, variables + datelist = pd.date_range(dt.datetime.today().date(), periods=self.number_of_samples, freq="H").tolist() + return xr.DataArray(Y, dims=['datetime', 'window', 'variables'], coords={"datetime": datelist, + "window": range(5), + "variables": range(1)}) + + def get_X(self, upsampling=False, as_numpy=False): + """Upsampling parameter is not used for X.""" + return np.copy(self._X) if as_numpy is True else self._X + + def get_Y(self, upsampling=False, as_numpy=False): + """Upsampling parameter is not used for Y.""" + return np.copy(self._Y) if as_numpy is True else self._Y + + def __str__(self): + return self.name + + +if __name__ == "__main__": + """Comment: This is more for testing. Maybe reuse parts of this code for the testing routines.""" a = DataHandlerNeighbors requirements = a.requirements() @@ -59,7 +146,17 @@ if __name__ == "__main__": "window_lead_time": 3, "neighbors": ["DEBW034"], "data_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"), - "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}, - "transformation": None,} + "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}, + "transformation": None, } a_inst = a.build("DEBW011", **kwargs) print(a_inst) + + from mlair.data_handler.iterator import KerasIterator, DataCollection + + data_prep = create_data_prep() + data_collection = DataCollection(data_prep) + for data in data_collection: + print(data) + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata", "keras") + keras_it = KerasIterator(data_collection, 100, path, upsampling=True) + keras_it[2] diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index e554a3b32d8e4e2f5482a388374cfba87f7add15..654f489fab8ee6ed8eb360be54be7c755da061e1 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -34,20 +34,24 @@ DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'max 'pblheight': 'maximum'} DEFAULT_WINDOW_LEAD_TIME = 3 DEFAULT_WINDOW_HISTORY_SIZE = 13 +DEFAULT_WINDOW_HISTORY_OFFSET = 0 DEFAULT_TIME_DIM = "datetime" DEFAULT_TARGET_VAR = "o3" DEFAULT_TARGET_DIM = "variables" DEFAULT_SAMPLING = "daily" +DEFAULT_INTERPOLATION_LIMIT = 0 DEFAULT_INTERPOLATION_METHOD = "linear" class DataHandlerSingleStation(AbstractDataHandler): def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE, - network=DEFAULT_NETWORK, sampling=DEFAULT_SAMPLING, target_dim=DEFAULT_TARGET_DIM, - target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM, - window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_lead_time=DEFAULT_WINDOW_LEAD_TIME, - interpolation_limit: int = 0, interpolation_method: str = DEFAULT_INTERPOLATION_METHOD, + network=DEFAULT_NETWORK, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING, + target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM, + window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET, + window_lead_time=DEFAULT_WINDOW_LEAD_TIME, + interpolation_limit: Union[int, Tuple[int]] = DEFAULT_INTERPOLATION_LIMIT, + interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD, overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True, min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None, **kwargs): super().__init__() # path, station, statistics_per_var, transformation, **kwargs) @@ -65,6 +69,7 @@ class DataHandlerSingleStation(AbstractDataHandler): self.target_var = target_var self.time_dim = time_dim self.window_history_size = window_history_size + self.window_history_offset = window_history_offset self.window_lead_time = window_lead_time self.interpolation_limit = interpolation_limit @@ -271,20 +276,23 @@ class DataHandlerSingleStation(AbstractDataHandler): chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", "toluene"] # used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys())) - used_chem_vars = list(set(chem_vars) & set(self.variables)) + used_chem_vars = list(set(chem_vars) & set(data.variables.values)) data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data def setup_data_path(self, data_path: str, sampling: str): return os.path.join(os.path.abspath(data_path), sampling) - def shift(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray: + def shift(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray: """ Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0). :param data: data set to shift :param dim: dimension along shift is applied :param window: number of steps to shift (corresponds to the window length) + :param offset: use offset to move the window by as many time steps as given in offset. This can be used, if the + index time of a history element is not the last timestamp. E.g. you could use offset=23 when dealing with + hourly data in combination with daily data (values from 00 to 23 are aggregated on 00 the same day). :return: shifted data """ @@ -295,9 +303,10 @@ class DataHandlerSingleStation(AbstractDataHandler): else: end = window + 1 res = [] - for w in range(start, end): + _range = list(map(lambda x: x + offset, range(start, end))) + for w in _range: res.append(data.shift({dim: -w})) - window_array = self.create_index_array('window', range(start, end), squeeze_dim=self.target_dim) + window_array = self.create_index_array('window', _range, squeeze_dim=self.target_dim) res = xr.concat(res, dim=window_array) return res @@ -387,7 +396,7 @@ class DataHandlerSingleStation(AbstractDataHandler): """ window = -abs(window) data = self.input_data.data - self.history = self.shift(data, dim_name_of_shift, window) + self.history = self.shift(data, dim_name_of_shift, window, offset=self.window_history_offset) def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, window: int) -> None: diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 584151e36fd0c9621d089e88b8ad61cffa0c5925..291bbc6616314db61282c380a6b3e105d8b6248a 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -11,6 +11,7 @@ import pickle import shutil from functools import reduce from typing import Tuple, Union, List +import multiprocessing import numpy as np import xarray as xr @@ -251,14 +252,29 @@ class DefaultDataHandler(AbstractDataHandler): return means = [None, None] stds = [None, None] - for station in set_stations: - try: - sp = cls.data_handler_transformation(station, **sp_keys) - for i, data in enumerate([sp.input_data, sp.target_data]): - means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean) - stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std) - except (AttributeError, EmptyQueryResult): - continue + + if multiprocessing.cpu_count() > 1: # parallel solution + logging.info("use parallel transformation approach") + pool = multiprocessing.Pool() + logging.info(f"running {getattr(pool, '_processes')} processes in parallel") + output = [ + pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys) + for station in set_stations] + for p in output: + dh, s = p.get() + if dh is not None: + for i, data in enumerate([dh.input_data, dh.target_data]): + means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean) + stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std) + else: # serial solution + logging.info("use serial transformation approach") + for station in set_stations: + dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys) + if dh is not None: + for i, data in enumerate([dh.input_data, dh.target_data]): + means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean) + stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std) + if means[0] is None: return None transformation_class.inputs.mean = means[0].mean("Stations") @@ -268,4 +284,18 @@ class DefaultDataHandler(AbstractDataHandler): return transformation_class def get_coordinates(self): - return self.id_class.get_coordinates() \ No newline at end of file + return self.id_class.get_coordinates() + + +def f_proc(data_handler, station, **sp_keys): + """ + Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and + therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and + the station that was used. This function must be implemented globally to work together with multiprocessing. + """ + try: + res = data_handler(station, **sp_keys) + except (AttributeError, EmptyQueryResult, KeyError, ValueError) as e: + logging.info(f"remove station {station} because it raised an error: {e}") + res = None + return res, station diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py index 49569405a587920da795820d48f8d968a8142cc7..30c45417a64e949b0c0535a96a20c933641fdcbb 100644 --- a/mlair/data_handler/iterator.py +++ b/mlair/data_handler/iterator.py @@ -33,13 +33,18 @@ class StandardIterator(Iterator): class DataCollection(Iterable): - def __init__(self, collection: list = None): + def __init__(self, collection: list = None, name: str = None): if collection is None: collection = [] assert isinstance(collection, list) - self._collection = collection + self._collection = collection.copy() self._mapping = {} self._set_mapping() + self._name = name + + @property + def name(self): + return self._name def __len__(self): return len(self._collection) @@ -55,7 +60,7 @@ class DataCollection(Iterable): def add(self, element): self._collection.append(element) - self._mapping[str(element)] = len(self._collection) + self._mapping[str(element)] = len(self._collection) - 1 def _set_mapping(self): for i, e in enumerate(self._collection): @@ -114,9 +119,10 @@ class KerasIterator(keras.utils.Sequence): def _get_batch(self, data_list: List[np.ndarray], b: int) -> List[np.ndarray]: """Get batch according to batch size from data list.""" - return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list)) + return list(map(lambda data: data[b * self.batch_size:(b + 1) * self.batch_size, ...], data_list)) - def _permute_data(self, X, Y): + @staticmethod + def _permute_data(X, Y): p = np.random.permutation(len(X[0])) # equiv to .shape[0] X = list(map(lambda x: x[p], X)) Y = list(map(lambda x: x[p], Y)) @@ -179,35 +185,3 @@ class KerasIterator(keras.utils.Sequence): """Randomly shuffle indexes if enabled.""" if self.shuffle is True: np.random.shuffle(self.indexes) - - -class DummyData: # pragma: no cover - - def __init__(self, number_of_samples=np.random.randint(100, 150)): - self.number_of_samples = number_of_samples - - def get_X(self): - X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables - X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables - X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables - return [X1, X2, X3] - - def get_Y(self): - Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables - Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables - return [Y1, Y2] - - -if __name__ == "__main__": - - collection = [] - for _ in range(3): - collection.append(DummyData(50)) - - data_collection = DataCollection(collection=collection) - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") - iterator = KerasIterator(data_collection, 25, path, shuffle=True) - - for data in data_collection: - print(data) \ No newline at end of file diff --git a/mlair/helpers/__init__.py b/mlair/helpers/__init__.py index 9e2f612c86dc0477693567210493fbdcf3002954..4671334c16267be819ab8ee0ad96b7135ee01531 100644 --- a/mlair/helpers/__init__.py +++ b/mlair/helpers/__init__.py @@ -3,4 +3,4 @@ from .testing import PyTestRegex, PyTestAllEqual from .time_tracking import TimeTracking, TimeTrackingWrapper from .logger import Logger -from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value +from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict diff --git a/mlair/helpers/datastore.py b/mlair/helpers/datastore.py index b4615216000d887f16e6ed30d97215a261e12c6d..d6c977c717c5ef869fdba517fb36fcd55cfe3961 100644 --- a/mlair/helpers/datastore.py +++ b/mlair/helpers/datastore.py @@ -65,7 +65,7 @@ class CorrectScope: return self.wrapper(*args, **kwargs) def __get__(self, instance, cls): - """Create bound method object and supply self argument to the decorated method.""" + """Create bound method object and supply self argument to the decorated method. <Python Cookbook, p.347>""" return types.MethodType(self, instance) @staticmethod @@ -101,6 +101,7 @@ class CorrectScope: class TrackParameter: + """Hint: Tracking is not working for static methods.""" def __init__(self, func): """Construct decorator.""" @@ -114,7 +115,7 @@ class TrackParameter: return self.__wrapped__(*args, **kwargs) def __get__(self, instance, cls): - """Create bound method object and supply self argument to the decorated method.""" + """Create bound method object and supply self argument to the decorated method. <Python Cookbook, p.347>""" return types.MethodType(self, instance) def track(self, tracker_obj, *args): @@ -312,7 +313,7 @@ class DataStoreByVariable(AbstractDataStore): if name not in self._store.keys(): self._store[name] = {} self._store[name][scope] = obj - if log: + if log: # pragma: no cover logging.debug(f"set: {name}({scope})={obj}") @CorrectScope @@ -463,7 +464,7 @@ class DataStoreByScope(AbstractDataStore): if scope not in self._store.keys(): self._store[scope] = {} self._store[scope][name] = obj - if log: + if log: # pragma: no cover logging.debug(f"set: {name}({scope})={obj}") @CorrectScope diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 3ecf1f6213bf39d2e3571a1b451173b981a3dadf..42b66dcb68b184112a321473e3aae250d697c452 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -12,13 +12,15 @@ from typing import Dict, Callable, Union, List, Any def to_list(obj: Any) -> List: """ - Transform given object to list if obj is not already a list. + Transform given object to list if obj is not already a list. Sets are also transformed to a list. :param obj: object to transform to list :return: list containing obj, or obj itself (if obj was already a list) """ - if not isinstance(obj, list): + if isinstance(obj, (set, tuple)): + obj = list(obj) + elif not isinstance(obj, list): obj = [obj] return obj @@ -99,8 +101,27 @@ def remove_items(obj: Union[List, Dict], items: Any): raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") +def select_from_dict(dict_obj: dict, sel_list: Any): + """ + Extract all key values pairs whose key is contained in the sel_list. + + Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore the number of pairs in the + returned dict is always smaller or equal to the number of elements in the sel_list. + """ + sel_list = to_list(sel_list) + assert isinstance(dict_obj, dict) + sel_dict = {k: v for k, v in dict_obj.items() if k in sel_list} + return sel_dict + + def extract_value(encapsulated_value): try: - return extract_value(encapsulated_value[0]) + if isinstance(encapsulated_value, str): + raise TypeError + if len(encapsulated_value) == 1: + return extract_value(encapsulated_value[0]) + else: + raise NotImplementedError("Trying to extract an encapsulated value from objects with more than a single " + "entry is not supported by this function.") except TypeError: return encapsulated_value diff --git a/mlair/helpers/join.py b/mlair/helpers/join.py index 43a0176811b54fba2983c1dba108f4c7977f1431..8a8ca0b8c964268aa6043312cd1cc88bc0d50544 100644 --- a/mlair/helpers/join.py +++ b/mlair/helpers/join.py @@ -45,7 +45,15 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t join_url_base, headers = join_settings(sampling) # load series information - vars_dict = load_series_information(station_name, station_type, network_name, join_url_base, headers, data_origin) + vars_dict, data_origin = load_series_information(station_name, station_type, network_name, join_url_base, headers, + data_origin) + + # check if all requested variables are available + if set(stat_var).issubset(vars_dict) is False: + missing_variables = set(stat_var).difference(vars_dict) + origin = helpers.select_from_dict(data_origin, missing_variables) + options = f"station={station_name}, type={station_type}, network={network_name}, origin={origin}" + raise EmptyQueryResult(f"No data found for variables {missing_variables} and options {options} in JOIN.") # correct stat_var values if data is not aggregated (hourly) if sampling == "hourly": @@ -58,11 +66,11 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t for var in _lower_list(sorted(vars_dict.keys())): if var in stat_var.keys(): - logging.debug('load: {}'.format(var)) # ToDo start here for #206 + logging.debug('load: {}'.format(var)) # create data link opts = {'base': join_url_base, 'service': 'stats', 'id': vars_dict[var], 'statistics': stat_var[var], - 'sampling': sampling, 'capture': 0, 'min_data_length': 1460, 'format': 'json'} + 'sampling': sampling, 'capture': 0, 'format': 'json'} # load data data = get_data(opts, headers) @@ -122,11 +130,14 @@ def get_data(opts: Dict, headers: Dict) -> Union[Dict, List]: """ url = create_url(**opts) response = requests.get(url, headers=headers) - return response.json() + if response.status_code == 200: + return response.json() + else: + raise EmptyQueryResult(f"There was an error (STATUS {response.status_code}) for request {url}") def load_series_information(station_name: List[str], station_type: str_or_none, network_name: str_or_none, - join_url_base: str, headers: Dict, data_origin: Dict = None) -> Dict: + join_url_base: str, headers: Dict, data_origin: Dict = None) -> [Dict, Dict]: """ List all series ids that are available for given station id and network name. @@ -144,27 +155,30 @@ def load_series_information(station_name: List[str], station_type: str_or_none, "network_name": network_name, "as_dict": "true", "columns": "id,network_name,station_id,parameter_name,parameter_label,parameter_attribute"} station_vars = get_data(opts, headers) - logging.debug(f"{station_name}: {station_vars}") # ToDo start here for #206 + logging.debug(f"{station_name}: {station_vars}") return _select_distinct_series(station_vars, data_origin) -def _select_distinct_series(vars: List[Dict], data_origin: Dict = None): +def _select_distinct_series(vars: List[Dict], data_origin: Dict = None) -> [Dict, Dict]: """ Select distinct series ids for all variables. Also check if a parameter is from REA or not. """ + data_origin_default = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA", + "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", + "no": "", "no2": "", "o3": "", "pm10": "", "so2": ""} if data_origin is None: - data_origin = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA", - "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", - "no": "", "no2": "", "o3": "", "pm10": "", "so2": ""} + data_origin = {} # ToDo: maybe press, wdir, wspeed from obs? or also temp, ... ? selected = {} for var in vars: name = var["parameter_name"].lower() var_attr = var["parameter_attribute"].lower() + if name not in data_origin.keys(): + data_origin.update({name: data_origin_default.get(name, "")}) attr = data_origin.get(name, "").lower() if var_attr == attr: selected[name] = var["id"] - return selected + return selected, data_origin def _save_to_pandas(df: Union[pd.DataFrame, None], data: dict, stat: str, var: str) -> pd.DataFrame: diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py index 244eb69fdc46dcadaeb3ada5779f09d44aa83e2a..abb50883c7af49a0c1571d99f737e310abff9b13 100644 --- a/mlair/helpers/testing.py +++ b/mlair/helpers/testing.py @@ -35,54 +35,54 @@ class PyTestRegex: return self._regex.pattern -class PyTestAllEqual: - """ - Check if all elements in list are the same. - - :param check_list: list with elements to check - """ - - def __init__(self, check_list: List): - """Construct class.""" - self._list = check_list - self._test_function = None - - def _set_test_function(self): - if isinstance(self._list[0], np.ndarray): - self._test_function = np.testing.assert_array_equal - else: - self._test_function = xr.testing.assert_equal - - def _check_all_equal(self) -> bool: - """ - Check if all elements are equal. - - :return boolean if elements are equal +def PyTestAllEqual(check_list: List): + class PyTestAllEqualClass: """ - equal = True - self._set_test_function() - for b in self._list: - equal *= self._test_function(self._list[0], b) is None - return bool(equal == 1) + Check if all elements in list are the same. - def is_true(self) -> bool: + :param check_list: list with elements to check """ - Start equality check. - :return: true if equality test is passed, false otherwise - """ - return self._check_all_equal() - - -def xr_all_equal(check_list: List) -> bool: - """ - Check if all given elements (preferably xarray's) in list are equal. - - :param check_list: list with elements to check - - :return: boolean if all elements are the same or not - """ - equal = True - for b in check_list: - equal *= xr.testing.assert_equal(check_list[0], b) is None - return equal == 1 \ No newline at end of file + def __init__(self, check_list: List): + """Construct class.""" + self._list = check_list + self._test_function = None + + def _set_test_function(self, _list): + if isinstance(_list[0], list): + _test_function = self._set_test_function(_list[0]) + self._test_function = lambda r, s: all(map(lambda x, y: _test_function(x, y) is None, r, s)) + elif isinstance(_list[0], np.ndarray): + self._test_function = np.testing.assert_array_equal + elif isinstance(_list[0], xr.DataArray): + self._test_function = xr.testing.assert_equal + else: + self._test_function = lambda x, y: self._assert(x, y) + # raise TypeError(f"given type {type(_list[0])} is not supported by PyTestAllEqual.") + return self._test_function + + @staticmethod + def _assert(x, y): + assert x == y + + def _check_all_equal(self) -> bool: + """ + Check if all elements are equal. + + :return boolean if elements are equal + """ + equal = True + self._set_test_function(self._list) + for b in self._list: + equal *= self._test_function(self._list[0], b) in [None, True] + return bool(equal == 1) + + def is_true(self) -> bool: + """ + Start equality check. + + :return: true if equality test is passed, false otherwise + """ + return self._check_all_equal() + + return PyTestAllEqualClass(check_list).is_true() diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index f0b6baeb0b56126ccccb80c9da993fb406428d93..f775f7419dba7530d8dbfbde9250d38c312496fa 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -1,6 +1,6 @@ """Collection of plots to evaluate a model, create overviews on data or forecasts.""" __author__ = "Lukas Leufen, Felix Kleinert" -__date__ = '2019-12-17' +__date__ = '2020-11-23' import logging import math @@ -8,10 +8,11 @@ import os import warnings from typing import Dict, List, Tuple - import matplotlib import matplotlib.patches as mpatches +import matplotlib.lines as mlines import matplotlib.pyplot as plt +import matplotlib.dates as mdates import numpy as np import pandas as pd import seaborn as sns @@ -75,7 +76,7 @@ class AbstractPlotClass: """ - def __init__(self, plot_folder, plot_name, resolution=500): + def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None): """Set up plot folder and name, and plot resolution (default 500dpi).""" plot_folder = os.path.abspath(plot_folder) if not os.path.exists(plot_folder): @@ -83,6 +84,15 @@ class AbstractPlotClass: self.plot_folder = plot_folder self.plot_name = plot_name self.resolution = resolution + if rc_params is None: + rc_params = {'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'large', + 'axes.titlesize': 'large', + } + self.rc_params = rc_params + self._update_rc_params() def _plot(self, *args): """Abstract plot class needs to be implemented in inheritance.""" @@ -95,6 +105,25 @@ class AbstractPlotClass: plt.savefig(plot_name, dpi=self.resolution, **kwargs) plt.close('all') + def _update_rc_params(self): + plt.rcParams.update(self.rc_params) + + @staticmethod + def _get_sampling(sampling): + if sampling == "daily": + return "D" + elif sampling == "hourly": + return "h" + + @staticmethod + def get_dataset_colors(): + """ + Standard colors used for train-, val-, and test-sets during postprocessing + """ + colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code + return colors + + @TimeTrackingWrapper class PlotMonthlySummary(AbstractPlotClass): @@ -113,18 +142,19 @@ class PlotMonthlySummary(AbstractPlotClass): :param window_lead_time: lead time to plot, if window_lead_time is higher than the available lead time or not given the maximum lead time from data is used. (default None -> use maximum lead time from data). :param plot_folder: path to save the plot (default: current directory) + :param target_var_unit: unit of target var for plot legend (default= ppb) """ def __init__(self, stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None, - plot_folder: str = "."): + plot_folder: str = ".", target_var_unit: str = 'ppb'): """Set attributes and create plot.""" super().__init__(plot_folder, "monthly_summary_box_plot") self._data_path = data_path self._data_name = name self._data = self._prepare_data(stations) self._window_lead_time = self._get_window_lead_time(window_lead_time) - self._plot(target_var) + self._plot(target_var, target_var_unit) self._save() def _prepare_data(self, stations: List) -> xr.DataArray: @@ -176,7 +206,12 @@ class PlotMonthlySummary(AbstractPlotClass): window_lead_time = ahead_steps return min(ahead_steps, window_lead_time) - def _plot(self, target_var: str): + @staticmethod + def _spell_out_chemical_concentrations(short_name: str): + short2long = {'o3': 'ozone', 'no': 'nitrogen oxide', 'no2': 'nitrogen dioxide', 'nox': 'nitrogen dioxides'} + return f"{short2long[short_name]} concentration" + + def _plot(self, target_var: str, target_var_unit: str): """ Create a monthly grouped box plot over all stations but with separate boxes for each lead time step. @@ -189,7 +224,8 @@ class PlotMonthlySummary(AbstractPlotClass): ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1., palette=color_palette, flierprops={'marker': '.', 'markersize': 1}, showmeans=True, meanprops={'markersize': 1, 'markeredgecolor': 'k'}) - ax.set(xlabel='month', ylabel=f'{target_var}') + ylabel = self._spell_out_chemical_concentrations(target_var) + ax.set(xlabel='month', ylabel=f'{ylabel} (in {target_var_unit})') plt.tight_layout() @@ -206,7 +242,7 @@ class PlotStationMap(AbstractPlotClass): :width: 400 """ - def __init__(self, generators: Dict, plot_folder: str = "."): + def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"): """ Set attributes and create plot. @@ -214,11 +250,11 @@ class PlotStationMap(AbstractPlotClass): as value. :param plot_folder: path to save the plot (default: current directory) """ - super().__init__(plot_folder, "station_map") + super().__init__(plot_folder, plot_name) self._ax = None self._gl = None self._plot(generators) - self._save() + self._save(bbox_inches="tight") def _draw_background(self): """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" @@ -245,13 +281,44 @@ class PlotStationMap(AbstractPlotClass): import cartopy.crs as ccrs if generators is not None: - for color, data_collection in generators.items(): + legend_elements = [] + default_colors = self.get_dataset_colors() + for element in generators: + data_collection, plot_opts = self._get_collection_and_opts(element) + name = data_collection.name or "unknown" + marker = plot_opts.get("marker", "s") + ms = plot_opts.get("ms", 6) + mec = plot_opts.get("mec", "k") + mfc = plot_opts.get("mfc", default_colors.get(name, "b")) + legend_elements.append( + mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None', + label=f"{name} ({len(data_collection)})")) for station in data_collection: coords = station.get_coordinates() IDx, IDy = coords["lon"], coords["lat"] - self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree()) + self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree()) + if len(legend_elements) > 0: + self._ax.legend(handles=legend_elements, loc='best') + + @staticmethod + def _adjust_marker(marker): + _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"} + if isinstance(marker, int) and marker in _adjust.keys(): + return _adjust[marker] + else: + return marker + + @staticmethod + def _get_collection_and_opts(element): + if isinstance(element, tuple): + if len(element) == 1: + return element[0], {} + else: + return element + else: + return element, {} - def _plot(self, generators: Dict): + def _plot(self, generators: List): """ Create the station map plot. @@ -453,7 +520,8 @@ class PlotConditionalQuantiles(AbstractPlotClass): def _plot(self): """Start plotting routines: overall plot and seasonal (if enabled).""" - logging.info(f"start plotting {self.__class__.__name__}, scheduled number of plots: {(len(self._seasons) + 1) * 2}") + logging.info( + f"start plotting {self.__class__.__name__}, scheduled number of plots: {(len(self._seasons) + 1) * 2}") if len(self._seasons) > 0: self._plot_seasons() @@ -504,7 +572,7 @@ class PlotConditionalQuantiles(AbstractPlotClass): # add histogram of the segmented data (pred_name) handles, labels = ax.get_legend_handles_labels() segmented_data.loc[x_model, d, :].to_pandas().hist(bins=self._bins, ax=ax2, color='k', alpha=.3, grid=False, - rwidth=1) + rwidth=1) # add legend plt.legend(handles[:3] + [handles[-1]], self._opts["legend"], loc='upper left', fontsize='large') # adjust limits and set labels @@ -688,25 +756,37 @@ class PlotBootstrapSkillScore(AbstractPlotClass): (score_only=True, default). Y-axis is adjusted following the data and not hard coded. The plot is saved under plot_folder path with name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. + By passing a list `separate_vars` containing variable names, a second plot is created showing the `separate_vars` + and the remaining variables side by side with different scaling. + .. image:: ../../../../../_source/_plots/skill_score_bootstrap.png :width: 400 + .. image:: ../../../../../_source/_plots/skill_score_bootstrap_separated.png + :width: 400 + """ - def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = ""): + def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = "", separate_vars: List = None): """ Set attributes and create plot. :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. :param plot_folder: path to save the plot (default: current directory) :param model_setup: architecture type to specify plot name (default "CNN") + :param separate_vars: variables to plot separated (default: ['o3']) """ super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}") + if separate_vars is None: + separate_vars = ['o3'] self._labels = None self._x_name = "boot_var" self._data = self._prepare_data(data) self._plot() self._save() + self.plot_name += '_separated' + self._plot(separate_vars=separate_vars) + self._save(bbox_inches='tight') def _prepare_data(self, data: Dict) -> pd.DataFrame: """ @@ -719,11 +799,30 @@ class PlotBootstrapSkillScore(AbstractPlotClass): :return: pre-processed data set """ data = helpers.dict_to_xarray(data, "station").sortby(self._x_name) + new_boot_coords = self._return_vars_without_number_tag(data.coords['boot_var'].values, split_by='_', keep=1) + data = data.assign_coords({'boot_var': new_boot_coords}) self._labels = [str(i) + "d" for i in data.coords["ahead"].values] if "station" not in data.dims: data = data.expand_dims("station") return data.to_dataframe("data").reset_index(level=[0, 1, 2]) + def _return_vars_without_number_tag(self, values, split_by, keep): + arr = np.array([v.split(split_by) for v in values]) + num = arr[:, 0] + new_val = arr[:, keep] + if self._all_values_are_equal(num, axis=0): + return new_val + else: + raise NotImplementedError + + + @staticmethod + def _all_values_are_equal(arr, axis=0): + if np.all(arr == arr[0], axis=axis): + return True + else: + return False + def _label_add(self, score_only: bool): """ Add the phrase "terms and " if score_only is disabled or empty string (if score_only=True). @@ -733,12 +832,111 @@ class PlotBootstrapSkillScore(AbstractPlotClass): """ return "" if score_only else "terms and " - def _plot(self): + def _plot(self, separate_vars=None): """Plot climatological skill score.""" + if separate_vars is None: + self._plot_all_variables() + else: + self._plot_selected_variables(separate_vars) + + def _plot_selected_variables(self, separate_vars: List): + # if separate_vars is None: + # separate_vars = ['o3'] + data = self._data + self.raise_error_if_separate_vars_do_not_exist(data, separate_vars) + all_variables = self._get_unique_values_from_column_of_df(data, 'boot_var') + # remaining_vars = helpers.list_pop(all_variables, separate_vars) #remove_items + remaining_vars = helpers.remove_items(all_variables, separate_vars) + data_first = self._select_data(df=data, variables=separate_vars, column_name='boot_var') + data_second = self._select_data(df=data, variables=remaining_vars, column_name='boot_var') + + fig, ax = plt.subplots(nrows=1, ncols=2, + gridspec_kw={'width_ratios': [len(separate_vars), + len(remaining_vars) + ] + } + ) + if len(separate_vars) > 1: + first_box_width = .8 + else: + first_box_width = 2. + + sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_first, ax=ax[0], whis=1., palette="Blues_d", + showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, + flierprops={"marker": "."}, width=first_box_width + ) + ax[0].set(ylabel=f"skill score", xlabel="") + + sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_second, ax=ax[1], whis=1., palette="Blues_d", + showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, + flierprops={"marker": "."}, + ) + ax[1].set(ylabel="", xlabel="") + ax[1].yaxis.tick_right() + handles, _ = ax[1].get_legend_handles_labels() + for sax in ax: + matplotlib.pyplot.sca(sax) + sax.axhline(y=0, color="grey", linewidth=.5) + plt.xticks(rotation=45, ha='right') + sax.legend_.remove() + + fig.legend(handles, self._labels, loc='upper center', ncol=len(handles) + 1, ) + + def align_yaxis(ax1, ax2): + """ + Align zeros of the two axes, zooming them out by same ratio + + This function is copy pasted from https://stackoverflow.com/a/41259922 + """ + axes = (ax1, ax2) + extrema = [ax.get_ylim() for ax in axes] + tops = [extr[1] / (extr[1] - extr[0]) for extr in extrema] + # Ensure that plots (intervals) are ordered bottom to top: + if tops[0] > tops[1]: + axes, extrema, tops = [list(reversed(l)) for l in (axes, extrema, tops)] + + # How much would the plot overflow if we kept current zoom levels? + tot_span = tops[1] + 1 - tops[0] + + b_new_t = extrema[0][0] + tot_span * (extrema[0][1] - extrema[0][0]) + t_new_b = extrema[1][1] - tot_span * (extrema[1][1] - extrema[1][0]) + axes[0].set_ylim(extrema[0][0], b_new_t) + axes[1].set_ylim(t_new_b, extrema[1][1]) + + align_yaxis(ax[0], ax[1]) + align_yaxis(ax[0], ax[1]) + + @staticmethod + def _select_data(df: pd.DataFrame, variables: List[str], column_name: str) -> pd.DataFrame: + for i, variable in enumerate(variables): + if i == 0: + selected_data = df.loc[df[column_name] == variable] + else: + tmp_var = df.loc[df[column_name] == variable] + selected_data = pd.concat([selected_data, tmp_var], axis=0) + return selected_data + + def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars): + if not self._variables_exist_in_df(df=data, variables=separate_vars): + raise ValueError(f"At least one entry of `separate_vars' does not exist in `self.data' ") + + @staticmethod + def _get_unique_values_from_column_of_df(df: pd.DataFrame, column_name: str) -> List: + return list(df[column_name].unique()) + + def _variables_exist_in_df(self, df: pd.DataFrame, variables: List[str], column_name: str = 'boot_var'): + vars_in_df = set(self._get_unique_values_from_column_of_df(df, column_name)) + return set(variables).issubset(vars_in_df) + + def _plot_all_variables(self): + """ + + """ fig, ax = plt.subplots() sns.boxplot(x=self._x_name, y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."}) ax.axhline(y=0, color="grey", linewidth=.5) + plt.xticks(rotation=45) ax.set(ylabel=f"skill score", xlabel="", title="summary of all stations") handles, _ = ax.get_legend_handles_labels() ax.legend(handles, self._labels) @@ -859,6 +1057,7 @@ class PlotTimeSeries: def _get_time_range(data): def f(x, f_x): return pd.to_datetime(f_x(x.index.values)).year + return f(data, min), f(data, max) @staticmethod @@ -911,8 +1110,10 @@ class PlotAvailability(AbstractPlotClass): # create standard Gantt plot for all stations (currently in single pdf file with single page) super().__init__(plot_folder, "data_availability") self.dim = time_dimension - self.linewidth = None self.sampling = self._get_sampling(sampling) + self.linewidth = None + if self.sampling == 'h': + self.linewidth = 0.001 plot_dict = self._prepare_data(generators) lgd = self._plot(plot_dict) self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") @@ -927,13 +1128,6 @@ class PlotAvailability(AbstractPlotClass): lgd = self._plot(plot_dict_summary) self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") - def _get_sampling(self, sampling): - if sampling == "daily": - return "D" - elif sampling == "hourly": - self.linewidth = 0.001 - return "h" - def _prepare_data(self, generators: Dict[str, DataCollection]): plt_dict = {} for subset, data_collection in generators.items(): @@ -978,9 +1172,7 @@ class PlotAvailability(AbstractPlotClass): return plt_dict def _plot(self, plt_dict): - # colors = {"train": "orange", "val": "blueishgreen", "test": "skyblue"} # color names - colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"} # hex code - # colors = {"train": (230, 159, 0), "val": (0, 158, 115), "test": (86, 180, 233)} # in rgb but as abs values + colors = self.get_dataset_colors() pos = 0 height = 0.8 # should be <= 1 yticklabels = [] @@ -1025,6 +1217,178 @@ class PlotSeparationOfScales(AbstractPlotClass): self._save() +@TimeTrackingWrapper +class PlotAvailabilityHistogram(AbstractPlotClass): + """ + Create data availability plots as histogram. + + Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean). + Calling this class creates two different types of histograms where each generator + + 1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis) + 2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number + of samples (yaxis) + + .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png + :width: 400 + + """ + + def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", + subset_dim: str = 'DataSet', history_dim: str = 'window', + station_dim: str = 'Stations',): + + super().__init__(plot_folder, "data_availability_histogram") + + self.subset_dim = subset_dim + self.history_dim = history_dim + self.station_dim = station_dim + + self.freq = None + self.temporal_dim = None + self.target_dim = None + self._prepare_data(generators) + + for plt_type in self.allowed_plot_types: + plot_name_tmp = self.plot_name + self.plot_name += '_' + plt_type + self._plot(plt_type=plt_type) + self._save() + self.plot_name = plot_name_tmp + + def _set_dims_from_datahandler(self, data_handler): + self.temporal_dim = data_handler.id_class.time_dim + self.target_dim = data_handler.id_class.target_dim + self.freq = self._get_sampling(data_handler.id_class.sampling) + + @property + def allowed_plot_types(self): + plot_types = ['hist', 'hist_cum'] + return plot_types + + def _prepare_data(self, generators: Dict[str, DataCollection]): + """ + Prepares data to be used by plot methods. + + Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim + """ + avail_data_time_sum = {} + avail_data_station_sum = {} + dataset_time_interval = {} + for subset, generator in generators.items(): + avail_list = [] + for station in generator: + self._set_dims_from_datahandler(data_handler=station) + station_data_x = station.get_X(as_numpy=False)[0] + station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame + self.target_dim: station_data_x[self.target_dim].values[0]}] + station_data_x = self._reduce_dims(station_data_x) + avail_list.append(station_data_x.notnull()) + avail_data = xr.concat(avail_list, dim=self.station_dim).notnull() + avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim) + avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim) + dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray( + avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict' + ) + avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(), + name=self.subset_dim) + ) + full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq) + self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(), + name=self.subset_dim)) + self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index}) + self.dataset_time_interval = dataset_time_interval + + def _reduce_dims(self, dataset): + if len(dataset.dims) > 2: + required = {self.temporal_dim, self.station_dim} + unimportant = set(dataset.dims).difference(required) + sel_dict = {un: dataset[un].values[0] for un in unimportant} + dataset = dataset.loc[sel_dict] + return dataset + + @staticmethod + def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'): + if isinstance(xarray, xr.DataArray): + first = xarray.coords[dim_name].values[0] + last = xarray.coords[dim_name].values[-1] + if return_type == 'as_tuple': + return first, last + elif return_type == 'as_dict': + return {'first': first, 'last': last} + else: + raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'") + else: + raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}") + + @staticmethod + def _make_full_time_index(irregular_time_index, freq): + full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq) + return full_time_index + + def _plot(self, plt_type='hist', *args): + if plt_type == 'hist': + self._plot_hist() + elif plt_type == 'hist_cum': + self._plot_hist_cum() + else: + raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}") + + def _plot_hist(self, *args): + colors = self.get_dataset_colors() + fig, axes = plt.subplots(figsize=(10, 3)) + for i, subset in enumerate(self.dataset_time_interval.keys()): + plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset, + self.temporal_dim: slice( + self.dataset_time_interval[subset]['first'], + self.dataset_time_interval[subset]['last'] + ) + } + ) + + plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset) + plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset]) + + lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), + facecolor='white', framealpha=1, edgecolor='black') + for lgd_line in lgd.get_lines(): + lgd_line.set_linewidth(4.0) + plt.gca().xaxis.set_major_locator(mdates.YearLocator()) + plt.title('') + plt.ylabel('Number of samples') + plt.tight_layout() + + def _plot_hist_cum(self, *args): + colors = self.get_dataset_colors() + fig, axes = plt.subplots(figsize=(10, 3)) + n_bins = int(self.avail_data_cum_sum.max().values) + bins = np.arange(0, n_bins+1) + descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby( + self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False + ).coords[self.subset_dim].values + + for subset in descending_subsets: + self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes, + bins=bins, + label=subset, + cumulative=-1, + color=colors[subset], + # alpha=.5 + ) + + lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), + facecolor='white', framealpha=1, edgecolor='black') + plt.title('') + plt.ylabel('Number of stations') + plt.xlabel('Number of samples') + plt.xlim((bins[0], bins[-1])) + plt.tight_layout() + + + if __name__ == "__main__": stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] path = "../../testrun_network/forecasts" diff --git a/mlair/plotting/tracker_plot.py b/mlair/plotting/tracker_plot.py index 406c32feb1ebda2d32d886051e32778d6c17f5db..53ec7496e7e04da0f53b1d0ce817793dea732963 100644 --- a/mlair/plotting/tracker_plot.py +++ b/mlair/plotting/tracker_plot.py @@ -119,11 +119,13 @@ class TrackChain: control_obj = control[variable][scope] if method == "set": track_objects = self._add_set_object(track_objects, tr, control_obj) - elif method == "get": + elif method == "get": # pragma: no branch track_objects, skip_control_update = self._add_get_object(track_objects, tr, control_obj, control, scope, variable) if skip_control_update is True: continue + else: # pragma: no cover + raise ValueError(f"method must be either set or get but given was {method}.") self._update_control(control, variable, scope, tr) return track_objects, control diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 9a9253eda522c39f348dd96700ed38730e87f9a8..54d2307718bf083cfbfb8296682c9c545157eb72 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -225,8 +225,8 @@ class ExperimentSetup(RunEnvironment): extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None, create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, - hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, sampling_inputs=None, - sampling_outputs=None, data_origin: Dict = None, **kwargs): + hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, + data_origin: Dict = None, **kwargs): # create run framework super().__init__() @@ -294,7 +294,6 @@ class ExperimentSetup(RunEnvironment): self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE) self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA, scope="preprocessing") - self._set_param("sampling_inputs", sampling_inputs, default=sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") self._set_param("data_handler", data_handler, default=DefaultDataHandler) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 3dc91cbd54094f116f0d959fb9c845751e998464..cd8ee266a092a6afb5111ad7b241b38cdbfb6fa7 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -19,8 +19,8 @@ from mlair.helpers import TimeTracking, statistics, extract_value from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules.model_class import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ - PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles, \ - PlotSeparationOfScales + PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotAvailabilityHistogram, \ + PlotConditionalQuantiles, PlotSeparationOfScales from mlair.run_modules.run_environment import RunEnvironment @@ -239,6 +239,7 @@ class PostProcessing(RunEnvironment): model = keras.models.load_model(model_name, custom_objects=model_class.custom_objects) return model + # noinspection PyBroadException def plot(self): """ Create all plots. @@ -257,42 +258,89 @@ class PostProcessing(RunEnvironment): .. note:: Bootstrap plots are only created if bootstraps are evaluated. """ - logging.debug("Run plotting routines...") + logging.info("Run plotting routines...") path = self.data_store.get("forecast_path") plot_list = self.data_store.get("plot_list", "postprocessing") time_dimension = self.data_store.get("time_dim") - if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and ("PlotSeparationOfScales" in plot_list): - PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path) - - if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list): - PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN") - - if "PlotConditionalQuantiles" in plot_list: - PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=path, plot_folder=self.plot_path) - if "PlotStationMap" in plot_list: - if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get( - "hostname")[:6] in self.data_store.get("hpc_hosts"): - logging.warning( - f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") - else: - PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) - if "PlotMonthlySummary" in plot_list: - PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var, - plot_folder=self.plot_path) - if "PlotClimatologicalSkillScore" in plot_list: - PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") - PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, - extra_name_tag="all_terms_", model_setup="CNN") - if "PlotCompetitiveSkillScore" in plot_list: - PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") - if "PlotTimeSeries" in plot_list: - PlotTimeSeries(self.test_data.keys(), path, r"forecasts_%s_test.nc", plot_folder=self.plot_path, - sampling=self._sampling) - if "PlotAvailability" in plot_list: - avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} - PlotAvailability(avail_data, plot_folder=self.plot_path, time_dimension=time_dimension) + try: + if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and ( + "PlotSeparationOfScales" in plot_list): + PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path) + except Exception as e: + logging.error(f"Could not create plot PlotSeparationOfScales due to the following error: {e}") + + try: + if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list): + PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN") + except Exception as e: + logging.error(f"Could not create plot PlotBootstrapSkillScore due to the following error: {e}") + + try: + if "PlotConditionalQuantiles" in plot_list: + PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=path, plot_folder=self.plot_path) + except Exception as e: + logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error: {e}") + + try: + if "PlotStationMap" in plot_list: + if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get( + "hostname")[:6] in self.data_store.get("hpc_hosts"): + logging.warning( + f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") + else: + gens = [(self.train_data, {"marker": 5, "ms": 9}), + (self.val_data, {"marker": 6, "ms": 9}), + (self.test_data, {"marker": 4, "ms": 9})] + PlotStationMap(generators=gens, plot_folder=self.plot_path) + gens = [(self.train_val_data, {"marker": 8, "ms": 9}), + (self.test_data, {"marker": 9, "ms": 9})] + PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var") + except Exception as e: + logging.error(f"Could not create plot PlotStationMap due to the following error: {e}") + + try: + if "PlotMonthlySummary" in plot_list: + PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var, + plot_folder=self.plot_path) + except Exception as e: + logging.error(f"Could not create plot PlotMonthlySummary due to the following error: {e}") + + try: + if "PlotClimatologicalSkillScore" in plot_list: + PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") + PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, + extra_name_tag="all_terms_", model_setup="CNN") + except Exception as e: + logging.error(f"Could not create plot PlotClimatologicalSkillScore due to the following error: {e}") + + try: + if "PlotCompetitiveSkillScore" in plot_list: + PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") + except Exception as e: + logging.error(f"Could not create plot PlotCompetitiveSkillScore due to the following error: {e}") + + try: + if "PlotTimeSeries" in plot_list: + PlotTimeSeries(self.test_data.keys(), path, r"forecasts_%s_test.nc", plot_folder=self.plot_path, + sampling=self._sampling) + except Exception as e: + logging.error(f"Could not create plot PlotTimeSeries due to the following error: {e}") + + try: + if "PlotAvailability" in plot_list: + avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} + PlotAvailability(avail_data, plot_folder=self.plot_path, time_dimension=time_dimension) + except Exception as e: + logging.error(f"Could not create plot PlotAvailability due to the following error: {e}") + + try: + if "PlotAvailabilityHistogram" in plot_list: + avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} + PlotAvailabilityHistogram(avail_data, plot_folder=self.plot_path, ) # time_dimension=time_dimension) + except Exception as e: + logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}") def calculate_test_score(self): """Evaluate test score of model and save locally.""" diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 4cee4a9744f33c86e8802aad27125cf0e0b30f3a..21aebd62bab490363797c0ef0624daa1d488097b 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -6,6 +6,8 @@ __date__ = '2019-11-25' import logging import os from typing import Tuple +import multiprocessing +import requests import numpy as np import pandas as pd @@ -113,9 +115,48 @@ class PreProcessing(RunEnvironment): precision = 4 path = os.path.join(self.data_store.get("experiment_path"), "latex_report") path_config.check_path_and_create(path) - set_names = ["train", "val", "test"] - df = pd.DataFrame(columns=meta_data + set_names) - for set_name in set_names: + names_of_set = ["train", "val", "test"] + df = self.create_info_df(meta_data, meta_round, names_of_set, precision) + column_format = self.create_column_format_for_tex(df) + self.save_to_tex(path=path, filename="station_sample_size.tex", column_format=column_format, df=df) + self.save_to_md(path=path, filename="station_sample_size.md", df=df) + df_nometa = df.drop(meta_data, axis=1) + column_format = self.create_column_format_for_tex(df) + self.save_to_tex(path=path, filename="station_sample_size_short.tex", column_format=column_format, df=df_nometa) + self.save_to_md(path=path, filename="station_sample_size_short.md", df=df_nometa) + # df_nometa.to_latex(os.path.join(path, "station_sample_size_short.tex"), na_rep='---', + # column_format=column_format) + df_descr = self.create_describe_df(df_nometa) + column_format = self.create_column_format_for_tex(df_descr) + self.save_to_tex(path=path, filename="station_describe_short.tex", column_format=column_format, df=df_descr) + self.save_to_md(path=path, filename="station_describe_short.md", df=df_descr) + # df_descr.to_latex(os.path.join(path, "station_describe_short.tex"), na_rep='---', column_format=column_format) + + @staticmethod + def create_describe_df(df, percentiles=None, ignore_last_lines: int = 2): + if percentiles is None: + percentiles = [.05, .1, .25, .5, .75, .9, .95] + df_descr = df.iloc[:-ignore_last_lines].astype('float32').describe( + percentiles=percentiles).astype("int32", errors="ignore") + df_descr = pd.concat([df.loc[['# Samples']], df_descr]).T + df_descr.rename(columns={"# Samples": "no. samples", "count": "no. stations"}, inplace=True) + df_descr_colnames = list(df_descr.columns) + df_descr_colnames = [df_descr_colnames[1]] + [df_descr_colnames[0]] + df_descr_colnames[2:] + df_descr = df_descr[df_descr_colnames] + return df_descr + + @staticmethod + def save_to_tex(path, filename, column_format, df, na_rep='---'): + df.to_latex(os.path.join(path, filename), na_rep=na_rep, column_format=column_format) + + @staticmethod + def save_to_md(path, filename, df, mode="w", encoding='utf-8', tablefmt="github"): + df.to_markdown(open(os.path.join(path, filename), mode=mode, encoding=encoding), + tablefmt=tablefmt) + + def create_info_df(self, meta_data, meta_round, names_of_set, precision): + df = pd.DataFrame(columns=meta_data + names_of_set) + for set_name in names_of_set: data = self.data_store.get("data_collection", set_name) for station in data: station_name = str(station.id_class) @@ -123,20 +164,27 @@ class PreProcessing(RunEnvironment): if df.loc[station_name, meta_data].isnull().any(): df.loc[station_name, meta_data] = station.id_class.meta.loc[meta_data].values.flatten() df.loc["# Samples", set_name] = df.loc[:, set_name].sum() - df.loc["# Stations", set_name] = df.loc[:, set_name].count() + assert len(data) == df.loc[:, set_name].count() - 1 + df.loc["# Stations", set_name] = len(data) df[meta_round] = df[meta_round].astype(float).round(precision) df.sort_index(inplace=True) df = df.reindex(df.index.drop(["# Stations", "# Samples"]).to_list() + ["# Stations", "# Samples"], ) df.index.name = 'stat. ID' + return df + + @staticmethod + def create_column_format_for_tex(df: pd.DataFrame) -> str: + """ + Creates column format for latex table based on the shape of a given DataFrame. + + Calculates number of columns and uses 'c' as column position. First element is set to 'l', last to 'r' + """ column_format = np.repeat('c', df.shape[1] + 1) column_format[0] = 'l' column_format[-1] = 'r' column_format = ''.join(column_format.tolist()) - df.to_latex(os.path.join(path, "station_sample_size.tex"), na_rep='---', column_format=column_format) - df.to_markdown(open(os.path.join(path, "station_sample_size.md"), mode="w", encoding='utf-8'), - tablefmt="github") - df.drop(meta_data, axis=1).to_latex(os.path.join(path, "station_sample_size_short.tex"), na_rep='---', - column_format=column_format) + return column_format + def split_train_val_test(self) -> None: """ @@ -201,6 +249,51 @@ class PreProcessing(RunEnvironment): Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. + :return: Corrected list containing only valid station IDs. + """ + t_outer = TimeTracking() + logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}") + # calculate transformation using train data + if set_name == "train": + logging.info("setup transformation using train data exclusively") + self.transformation(data_handler, set_stations) + # start station check + collection = DataCollection(name=set_name) + valid_stations = [] + kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name) + + if multiprocessing.cpu_count() > 1: # parallel solution + logging.info("use parallel validate station approach") + pool = multiprocessing.Pool() + logging.info(f"running {getattr(pool, '_processes')} processes in parallel") + output = [ + pool.apply_async(f_proc, args=(data_handler, station, set_name, store_processed_data), kwds=kwargs) + for station in set_stations] + for p in output: + dh, s = p.get() + if dh is not None: + collection.add(dh) + valid_stations.append(s) + else: # serial solution + logging.info("use serial validate station approach") + for station in set_stations: + dh, s = f_proc(data_handler, station, set_name, store_processed_data, **kwargs) + if dh is not None: + collection.add(dh) + valid_stations.append(s) + + logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/" + f"{len(set_stations)} valid stations.") + return collection, valid_stations + + def validate_station_old(self, data_handler: AbstractDataHandler, set_stations, set_name=None, + store_processed_data=True): + """ + Check if all given stations in `all_stations` are valid. + + Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the + loading time are logged in debug mode. + :return: Corrected list containing only valid station IDs. """ t_outer = TimeTracking() @@ -231,3 +324,18 @@ class PreProcessing(RunEnvironment): transformation_dict = data_handler.transformation(stations, **kwargs) if transformation_dict is not None: self.data_store.set("transformation", transformation_dict) + + +def f_proc(data_handler, station, name_affix, store, **kwargs): + """ + Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and + therefore f_proc will return None as indication. On a successfull build, f_proc returns the built data handler and + the station that was used. This function must be implemented globally to work together with multiprocessing. + """ + try: + res = data_handler.build(station, name_affix=name_affix, store_processed_data=store, + **kwargs) + except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError) as e: + logging.info(f"remove station {station} because it raised an error: {e}") + res = None + return res, station diff --git a/mlair/workflows/abstract_workflow.py b/mlair/workflows/abstract_workflow.py index 3a627d9f72a5c1c97c35b464af1b0944bc397ea5..c969aa35ebca60aa749a294bcaa5de727407a461 100644 --- a/mlair/workflows/abstract_workflow.py +++ b/mlair/workflows/abstract_workflow.py @@ -3,8 +3,6 @@ __author__ = "Lukas Leufen" __date__ = '2020-06-26' -from collections import OrderedDict - from mlair import RunEnvironment diff --git a/mlair/workflows/default_workflow.py b/mlair/workflows/default_workflow.py index 4d113190fdc90ec852d7db2b33459b9162867a24..5894555a6af52299efcd8d88d76c0d3791a1599e 100644 --- a/mlair/workflows/default_workflow.py +++ b/mlair/workflows/default_workflow.py @@ -54,41 +54,10 @@ class DefaultWorkflow(Workflow): self.add(PostProcessing) -class DefaultWorkflowHPC(Workflow): +class DefaultWorkflowHPC(DefaultWorkflow): """A default workflow for Jülich HPC systems executing ExperimentSetup, PreProcessing, PartitionCheck, ModelSetup, Training and PostProcessing in exact the mentioned ordering.""" - def __init__(self, stations=None, - train_model=None, create_new_model=None, - window_history_size=None, - experiment_date="testrun", - variables=None, statistics_per_var=None, - start=None, end=None, - target_var=None, target_dim=None, - window_lead_time=None, - dimensions=None, - interpolation_method=None, time_dim=None, limit_nan_fill=None, - train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, - use_all_stations_on_all_data_sets=None, fraction_of_train=None, - experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, - overwrite_local_data=None, - sampling=None, - permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None, - transformation=None, - train_min_length=None, val_min_length=None, test_min_length=None, - evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None, - plot_list=None, - model=None, - batch_size=None, - epochs=None, - data_handler=None, **kwargs): - super().__init__() - - # extract all given kwargs arguments - params = remove_items(inspect.getfullargspec(self.__init__).args, "self") - kwargs_default = {k: v for k, v in locals().items() if k in params and v is not None} - self._setup(**kwargs_default, **kwargs) - def _setup(self, **kwargs): """Set up default workflow.""" self.add(ExperimentSetup, **kwargs) diff --git a/run_hourly.py b/run_hourly.py index b831cf1e1ee733a3c652c6cea364013b44cf2c0d..a21c779bc007c7fbe67c98584687be3954e1d62c 100644 --- a/run_hourly.py +++ b/run_hourly.py @@ -6,6 +6,17 @@ import argparse from mlair.workflows import DefaultWorkflow +def load_stations(): + import json + try: + filename = 'supplement/station_list_north_german_plain.json' + with open(filename, 'r') as jfile: + stations = json.load(jfile) + except FileNotFoundError: + stations = None + return stations + + def main(parser_args): workflow = DefaultWorkflow(sampling="hourly", window_history_size=48, **parser_args.__dict__) diff --git a/run_mixed_sampling.py b/run_mixed_sampling.py index a87e9f38e9379d10f3472009934b61acb2d147ff..04683a17ede641a5370aaeef741d2f4546f966b7 100644 --- a/run_mixed_sampling.py +++ b/run_mixed_sampling.py @@ -7,22 +7,34 @@ from mlair.workflows import DefaultWorkflow from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \ DataHandlerSeparationOfScales +stats = {'o3': 'dma8eu', 'no': 'dma8eu', 'no2': 'dma8eu', + 'relhum': 'average_values', 'u': 'average_values', 'v': 'average_values', + 'cloudcover': 'average_values', 'pblheight': 'maximum', + 'temp': 'maximum'} +data_origin = {'o3': '', 'no': '', 'no2': '', + 'relhum': 'REA', 'u': 'REA', 'v': 'REA', + 'cloudcover': 'REA', 'pblheight': 'REA', + 'temp': 'REA'} + def main(parser_args): - args = dict(sampling="daily", - sampling_inputs="hourly", - window_history_size=24, - **parser_args.__dict__, - data_handler=DataHandlerSeparationOfScales, + args = dict(stations=["DEBW107", "DEBW013"], + network="UBA", + evaluate_bootstraps=False, plot_list=[], + data_origin=data_origin, data_handler=DataHandlerMixedSampling, + interpolation_limit=(3, 1), overwrite_local_data=False, + sampling=("hourly", "daily"), + statistics_per_var=stats, + create_new_model=False, train_model=False, epochs=1, + window_history_size=48, + window_history_offset=17, kz_filter_length=[100 * 24, 15 * 24], kz_filter_iter=[4, 5], start="2006-01-01", train_start="2006-01-01", end="2011-12-31", test_end="2011-12-31", - stations=["DEBW107", "DEBW013"], - epochs=100, - network="UBA", + **parser_args.__dict__, ) workflow = DefaultWorkflow(**args) workflow.run() diff --git a/supplement/station_list_north_german_plain.json b/supplement/station_list_north_german_plain.json new file mode 100644 index 0000000000000000000000000000000000000000..5e92dee5facdd26f0ac044a3c8cbfeac4256bf56 --- /dev/null +++ b/supplement/station_list_north_german_plain.json @@ -0,0 +1,81 @@ +[ +"DENI031", +"DESH016", +"DEBB050", +"DEHH022", +"DEHH049", +"DEHH021", +"DEMV007", +"DESH015", +"DEBE062", +"DEHH012", +"DESH004", +"DENI062", +"DEBE051", +"DEHH011", +"DEHH023", +"DEUB020", +"DESH005", +"DEBB039", +"DEHH050", +"DENI029", +"DESH001", +"DEBE001", +"DEHH030", +"DEHH018", +"DEUB022", +"DEBB038", +"DEBB053", +"DEMV017", +"DENI063", +"DENI058", +"DESH014", +"DEUB007", +"DEUB005", +"DEBB051", +"DEUB034", +"DEST089", +"DEHH005", +"DESH003", +"DEUB028", +"DESH017", +"DEUB030", +"DEMV012", +"DENI052", +"DENI059", +"DENI060", +"DESH013", +"DEUB006", +"DEMV018", +"DEUB027", +"DEUB026", +"DEUB038", +"DEMV001", +"DEUB024", +"DEUB037", +"DESH008", +"DEMV004", +"DEUB040", +"DEMV024", +"DEMV026", +"DESH056", +"DEHH063", +"DEUB001", +"DEST069", +"DEBB040", +"DEBB028", +"DEBB048", +"DEBB063", +"DEBB067", +"DESH006", +"DEBE008", +"DESH012", +"DEHH004", +"DEBE009", +"DEHH007", +"DEBE005", +"DEHH057", +"DEHH047", +"DEBE006", +"DEBB110" +] diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py index fffe7c84075eeeab37ebf59d52bc42dbf87bf522..ae81ef2ef0a15ad08f14ad19312f04040ab71263 100644 --- a/test/test_configuration/test_defaults.py +++ b/test/test_configuration/test_defaults.py @@ -70,4 +70,5 @@ class TestAllDefaults: assert DEFAULT_NUMBER_OF_BOOTSTRAPS == 20 assert DEFAULT_PLOT_LIST == ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", - "PlotConditionalQuantiles", "PlotAvailability", "PlotSeparationOfScales"] + "PlotConditionalQuantiles", "PlotAvailability", "PlotAvailabilityHistogram", + "PlotSeparationOfScales"] diff --git a/test/test_configuration/test_join_settings.py b/test/test_configuration/test_join_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..8d977f450b9fca0bc691d13f63965c71f7228cb1 --- /dev/null +++ b/test/test_configuration/test_join_settings.py @@ -0,0 +1,25 @@ +from mlair.configuration.join_settings import join_settings + +import pytest + + +class TestJoinSettings: + + def test_no_args(self): + url, headers = join_settings() + assert url == 'https://join.fz-juelich.de/services/rest/surfacedata/' + assert headers == {} + + def test_daily(self): + url, headers = join_settings("daily") + assert url == 'https://join.fz-juelich.de/services/rest/surfacedata/' + assert headers == {} + + def test_hourly(self): + url, headers = join_settings("hourly") + assert "Authorization" in headers.keys() + + def test_unknown_sampling(self): + with pytest.raises(NameError) as e: + join_settings("monthly") + assert "Given sampling monthly is not supported, choose from either daily or hourly sampling" in e.value.args[0] diff --git a/test/test_configuration/test_path_config.py b/test/test_configuration/test_path_config.py index 2ba80a3bdf62b7fdf10b645da75769435cf7b6b9..fb8a2b1950cd07909543fbe564230ab73661c126 100644 --- a/test/test_configuration/test_path_config.py +++ b/test/test_configuration/test_path_config.py @@ -16,9 +16,9 @@ class TestPrepareHost: @mock.patch("getpass.getuser", return_value="testUser") @mock.patch("os.path.exists", return_value=True) def test_prepare_host(self, mock_host, mock_user, mock_path): - assert prepare_host() == "/home/testUser/Data/toar_daily/" - assert prepare_host() == "/home/testUser/Data/toar_daily/" - assert prepare_host() == "/p/project/cjjsc42/testUser/DATA/toar_daily/" + assert prepare_host() == "/home/testUser/Data/toar/" + assert prepare_host() == "/home/testUser/Data/toar/" + assert prepare_host() == "/p/project/cjjsc42/testUser/DATA/toar/" assert prepare_host() == "/p/project/deepacf/intelliaq/testUser/DATA/MLAIR/" assert prepare_host() == '/home/testUser/mlair/data/' @@ -27,6 +27,10 @@ class TestPrepareHost: def test_prepare_host_unknown(self, mock_user, mock_host): assert prepare_host() == os.path.join(os.path.abspath(os.getcwd()), 'data') + def test_prepare_host_given_path(self): + path = os.path.join(os.path.abspath(os.getcwd()), 'data') + assert prepare_host(data_path=path) == path + @mock.patch("getpass.getuser", return_value="zombie21") @mock.patch("mlair.configuration.path_config.check_path_and_create", side_effect=PermissionError) @mock.patch("os.path.exists", return_value=False) @@ -47,7 +51,7 @@ class TestPrepareHost: @mock.patch("os.makedirs", side_effect=None) def test_os_path_exists(self, mock_host, mock_user, mock_path, mock_check): path = prepare_host() - assert path == "/home/testUser/Data/toar_daily/" + assert path == "/home/testUser/Data/toar/" class TestSetExperimentName: diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d2f9ce00224a61815c89e44b7c37a667d239b2f5 --- /dev/null +++ b/test/test_data_handler/test_data_handler_mixed_sampling.py @@ -0,0 +1,130 @@ +__author__ = 'Lukas Leufen' +__date__ = '2020-12-10' + +from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, \ + DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithFilter, \ + DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerSeparationOfScales, \ + DataHandlerSeparationOfScalesSingleStation +from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.helpers import remove_items +from mlair.configuration.defaults import DEFAULT_INTERPOLATION_METHOD + +import pytest +import mock + + +class TestDataHandlerMixedSampling: + + def test_data_handler(self): + obj = object.__new__(DataHandlerMixedSampling) + assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingSingleStation.__qualname__ + + def test_data_handler_transformation(self): + obj = object.__new__(DataHandlerMixedSampling) + assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingSingleStation.__qualname__ + + def test_requirements(self): + obj = object.__new__(DataHandlerMixedSampling) + req = object.__new__(DataHandlerSingleStation) + assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station")) + + +class TestDataHandlerMixedSamplingSingleStation: + + def test_requirements(self): + obj = object.__new__(DataHandlerMixedSamplingSingleStation) + req = object.__new__(DataHandlerSingleStation) + assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station")) + + @mock.patch("mlair.data_handler.data_handler_mixed_sampling.DataHandlerMixedSamplingSingleStation.setup_samples") + def test_init(self, mock_super_init): + obj = DataHandlerMixedSamplingSingleStation("first_arg", "second", {}, test=23, sampling="hourly", + interpolation_limit=(1, 10)) + assert obj.sampling == ("hourly", "hourly") + assert obj.interpolation_limit == (1, 10) + assert obj.interpolation_method == (DEFAULT_INTERPOLATION_METHOD, DEFAULT_INTERPOLATION_METHOD) + + @pytest.fixture + def kwargs_dict(self): + return {"test1": 2, "param_2": "string", "another": (10, 2)} + + def test_update_kwargs_single_to_tuple(self, kwargs_dict): + obj = object.__new__(DataHandlerMixedSamplingSingleStation) + obj.update_kwargs("test1", "23", kwargs_dict) + assert kwargs_dict["test1"] == (2, 2) + obj.update_kwargs("param_2", "23", kwargs_dict) + assert kwargs_dict["param_2"] == ("string", "string") + + def test_update_kwargs_tuple(self, kwargs_dict): + obj = object.__new__(DataHandlerMixedSamplingSingleStation) + obj.update_kwargs("another", "23", kwargs_dict) + assert kwargs_dict["another"] == (10, 2) + + def test_update_kwargs_default(self, kwargs_dict): + obj = object.__new__(DataHandlerMixedSamplingSingleStation) + obj.update_kwargs("not_existing", "23", kwargs_dict) + assert kwargs_dict["not_existing"] == ("23", "23") + obj.update_kwargs("also_new", (4, 2), kwargs_dict) + assert kwargs_dict["also_new"] == (4, 2) + + def test_update_kwargs_assert_failure(self, kwargs_dict): + obj = object.__new__(DataHandlerMixedSamplingSingleStation) + with pytest.raises(AssertionError): + obj.update_kwargs("error_too_long", (1, 2, 3), kwargs_dict) + + def test_setup_samples(self): + pass + + def test_load_and_interpolate(self): + pass + + def test_set_inputs_and_targets(self): + pass + + def test_setup_data_path(self): + pass + + +class TestDataHandlerMixedSamplingWithFilter: + + def test_data_handler(self): + obj = object.__new__(DataHandlerMixedSamplingWithFilter) + assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__ + + def test_data_handler_transformation(self): + obj = object.__new__(DataHandlerMixedSamplingWithFilter) + assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__ + + def test_requirements(self): + obj = object.__new__(DataHandlerMixedSamplingWithFilter) + req1 = object.__new__(DataHandlerMixedSamplingSingleStation) + req2 = object.__new__(DataHandlerKzFilterSingleStation) + req = list(set(req1.requirements() + req2.requirements())) + assert sorted(obj._requirements) == sorted(remove_items(req, "station")) + + +class TestDataHandlerMixedSamplingWithFilterSingleStation: + pass + + +class TestDataHandlerSeparationOfScales: + + def test_data_handler(self): + obj = object.__new__(DataHandlerSeparationOfScales) + assert obj.data_handler.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ + + def test_data_handler_transformation(self): + obj = object.__new__(DataHandlerSeparationOfScales) + assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ + + def test_requirements(self): + obj = object.__new__(DataHandlerMixedSamplingWithFilter) + req1 = object.__new__(DataHandlerMixedSamplingSingleStation) + req2 = object.__new__(DataHandlerKzFilterSingleStation) + req = list(set(req1.requirements() + req2.requirements())) + assert sorted(obj._requirements) == sorted(remove_items(req, "station")) + + +class TestDataHandlerSeparationOfScalesSingleStation: + pass diff --git a/test/test_data_handler/test_iterator.py b/test/test_data_handler/test_iterator.py index ec224c06e358297972097f2cc75cea86f768784f..ade5c19215e61de5e209db900920187294ac9b18 100644 --- a/test/test_data_handler/test_iterator.py +++ b/test/test_data_handler/test_iterator.py @@ -52,20 +52,57 @@ class TestDataCollection: for e, i in enumerate(data_collection): assert i == e + def test_add(self): + data_collection = DataCollection() + data_collection.add("first_element") + assert len(data_collection) == 1 + assert data_collection["first_element"] == "first_element" + assert data_collection[0] == "first_element" + + def test_name(self): + data_collection = DataCollection(name="testcase") + assert data_collection._name == "testcase" + assert data_collection.name == "testcase" + + def test_set_mapping(self): + data_collection = object.__new__(DataCollection) + data_collection._collection = ["a", "b", "c"] + data_collection._mapping = {} + data_collection._set_mapping() + assert data_collection._mapping == {"a": 0, "b": 1, "c": 2} + + def test_getitem(self): + data_collection = DataCollection(["a", "b", "c"]) + assert data_collection["a"] == "a" + assert data_collection[1] == "b" + + def test_keys(self): + collection = ["a", "b", "c"] + data_collection = DataCollection(collection) + assert data_collection.keys() == collection + data_collection.add("another") + assert data_collection.keys() == collection + ["another"] + class DummyData: def __init__(self, number_of_samples=np.random.randint(100, 150)): + np.random.seed(45) self.number_of_samples = number_of_samples def get_X(self, upsampling=False, as_numpy=True): + np.random.seed(45) X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables + np.random.seed(45) X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables + np.random.seed(45) X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables return [X1, X2, X3] def get_Y(self, upsampling=False, as_numpy=True): + np.random.seed(45) Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables + np.random.seed(45) Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables return [Y1, Y2] @@ -80,6 +117,14 @@ class TestKerasIterator: data_coll = DataCollection(collection=coll) return data_coll + @pytest.fixture + def collection_small(self): + coll = [] + for i in range(3): + coll.append(DummyData(5 + i)) + data_coll = DataCollection(collection=coll) + return data_coll + @pytest.fixture def path(self): p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") @@ -161,6 +206,27 @@ class TestKerasIterator: assert len(iterator) == 4 assert iterator.indexes == [0, 1, 2, 3] + def test_prepare_batches_upsampling(self, collection_small, path): + iterator = object.__new__(KerasIterator) + iterator._collection = collection_small + iterator.batch_size = 100 + iterator.indexes = [] + iterator.model = None + iterator.upsampling = False + iterator._path = os.path.join(path, "%i.pickle") + os.makedirs(path) + iterator._prepare_batches() + X1, Y1 = iterator[0] + iterator.upsampling = True + iterator._prepare_batches() + X1p, Y1p = iterator[0] + assert X1[0].shape == X1p[0].shape + assert Y1[0].shape == Y1p[0].shape + assert np.testing.assert_almost_equal(X1[0].sum(), X1p[0].sum(), 2) is None + assert np.testing.assert_almost_equal(Y1[0].sum(), Y1p[0].sum(), 2) is None + f = np.testing.assert_array_almost_equal + assert np.testing.assert_raises(AssertionError, f, X1[0], X1p[0]) is None + def test_prepare_batches_no_remaining(self, path): iterator = object.__new__(KerasIterator) iterator._collection = DataCollection([DummyData(50)]) @@ -189,10 +255,6 @@ class TestKerasIterator: expected = next(iter(collection)) assert PyTestAllEqual([X, expected.get_X()]) assert PyTestAllEqual([Y, expected.get_Y()]) - reversed(iterator.indexes) - X, Y = iterator[3] - assert PyTestAllEqual([X, expected.get_X()]) - assert PyTestAllEqual([Y, expected.get_Y()]) def test_on_epoch_end(self): iterator = object.__new__(KerasIterator) @@ -226,3 +288,15 @@ class TestKerasIterator: iterator.model = mock.MagicMock(return_value=1) with pytest.raises(TypeError): iterator._get_model_rank() + + def test_permute(self): + iterator = object.__new__(KerasIterator) + X = [np.array([[1, 2, 3, 4], + [1.1, 2.1, 3.1, 4.1], + [1.2, 2.2, 3.2, 4.2]], dtype="f2")] + Y = [np.array([1, 2, 3])] + X_p, Y_p = iterator._permute_data(X, Y) + assert X_p[0].shape == X[0].shape + assert Y_p[0].shape == Y[0].shape + assert np.testing.assert_almost_equal(X_p[0].sum(), X[0].sum(), 2) is None + assert np.testing.assert_almost_equal(Y_p[0].sum(), Y[0].sum(), 2) is None diff --git a/test/test_datastore.py b/test/test_helpers/test_datastore.py similarity index 87% rename from test/test_datastore.py rename to test/test_helpers/test_datastore.py index 662c90bf04e11b8b4ff9647506c1981c8883f30b..1eecc576e60e5dc43b97a6e8254f8a2fea29728a 100644 --- a/test/test_datastore.py +++ b/test/test_helpers/test_datastore.py @@ -3,7 +3,8 @@ __date__ = '2019-11-22' import pytest -from mlair.helpers.datastore import AbstractDataStore, DataStoreByVariable, DataStoreByScope, CorrectScope +from mlair.helpers.datastore import AbstractDataStore, DataStoreByVariable, DataStoreByScope +from mlair.helpers.datastore import CorrectScope, TrackParameter from mlair.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope, EmptyScope @@ -339,3 +340,52 @@ class TestCorrectScope: assert self.function1(21) == (21, "general", 44) assert self.function1(55, "sub", 34) == (55, "general.sub", 34) assert self.function1("string", b=99, scope="tester") == ("string", "general.tester", 99) + + +class TestTracking: + class Tracker: + def __init__(self): + self.tracker = [{}] + + @TrackParameter + def function2(self, arg1, arg2, arg3): + return + + @staticmethod + def function1(): + return + + def test_init(self): + track = self.Tracker() + track.function2(1, "2", "scopy") + assert track.tracker == [{1: [{"method": "function2", "scope": "scopy"}]}] + + def test_track_first_entry(self): + track = object.__new__(TrackParameter) + track.__wrapped__ = self.function1 + tracker_obj = self.Tracker() + assert len(tracker_obj.tracker[-1].keys()) == 0 + track.track(tracker_obj, "eins", 2) + assert len(tracker_obj.tracker[-1].keys()) == 1 + assert tracker_obj.tracker == [{"eins": [{"method": "function1", "scope": 2}]}] + track.track(tracker_obj, "zwei", 20) + assert len(tracker_obj.tracker[-1].keys()) == 2 + assert tracker_obj.tracker == [{"eins": [{"method": "function1", "scope": 2}], + "zwei": [{"method": "function1", "scope": 20}]}] + + def test_track_second_entry(self): + track = object.__new__(TrackParameter) + track.__wrapped__ = self.function1 + tracker_obj = self.Tracker() + assert len(tracker_obj.tracker[-1].keys()) == 0 + track.track(tracker_obj, "eins", 2) + track.track(tracker_obj, "eins", 23) + assert len(tracker_obj.tracker[-1].keys()) == 1 + assert tracker_obj.tracker == [{"eins": [{"method": "function1", "scope": 2}, + {"method": "function1", "scope": 23}]}] + + def test_decrypt_args(self): + track = object.__new__(TrackParameter) + assert track._decrypt_args(23) == (23,) + assert track._decrypt_args("test", 33, 4) == ("test", 33, 4) + assert track._decrypt_args("eins", 2) == ("eins", None, 2) diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py index 723b4a87d70453327ed6b7e355d3ef78a246652a..a5aaa707c83a65c3e10f76fdbfcd8142d3615e2e 100644 --- a/test/test_helpers/test_helpers.py +++ b/test/test_helpers/test_helpers.py @@ -10,7 +10,7 @@ import os import mock import pytest -from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items +from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict from mlair.helpers import PyTestRegex from mlair.helpers import Logger, TimeTracking @@ -22,6 +22,10 @@ class TestToList: assert to_list('abcd') == ['abcd'] assert to_list([1, 2, 3]) == [1, 2, 3] assert to_list([45]) == [45] + s = {34, 2, "test"} + assert to_list(s) == list(s) + assert to_list((34, 2, "test")) == [34, 2, "test"] + assert to_list(("test")) == ["test"] class TestTimeTracking: @@ -164,6 +168,22 @@ class TestFloatRound: assert float_round(-34.9221, 0) == -34. +class TestSelectFromDict: + + @pytest.fixture + def dictionary(self): + return {"a": 1, "b": 23, "c": "last"} + + def test_select(self, dictionary): + assert select_from_dict(dictionary, "c") == {"c": "last"} + assert select_from_dict(dictionary, ["a", "c"]) == {"a": 1, "c": "last"} + assert select_from_dict(dictionary, "d") == {} + + def test_select_no_dict_given(self): + with pytest.raises(AssertionError): + select_from_dict(["we"], "now") + + class TestRemoveItems: @pytest.fixture @@ -229,6 +249,11 @@ class TestRemoveItems: remove_items(custom_list) assert "remove_items() missing 1 required positional argument: 'items'" in e.value.args[0] + def test_remove_not_supported_type(self): + with pytest.raises(TypeError) as e: + remove_items(23, "test") + assert f"remove_items does not support type {type(23)}" in e.value.args[0] + class TestLogger: @@ -272,3 +297,18 @@ class TestLogger: with pytest.raises(TypeError) as e: logger.logger_console(1.5) assert "Level not an integer or a valid string: 1.5" == e.value.args[0] + + +class TestExtractValue: + + def test_extract(self): + assert extract_value([1]) == 1 + assert extract_value([[23]]) == 23 + assert extract_value([("test")]) == "test" + assert extract_value((2,)) == 2 + + def test_extract_multiple_elements(self): + with pytest.raises(NotImplementedError) as e: + extract_value([1, 2, 3]) + assert "Trying to extract an encapsulated value from objects with more than a single entry is not supported " \ + "by this function." in e.value.args[0] diff --git a/test/test_join.py b/test/test_helpers/test_join.py similarity index 78% rename from test/test_join.py rename to test/test_helpers/test_join.py index a9a4c381cbf58a272389b0b11283c8b0cce3ab42..e903669bf63f4056a8278401b07818d31a09616d 100644 --- a/test/test_join.py +++ b/test/test_helpers/test_join.py @@ -7,14 +7,6 @@ from mlair.helpers.join import _save_to_pandas, _correct_stat_name, _lower_list, from mlair.configuration.join_settings import join_settings -class TestJoinUrlBase: - - def test_url(self): - url, headers = join_settings() - assert url == 'https://join.fz-juelich.de/services/rest/surfacedata/' - assert headers == {} - - class TestDownloadJoin: def test_download_single_var(self): @@ -25,7 +17,18 @@ class TestDownloadJoin: def test_download_empty(self): with pytest.raises(EmptyQueryResult) as e: download_join("DEBW107", {"o3": "dma8eu"}, "traffic") - assert e.value.args[-1] == "No data found in JOIN." + assert e.value.args[-1] == "No data found for variables {'o3'} and options station=['DEBW107'], type=traffic," \ + " network=None, origin={} in JOIN." + + def test_download_incomplete(self): + with pytest.raises(EmptyQueryResult) as e: + download_join("DEBW107", {"o3": "dma8eu", "o10": "maximum"}, "background") + assert e.value.args[-1] == "No data found for variables {'o10'} and options station=['DEBW107'], " \ + "type=background, network=None, origin={} in JOIN." + with pytest.raises(EmptyQueryResult) as e: + download_join("DEBW107", {"o3": "dma8eu", "o10": "maximum"}, "background", data_origin={"o10": ""}) + assert e.value.args[-1] == "No data found for variables {'o10'} and options station=['DEBW107'], " \ + "type=background, network=None, origin={'o10': ''} in JOIN." class TestCorrectDataFormat: @@ -53,11 +56,12 @@ class TestLoadSeriesInformation: def test_standard_query(self): expected_subset = {'o3': 23031, 'no2': 39002, 'temp': 85584, 'wspeed': 17060} - assert expected_subset.items() <= load_series_information(['DEBW107'], None, None, join_settings()[0], - {}).items() + res, orig = load_series_information(['DEBW107'], None, None, join_settings()[0], {}) + assert expected_subset.items() <= res.items() def test_empty_result(self): - assert load_series_information(['DEBW107'], "traffic", None, join_settings()[0], {}) == {} + res, orig = load_series_information(['DEBW107'], "traffic", None, join_settings()[0], {}) + assert res == {} class TestSelectDistinctSeries: @@ -81,15 +85,18 @@ class TestSelectDistinctSeries: 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}] def test_no_origin_given(self, vars): - res = _select_distinct_series(vars) + res, orig = _select_distinct_series(vars) assert res == {"no2": 16686, "o3": 16687, "cloudcover": 54036, "temp": 88491, "press": 102660} + assert orig == {"no2": "", "o3": "", "cloudcover": "REA", "temp": "REA", "press": "REA"} def test_different_origins(self, vars): origin = {"no2": "test", "temp": "", "cloudcover": "REA"} - res = _select_distinct_series(vars, data_origin=origin) - assert res == {"o3": 16687, "press": 16692, "temp": 16693, "cloudcover": 54036} - res = _select_distinct_series(vars, data_origin={}) - assert res == {"no2": 16686, "o3": 16687, "press": 16692, "temp": 16693} + res, orig = _select_distinct_series(vars, data_origin=origin) + assert res == {"o3": 16687, "press": 102660, "temp": 16693, "cloudcover": 54036} + assert orig == {"no2": "test", "o3": "", "cloudcover": "REA", "temp": "", "press": "REA"} + res, orig = _select_distinct_series(vars, data_origin={}) + assert res == {"cloudcover": 54036, "no2": 16686, "o3": 16687, "press": 102660, "temp": 88491} + assert orig == {"no2": "", "o3": "", "temp": "REA", "press": "REA", "cloudcover": "REA"} class TestSaveToPandas: diff --git a/test/test_statistics.py b/test/test_helpers/test_statistics.py similarity index 100% rename from test/test_statistics.py rename to test/test_helpers/test_statistics.py diff --git a/test/test_helpers/test_testing_helpers.py b/test/test_helpers/test_testing_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..385161c740f386847ef2f2dc4df17c1c84fa7fa5 --- /dev/null +++ b/test/test_helpers/test_testing_helpers.py @@ -0,0 +1,48 @@ +from mlair.helpers.testing import PyTestRegex, PyTestAllEqual + +import re +import xarray as xr +import numpy as np + +import pytest + + +class TestPyTestRegex: + + def test_init(self): + test_regex = PyTestRegex(r"TestString\d+") + assert isinstance(test_regex._regex, re._pattern_type) + + def test_eq(self): + assert PyTestRegex(r"TestString\d*") == "TestString" + assert PyTestRegex(r"TestString\d+") == "TestString9" + assert "TestString4" == PyTestRegex(r"TestString\d+") + + def test_repr(self): + assert repr(PyTestRegex(r"TestString\d+")) == r"TestString\d+" + + +class TestPyTestAllEqual: + + def test_numpy(self): + assert PyTestAllEqual([np.array([1, 2, 3]), np.array([1, 2, 3]), np.array([1, 2, 3])]) + with pytest.raises(AssertionError): + PyTestAllEqual([np.array([1, 2, 3]), np.array([2, 2, 3]), np.array([1, 2, 3])]) + + def test_xarray(self): + assert PyTestAllEqual([xr.DataArray([1, 2, 3]), xr.DataArray([1, 2, 3])]) + with pytest.raises(AssertionError): + PyTestAllEqual([xr.DataArray([1, 2, 3]), xr.DataArray([1, 2, 3, 4])]) + + def test_other(self): + assert PyTestAllEqual(["test", "test", "test"]) + with pytest.raises(AssertionError): + PyTestAllEqual(["test", "test", "tes2t"]) + + def test_encapsulated(self): + assert PyTestAllEqual([[np.array([1, 2, 3]), np.array([12, 22, 32])], + [np.array([1, 2, 3]), np.array([12, 22, 32])]]) + assert PyTestAllEqual([[xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])], + [xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])]]) + assert PyTestAllEqual([["test", "test2"], + ["test", "test2"]]) diff --git a/test/test_plotting/test_tracker_plot.py b/test/test_plotting/test_tracker_plot.py index 196879657452fe12238c990fc419cb0848c9ec9c..62e171fe75c1112bb96447e9831b0008269fac7b 100644 --- a/test/test_plotting/test_tracker_plot.py +++ b/test/test_plotting/test_tracker_plot.py @@ -356,13 +356,13 @@ class TestTrackPlot: assert len(track_plot_obj.ax.lines) == 0 track_plot_obj.line(start_x=5, end_x=6, y=2) assert len(track_plot_obj.ax.lines) == 2 - pos_x, pos_y = np.array([5 + w, 6]), np.ones((2, )) * (2 + h / 2) + pos_x, pos_y = np.array([5 + w, 6]), np.ones((2,)) * (2 + h / 2) assert track_plot_obj.ax.lines[0]._color == "white" assert track_plot_obj.ax.lines[0]._linewidth == 2.5 assert track_plot_obj.ax.lines[1]._color == "darkgrey" assert track_plot_obj.ax.lines[1]._linewidth == 1.4 - assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() - assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]) + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]) def test_step(self, track_plot_obj): x_int, h, w = 0.5, 0.6, 0.65 @@ -379,8 +379,8 @@ class TestTrackPlot: assert track_plot_obj.ax.lines[0]._linewidth == 2.5 assert track_plot_obj.ax.lines[1]._color == "black" assert track_plot_obj.ax.lines[1]._linewidth == 1.4 - assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() - assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]) + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]) def test_rect(self, track_plot_obj): h, w = 0.5, 0.6 @@ -392,20 +392,18 @@ class TestTrackPlot: track_plot_obj.rect(x=4, y=2) assert len(track_plot_obj.ax.artists) == 1 assert len(track_plot_obj.ax.texts) == 1 - track_plot_obj.ax.artists[0].xy == (4, 2) - track_plot_obj.ax.artists[0]._height == h - track_plot_obj.ax.artists[0]._width == w - track_plot_obj.ax.artists[0]._original_facecolor == "orange" - track_plot_obj.ax.texts[0].xy == (4 + w / 2, 2 + h / 2) - track_plot_obj.ax.texts[0]._color == "w" - track_plot_obj.ax.texts[0]._text == "get" + assert track_plot_obj.ax.artists[0].xy == (4, 2) + assert track_plot_obj.ax.artists[0]._height == h + assert track_plot_obj.ax.artists[0]._width == w + assert track_plot_obj.ax.artists[0]._original_facecolor == "orange" + assert track_plot_obj.ax.texts[0].xy == (4 + w / 2, 2 + h / 2) + assert track_plot_obj.ax.texts[0]._color == "w" + assert track_plot_obj.ax.texts[0]._text == "get" track_plot_obj.rect(x=4, y=2, method="set") assert len(track_plot_obj.ax.artists) == 2 assert len(track_plot_obj.ax.texts) == 2 - track_plot_obj.ax.artists[0]._original_facecolor == "lightblue" - track_plot_obj.ax.texts[0]._text == "set" - - + assert track_plot_obj.ax.artists[1]._original_facecolor == "lightblue" + assert track_plot_obj.ax.texts[1]._text == "set" def test_set_ypos_anchor(self, track_plot_obj, scopes, dims): assert not hasattr(track_plot_obj, "y_pos") diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py index bdb8fdabff67ad894275c805522b9df4cf167011..11c46e99fb38489f5cbb26a8a87032049c96c7ca 100644 --- a/test/test_run_modules/test_pre_processing.py +++ b/test/test_run_modules/test_pre_processing.py @@ -1,6 +1,7 @@ import logging import pytest +import mock from mlair.data_handler import DefaultDataHandler, DataCollection, AbstractDataHandler from mlair.helpers.datastore import NameNotFoundInScope @@ -8,6 +9,9 @@ from mlair.helpers import PyTestRegex from mlair.run_modules.experiment_setup import ExperimentSetup from mlair.run_modules.pre_processing import PreProcessing from mlair.run_modules.run_environment import RunEnvironment +import pandas as pd +import numpy as np +import multiprocessing class TestPreProcessing: @@ -86,7 +90,7 @@ class TestPreProcessing: assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] @pytest.mark.parametrize("name", (None, "tester")) - def test_validate_station(self, caplog, obj_with_exp_setup, name): + def test_validate_station_serial(self, caplog, obj_with_exp_setup, name): pre = obj_with_exp_setup caplog.set_level(logging.INFO) stations = pre.data_store.get("stations", "general") @@ -97,6 +101,25 @@ class TestPreProcessing: assert valid_stations == stations[:-1] expected = "check valid stations started" + ' (%s)' % (name if name else 'all') assert caplog.record_tuples[0] == ('root', 20, expected) + assert caplog.record_tuples[1] == ('root', 20, "use serial validate station approach") + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 ' + r'station\(s\). Found 5/6 valid stations.')) + + @mock.patch("multiprocessing.cpu_count", return_value=3) + @mock.patch("multiprocessing.Pool", return_value=multiprocessing.Pool(3)) + def test_validate_station_parallel(self, mock_pool, mock_cpu, caplog, obj_with_exp_setup): + pre = obj_with_exp_setup + caplog.clear() + caplog.set_level(logging.INFO) + stations = pre.data_store.get("stations", "general") + data_preparation = pre.data_store.get("data_handler") + collection, valid_stations = pre.validate_station(data_preparation, stations, set_name=None) + assert isinstance(collection, DataCollection) + assert len(valid_stations) < len(stations) + assert valid_stations == stations[:-1] + assert caplog.record_tuples[0] == ('root', 20, "check valid stations started (all)") + assert caplog.record_tuples[1] == ('root', 20, "use parallel validate station approach") + assert caplog.record_tuples[2] == ('root', 20, "running 3 processes in parallel") assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 ' r'station\(s\). Found 5/6 valid stations.')) @@ -115,3 +138,38 @@ class TestPreProcessing: assert pre.transformation(data_preparation, stations) is None class data_preparation_no_trans: pass assert pre.transformation(data_preparation_no_trans, stations) is None + + @pytest.fixture + def dummy_df(self): + data_dict = {'station_name': {'DEBW013': 'Stuttgart Bad Cannstatt', 'DEBW076': 'Baden-Baden', + 'DEBW087': 'Schwäbische_Alb', 'DEBW107': 'Tübingen', + 'DEBY081': 'Garmisch-Partenkirchen/Kreuzeckbahnstraße', '# Stations': np.nan, + '# Samples': np.nan}, + 'station_lon': {'DEBW013': 9.2297, 'DEBW076': 8.2202, 'DEBW087': 9.2076, 'DEBW107': 9.0512, + 'DEBY081': 11.0631, '# Stations': np.nan, '# Samples': np.nan}, + 'station_lat': {'DEBW013': 48.8088, 'DEBW076': 48.7731, 'DEBW087': 48.3458, 'DEBW107': 48.5077, + 'DEBY081': 47.4764, '# Stations': np.nan, '# Samples': np.nan}, + 'station_alt': {'DEBW013': 235.0, 'DEBW076': 148.0, 'DEBW087': 798.0, 'DEBW107': 325.0, + 'DEBY081': 735.0, '# Stations': np.nan, '# Samples': np.nan}, + 'train': {'DEBW013': 1413, 'DEBW076': 3002, 'DEBW087': 3016, 'DEBW107': 1782, 'DEBY081': 2837, + '# Stations': 6, '# Samples': 12050}, + 'val': {'DEBW013': 698, 'DEBW076': 715, 'DEBW087': 700, 'DEBW107': 701, 'DEBY081': 456, + '# Stations': 6, '# Samples': 3270}, + 'test': {'DEBW013': 1066, 'DEBW076': 696, 'DEBW087': 1080, 'DEBW107': 1080, 'DEBY081': 700, + '# Stations': 6, '# Samples': 4622}} + df = pd.DataFrame.from_dict(data_dict) + return df + + def test_create_column_format_for_tex(self): + df = pd.DataFrame(np.ones((2, 1))) + df_col = PreProcessing.create_column_format_for_tex(df) # len: 1+1 + assert df_col == 'lr' + assert len(df_col) == 2 + df = pd.DataFrame(np.ones((2, 2))) + df_col = PreProcessing.create_column_format_for_tex(df) # len: 2+1 + assert df_col == 'lcr' + assert len(df_col) == 3 + df = pd.DataFrame(np.ones((2, 3))) + df_col = PreProcessing.create_column_format_for_tex(df) # len: 3+1 + assert df_col == 'lccr' + assert len(df_col) == 4 diff --git a/test/test_workflows/test_abstract_workflow.py b/test/test_workflows/test_abstract_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..6530f8565ba4a1ccc185c133978ad1809905dff9 --- /dev/null +++ b/test/test_workflows/test_abstract_workflow.py @@ -0,0 +1,53 @@ +from mlair.workflows.abstract_workflow import Workflow + +import logging + + +class TestWorkflow: + + def test_init(self): + flow = Workflow() + assert len(flow._registry_kwargs.keys()) == 0 + assert len(flow._registry) == 0 + assert flow._name == "Workflow" + flow = Workflow(name="river") + assert flow._name == "river" + + def test_add(self): + flow = Workflow() + flow.add("stage") + assert len(flow._registry_kwargs.keys()) == 1 + assert len(flow._registry) == 1 + assert len(flow._registry_kwargs[0].keys()) == 0 + flow.add("stagekwargs", test=23, another="string") + assert len(flow._registry_kwargs.keys()) == 2 + assert len(flow._registry) == 2 + assert len(flow._registry_kwargs[1].keys()) == 2 + assert list(flow._registry_kwargs.keys()) == [0, 1] + assert flow._registry == ["stage", "stagekwargs"] + assert list(flow._registry_kwargs[1].keys()) == ["test", "another"] + assert flow._registry_kwargs[1]["another"] == "string" + + def test_run(self, caplog): + caplog.set_level(logging.INFO) + + class A: + def __init__(self, a=3): + self.a = a + logging.info(self.a) + + class B: + def __init__(self): + self.b = 2 + logging.info(self.b) + + flow = Workflow() + flow.add(A, a=6) + flow.add(B) + flow.add(A) + flow.run() + pos = int(".log" in caplog.messages[0]) + assert caplog.record_tuples[0 + pos] == ('root', 20, "Workflow started") + assert caplog.record_tuples[1 + pos] == ('root', 20, "6") + assert caplog.record_tuples[2 + pos] == ('root', 20, "2") + assert caplog.record_tuples[3 + pos] == ('root', 20, "3") diff --git a/test/test_workflows/test_default_workflow.py b/test/test_workflows/test_default_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c198a4821f779329b9f5f19b04e757d8ebc7da --- /dev/null +++ b/test/test_workflows/test_default_workflow.py @@ -0,0 +1,48 @@ +from mlair.workflows.default_workflow import DefaultWorkflow, DefaultWorkflowHPC +from mlair.run_modules.experiment_setup import ExperimentSetup +from mlair.run_modules.pre_processing import PreProcessing +from mlair.run_modules.model_setup import ModelSetup +from mlair.run_modules.partition_check import PartitionCheck +from mlair.run_modules.training import Training +from mlair.run_modules.post_processing import PostProcessing + + +class TestDefaultWorkflow: + + def test_init_no_args(self): + flow = DefaultWorkflow() + assert flow._registry[0].__name__ == ExperimentSetup.__name__ + assert len(flow._registry_kwargs[0].keys()) == 1 + + def test_init_with_args(self): + flow = DefaultWorkflow(stations="test", start="2020", model=None) + assert flow._registry[0].__name__ == ExperimentSetup.__name__ + assert len(flow._registry_kwargs[0].keys()) == 3 + + def test_init_with_kwargs(self): + flow = DefaultWorkflow(stations="test", real_kwarg=4) + assert flow._registry[0].__name__ == ExperimentSetup.__name__ + assert len(flow._registry_kwargs[0].keys()) == 3 + assert list(flow._registry_kwargs[0].keys()) == ["experiment_date", "stations", "real_kwarg"] + + def test_setup(self): + flow = DefaultWorkflow() + assert len(flow._registry) == 5 + assert flow._registry[0].__name__ == ExperimentSetup.__name__ + assert flow._registry[1].__name__ == PreProcessing.__name__ + assert flow._registry[2].__name__ == ModelSetup.__name__ + assert flow._registry[3].__name__ == Training.__name__ + assert flow._registry[4].__name__ == PostProcessing.__name__ + + +class TestDefaultWorkflowHPC: + + def test_setup(self): + flow = DefaultWorkflowHPC() + assert len(flow._registry) == 6 + assert flow._registry[0].__name__ == ExperimentSetup.__name__ + assert flow._registry[1].__name__ == PreProcessing.__name__ + assert flow._registry[2].__name__ == PartitionCheck.__name__ + assert flow._registry[3].__name__ == ModelSetup.__name__ + assert flow._registry[4].__name__ == Training.__name__ + assert flow._registry[5].__name__ == PostProcessing.__name__