diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 4a59b5b91edbe7a918a80884cf9e38a5d70a8826..7da41c750d8a2eb06266edbc080d9ee460e225f8 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -41,12 +41,16 @@ tests (from scratch): before_script: - chmod +x ./CI/update_badge.sh - ./CI/update_badge.sh > /dev/null + - source /opt/venv/bin/activate script: - pip install --upgrade pip - - pip install numpy wheel six==1.15.0 - - zypper --non-interactive install binutils libproj-devel gdal-devel - - zypper --non-interactive install proj geos-devel - # - cat requirements.txt | cut -f1 -d"#" | sed '/^\s*$/d' | xargs -L 1 pip install + - zypper --no-gpg-checks addrepo https://download.opensuse.org/repositories/Application:Geo/15.4/Application:Geo.repo + - zypper --no-gpg-checks refresh + - zypper --no-gpg-checks --non-interactive install proj=8.2.1 + - zypper --no-gpg-checks --non-interactive install geos=3.10.3 + - zypper --no-gpg-checks --non-interactive install geos-devel=3.9.1 + - zypper --no-gpg-checks --non-interactive install libproj22=8.2.1 + - zypper --no-gpg-checks --non-interactive install binutils libproj-devel gdal-devel - pip install -r requirements.txt - chmod +x ./CI/run_pytest.sh - ./CI/run_pytest.sh @@ -60,34 +64,6 @@ tests (from scratch): - badges/ - test_results/ -### Tests (on GPU) ### -#tests (on GPU): -# tags: -# - gpu -# - zam347 -# stage: test -# only: -# - master -# - /^release.*$/ -# - develop -# variables: -# FAILURE_THRESHOLD: 100 -# TEST_TYPE: "gpu" -# before_script: -# - chmod +x ./CI/update_badge.sh -# - ./CI/update_badge.sh > /dev/null -# script: -# - pip install -r test/requirements_tf_skip.txt -# - chmod +x ./CI/run_pytest.sh -# - ./CI/run_pytest.sh -# after_script: -# - ./CI/update_badge.sh > /dev/null -# artifacts: -# name: pages -# when: always -# paths: -# - badges/ -# - test_results/ ### Tests ### tests: @@ -100,6 +76,7 @@ tests: before_script: - chmod +x ./CI/update_badge.sh - ./CI/update_badge.sh > /dev/null + - source /opt/venv/bin/activate script: - pip install -r requirements.txt - chmod +x ./CI/run_pytest.sh @@ -125,6 +102,7 @@ coverage: before_script: - chmod +x ./CI/update_badge.sh - ./CI/update_badge.sh > /dev/null + - source /opt/venv/bin/activate script: - pip install -r requirements.txt - chmod +x ./CI/run_pytest_coverage.sh @@ -148,6 +126,7 @@ sphinx docs: before_script: - chmod +x ./CI/update_badge.sh - ./CI/update_badge.sh > /dev/null + - source /opt/venv/bin/activate script: - pip install -r requirements.txt - pip install -r docs/requirements_docs.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 266cb33ec8666099ffcb638ff85d814d7e2cf184..988e3e5a7863868cead1a2fec7c7b6d1c750b8d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,29 @@ # Changelog All notable changes to this project will be documented in this file. +## v2.1.0 - 2022-06-07 - new evaluation metrics and improved training + +### general: +* new evaluation metrics, IOA and MNMB +* advanced train options for early stopping +* reduced execution time by refactoring + +### new features: +* uncertainty estimation of MSE is now applied for each season separately (#374) +* added different configurations of early stopping to use either last trained or best epoch (#378) +* train monitoring plots now add a star for best epoch when using early stopping (#367) +* new evaluation metric index of agreement, IOA (#376) +* new evaluation metric modified normalised mean bias, MNMB (#380) +* new plot available that shows temporal evolution of MSE for each station (#381) + +### technical: +* reduced loading of forecast path from data store (#328) +* bug fix for not catched error during transformation (#385) +* bug fix for data handler with climate and fir filter leading to calculate transformation always with fir filter (#387) +* improved duration for latex report creation at end of preprocessing (#388) +* enhanced speed for make prediction in postprocessing (#389) +* fix to always create version badge from version and not from tag name (#382) + ## v2.0.0 - 2022-04-08 - tf2 usage, new model classes, and improved uncertainty estimate ### general: diff --git a/CI/Dockerfile b/CI/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f3b99b2f8b78129d3fff4d49d88be54613bf5929 --- /dev/null +++ b/CI/Dockerfile @@ -0,0 +1,65 @@ +# ---- base node ---- +FROM opensuse/leap:latest AS base +MAINTAINER Lukas Leufen <l.leufen@fz-juelich.de> + +# install git +RUN zypper --non-interactive install git + +# install python3 +RUN zypper --non-interactive install python39 python39-devel + +# install pip +RUN zypper --non-interactive install python39-pip + +# upgrade pip +RUN pip3.9 install --upgrade pip + +# install curl +RUN zypper --non-interactive install curl + +# install make +RUN zypper --non-interactive install make + +# install gcc +RUN zypper --non-interactive install gcc-c++ + +# create and activate venv +ENV VIRTUAL_ENV=/opt/venv +RUN python3.9 -m venv $VIRTUAL_ENV +ENV PATH="$VIRTUAL_ENV/bin:$PATH" +# RUN source venv/bin/activate + +# ---- test node ---- +FROM base AS test + +# install pytest +RUN pip install pytest pytest-html pytest-lazy-fixture + +# ---- coverage node ---- +FROM test AS coverage + +# install pytest coverage +RUN pip install pytest-cov + +# ---- docs node ---- +FROM base AS docs + +# install sphinx +RUN pip install sphinx + +# ---- MLAir ready to use ---- +FROM base AS mlair + +# install geo packages +RUN zypper --no-gpg-checks addrepo https://download.opensuse.org/repositories/Application:Geo/15.4/Application:Geo.repo +RUN zypper --no-gpg-checks refresh +RUN zypper --no-gpg-checks --non-interactive install proj=8.2.1 +RUN zypper --no-gpg-checks --non-interactive install geos=3.10.3 +RUN zypper --no-gpg-checks --non-interactive install geos-devel=3.9.1 +RUN zypper --no-gpg-checks --non-interactive install libproj22=8.2.1 +RUN zypper --no-gpg-checks --non-interactive install binutils libproj-devel gdal-devel + +# install requirements +ADD requirements.txt . +RUN pip install -r requirements.txt + diff --git a/CI/create_version_badge.sh b/CI/create_version_badge.sh index c7a85af2b89eccb48679601dffe6a31396739cfc..87ceae3ce7449dea4ba8ce33b943357ca9fe9ce8 100644 --- a/CI/create_version_badge.sh +++ b/CI/create_version_badge.sh @@ -1,6 +1,6 @@ #!/bin/bash -VERSION="$(git describe master)" +VERSION="$(git describe --match v[0-9]*.[0-9]*.[0-9]* --abbrev=0)" COLOR="blue" BADGE_NAME="version" diff --git a/CI/run_pytest.sh b/CI/run_pytest.sh index baa7ef8e892fc2d9efdd30094917ca492017de3d..060569abac395c49d5a5fcda80a29726d5e9001a 100644 --- a/CI/run_pytest.sh +++ b/CI/run_pytest.sh @@ -1,7 +1,7 @@ #!/bin/bash # run pytest for all run_modules -python3.6 -m pytest --html=report.html --self-contained-html test/ | tee test_results.out +python -m pytest --html=report.html --self-contained-html test/ | tee test_results.out IS_FAILED=$? diff --git a/CI/run_pytest_coverage.sh b/CI/run_pytest_coverage.sh index 24d916b1a32da714abc2e5de0ac2b4c2790752a9..f6efaf2fb78223d3c41e7d7ba0a40c87befd3296 100644 --- a/CI/run_pytest_coverage.sh +++ b/CI/run_pytest_coverage.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash # run coverage twice, 1) for html deploy 2) for success evaluation -python3.6 -m pytest --cov=mlair --cov-report term --cov-report html test/ | tee coverage_results.out +python -m pytest --cov=mlair --cov-report term --cov-report html test/ | tee coverage_results.out IS_FAILED=$? diff --git a/HPC_setup/create_runscripts_HPC.sh b/HPC_setup/create_runscripts_HPC.sh index 730aa52ef42144826bd000d88c0fc81c9d508de0..b3d9d644334d06ff674a22274bf4e04619853b15 100755 --- a/HPC_setup/create_runscripts_HPC.sh +++ b/HPC_setup/create_runscripts_HPC.sh @@ -85,7 +85,7 @@ source venv_${hpcsys}/bin/activate timestamp=\`date +"%Y-%m-%d_%H%M-%S"\` -export PYTHONPATH=\${PWD}/venv_${hpcsys}/lib/python3.8/site-packages:\${PYTHONPATH} +export PYTHONPATH=\${PWD}/venv_${hpcsys}/lib/python3.9/site-packages:\${PYTHONPATH} srun --cpu-bind=none python run.py --experiment_date=\$timestamp EOT @@ -111,7 +111,7 @@ source venv_${hpcsys}/bin/activate timestamp=\`date +"%Y-%m-%d_%H%M-%S"\` -export PYTHONPATH=\${PWD}/venv_${hpcsys}/lib/python3.8/site-packages:\${PYTHONPATH} +export PYTHONPATH=\${PWD}/venv_${hpcsys}/lib/python3.9/site-packages:\${PYTHONPATH} srun --cpu-bind=none python run_HPC.py --experiment_date=\$timestamp EOT diff --git a/HPC_setup/mlt_modules_hdfml.sh b/HPC_setup/mlt_modules_hdfml.sh index df8ae0830ad70c572955447b1c5e87341b8af9ec..4efc5a69987f4a4687080740b93543bcf8107c4c 100644 --- a/HPC_setup/mlt_modules_hdfml.sh +++ b/HPC_setup/mlt_modules_hdfml.sh @@ -8,13 +8,12 @@ module --force purge module use $OTHERSTAGES -ml Stages/2020 -ml GCCcore/.10.3.0 +ml Stages/2022 +ml GCCcore/.11.2.0 -ml Jupyter/2021.3.1-Python-3.8.5 -ml Python/3.8.5 -ml TensorFlow/2.5.0-Python-3.8.5 -ml SciPy-Stack/2021-Python-3.8.5 -ml dask/2.22.0-Python-3.8.5 -ml GEOS/3.8.1-Python-3.8.5 -ml Graphviz/2.44.1 \ No newline at end of file +ml Python/3.9.6 +ml TensorFlow/2.6.0-CUDA-11.5 +ml dask/2021.9.1 +ml GEOS/3.9.1 +ml Cartopy/0.20.0 +ml Graphviz/2.49.3 diff --git a/HPC_setup/mlt_modules_juwels.sh b/HPC_setup/mlt_modules_juwels.sh index ffacfe6fc45302dfa60b108ca2493d9a27408df1..37636fb8834601768ade2d86dc8c7287e273a5d4 100755 --- a/HPC_setup/mlt_modules_juwels.sh +++ b/HPC_setup/mlt_modules_juwels.sh @@ -8,13 +8,12 @@ module --force purge module use $OTHERSTAGES -ml Stages/2020 -ml GCCcore/.10.3.0 +ml Stages/2022 +ml GCCcore/.11.2.0 -ml Jupyter/2021.3.1-Python-3.8.5 -ml Python/3.8.5 -ml TensorFlow/2.5.0-Python-3.8.5 -ml SciPy-Stack/2021-Python-3.8.5 -ml dask/2.22.0-Python-3.8.5 -ml GEOS/3.8.1-Python-3.8.5 -ml Graphviz/2.44.1 \ No newline at end of file +ml Python/3.9.6 +ml TensorFlow/2.6.0-CUDA-11.5 +ml dask/2021.9.1 +ml GEOS/3.9.1 +ml Cartopy/0.20.0 +ml Graphviz/2.49.3 diff --git a/HPC_setup/requirements_HDFML_additionals.txt b/HPC_setup/requirements_HDFML_additionals.txt index ebfac3cd0d989a8845f2a3fceba33d562b898b8d..90ccd220394c9afb00fba7e069af8e677d4ae0b2 100644 --- a/HPC_setup/requirements_HDFML_additionals.txt +++ b/HPC_setup/requirements_HDFML_additionals.txt @@ -1,15 +1,19 @@ -astropy==4.1 -bottleneck==1.3.2 -cached-property==1.5.2 -iniconfig==1.1.1 -ordered-set==4.0.2 -pyshp==2.1.3 -pytest-html==3.1.1 -pytest-lazy-fixture==0.6.3 -pytest-metadata==1.11.0 -pytest-sugar==0.9.4 -tabulate==0.8.8 +astropy==5.1 +pytz==2022.1 +python-dateutil==2.8.2 +requests==2.28.1 +werkzeug>=0.11.15 +wheel>=0.26 +six==1.15.0 +psutil==5.9.1 +pyparsing==3.0.9 +packaging==21.3 +timezonefinder==5.2.0 +patsy==0.5.2 +statsmodels==0.13.2 +seaborn==0.11.2 +xarray==0.16.2 +tabulate==0.8.10 wget==3.2 ---no-binary shapely Shapely==1.7.0 - -#Cartopy==0.18.0 +pydot==1.4.2 +netcdf4==1.6.0 \ No newline at end of file diff --git a/HPC_setup/requirements_JUWELS_additionals.txt b/HPC_setup/requirements_JUWELS_additionals.txt index ebfac3cd0d989a8845f2a3fceba33d562b898b8d..90ccd220394c9afb00fba7e069af8e677d4ae0b2 100644 --- a/HPC_setup/requirements_JUWELS_additionals.txt +++ b/HPC_setup/requirements_JUWELS_additionals.txt @@ -1,15 +1,19 @@ -astropy==4.1 -bottleneck==1.3.2 -cached-property==1.5.2 -iniconfig==1.1.1 -ordered-set==4.0.2 -pyshp==2.1.3 -pytest-html==3.1.1 -pytest-lazy-fixture==0.6.3 -pytest-metadata==1.11.0 -pytest-sugar==0.9.4 -tabulate==0.8.8 +astropy==5.1 +pytz==2022.1 +python-dateutil==2.8.2 +requests==2.28.1 +werkzeug>=0.11.15 +wheel>=0.26 +six==1.15.0 +psutil==5.9.1 +pyparsing==3.0.9 +packaging==21.3 +timezonefinder==5.2.0 +patsy==0.5.2 +statsmodels==0.13.2 +seaborn==0.11.2 +xarray==0.16.2 +tabulate==0.8.10 wget==3.2 ---no-binary shapely Shapely==1.7.0 - -#Cartopy==0.18.0 +pydot==1.4.2 +netcdf4==1.6.0 \ No newline at end of file diff --git a/HPC_setup/setup_venv_hdfml.sh b/HPC_setup/setup_venv_hdfml.sh index f1b4a63f9a5c90d7afacb5c3dc027adb4e6e29fc..11c273b477ea26383e53799ae0025ceb5c947a4a 100644 --- a/HPC_setup/setup_venv_hdfml.sh +++ b/HPC_setup/setup_venv_hdfml.sh @@ -22,7 +22,7 @@ python3 -m venv ${cur}../venv_hdfml source ${cur}/../venv_hdfml/bin/activate # export path for side-packages -export PYTHONPATH=${cur}/../venv_hdfml/lib/python3.8/site-packages:${PYTHONPATH} +export PYTHONPATH=${cur}/../venv_hdfml/lib/python3.9/site-packages:${PYTHONPATH} echo "##### START INSTALLING requirements_HDFML_additionals.txt #####" pip install -r ${cur}/requirements_HDFML_additionals.txt diff --git a/HPC_setup/setup_venv_juwels.sh b/HPC_setup/setup_venv_juwels.sh index 3e1f489532ef118522ccd37dd56cf6e6306046ac..8d609b8f5094de4e3840aad50656b5c11ff1a86d 100755 --- a/HPC_setup/setup_venv_juwels.sh +++ b/HPC_setup/setup_venv_juwels.sh @@ -22,7 +22,7 @@ python3 -m venv ${cur}/../venv_juwels source ${cur}/../venv_juwels/bin/activate # export path for side-packages -export PYTHONPATH=${cur}/../venv_juwels/lib/python3.8/site-packages:${PYTHONPATH} +export PYTHONPATH=${cur}/../venv_juwels/lib/python3.9/site-packages:${PYTHONPATH} echo "##### START INSTALLING requirements_JUWELS_additionals.txt #####" diff --git a/README.md b/README.md index 8decf00b29f91e0a3a014bbf57e92aff12c5e035..792c6d4a06564eb050467f271f660761ec4d3d71 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ HPC systems, see [here](#special-instructions-for-installation-on-jülich-hpc-sy * Installation of **MLAir**: * Either clone MLAir from the [gitlab repository](https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair.git) and use it without installation (beside the requirements) - * or download the distribution file ([current version](https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.0.0-py3-none-any.whl)) + * or download the distribution file ([current version](https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.1.0-py3-none-any.whl)) and install it via `pip install <dist_file>.whl`. In this case, you can simply import MLAir in any python script inside your virtual environment using `import mlair`. diff --git a/dist/mlair-2.1.0-py3-none-any.whl b/dist/mlair-2.1.0-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..b5069f2ae900ff7bf43428d3adba8a50be742588 Binary files /dev/null and b/dist/mlair-2.1.0-py3-none-any.whl differ diff --git a/docs/_source/installation.rst b/docs/_source/installation.rst index 6ac4937e6a729c12e54007aa32f0e59635289fdd..6cbf8c424bdd29470c23eb95a9b5d3a5071cf39f 100644 --- a/docs/_source/installation.rst +++ b/docs/_source/installation.rst @@ -27,7 +27,7 @@ Installation of MLAir * Install all requirements from `requirements.txt <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/requirements.txt>`_ preferably in a virtual environment * Either clone MLAir from the `gitlab repository <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair.git>`_ -* or download the distribution file (`current version <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.0.0-py3-none-any.whl>`_) +* or download the distribution file (`current version <https://gitlab.jsc.fz-juelich.de/esde/machine-learning/mlair/-/blob/master/dist/mlair-2.1.0-py3-none-any.whl>`_) and install it via :py:`pip install <dist_file>.whl`. In this case, you can simply import MLAir in any python script inside your virtual environment using :py:`import mlair`. diff --git a/docs/requirements_docs.txt b/docs/requirements_docs.txt index ee455d83f0debc10faa09ffd82cad9a77930d936..66fca62c011263ddb81ab43b2c5258789073e641 100644 --- a/docs/requirements_docs.txt +++ b/docs/requirements_docs.txt @@ -2,8 +2,8 @@ sphinx==3.0.3 sphinx-autoapi==1.8.4 sphinx-autodoc-typehints==1.12.0 sphinx-rtd-theme==0.4.3 -#recommonmark==0.6.0 m2r2==0.3.1 docutils<0.18 mistune==0.8.4 -setuptools>=59.5.0 \ No newline at end of file +setuptools>=59.5.0 +Jinja2<3.1 \ No newline at end of file diff --git a/mlair/__init__.py b/mlair/__init__.py index 2ca5c3ab96fb3f96fa2343efab02860d465db870..901947e5313a183e3909687b1fea0096075f836c 100644 --- a/mlair/__init__.py +++ b/mlair/__init__.py @@ -1,6 +1,6 @@ __version_info__ = { 'major': 2, - 'minor': 0, + 'minor': 1, 'micro': 0, } diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index 6f1980e2f355219865fe441757ebcd4c23b36076..5be0e7d3a3e5dd8b402230f281b528fb0a2476a2 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -9,7 +9,6 @@ DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 'pblheight': 'maximum'} -DEFAULT_NETWORK = "AIRBASE" DEFAULT_STATION_TYPE = "background" DEFAULT_VARIABLES = DEFAULT_VAR_ALL_DICT.keys() DEFAULT_START = "1997-01-01" diff --git a/mlair/configuration/era5_settings.py b/mlair/configuration/era5_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..9f44176bd50bf95226a0ea7a4913152a34619f9a --- /dev/null +++ b/mlair/configuration/era5_settings.py @@ -0,0 +1,19 @@ +"""Settings to access not public era5 data.""" + +from typing import Tuple + + +def era5_settings(sampling="daily") -> Tuple[str, str]: + """ + Check for sampling as only hourly resolution is supported by era5 and return path on HPC systems. + + :param sampling: temporal resolution to load data for, only hourly supported (default "daily") + + :return: HPC path + """ + if sampling == "hourly": # pragma: no branch + ERA5_DATA_PATH = "." + FILE_NAMES = "*.nc" + else: + raise NameError(f"Given sampling {sampling} is not supported, only hourly sampling can be used.") + return ERA5_DATA_PATH, FILE_NAMES diff --git a/mlair/configuration/toar_data_v2_settings.py b/mlair/configuration/toar_data_v2_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..a8bb9f42cf5a1967f150aa18019c2dbdc89f43a2 --- /dev/null +++ b/mlair/configuration/toar_data_v2_settings.py @@ -0,0 +1,20 @@ +"""Settings to access https://toar-data.fz-juelich.de""" +from typing import Tuple, Dict + + +def toar_data_v2_settings(sampling="daily") -> Tuple[str, Dict]: + """ + Set url for toar-data and required headers. Headers information is not required for now. + + :param sampling: temporal resolution to access. + :return: Service url and optional headers + """ + if sampling == "daily": # pragma: no branch + TOAR_SERVICE_URL = "https://toar-data.fz-juelich.de/statistics/api/v1/" + headers = {} + elif sampling == "hourly" or sampling == "meta": + TOAR_SERVICE_URL = "https://toar-data.fz-juelich.de/api/v2/" + headers = {} + else: + raise NameError(f"Given sampling {sampling} is not supported, choose from either daily or hourly sampling.") + return TOAR_SERVICE_URL, headers diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 84596ad081b922a92a91b3df0513a4e730b8eb53..eaa6a21175bd5f88c32c9c3cb74947c0cc0956a3 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -63,8 +63,7 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): vars = [self.variables, self.target_var] stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], - self.station_type, self.network, self.store_data_locally, self.data_origin, - self.start, self.end) + self.store_data_locally, self.data_origin, self.start, self.end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], limit=self.interpolation_limit[ind], sampling=self.sampling[ind]) @@ -115,7 +114,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi def make_input_target(self): """ - A FIR filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values + A FIR filter is applied on the input data that has hourly resolution. Labels Y are provided as aggregated values with daily resolution. """ self._data = tuple(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data @@ -147,8 +146,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], - self.station_type, self.network, self.store_data_locally, self.data_origin, - start, end) + self.store_data_locally, self.data_origin, start, end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], limit=self.interpolation_limit[ind], sampling=self.sampling[ind]) return data @@ -353,6 +351,7 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs} sp_keys = cls.build_update_transformation(sp_keys, dh_type="unfiltered_chem") cls.prepare_build(sp_keys, chem_vars, cls.chem_indicator) + cls.correct_overwrite_option(sp_keys) sp_chem_unfiltered = cls.data_handler_unfiltered(station, **sp_keys) if len(meteo_vars) > 0: cls.set_data_handler_fir_pos(**kwargs) @@ -364,11 +363,18 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs} sp_keys = cls.build_update_transformation(sp_keys, dh_type="unfiltered_meteo") cls.prepare_build(sp_keys, meteo_vars, cls.meteo_indicator) + cls.correct_overwrite_option(sp_keys) sp_meteo_unfiltered = cls.data_handler_unfiltered(station, **sp_keys) dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} return cls(sp_chem, sp_meteo, sp_chem_unfiltered, sp_meteo_unfiltered, chem_vars, meteo_vars, **dp_args) + @classmethod + def correct_overwrite_option(cls, kwargs): + """Set `overwrite_local_data=False`.""" + if "overwrite_local_data" in kwargs: + kwargs["overwrite_local_data"] = False + @classmethod def set_data_handler_fir_pos(cls, **kwargs): """ diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index 4217583d4b7ae03a2529deaae38fd33234bba5db..ec0f1f73282979a1e69945e1ad7f6817bdf3ba12 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -20,8 +20,9 @@ import xarray as xr from mlair.configuration import check_path_and_create from mlair import helpers -from mlair.helpers import join, statistics, TimeTrackingWrapper +from mlair.helpers import statistics, TimeTrackingWrapper, filter_dict_by_value, select_from_dict from mlair.data_handler.abstract_data_handler import AbstractDataHandler +from mlair.helpers import data_sources # define a more general date type for type hinting date = Union[dt.date, dt.datetime] @@ -38,8 +39,6 @@ class DataHandlerSingleStation(AbstractDataHandler): indicates that not all values up to t0 are used, a positive values indicates usage of values at t>t0. Default is 0. """ - DEFAULT_STATION_TYPE = "background" - DEFAULT_NETWORK = "AIRBASE" DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 'pblheight': 'maximum'} @@ -58,12 +57,11 @@ class DataHandlerSingleStation(AbstractDataHandler): chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", "toluene"] - _hash = ["station", "statistics_per_var", "data_origin", "station_type", "network", "sampling", "target_dim", - "target_var", "time_dim", "iter_dim", "window_dim", "window_history_size", "window_history_offset", - "window_lead_time", "interpolation_limit", "interpolation_method", "variables", "window_history_end"] + _hash = ["station", "statistics_per_var", "data_origin", "sampling", "target_dim", "target_var", "time_dim", + "iter_dim", "window_dim", "window_history_size", "window_history_offset", "window_lead_time", + "interpolation_limit", "interpolation_method", "variables", "window_history_end"] - def __init__(self, station, data_path, statistics_per_var=None, station_type=DEFAULT_STATION_TYPE, - network=DEFAULT_NETWORK, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING, + def __init__(self, station, data_path, statistics_per_var=None, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING, target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM, iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM, window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET, @@ -87,8 +85,6 @@ class DataHandlerSingleStation(AbstractDataHandler): self.input_data, self.target_data = None, None self._transformation = self.setup_transformation(transformation) - self.station_type = station_type - self.network = network self.sampling = sampling self.target_dim = target_dim self.target_var = target_var @@ -140,9 +136,8 @@ class DataHandlerSingleStation(AbstractDataHandler): return self._data.shape, self.get_X().shape, self.get_Y().shape def __repr__(self): - return f"StationPrep(station={self.station}, data_path='{self.path}', " \ + return f"StationPrep(station={self.station}, data_path='{self.path}', data_origin={self.data_origin}, " \ f"statistics_per_var={self.statistics_per_var}, " \ - f"station_type='{self.station_type}', network='{self.network}', " \ f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \ f"time_dim='{self.time_dim}', window_history_size={self.window_history_size}, " \ f"window_lead_time={self.window_lead_time}, interpolation_limit={self.interpolation_limit}, " \ @@ -169,8 +164,12 @@ class DataHandlerSingleStation(AbstractDataHandler): return self.get_transposed_label() def get_coordinates(self): - coords = self.meta.loc[["station_lon", "station_lat"]].astype(float) - return coords.rename(index={"station_lon": "lon", "station_lat": "lat"}).to_dict()[str(self)] + try: + coords = self.meta.loc[["station_lon", "station_lat"]].astype(float) + coords = coords.rename(index={"station_lon": "lon", "station_lat": "lat"}) + except KeyError: + coords = self.meta.loc[["lon", "lat"]].astype(float) + return coords.to_dict()[str(self)] def call_transform(self, inverse=False): opts_input = self._transformation[0] @@ -301,8 +300,7 @@ class DataHandlerSingleStation(AbstractDataHandler): def make_input_target(self): data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, - self.station_type, self.network, self.store_data_locally, self.data_origin, - self.start, self.end) + self.store_data_locally, self.data_origin, self.start, self.end) self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit, sampling=self.sampling) self.set_inputs_and_targets() @@ -320,8 +318,8 @@ class DataHandlerSingleStation(AbstractDataHandler): self.make_observation(self.target_dim, self.target_var, self.time_dim) self.remove_nan(self.time_dim) - def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None, - store_data_locally=False, data_origin: Dict = None, start=None, end=None): + def load_data(self, path, station, statistics_per_var, sampling, store_data_locally=False, + data_origin: Dict = None, start=None, end=None): """ Load data and meta data either from local disk (preferred) or download new data by using a custom download method. @@ -339,35 +337,34 @@ class DataHandlerSingleStation(AbstractDataHandler): if os.path.exists(meta_file): os.remove(meta_file) data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, - station_type=station_type, network=network, - store_data_locally=store_data_locally, data_origin=data_origin) + store_data_locally=store_data_locally, data_origin=data_origin, + time_dim=self.time_dim, target_dim=self.target_dim, iter_dim=self.iter_dim) logging.debug(f"loaded new data") else: try: logging.debug(f"try to load local data from: {file_name}") data = xr.open_dataarray(file_name) meta = pd.read_csv(meta_file, index_col=0) - self.check_station_meta(meta, station, station_type, network) + self.check_station_meta(meta, station, data_origin, statistics_per_var) logging.debug("loading finished") except FileNotFoundError as e: logging.debug(e) logging.debug(f"load new data") data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, - station_type=station_type, network=network, - store_data_locally=store_data_locally, data_origin=data_origin) + store_data_locally=store_data_locally, data_origin=data_origin, + time_dim=self.time_dim, target_dim=self.target_dim, + iter_dim=self.iter_dim) logging.debug("loading finished") # create slices and check for negative concentration. data = self._slice_prep(data, start=start, end=end) data = self.check_for_negative_concentrations(data) return data, meta - @staticmethod - def download_data_from_join(file_name: str, meta_file: str, station, statistics_per_var, sampling, - station_type=None, network=None, store_data_locally=True, data_origin: Dict = None, - time_dim=DEFAULT_TIME_DIM, target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM) \ - -> [xr.DataArray, pd.DataFrame]: + def download_data(self, file_name: str, meta_file: str, station, statistics_per_var, sampling, + store_data_locally=True, data_origin: Dict = None, time_dim=DEFAULT_TIME_DIM, + target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM) -> [xr.DataArray, pd.DataFrame]: """ - Download data from TOAR database using the JOIN interface. + Download data from TOAR database using the JOIN interface or load local era5 data. Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally stored locally using given names for file and meta file. @@ -378,8 +375,40 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: downloaded data and its meta data """ df_all = {} - df, meta = join.download_join(station_name=station, stat_var=statistics_per_var, station_type=station_type, - network_name=network, sampling=sampling, data_origin=data_origin) + df_era5, df_toar = None, None + meta_era5, meta_toar = None, None + if data_origin is not None: + era5_origin = filter_dict_by_value(data_origin, "era5", True) + era5_stats = select_from_dict(statistics_per_var, era5_origin.keys()) + toar_origin = filter_dict_by_value(data_origin, "era5", False) + toar_stats = select_from_dict(statistics_per_var, era5_origin.keys(), filter_cond=False) + assert len(era5_origin) + len(toar_origin) == len(data_origin) + assert len(era5_stats) + len(toar_stats) == len(statistics_per_var) + else: + era5_origin, toar_origin = None, None + era5_stats, toar_stats = statistics_per_var, statistics_per_var + + # load data + if era5_origin is not None and len(era5_stats) > 0: + # load era5 data + df_era5, meta_era5 = data_sources.era5.load_era5(station_name=station, stat_var=era5_stats, + sampling=sampling, data_origin=era5_origin) + if toar_origin is None or len(toar_stats) > 0: + # load combined data from toar-data (v2 & v1) + df_toar, meta_toar = data_sources.toar_data.download_toar(station=station, toar_stats=toar_stats, + sampling=sampling, data_origin=toar_origin) + + if df_era5 is None and df_toar is None: + raise data_sources.toar_data.EmptyQueryResult(f"No data available for era5 and toar-data") + + df = pd.concat([df_era5, df_toar], axis=1, sort=True) + if meta_era5 is not None and meta_toar is not None: + meta = meta_era5.combine_first(meta_toar) + else: + meta = meta_era5 if meta_era5 is not None else meta_toar + meta.loc["data_origin"] = str(data_origin) + meta.loc["statistics_per_var"] = str(statistics_per_var) + df_all[station[0]] = df # convert df_all to xarray xarr = {k: xr.DataArray(v, dims=[time_dim, target_dim]) for k, v in df_all.items()} @@ -390,28 +419,22 @@ class DataHandlerSingleStation(AbstractDataHandler): meta.to_csv(meta_file) return xarr, meta - def download_data(self, *args, **kwargs): - data, meta = self.download_data_from_join(*args, **kwargs, time_dim=self.time_dim, target_dim=self.target_dim, - iter_dim=self.iter_dim) - return data, meta - @staticmethod - def check_station_meta(meta, station, station_type, network): + def check_station_meta(meta, station, data_origin, statistics_per_var): """ Search for the entries in meta data and compare the value with the requested values. Will raise a FileNotFoundError if the values mismatch. """ - if station_type is not None: - check_dict = {"station_type": station_type, "network_name": network} - for (k, v) in check_dict.items(): - if v is None: - continue - if meta.at[k, station[0]] != v: - logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " - f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new " - f"grapping from web.") - raise FileNotFoundError + check_dict = {"data_origin": str(data_origin), "statistics_per_var": str(statistics_per_var)} + for (k, v) in check_dict.items(): + if v is None or k not in meta.index: + continue + if meta.at[k, station[0]] != v: + logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " + f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new " + f"grapping from web.") + raise FileNotFoundError def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: """ diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index 47ccc5510c8135745c518611504cd02900a1f883..e5760e9afb52f9d55071214fb632601d744f124e 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -68,8 +68,7 @@ class DataHandlerFilterSingleStation(DataHandlerSingleStation): def make_input_target(self): data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, - self.station_type, self.network, self.store_data_locally, self.data_origin, - self.start, self.end) + self.store_data_locally, self.data_origin, self.start, self.end) self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) self.set_inputs_and_targets() diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 300e0435c4e8441e299675319e2c72604ebb3200..ae46ad918630f2a5f083c62b558609e85d7cc2d8 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -22,7 +22,7 @@ import xarray as xr from mlair.data_handler.abstract_data_handler import AbstractDataHandler from mlair.helpers import remove_items, to_list, TimeTrackingWrapper -from mlair.helpers.join import EmptyQueryResult +from mlair.helpers.data_sources.toar_data import EmptyQueryResult number = Union[float, int] @@ -168,7 +168,7 @@ class DefaultDataHandler(AbstractDataHandler): dim = self.time_dim intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original)) if len(intersect) < max(self.min_length, 1): - X, Y = None, None + raise ValueError(f"There is no intersection of X.") else: X = list(map(lambda x: x.sel({dim: intersect}), X_original)) Y = Y_original.sel({dim: intersect}) @@ -205,10 +205,6 @@ class DefaultDataHandler(AbstractDataHandler): if True only extract values larger than extreme_values :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime """ - # check if X or Y is None - if (self._X is None) or (self._Y is None): - logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes") - return if extreme_values is None: logging.debug(f"No extreme values given, skip multiply extremes") self._X_extreme, self._Y_extreme = self._X, self._Y diff --git a/mlair/helpers/__init__.py b/mlair/helpers/__init__.py index 3a5b8699a11ae39c0d3510a534db1dd144419d09..cf50fa05885d576bd64de67b83df3c8ed6d272e2 100644 --- a/mlair/helpers/__init__.py +++ b/mlair/helpers/__init__.py @@ -1,6 +1,7 @@ """Collection of different supporting functions and classes.""" -from .testing import PyTestRegex, PyTestAllEqual +from .testing import PyTestRegex, PyTestAllEqual, check_nested_equality from .time_tracking import TimeTracking, TimeTrackingWrapper from .logger import Logger -from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, make_keras_pickable, sort_like +from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict, \ + make_keras_pickable, sort_like, filter_dict_by_value diff --git a/mlair/helpers/data_sources/__init__.py b/mlair/helpers/data_sources/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b753bc3afb961be65ff0f934ef4f0de08804a0b --- /dev/null +++ b/mlair/helpers/data_sources/__init__.py @@ -0,0 +1,10 @@ +""" +Data Sources. + +The module data_sources collects different data sources, namely ERA5, TOAR-Data v1 (JOIN), and TOAR-Data v2 +""" + +__author__ = "Lukas Leufen" +__date__ = "2022-07-05" + +from . import era5, join, toar_data, toar_data_v2 diff --git a/mlair/helpers/data_sources/era5.py b/mlair/helpers/data_sources/era5.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb7a03b2629db1d006e03fcc9d30b2af714c270 --- /dev/null +++ b/mlair/helpers/data_sources/era5.py @@ -0,0 +1,88 @@ +"""Methods to load era5 data.""" +__author__ = "Lukas Leufen" +__date__ = "2022-06-09" + +import logging +import os + +import pandas as pd +import xarray as xr + +from mlair import helpers +from mlair.configuration.era5_settings import era5_settings +from mlair.configuration.toar_data_v2_settings import toar_data_v2_settings +from mlair.helpers.data_sources.toar_data_v2 import load_station_information, combine_meta_data, correct_timezone +from mlair.helpers.data_sources.toar_data import EmptyQueryResult +from mlair.helpers.meteo import relative_humidity_from_dewpoint + + +def load_era5(station_name, stat_var, sampling, data_origin): + + # make sure station_name parameter is a list + station_name = helpers.to_list(station_name) + + # get data path + data_path, file_names = era5_settings(sampling) + + # correct stat_var values if data is not aggregated (hourly) + if sampling == "hourly": + stat_var = {key: "values" for key in stat_var.keys()} + else: + raise ValueError(f"Given sampling {sampling} is not supported, only hourly sampling can be used.") + + # load station meta using toar-data v2 API + meta_url_base, headers = toar_data_v2_settings("meta") + station_meta = load_station_information(station_name, meta_url_base, headers) + + # sel data for station using sel method nearest + logging.info(f"load data for {station_meta['codes'][0]} from ERA5") + try: + with xr.open_mfdataset(os.path.join(data_path, file_names)) as data: + lon, lat = station_meta["coordinates"]["lng"], station_meta["coordinates"]["lat"] + station_dask = data.sel(lon=lon, lat=lat, method="nearest", drop=True) + station_data = station_dask.to_array().T.compute() + except OSError as e: + logging.info(f"Cannot load era5 data from path {data_path} and filenames {file_names} due to: {e}") + return None, None + + # transform data and meta to pandas + station_data = station_data.to_pandas() + if "relhum" in stat_var: + station_data["RHw"] = relative_humidity_from_dewpoint(station_data["D2M"], station_data["T2M"]) + station_data.columns = _rename_era5_variables(station_data.columns) + + # check if all requested variables are available + if set(stat_var).issubset(station_data.columns) is False: + missing_variables = set(stat_var).difference(stat_var) + origin = helpers.select_from_dict(data_origin, missing_variables) + options = f"station={station_name}, origin={origin}" + raise EmptyQueryResult(f"No data found for variables {missing_variables} and options {options} in JOIN.") + else: + station_data = station_data[stat_var] + + # convert to local timezone + station_data = correct_timezone(station_data, station_meta, sampling) + + variable_meta = _emulate_meta_data(station_data) + meta = combine_meta_data(station_meta, variable_meta) + meta = pd.DataFrame.from_dict(meta, orient='index') + meta.columns = station_name + return station_data, meta + + +def _emulate_meta_data(station_data): + general_meta = {"sampling_frequency": "hourly", "data_origin": "model", "data_origin_type": "model"} + roles_meta = {"roles": [{"contact": {"organisation": {"name": "ERA5", "longname": "ECMWF"}}}]} + variable_meta = {var: {"variable": {"name": var}, **roles_meta, ** general_meta} for var in station_data.columns} + return variable_meta + + +def _rename_era5_variables(era5_names): + mapper = {"SP": "press", "U10M": "u", "V10M": "v", "T2M": "temp", "D2M": "dew", "BLH": "pblheight", + "TCC": "cloudcover", "RHw": "relhum"} + era5_names = list(era5_names) + try: + join_names = list(map(lambda x: mapper[x], era5_names)) + return join_names + except KeyError as e: + raise KeyError(f"Cannot map names from era5 to join naming convention: {e}") \ No newline at end of file diff --git a/mlair/helpers/data_sources/join.py b/mlair/helpers/data_sources/join.py new file mode 100644 index 0000000000000000000000000000000000000000..a978b2712a83b21f3c1256b2bf0826da63bdda3a --- /dev/null +++ b/mlair/helpers/data_sources/join.py @@ -0,0 +1,366 @@ +"""Functions to access join database.""" +__author__ = 'Felix Kleinert, Lukas Leufen' +__date__ = '2019-10-16' + +import datetime as dt +import logging +from typing import Iterator, Union, List, Dict, Tuple + +import pandas as pd + +from mlair import helpers +from mlair.configuration.join_settings import join_settings +from mlair.helpers.data_sources import toar_data, toar_data_v2 + + +# join_url_base = 'https://join.fz-juelich.de/services/rest/surfacedata/' +str_or_none = Union[str, None] + + +def download_join(station_name: Union[str, List[str]], stat_var: dict, station_type: str = None, + sampling: str = "daily", data_origin: Dict = None) -> [pd.DataFrame, pd.DataFrame]: + """ + Read data from JOIN/TOAR. + + :param station_name: Station name e.g. DEBY122 + :param stat_var: key as variable like 'O3', values as statistics on keys like 'mean' + :param station_type: set the station type like "traffic" or "background", can be none + :param sampling: sampling rate of the downloaded data, either set to daily or hourly (default daily) + :param data_origin: additional dictionary to specify data origin as key (for variable) value (origin) pair. Valid + origins are "REA" for reanalysis data and "" (empty string) for observational data. + + :returns: data frame with all variables and statistics and meta data frame with all meta information + """ + # make sure station_name parameter is a list + station_name = helpers.to_list(station_name) + + # split network and origin information + data_origin, network_name = split_network_and_origin(data_origin) + + # get data connection settings + join_url_base, headers = join_settings(sampling) + + # load series information + vars_dict, data_origin = load_series_information(station_name, station_type, network_name, join_url_base, headers, + data_origin, stat_var) + + # check if all requested variables are available + if set(stat_var).issubset(vars_dict) is False: + missing_variables = set(stat_var).difference(vars_dict) + origin = helpers.select_from_dict(data_origin, missing_variables) + options = f"station={station_name}, type={station_type}, network={network_name}, origin={origin}" + raise toar_data.EmptyQueryResult(f"No data found for variables {missing_variables} and options {options} in " + f"JOIN.") + + # correct stat_var values if data is not aggregated (hourly) + if sampling == "hourly": + stat_var = {key: "values" for key in stat_var.keys()} + + # download all variables with given statistic + data = None + df = None + meta = {} + logging.info(f"load data for {station_name[0]} from JOIN") + for var in _lower_list(sorted(vars_dict.keys())): + if var in stat_var.keys(): + + logging.debug('load: {}'.format(var)) + + # create data link + opts = {'base': join_url_base, 'service': 'stats', 'id': vars_dict[var], 'statistics': stat_var[var], + 'sampling': sampling, 'capture': 0, 'format': 'json'} + + # load data + data = toar_data.get_data(opts, headers) + + # adjust data format if given as list of list + # no branch cover because this just happens when downloading hourly data using a secret token, not available + # for CI testing. + if isinstance(data, list): # pragma: no branch + data = correct_data_format(data) + + # correct namespace of statistics + stat = toar_data.correct_stat_name(stat_var[var]) + + # store data in pandas dataframe + df = _save_to_pandas(df, data, stat, var) + meta[var] = _correct_meta(data["metadata"]) + + logging.debug('finished: {}'.format(var)) + + if data: + # load station meta using toar-data v2 API and convert to local timezone + meta_url_base, headers = toar_data_v2.toar_data_v2_settings("meta") + station_meta = toar_data_v2.load_station_information(station_name, meta_url_base, headers) + df = toar_data_v2.correct_timezone(df, station_meta, sampling) + + # create meta data + meta = toar_data_v2.combine_meta_data(station_meta, meta) + meta = pd.DataFrame.from_dict(meta, orient='index') + meta.columns = station_name + return df, meta + else: + raise toar_data.EmptyQueryResult("No data found in JOIN.") + + +def _correct_meta(meta): + meta_out = {} + for k, v in meta.items(): + if k.startswith("station"): + _k = k.split("_", 1)[1] + _d = meta_out.get("station", {}) + _d[_k] = v + meta_out["station"] = _d + elif k.startswith("parameter"): + _k = k.split("_", 1)[1] + _d = meta_out.get("variable", {}) + _d[_k] = v + meta_out["variable"] = _d + elif k == "network_name": + if v == "AIRBASE": + _d = {"name": "EEA", "longname": "European Environment Agency", "kind": "government"} + elif v == "UBA": + _d = {"name": "UBA", "longname": "Umweltbundesamt", "kind": "government", "country": "Germany"} + else: + _d = {"name": v} + meta_out["roles"] = [{"contact": {"organisation": _d}}] + elif k in ["google_resolution", "numid"]: + continue + else: + meta_out[k] = v + return meta_out + + +def split_network_and_origin(origin_network_dict: dict) -> Tuple[Union[None, dict], Union[None, dict]]: + """ + Split given dict into network and data origin. + + Method is required to transform Toar-Data v2 structure (using only origin) into Toar-Data v1 (JOIN) structure (which + uses origin and network parameter). Furthermore, EEA network (v2) is renamed to AIRBASE (v1). + """ + if origin_network_dict is None or len(origin_network_dict) == 0: + data_origin, network = None, None + else: + data_origin = {} + network = {} + for k, v in origin_network_dict.items(): + network[k] = [] + for _network in helpers.to_list(v): + if _network.lower() == "EEA".lower(): + network[k].append("AIRBASE") + elif _network.lower() != "REA".lower(): + network[k].append(_network) + if "REA" in v: + data_origin[k] = "REA" + else: + data_origin[k] = "" + network[k] = filter_network(network[k]) + return data_origin, network + + +def filter_network(network: list) -> Union[list, None]: + """ + Filter given list of networks. + + :param network: list of various network names (can contain duplicates) + :return: sorted list with unique entries + """ + sorted_network = [] + for v in list(filter(lambda x: x != "", network)): + if v not in sorted_network: + sorted_network.append(v) + if len(sorted_network) == 0: + sorted_network = None + return sorted_network + + +def correct_data_format(data): + """ + Transform to the standard data format. + + For some cases (e.g. hourly data), the data is returned as list instead of a dictionary with keys datetime, values + and metadata. This functions addresses this issue and transforms the data into the dictionary version. + + :param data: data in hourly format + + :return: the same data but formatted to fit with aggregated format + """ + formatted = {"datetime": [], + "values": [], + "metadata": data[-1]} + for d in data[:-1]: + for k, v in zip(["datetime", "values"], d): + formatted[k].append(v) + return formatted + + +def load_series_information(station_name: List[str], station_type: str_or_none, network_name: str_or_none, + join_url_base: str, headers: Dict, data_origin: Dict = None, stat_var: Dict = None) -> [Dict, Dict]: + """ + List all series ids that are available for given station id and network name. + + :param station_name: Station name e.g. DEBW107 + :param station_type: station type like "traffic" or "background" + :param network_name: measurement network of the station like "UBA" or "AIRBASE" + :param join_url_base: base url name to download data from + :param headers: additional headers information like authorization, can be empty + :param data_origin: additional information to select a distinct series e.g. from reanalysis (REA) or from observation + ("", empty string). This dictionary should contain a key for each variable and the information as key + :return: all available series for requested station stored in an dictionary with parameter name (variable) as key + and the series id as value. + """ + network_name_opts = _create_network_name_opts(network_name) + parameter_name_opts = _create_parameter_name_opts(stat_var) + opts = {"base": join_url_base, "service": "search", "station_id": station_name[0], "station_type": station_type, + "network_name": network_name_opts, "as_dict": "true", "parameter_name": parameter_name_opts, + "columns": "id,network_name,station_id,parameter_name,parameter_label,parameter_attribute"} + station_vars = toar_data.get_data(opts, headers) + logging.debug(f"{station_name}: {station_vars}") + return _select_distinct_series(station_vars, data_origin, network_name) + + +def _create_parameter_name_opts(stat_var): + if stat_var is None: + parameter_name_opts = None + else: + parameter_name_opts = ",".join(stat_var.keys()) + return parameter_name_opts + + +def _create_network_name_opts(network_name): + if network_name is None: + network_name_opts = network_name + elif isinstance(network_name, list): + network_name_opts = ",".join(helpers.to_list(network_name)) + elif isinstance(network_name, dict): + _network = [] + for v in network_name.values(): + _network.extend(helpers.to_list(v)) + network_name_opts = ",".join(filter(lambda x: x is not None, set(_network))) + network_name_opts = None if len(network_name_opts) == 0 else network_name_opts + else: + raise TypeError(f"network_name parameter must be of type None, list, or dict. Given is {type(network_name)}.") + return network_name_opts + + +def _select_distinct_series(vars: List[Dict], data_origin: Dict = None, network_name: Union[str, List[str]] = None) \ + -> [Dict, Dict]: + """ + Select distinct series ids for all variables. Also check if a parameter is from REA or not. + """ + data_origin = {} if data_origin is None else data_origin + selected, data_origin = _select_distinct_data_origin(vars, data_origin) + + network_name = [] if network_name is None else network_name + selected = _select_distinct_network(selected, network_name) + + # extract id + selected = {k: v["id"] for k, v in selected.items()} + return selected, data_origin + + +def _select_distinct_network(vars: dict, network_name: Union[list, dict]) -> dict: + """ + Select distinct series regarding network name. The order the network names are provided in parameter `network_name` + indicates priority (from high to low). If no network name is provided, first entry is used and a logging info is + issued. In case network names are given but no match can be found, this method raises a ValueError. + + :param vars: dictionary with all series candidates already grouped by variable name as key. Value should be a list + of possible candidates to select from. Each candidate must be a dictionary with at least keys `id` and + `network_name`. + :param network_name: list of networks to use with increasing priority (1st element has priority). Can be empty list + indicating to use always first candidate for each variable. + :return: dictionary with single series reference for each variable + """ + if isinstance(network_name, (list, str)): + network_name = {var: helpers.to_list(network_name) for var in vars.keys()} + selected = {} + for var, series in vars.items(): + res = [] + network_list = helpers.to_list(network_name.get(var, []) or []) + for network in network_list: + res.extend(list(filter(lambda x: x["network_name"].upper() == network.upper(), series))) + if len(res) > 0: # use first match which has the highest priority + selected[var] = res[0] + else: + if len(network_list) == 0: # just print message which network is used if none is provided + selected[var] = series[0] + logging.info(f"Could not find a valid match for variable {var} and networks {network_name.get(var, [])}" + f"! Therefore, use first answer from JOIN: {series[0]}") + else: # raise error if network name is provided but no match could be found + raise ValueError(f"Cannot find a valid match for requested networks {network_name.get(var, [])} and " + f"variable {var} as only following networks are available in JOIN: " + f"{list(map(lambda x: x['network_name'], series))}") + return selected + + +def _select_distinct_data_origin(vars: List[Dict], data_origin: Dict) -> (Dict[str, List], Dict): + """ + Select distinct series regarding their data origin. Series are grouped as list according to their variable's name. + As series can be reported with different network attribution, results might contain multiple entries for a variable. + This method assumes the default data origin for chemical variables as `` (empty source) and for meteorological + variables as `REA`. + :param vars: list of all entries to check data origin for + :param data_origin: data origin to match series with, if empty default values are used + :return: dictionary with unique variable names as keys and list of respective series as values + """ + data_origin_default = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA", + "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", + "no": "", "no2": "", "o3": "", "pm10": "", "so2": ""} + selected = {} + for var in vars: + name = var["parameter_name"].lower() + var_attr = var["parameter_attribute"].lower() + if name not in data_origin.keys(): + data_origin.update({name: data_origin_default.get(name, "")}) + attr = data_origin.get(name, "").lower() + if var_attr == attr: + selected[name] = selected.get(name, []) + helpers.to_list(var) + return selected, data_origin + + +def _save_to_pandas(df: Union[pd.DataFrame, None], data: dict, stat: str, var: str) -> pd.DataFrame: + """ + Save given data in data frame. + + If given data frame is not empty, the data is appened as new column. + + :param df: data frame to append the new data, can be none + :param data: new data to append or format as data frame containing the keys 'datetime' and '<stat>' + :param stat: extracted statistic to get values from data (e.g. 'mean', 'dma8eu') + :param var: variable the data is from (e.g. 'o3') + + :return: new created or concatenated data frame + """ + if len(data["datetime"][0]) == 19: + str_format = "%Y-%m-%d %H:%M:%S" + else: + str_format = "%Y-%m-%d %H:%M" + index = map(lambda s: dt.datetime.strptime(s, str_format), data['datetime']) + if df is None: + df = pd.DataFrame(data[stat], index=index, columns=[var]) + else: + df = pd.concat([df, pd.DataFrame(data[stat], index=index, columns=[var])], axis=1) + return df + + +def _lower_list(args: List[str]) -> Iterator[str]: + """ + Lower all elements of given list. + + :param args: list with string entries to lower + + :return: iterator that lowers all list entries + """ + for string in args: + yield string.lower() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + var_all_dic = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', + 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', + 'pblheight': 'maximum'} + station = 'DEBW107' + # download_join(station, var_all_dic, sampling="daily") + download_join(station, var_all_dic, sampling="hourly") diff --git a/mlair/helpers/data_sources/toar_data.py b/mlair/helpers/data_sources/toar_data.py new file mode 100644 index 0000000000000000000000000000000000000000..27522855cbe0f3c6f0b78d1598709a694fc7b862 --- /dev/null +++ b/mlair/helpers/data_sources/toar_data.py @@ -0,0 +1,128 @@ +__author__ = "Lukas Leufen" +__date__ = "2022-07-05" + + +from typing import Union, List, Dict + +from . import join, toar_data_v2 + +import requests +from requests.adapters import HTTPAdapter +from requests.packages.urllib3.util.retry import Retry +import pandas as pd + + +class EmptyQueryResult(Exception): + """Exception that get raised if a query to JOIN returns empty results.""" + + pass + + +def create_url(base: str, service: str, param_id: Union[str, int, None] = None, + **kwargs: Union[str, int, float, None]) -> str: + """ + Create a request url with given base url, service type and arbitrarily many additional keyword arguments. + + :param base: basic url of the rest service + :param service: service type, e.g. series, stats + :param param_id: id for a distinct service, is added between ending / of service and ? of kwargs + :param kwargs: keyword pairs for optional request specifications, e.g. 'statistics=maximum' + + :return: combined url as string + """ + url = f"{base}" + if not url.endswith("/"): + url += "/" + if service is not None: + url = f"{url}{service}" + if not url.endswith("/"): + url += "/" + if param_id is not None: + url = f"{url}{param_id}" + if len(kwargs) > 0: + url = f"{url}?{'&'.join(f'{k}={v}' for k, v in kwargs.items() if v is not None)}" + return url + + +def get_data(opts: Dict, headers: Dict, as_json: bool = True) -> Union[Dict, List, str]: + """ + Download join data using requests framework. + + Data is returned as json like structure. Depending on the response structure, this can lead to a list or dictionary. + + :param opts: options to create the request url + :param headers: additional headers information like authorization, can be empty + :param as_json: extract response as json if true (default True) + + :return: requested data (either as list or dictionary) + """ + url = create_url(**opts) + try: + response = retries_session().get(url, headers=headers, timeout=(5, None)) # timeout=(open, read) + if response.status_code == 200: + return response.json() if as_json is True else response.text + else: + raise EmptyQueryResult(f"There was an error (STATUS {response.status_code}) for request {url}") + except requests.exceptions.RetryError as e: + raise EmptyQueryResult(f"There was an RetryError for request {url}: {e}") + + +def retries_session(max_retries=3): + retry_strategy = Retry(total=max_retries, + backoff_factor=0.1, + status_forcelist=[429, 500, 502, 503, 504], + method_whitelist=["HEAD", "GET", "OPTIONS"]) + adapter = HTTPAdapter(max_retries=retry_strategy) + http = requests.Session() + http.mount("https://", adapter) + http.mount("http://", adapter) + return http + + +def download_toar(station, toar_stats, sampling, data_origin): + + try: + # load data from toar-data (v2) + df_toar, meta_toar = toar_data_v2.download_toar(station, toar_stats, sampling=sampling, data_origin=data_origin) + except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError, IndexError): + df_toar, meta_toar = None, None + + try: + # load join data (toar-data v1) + df_join, meta_join = join.download_join(station_name=station, stat_var=toar_stats, sampling=sampling, + data_origin=data_origin) + except (AttributeError, EmptyQueryResult, KeyError, requests.ConnectionError, ValueError, IndexError): + df_join, meta_join = None, None + + # merge both data sources with priority on toar-data v2 + if df_toar is not None and df_join is not None: + df_merged = merge_toar_join(df_toar, df_join, sampling) + meta_merged = meta_toar + else: + df_merged = df_toar if df_toar is not None else df_join + meta_merged = meta_toar if df_toar is not None else meta_join + return df_merged, meta_merged + + +def merge_toar_join(df_toar, df_join, sampling): + start_date = min([df_toar.index.min(), df_join.index.min()]) + end_date = max([df_toar.index.max(), df_join.index.max()]) + freq = {"hourly": "1H", "daily": "1d"}.get(sampling) + full_time = pd.date_range(start_date, end_date, freq=freq) + full_data = df_toar.reindex(full_time) + full_data.update(df_join, overwrite=False) + return full_data + + +def correct_stat_name(stat: str) -> str: + """ + Map given statistic name to new namespace defined by mapping dict. + + Return given name stat if not element of mapping namespace. + + :param stat: namespace from JOIN server + + :return: stat mapped to local namespace + """ + mapping = {'average_values': 'mean', 'maximum': 'max', 'minimum': 'min'} + return mapping.get(stat, stat) diff --git a/mlair/helpers/data_sources/toar_data_v2.py b/mlair/helpers/data_sources/toar_data_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..0fa53a7eb23f11675eeef2c12a7d5dceec3c38ac --- /dev/null +++ b/mlair/helpers/data_sources/toar_data_v2.py @@ -0,0 +1,238 @@ +"""Functions to access https://toar-data.fz-juelich.de/api/v2/""" +__author__ = 'Lukas Leufen' +__date__ = '2022-06-30' + + +import logging +from typing import Union, List, Dict +from io import StringIO + +import pandas as pd +import pytz +from timezonefinder import TimezoneFinder + +from mlair.configuration.toar_data_v2_settings import toar_data_v2_settings +from mlair.helpers import to_list +from mlair.helpers.data_sources.toar_data import EmptyQueryResult, get_data, correct_stat_name + + +str_or_none = Union[str, None] + + +def download_toar(station_name: Union[str, List[str]], stat_var: dict, + sampling: str = "daily", data_origin: Dict = None): + """ + Download data from https://toar-data.fz-juelich.de/api/v2/ + + Uses station name to indicate measurement site and keys of stat_var to indicate variable name. If data origin is + given, this method tries to load time series for this origin. In case no origin is provided, this method loads data + with the highest priority according to toar-data's order parameter. + + :param station_name: + :param stat_var: + :param sampling: + :param data_origin: + :return: + """ + + # make sure station_name parameter is a list + station_name = to_list(station_name) + + # also ensure that given data_origin dict is no reference + if data_origin is None or len(data_origin) == 0: + data_origin = None + else: + data_origin = {k: v for (k, v) in data_origin.items()} + + # get data connection settings for meta + meta_url_base, headers = toar_data_v2_settings("meta") + + # load variables + var_meta = load_variables_information(stat_var, meta_url_base, headers) + + # load station meta + station_meta = load_station_information(station_name, meta_url_base, headers) + + # load series information + timeseries_meta = load_timeseries_information(station_meta, var_meta, meta_url_base, headers, data_origin) + + # # correct stat_var values if data is not aggregated (hourly) + # if sampling == "hourly": + # stat_var = {key: "values" for key in stat_var.keys()} + + logging.info(f"load data for {station_meta['codes'][0]} from TOAR-DATA") + # get data connection settings for data + data_url_base, headers = toar_data_v2_settings(sampling) + + data_dict = {} + for var, meta in timeseries_meta.items(): + logging.debug(f"load {var}") + meta_and_opts = prepare_meta(meta, sampling, stat_var, var) + data_var = [] + for var_meta, opts in meta_and_opts: + data_var.extend(load_timeseries_data(var_meta, data_url_base, opts, headers, sampling)) + data_dict[var] = merge_data(*data_var, sampling=sampling) + data = pd.DataFrame.from_dict(data_dict) + data = correct_timezone(data, station_meta, sampling) + + meta = combine_meta_data(station_meta, {k: v[0] for k, v in timeseries_meta.items()}) + meta = pd.DataFrame.from_dict(meta, orient='index') + meta.columns = station_name + return data, meta + + +def merge_data(*args, sampling="hourly"): + start_date = min(map(lambda x: x.index.min(), args)) + end_date = max(map(lambda x: x.index.max(), args)) + freq = {"hourly": "1H", "daily": "1d"}.get(sampling) + full_time = pd.date_range(start_date, end_date, freq=freq) + full_data = args[0].reindex(full_time) + if not isinstance(full_data, pd.DataFrame): + full_data = full_data.to_frame() + for d in args[1:]: + full_data.update(d, overwrite=False) + return full_data.squeeze() + + +def correct_timezone(data, meta, sampling): + """ + Extract timezone information and convert data index to this timezone. + + Uses UTC if no information is provided. Note that is method only modifies data in with sampling='hourly'. In all + other cases, it returns just the given data without any change. This method expects date index of data to be in UTC. + Timezone information is not added to the index to get rid of daylight saving time and ambiguous timestamps. + """ + if sampling == "hourly": + tz_info = meta.get("timezone", "UTC") + try: + tz = pytz.timezone(tz_info) + except pytz.exceptions.UnknownTimeZoneError as e: + lon, lat = meta["coordinates"]["lng"], meta["coordinates"]["lat"] + tz = pytz.timezone(TimezoneFinder().timezone_at(lng=lon, lat=lat)) + index = data.index + index = index.tz_localize(None) + utc_offset = tz.utcoffset(index[0]) - tz.dst(index[0]) + data.index = index + utc_offset + return data + + +def prepare_meta(meta, sampling, stat_var, var): + out = [] + for m in meta: + opts = {} + if sampling == "daily": + opts["timeseries_id"] = m.pop("id") + m["id"] = None + opts["names"] = stat_var[var] + opts["sampling"] = sampling + out.append(([m], opts)) + return out + + +def combine_meta_data(station_meta, timeseries_meta): + meta = {} + for k, v in station_meta.items(): + if k == "codes": + meta[k] = v[0] + elif k in ["coordinates", "additional_metadata", "globalmeta"]: + for _key, _val in v.items(): + if _key == "lng": + meta["lon"] = _val + else: + meta[_key] = _val + elif k in ["changelog", "roles", "annotations", "aux_images", "aux_docs", "aux_urls"]: + continue + else: + meta[k] = v + for var, var_meta in timeseries_meta.items(): + for k, v in var_meta.items(): + if k in ["additional_metadata", "station", "programme", "annotations", "changelog"]: + continue + elif k == "roles": + for _key, _val in v[0]["contact"]["organisation"].items(): + new_k = f"{var}_organisation_{_key}" + meta[new_k] = _val + elif k == "variable": + for _key, _val in v.items(): + new_k = f"{var}_{_key}" + meta[new_k] = _val + else: + new_k = f"{var}_{k}" + meta[new_k] = v + return meta + + +def load_timeseries_data(timeseries_meta, url_base, opts, headers, sampling): + coll = [] + for meta in timeseries_meta: + series_id = meta["id"] + # opts = {"base": url_base, "service": f"data/timeseries/{series_id}"} + opts = {"base": url_base, "service": f"data/timeseries", "param_id": series_id, "format": "csv", **opts} + if sampling != "hourly": + opts["service"] = None + res = get_data(opts, headers, as_json=False) + data = pd.read_csv(StringIO(res), comment="#", index_col="datetime", parse_dates=True, + infer_datetime_format=True) + if len(data.index) > 0: + data = data[correct_stat_name(opts.get("names", "value"))].rename(meta["variable"]["name"]) + coll.append(data) + return coll + + +def load_station_information(station_name: List[str], url_base: str, headers: Dict): + # opts = {"base": url_base, "service": f"stationmeta/{station_name[0]}"} + opts = {"base": url_base, "service": f"stationmeta", "param_id": station_name[0]} + return get_data(opts, headers) + + +def load_timeseries_information(station_meta, var_meta, url_base: str, headers: Dict, + data_origin: Dict = None) -> [Dict, Dict]: + timeseries_id_dict = {} + missing = [] + for var, meta in var_meta.items(): + timeseries_id_dict[var] = [] + opts = {"base": url_base, "service": "search", "station_id": station_meta["id"], "variable_id": meta["id"]} + res = get_data(opts, headers) + if len(res) == 0: + missing.append((var, meta)) + # raise EmptyQueryResult(f"Cannot find any timeseries for station id {station_meta['id']} " + # f"({station_meta['codes'][0]}) and variable id {meta['id']} ({var}).") + if data_origin is not None: + var_origin = data_origin[var] + timeseries_id_dict[var] = select_timeseries_by_origin(res, var_origin) + # if len(timeseries_id_dict[var]) == 0: + # raise EmptyQueryResult(f"Cannot find any timeseries for station id {station_meta['id']} " + # f"({station_meta['codes'][0]}), variable id {meta['id']} ({var}) " + # f"and timeseries origin {var_origin}.") + if data_origin is None or len(timeseries_id_dict[var]) == 0: + timeseries_id_dict[var] = select_timeseries_by_order(res) + if len(missing) > 0: + missing = ",".join([f"{m[0]} ({m[1]['id']})" for m in missing]) + raise EmptyQueryResult(f"Cannot find any timeseries for station id {station_meta['id']} " + f"({station_meta['codes'][0]}) and variables {missing}.") + return timeseries_id_dict + + +def select_timeseries_by_order(toar_meta): + order_dict = {meta["order"]: meta for meta in toar_meta} + res = [order_dict[order] for order in sorted(order_dict.keys())] + return res + + +def select_timeseries_by_origin(toar_meta, var_origin): + res = [] + for origin in to_list(var_origin): + for meta in toar_meta: + for roles in meta["roles"]: + if roles["contact"]["organisation"]["name"].lower() == origin.lower(): + res.append(meta) + break + return res + + +def load_variables_information(var_dict, url_base, headers): + var_meta_dict = {} + for var in var_dict.keys(): + opts = {"base": url_base, "service": f"variables", "param_id": var} + var_meta_dict[var] = get_data(opts, headers) + return var_meta_dict diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py index 247c4fc9c7c6d57d721c1d0895cc8c719b1bd4a5..5fc3df951ed5dec9e94ed7d34d8dc02bafddf262 100644 --- a/mlair/helpers/filter.py +++ b/mlair/helpers/filter.py @@ -214,6 +214,7 @@ class ClimateFIRFilter(FIRFilter): h = [] if self.sel_opts is not None: self.sel_opts = self.sel_opts if isinstance(self.sel_opts, dict) else {self.time_dim: self.sel_opts} + self._check_sel_opts() sampling = {1: "1d", 24: "1H"}.get(int(self.fs)) logging.debug(f"{self.display_name}: create diurnal_anomalies") if self.apriori_diurnal is True and sampling == "1H": @@ -303,6 +304,10 @@ class ClimateFIRFilter(FIRFilter): except Exception as e: logging.info(f"Could not plot climate fir filter due to following reason:\n{e}") + def _check_sel_opts(self): + if len(self.data.sel(**self.sel_opts).coords[self.time_dim]) == 0: + raise ValueError(f"Abort {self.__class__.__name__} as no data is available after applying sel_opts to data") + @staticmethod def _next_order(order: list, minimum_length: Union[int, None], pos: int, window: Union[str, tuple]) -> int: next_order = 0 diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index b583cf7dc473db96181f88b0ab26e60ee225240d..ca69f28557c6386f021b137e5861660f40b867d9 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -57,7 +57,7 @@ def to_list(obj: Any) -> List: :return: list containing obj, or obj itself (if obj was already a list) """ - if isinstance(obj, (set, tuple)): + if isinstance(obj, (set, tuple, type({}.keys()))): obj = list(obj) elif not isinstance(obj, list): obj = [obj] @@ -176,16 +176,17 @@ def remove_items(obj: Union[List, Dict, Tuple], items: Any): raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") -def select_from_dict(dict_obj: dict, sel_list: Any, remove_none=False): +def select_from_dict(dict_obj: dict, sel_list: Any, remove_none: bool = False, filter_cond: bool = True) -> dict: """ Extract all key values pairs whose key is contained in the sel_list. - Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore the number of pairs in the - returned dict is always smaller or equal to the number of elements in the sel_list. + Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore, the number of pairs in the + returned dict is always smaller or equal to the number of elements in the sel_list. If `filter_cond` is given, this + method either return the parts of the input dictionary that are included or not in `sel_list`. """ sel_list = to_list(sel_list) assert isinstance(dict_obj, dict) - sel_dict = {k: v for k, v in dict_obj.items() if k in sel_list} + sel_dict = {k: v for k, v in dict_obj.items() if (k in sel_list) is filter_cond} sel_dict = sel_dict if not remove_none else {k: v for k, v in sel_dict.items() if v is not None} return sel_dict @@ -252,6 +253,19 @@ def convert2xrda(arr: Union[xr.DataArray, xr.Dataset, np.ndarray, int, float], return xr.DataArray(arr, **kwargs) +def filter_dict_by_value(dictionary: dict, filter_val: Any, filter_cond: bool) -> dict: + """ + Filter dictionary by its values. + + :param dictionary: dict to filter + :param filter_val: search only for key value pair with a value equal to filter_val + :param filter_cond: indicate to use either all dict entries that fulfil the filter_val criteria (if `True`) or that + do not match the criteria (if `False`) + :returns: a filtered dict with either matching or non-matching elements depending on the `filter_cond` + """ + return dict(filter(lambda x: (x[1] == filter_val) is filter_cond, dictionary.items())) + + # def convert_size(size_bytes): # if size_bytes == 0: # return "0B" diff --git a/mlair/helpers/join.py b/mlair/helpers/join.py deleted file mode 100644 index 67591b29a4e4bcc8b3083869825aed09ebebaf58..0000000000000000000000000000000000000000 --- a/mlair/helpers/join.py +++ /dev/null @@ -1,275 +0,0 @@ -"""Functions to access join database.""" -__author__ = 'Felix Kleinert, Lukas Leufen' -__date__ = '2019-10-16' - -import datetime as dt -import logging -from typing import Iterator, Union, List, Dict - -import pandas as pd -import requests -from requests.adapters import HTTPAdapter -from requests.packages.urllib3.util.retry import Retry - -from mlair import helpers -from mlair.configuration.join_settings import join_settings - -# join_url_base = 'https://join.fz-juelich.de/services/rest/surfacedata/' -str_or_none = Union[str, None] - - -class EmptyQueryResult(Exception): - """Exception that get raised if a query to JOIN returns empty results.""" - - pass - - -def download_join(station_name: Union[str, List[str]], stat_var: dict, station_type: str = None, - network_name: str = None, sampling: str = "daily", data_origin: Dict = None) -> [pd.DataFrame, - pd.DataFrame]: - """ - Read data from JOIN/TOAR. - - :param station_name: Station name e.g. DEBY122 - :param stat_var: key as variable like 'O3', values as statistics on keys like 'mean' - :param station_type: set the station type like "traffic" or "background", can be none - :param network_name: set the measurement network like "UBA" or "AIRBASE", can be none - :param sampling: sampling rate of the downloaded data, either set to daily or hourly (default daily) - :param data_origin: additional dictionary to specify data origin as key (for variable) value (origin) pair. Valid - origins are "REA" for reanalysis data and "" (empty string) for observational data. - - :returns: data frame with all variables and statistics and meta data frame with all meta information - """ - # make sure station_name parameter is a list - station_name = helpers.to_list(station_name) - - # also ensure that given data_origin dict is no reference - data_origin = None if data_origin is None else {k: v for (k, v) in data_origin.items()} - - # get data connection settings - join_url_base, headers = join_settings(sampling) - - # load series information - vars_dict, data_origin = load_series_information(station_name, station_type, network_name, join_url_base, headers, - data_origin) - - # check if all requested variables are available - if set(stat_var).issubset(vars_dict) is False: - missing_variables = set(stat_var).difference(vars_dict) - origin = helpers.select_from_dict(data_origin, missing_variables) - options = f"station={station_name}, type={station_type}, network={network_name}, origin={origin}" - raise EmptyQueryResult(f"No data found for variables {missing_variables} and options {options} in JOIN.") - - # correct stat_var values if data is not aggregated (hourly) - if sampling == "hourly": - stat_var = {key: "values" for key in stat_var.keys()} - - # download all variables with given statistic - data = None - df = None - logging.info(f"load data for {station_name[0]} from JOIN") - for var in _lower_list(sorted(vars_dict.keys())): - if var in stat_var.keys(): - - logging.debug('load: {}'.format(var)) - - # create data link - opts = {'base': join_url_base, 'service': 'stats', 'id': vars_dict[var], 'statistics': stat_var[var], - 'sampling': sampling, 'capture': 0, 'format': 'json'} - - # load data - data = get_data(opts, headers) - - # adjust data format if given as list of list - # no branch cover because this just happens when downloading hourly data using a secret token, not available - # for CI testing. - if isinstance(data, list): # pragma: no branch - data = correct_data_format(data) - - # correct namespace of statistics - stat = _correct_stat_name(stat_var[var]) - - # store data in pandas dataframe - df = _save_to_pandas(df, data, stat, var) - - logging.debug('finished: {}'.format(var)) - - if data: - meta = pd.DataFrame.from_dict(data['metadata'], orient='index') - meta.columns = station_name - return df, meta - else: - raise EmptyQueryResult("No data found in JOIN.") - - -def correct_data_format(data): - """ - Transform to the standard data format. - - For some cases (e.g. hourly data), the data is returned as list instead of a dictionary with keys datetime, values - and metadata. This functions addresses this issue and transforms the data into the dictionary version. - - :param data: data in hourly format - - :return: the same data but formatted to fit with aggregated format - """ - formatted = {"datetime": [], - "values": [], - "metadata": data[-1]} - for d in data[:-1]: - for k, v in zip(["datetime", "values"], d): - formatted[k].append(v) - return formatted - - -def get_data(opts: Dict, headers: Dict) -> Union[Dict, List]: - """ - Download join data using requests framework. - - Data is returned as json like structure. Depending on the response structure, this can lead to a list or dictionary. - - :param opts: options to create the request url - :param headers: additional headers information like authorization, can be empty - - :return: requested data (either as list or dictionary) - """ - url = create_url(**opts) - response = retries_session().get(url, headers=headers, timeout=(5, None)) # timeout=(open, read) - if response.status_code == 200: - return response.json() - else: - raise EmptyQueryResult(f"There was an error (STATUS {response.status_code}) for request {url}") - - -def retries_session(max_retries=3): - retry_strategy = Retry(total=max_retries, - backoff_factor=0.1, - status_forcelist=[429, 500, 502, 503, 504], - method_whitelist=["HEAD", "GET", "OPTIONS"]) - adapter = HTTPAdapter(max_retries=retry_strategy) - http = requests.Session() - http.mount("https://", adapter) - http.mount("http://", adapter) - return http - - -def load_series_information(station_name: List[str], station_type: str_or_none, network_name: str_or_none, - join_url_base: str, headers: Dict, data_origin: Dict = None) -> [Dict, Dict]: - """ - List all series ids that are available for given station id and network name. - - :param station_name: Station name e.g. DEBW107 - :param station_type: station type like "traffic" or "background" - :param network_name: measurement network of the station like "UBA" or "AIRBASE" - :param join_url_base: base url name to download data from - :param headers: additional headers information like authorization, can be empty - :param data_origin: additional information to select a distinct series e.g. from reanalysis (REA) or from observation - ("", empty string). This dictionary should contain a key for each variable and the information as key - :return: all available series for requested station stored in an dictionary with parameter name (variable) as key - and the series id as value. - """ - opts = {"base": join_url_base, "service": "search", "station_id": station_name[0], "station_type": station_type, - "network_name": network_name, "as_dict": "true", - "columns": "id,network_name,station_id,parameter_name,parameter_label,parameter_attribute"} - station_vars = get_data(opts, headers) - logging.debug(f"{station_name}: {station_vars}") - return _select_distinct_series(station_vars, data_origin) - - -def _select_distinct_series(vars: List[Dict], data_origin: Dict = None) -> [Dict, Dict]: - """ - Select distinct series ids for all variables. Also check if a parameter is from REA or not. - """ - data_origin_default = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA", - "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", - "no": "", "no2": "", "o3": "", "pm10": "", "so2": ""} - if data_origin is None: - data_origin = {} - # ToDo: maybe press, wdir, wspeed from obs? or also temp, ... ? - selected = {} - for var in vars: - name = var["parameter_name"].lower() - var_attr = var["parameter_attribute"].lower() - if name not in data_origin.keys(): - data_origin.update({name: data_origin_default.get(name, "")}) - attr = data_origin.get(name, "").lower() - if var_attr == attr: - selected[name] = var["id"] - return selected, data_origin - - -def _save_to_pandas(df: Union[pd.DataFrame, None], data: dict, stat: str, var: str) -> pd.DataFrame: - """ - Save given data in data frame. - - If given data frame is not empty, the data is appened as new column. - - :param df: data frame to append the new data, can be none - :param data: new data to append or format as data frame containing the keys 'datetime' and '<stat>' - :param stat: extracted statistic to get values from data (e.g. 'mean', 'dma8eu') - :param var: variable the data is from (e.g. 'o3') - - :return: new created or concatenated data frame - """ - if len(data["datetime"][0]) == 19: - str_format = "%Y-%m-%d %H:%M:%S" - else: - str_format = "%Y-%m-%d %H:%M" - index = map(lambda s: dt.datetime.strptime(s, str_format), data['datetime']) - if df is None: - df = pd.DataFrame(data[stat], index=index, columns=[var]) - else: - df = pd.concat([df, pd.DataFrame(data[stat], index=index, columns=[var])], axis=1) - return df - - -def _correct_stat_name(stat: str) -> str: - """ - Map given statistic name to new namespace defined by mapping dict. - - Return given name stat if not element of mapping namespace. - - :param stat: namespace from JOIN server - - :return: stat mapped to local namespace - """ - mapping = {'average_values': 'mean', 'maximum': 'max', 'minimum': 'min'} - return mapping.get(stat, stat) - - -def _lower_list(args: List[str]) -> Iterator[str]: - """ - Lower all elements of given list. - - :param args: list with string entries to lower - - :return: iterator that lowers all list entries - """ - for string in args: - yield string.lower() - - -def create_url(base: str, service: str, **kwargs: Union[str, int, float, None]) -> str: - """ - Create a request url with given base url, service type and arbitrarily many additional keyword arguments. - - :param base: basic url of the rest service - :param service: service type, e.g. series, stats - :param kwargs: keyword pairs for optional request specifications, e.g. 'statistics=maximum' - - :return: combined url as string - """ - if not base.endswith("/"): - base += "/" - url = f"{base}{service}/?{'&'.join(f'{k}={v}' for k, v in kwargs.items() if v is not None)}" - return url - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - var_all_dic = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', - 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', - 'pblheight': 'maximum'} - station = 'DEBW107' - # download_join(station, var_all_dic, sampling="daily") - download_join(station, var_all_dic, sampling="hourly") diff --git a/mlair/helpers/meteo.py b/mlair/helpers/meteo.py new file mode 100644 index 0000000000000000000000000000000000000000..c43d4ff23239f4ebff2b130779b3f8e2323620ca --- /dev/null +++ b/mlair/helpers/meteo.py @@ -0,0 +1,14 @@ + +import numpy as np + + +def relative_humidity_from_dewpoint(dew, temp): + return np.clip(100 * e_sat(dew) / e_sat(temp), 0, 100) + + +def e_sat(temperature): + a1 = 611.21 # Pa + a3 = 17.502 + a4 = 32.19 # K + T0 = 273.16 # K + return a1 * np.exp(a3 * (temperature - T0) / (temperature - a4)) diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index 7633a2a9c1842219d7af7b9c7b2b4f23a034cbdf..5f3aa45161530ff7d425ccbc7625dd7e081d8839 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -419,10 +419,13 @@ class SkillScores: skill_score.loc[["CASE III", "AIII"], iahead] = np.stack(self._climatological_skill_score( data, mu_type=3, forecast_name=forecast_name, observation_name=self.observation_name, external_data=external_data).values.flatten()) - - skill_score.loc[["CASE IV", "AIV", "BIV", "CIV"], iahead] = np.stack(self._climatological_skill_score( - data, mu_type=4, forecast_name=forecast_name, observation_name=self.observation_name, - external_data=external_data).values.flatten()) + try: + skill_score.loc[["CASE IV", "AIV", "BIV", "CIV"], iahead] = np.stack( + self._climatological_skill_score(data, mu_type=4, forecast_name=forecast_name, + observation_name=self.observation_name, + external_data=external_data).values.flatten()) + except ValueError: + pass return skill_score diff --git a/mlair/reference_models/abstract_reference_model.py b/mlair/reference_models/abstract_reference_model.py index e187e7ef62e3fe84f7ba2149a490f63ac718308f..f400447385be2f29e2ebf969ef16f3df0a67fd99 100644 --- a/mlair/reference_models/abstract_reference_model.py +++ b/mlair/reference_models/abstract_reference_model.py @@ -17,13 +17,14 @@ class AbstractReferenceModel(ABC): def __init__(self, *args, **kwargs): pass - def make_reference_available_locally(self): + def make_reference_available_locally(self, *args): raise NotImplementedError @staticmethod def is_reference_available_locally(reference_path) -> bool: """ Checks if reference is available locally + :param reference_path: look in this path for data """ try: diff --git a/mlair/reference_models/reference_model_cams.py b/mlair/reference_models/reference_model_cams.py new file mode 100644 index 0000000000000000000000000000000000000000..1db19c05a846ec948d3eda71727d11dd597643fa --- /dev/null +++ b/mlair/reference_models/reference_model_cams.py @@ -0,0 +1,56 @@ +__author__ = "Lukas Leufen" +__date__ = "2022-06-27" + + +from mlair.configuration.path_config import check_path_and_create +from mlair.reference_models.abstract_reference_model import AbstractReferenceModel +import os +import xarray as xr +import pandas as pd + + +class CAMSforecast(AbstractReferenceModel): + + def __init__(self, ref_name: str, ref_store_path: str = None, data_path: str = None): + + super().__init__() + self.ref_name = ref_name + if ref_store_path is None: + ref_store_path = f"{self.ref_name}/" + self.ref_store_path = ref_store_path + if data_path is None: + self.data_path = os.path.abspath(".") + else: + self.data_path = os.path.abspath(data_path) + self.file_pattern = "forecasts_%s_test.nc" + self.time_dim = "index" + self.ahead_dim = "ahead" + self.type_dim = "type" + + def make_reference_available_locally(self, stations): + "dma8eu_ENS_FORECAST_2019-04-09.nc" + missing_stations = self.list_locally_available_references(self.ref_store_path, stations) + if len(missing_stations) > 0: + check_path_and_create(self.ref_store_path) + dataset = xr.open_mfdataset(os.path.join(self.data_path, "dma8eu_ENS_FORECAST_*.nc")) + darray = dataset.to_array().sortby(["longitude", "latitude"]) + for station, coords in missing_stations.items(): + lon, lat = coords["lon"], coords["lat"] + station_data = darray.sel(longitude=lon, latitude=lat, method="nearest", drop=True).squeeze(drop=True) + station_data = station_data.expand_dims(dim={self.type_dim: [self.ref_name]}).compute() + station_data.coords[self.time_dim] = station_data.coords[self.time_dim] - pd.Timedelta(days=1) + station_data.coords[self.ahead_dim] = station_data.coords[self.ahead_dim] + 1 + file_name = self.file_pattern % str(station) + station_data.to_netcdf(os.path.join(self.ref_store_path, file_name)) + + @staticmethod + def list_locally_available_references(reference_path, stations) -> dict: + try: + file_list = os.listdir(reference_path) + if len(file_list) > 0: + res = {k: v for k, v in stations.items() if all(k not in x for x in file_list)} + else: + res = stations + except FileNotFoundError: + res = stations + return res diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 75cd9e9f0e924a3ac7aa6da174207412c1e2be40..706f6169c756a1558eaf0177801a7f484fdca1d1 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -10,7 +10,7 @@ from dill.source import getsource from mlair.configuration import path_config from mlair import helpers -from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_NETWORK, DEFAULT_STATION_TYPE, \ +from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_STATION_TYPE, \ DEFAULT_START, DEFAULT_END, DEFAULT_WINDOW_HISTORY_SIZE, DEFAULT_OVERWRITE_LOCAL_DATA, \ DEFAULT_HPC_LOGIN_LIST, DEFAULT_HPC_HOST_LIST, DEFAULT_CREATE_NEW_MODEL, DEFAULT_TRAIN_MODEL, \ DEFAULT_FRACTION_OF_TRAINING, DEFAULT_EXTREME_VALUES, DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY, DEFAULT_PERMUTE_DATA, \ diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 00d82f3c6f48c3560e31d62b5bed4ddbd2bc49be..97d1817f5eb884e80c042d56f02dd7a61f88d935 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -13,6 +13,7 @@ from typing import Dict, Tuple, Union, List, Callable import numpy as np import pandas as pd import xarray as xr +import datetime as dt from mlair.configuration import path_config from mlair.data_handler import Bootstraps, KerasIterator @@ -261,11 +262,17 @@ class PostProcessing(RunEnvironment): """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped.""" start_data = data.coords[dim].values[0] freq = {"daily": "1D", "hourly": "1H"}.get(sampling) - datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq)) + _ind = pd.date_range(start, end, freq=freq) # two steps required to include all hours of end interval + datetime_index = pd.DataFrame(index=pd.date_range(_ind.min(), _ind.max() + dt.timedelta(days=1), + closed="left", freq=freq)) t = data.sel({dim: start_data}, drop=True) res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords]) res = res.transpose(*data.dims) - res.loc[data.coords] = data + if data.shape == res.shape: + res.loc[data.coords] = data + else: + _d = data.sel({dim: slice(start, end)}) + res.loc[_d.coords] = _d return res def load_competitors(self, station_name: str) -> xr.DataArray: @@ -610,19 +617,16 @@ class PostProcessing(RunEnvironment): try: if "PlotStationMap" in plot_list: - if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get( - "hostname")[:6] in self.data_store.get("hpc_hosts"): - logging.warning( - f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") - else: - gens = [(self.train_data, {"marker": 5, "ms": 9}), - (self.val_data, {"marker": 6, "ms": 9}), - (self.test_data, {"marker": 4, "ms": 9})] - PlotStationMap(generators=gens, plot_folder=self.plot_path) - gens = [(self.train_val_data, {"marker": 8, "ms": 9}), - (self.test_data, {"marker": 9, "ms": 9})] - PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var") + gens = [(self.train_data, {"marker": 5, "ms": 9}), + (self.val_data, {"marker": 6, "ms": 9}), + (self.test_data, {"marker": 4, "ms": 9})] + PlotStationMap(generators=gens, plot_folder=self.plot_path) + gens = [(self.train_val_data, {"marker": 8, "ms": 9}), + (self.test_data, {"marker": 9, "ms": 9})] + PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var") except Exception as e: + if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get("hostname")[:6] in self.data_store.get("hpc_hosts"): + logging.info(f"PlotStationMap might have failed as current workflow is running on hpc node {self.data_store.get('hostname')}. To download geographic elements, please run PlotStationMap once on login node.") logging.error(f"Could not create plot PlotStationMap due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") @@ -761,6 +765,7 @@ class PostProcessing(RunEnvironment): indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file. + Forecast is trimmed on interval start and end of test subset. :param station_name: name of the station to load data for :param competitor_name: name of the model @@ -770,9 +775,14 @@ class PostProcessing(RunEnvironment): file = os.path.join(path, f"forecasts_{station_name}_test.nc") with xr.open_dataarray(file) as da: data = da.load() - forecast = data.sel(type=[self.forecast_indicator]) - forecast.coords[self.model_type_dim] = [competitor_name] - return forecast + if self.forecast_indicator in data.coords[self.model_type_dim]: + forecast = data.sel({self.model_type_dim: [self.forecast_indicator]}) + forecast.coords[self.model_type_dim] = [competitor_name] + else: + forecast = data.sel({self.model_type_dim: [competitor_name]}) + # limit forecast to time range of test subset + start, end = self.data_store.get("start", "test"), self.data_store.get("end", "test") + return self.create_full_time_dim(forecast, self.index_dim, self._sampling, start, end) def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray: """ diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 5cf676dd3acfaec5532d006e6de0b22e769db873..94906427c7fb3cb953c48bdd5f9748697af44821 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -18,7 +18,7 @@ import pandas as pd from mlair.data_handler import DataCollection, AbstractDataHandler from mlair.helpers import TimeTracking, to_list, tables, remove_items from mlair.configuration import path_config -from mlair.helpers.join import EmptyQueryResult +from mlair.helpers.data_sources.toar_data import EmptyQueryResult from mlair.helpers.testing import check_nested_equality from mlair.run_modules.run_environment import RunEnvironment from mlair.helpers.datastore import DataStoreByScope @@ -122,8 +122,8 @@ class PreProcessing(RunEnvironment): +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ """ - meta_cols = ['station_name', 'station_lon', 'station_lat', 'station_alt'] - meta_round = ["station_lon", "station_lat", "station_alt"] + meta_cols = ["name", "lat", "lon", "alt", "country", "state", "type", "type_of_area", "toar1_category"] + meta_round = ["lat", "lon", "alt"] precision = 4 path = os.path.join(self.data_store.get("experiment_path"), "latex_report") path_config.check_path_and_create(path) @@ -378,7 +378,19 @@ class PreProcessing(RunEnvironment): logging.info("Prepare IntelliO3-ts-v1 model") from mlair.reference_models.reference_model_intellio3_v1 import IntelliO3_ts_v1 path = os.path.join(self.data_store.get("competitor_path"), competitor_name) - IntelliO3_ts_v1("IntelliO3-ts-v1", path).make_reference_available_locally(remove_tmp_dir=False) + IntelliO3_ts_v1("IntelliO3-ts-v1", ref_store_path=path).make_reference_available_locally(remove_tmp_dir=False) + elif competitor_name.lower() == "CAMS".lower(): + logging.info("Prepare CAMS forecasts") + from mlair.reference_models.reference_model_cams import CAMSforecast + data_path = self.data_store.get_default("cams_data_path", default=None) + path = os.path.join(self.data_store.get("competitor_path"), competitor_name) + stations = {} + for subset in ["train", "val", "test"]: + data_collection = self.data_store.get("data_collection", subset) + stations.update({str(s): s.get_coordinates() for s in data_collection if s not in stations}) + CAMSforecast("CAMS", ref_store_path=path, data_path=data_path).make_reference_available_locally(stations) + + else: logging.info("No preparation required because no competitor was provided to the workflow.") @@ -453,8 +465,9 @@ def f_proc(data_handler, station, name_affix, store, return_strategy="", tmp_pat def f_proc_create_info_df(data, meta_cols): station_name = str(data.id_class) + meta = data.id_class.meta res = {"station_name": station_name, "Y_shape": data.get_Y()[0].shape[0], - "meta": data.id_class.meta.loc[meta_cols].values.flatten()} + "meta": meta.reindex(meta_cols).values.flatten()} return res diff --git a/mlair/run_modules/run_environment.py b/mlair/run_modules/run_environment.py index 4da0fa8c8364764c444a662b79ffabc4f427ccc8..191ee30f0485c6abc42a8d718612b7b26feb221a 100644 --- a/mlair/run_modules/run_environment.py +++ b/mlair/run_modules/run_environment.py @@ -114,7 +114,10 @@ class RunEnvironment(object): """ if not self.del_by_exit: self.time.stop() - logging.info(f"{self._name} finished after {self.time}") + try: + logging.info(f"{self._name} finished after {self.time}") + except NameError: + pass self.del_by_exit = True # copy log file and clear data store only if called as base class and not as super class if self.__class__.__name__ == "RunEnvironment": @@ -122,7 +125,7 @@ class RunEnvironment(object): self.__plot_tracking() self.__save_tracking() self.__move_log_file() - except FileNotFoundError: + except (FileNotFoundError, NameError): pass self.data_store.clear_data_store() diff --git a/requirements.txt b/requirements.txt index 3afc17b67fddbf5a269df1e1b7e103045630a290..4f911b37f5a27be1f30caf69a613df5deef62a29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,34 +1,34 @@ -astropy==4.1 +astropy==5.1 auto_mix_prep==0.2.0 -Cartopy==0.18.0 -dask==2021.3.0 +Cartopy==0.20.0 +dask==2021.9.1 dill==0.3.3 -fsspec==2021.11.0 -keras==2.6.0 -keras_nightly==2.5.0.dev2021032900 +fsspec==2021.10.1 +Keras==2.6.0 locket==0.2.1 -matplotlib==3.3.4 +matplotlib==3.4.3 mock==4.0.3 -netcdf4==1.5.8 -numpy==1.19.5 -pandas==1.1.5 +netcdf4==1.6.0 +numpy~=1.19.2 +pandas==1.3.4 partd==1.2.0 -psutil==5.8.0 +psutil==5.9.1 pydot==1.4.2 pytest==6.2.2 pytest-cov==2.11.1 pytest-html==3.1.1 pytest-lazy-fixture==0.6.3 -requests==2.25.1 -scipy==1.5.2 -seaborn==0.11.1 +requests==2.28.1 +scipy==1.7.1 +seaborn==0.11.2 setuptools==47.1.0 --no-binary shapely Shapely==1.8.0 six==1.15.0 -statsmodels==0.12.2 -tabulate==0.8.9 -tensorflow==2.5.0 -toolz==0.11.2 -typing_extensions==3.7.4.3 +statsmodels==0.13.2 +tabulate==0.8.10 +tensorflow==2.6.0 +timezonefinder==5.2.0 +toolz==0.11.1 +typing_extensions~=3.7.4 wget==3.2 xarray==0.16.2 diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py index 07a5aa2f543b1992baf10421de4b28133feb0eac..b46590290eff09ac98d549c7d38010eb5506d09c 100644 --- a/test/test_configuration/test_defaults.py +++ b/test/test_configuration/test_defaults.py @@ -31,7 +31,6 @@ class TestAllDefaults: 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 'pblheight': 'maximum'} - assert DEFAULT_NETWORK == "AIRBASE" assert DEFAULT_STATION_TYPE == "background" assert DEFAULT_VARIABLES == DEFAULT_VAR_ALL_DICT.keys() assert DEFAULT_START == "1997-01-01" diff --git a/test/test_helpers/test_data_sources/test_join.py b/test/test_helpers/test_data_sources/test_join.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b12f5a7ff20e898695de0a0f035bed023674f2 --- /dev/null +++ b/test/test_helpers/test_data_sources/test_join.py @@ -0,0 +1,344 @@ +from typing import Iterable + +import pytest + +from mlair.helpers.data_sources.join import * +from mlair.helpers.data_sources.join import _save_to_pandas, _lower_list, _select_distinct_series, \ + _select_distinct_data_origin, _select_distinct_network +from mlair.configuration.join_settings import join_settings +from mlair.helpers.testing import check_nested_equality +from mlair.helpers.data_sources.toar_data import EmptyQueryResult + + +class TestDownloadJoin: + + def test_download_single_var(self): + data, meta = download_join("DEBW107", {"o3": "dma8eu"}) + assert data.columns == "o3" + assert meta.columns == "DEBW107" + + def test_download_empty(self): + with pytest.raises(EmptyQueryResult) as e: + download_join("DEBW107", {"o3": "dma8eu"}, "traffic") + assert e.value.args[-1] == "No data found for variables {'o3'} and options station=['DEBW107'], type=traffic," \ + " network=None, origin={} in JOIN." + + def test_download_incomplete(self): + with pytest.raises(EmptyQueryResult) as e: + download_join("DEBW107", {"o3": "dma8eu", "o10": "maximum"}, "background") + assert e.value.args[-1] == "No data found for variables {'o10'} and options station=['DEBW107'], " \ + "type=background, network=None, origin={} in JOIN." + with pytest.raises(EmptyQueryResult) as e: + download_join("DEBW107", {"o3": "dma8eu", "o10": "maximum"}, "background", data_origin={"o10": ""}) + assert e.value.args[-1] == "No data found for variables {'o10'} and options station=['DEBW107'], " \ + "type=background, network={'o10': None}, origin={'o10': ''} in JOIN." + + +class TestCorrectDataFormat: + + def test_correct_data_format(self): + list_data = [["2020-01-01 06:00:01", 23.], ["2020-01-01 06:00:11", 24.], ["2020-01-01 06:00:21", 25.], + ["2020-01-01 06:00:31", 26.], ["2020-01-01 06:00:41", 27.], ["2020-01-01 06:00:51", 23.], + {"station": "test_station_001", "author": "ME", "success": True}] + dict_data = correct_data_format(list_data) + assert dict_data == {"datetime": ["2020-01-01 06:00:01", "2020-01-01 06:00:11", "2020-01-01 06:00:21", + "2020-01-01 06:00:31", "2020-01-01 06:00:41", "2020-01-01 06:00:51"], + "values": [23., 24., 25., 26., 27., 23.], + "metadata": {"station": "test_station_001", "author": "ME", "success": True}} + + +class TestLoadSeriesInformation: + + def test_standard_query(self): + expected_subset = {'o3': 17057, 'no2': 17058, 'temp': 85587, 'wspeed': 17060} + res, orig = load_series_information(['DEBW107'], None, None, join_settings()[0], {}) + assert expected_subset.items() <= res.items() + + def test_empty_result(self): + res, orig = load_series_information(['DEBW107'], "traffic", None, join_settings()[0], {}) + assert res == {} + + +class TestSelectDistinctDataOrigin: + + @pytest.fixture + def vars(self): + return [{'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', + 'parameter_label': 'NO2', 'parameter_attribute': ''}, + {'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', 'parameter_attribute': ''}, + {'id': 16692, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS--LANUV', 'parameter_attribute': ''}, + {'id': 16693, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP--LANUV', 'parameter_attribute': ''}, + {'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}, + {'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}, + {'id': 102660, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}, + {'id': 26692, 'network_name': 'AIRBASE', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS', 'parameter_attribute': 'REA'}] + + def test_no_origin_given(self, vars): + res, orig = _select_distinct_data_origin(vars, {}) + expected = { + "no2": [{'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', + 'parameter_label': 'NO2', 'parameter_attribute': ''}], + "o3": [{'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', 'parameter_attribute': ''}], + "cloudcover": [{'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}], + "temp": [{'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}], + "press": [{'id': 102660, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}, + {'id': 26692, 'network_name': 'AIRBASE', 'station_id': 'DENW053', + 'parameter_name': 'press', 'parameter_label': 'PRESS', 'parameter_attribute': 'REA'}]} + + assert check_nested_equality(res, expected) is True + # assert res == {"no2": 16686, "o3": 16687, "cloudcover": 54036, "temp": 88491, "press": 102660} + assert orig == {"no2": "", "o3": "", "cloudcover": "REA", "temp": "REA", "press": "REA"} + + def test_different_origins(self, vars): + origin = {"no2": "test", "temp": "", "cloudcover": "REA"} + res, orig = _select_distinct_data_origin(vars, data_origin=origin) + expected = { + "o3": [{'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', 'parameter_attribute': ''}], + "cloudcover": [{'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}], + "temp": [{'id': 16693, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP--LANUV', 'parameter_attribute': ''}], + "press": [{'id': 102660, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}, + {'id': 26692, 'network_name': 'AIRBASE', 'station_id': 'DENW053', + 'parameter_name': 'press', 'parameter_label': 'PRESS', 'parameter_attribute': 'REA'}]} + assert check_nested_equality(res, expected) is True + # assert res == {"o3": 16687, "press": 102660, "temp": 16693, "cloudcover": 54036} + assert orig == {"no2": "test", "o3": "", "cloudcover": "REA", "temp": "", "press": "REA"} + + +class TestSelectDistinctNetwork: + + @pytest.fixture + def vars(self): + return { + "no2": [{'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', + 'parameter_label': 'NO2', 'parameter_attribute': ''}], + "o3": [{'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', 'parameter_attribute': ''}], + "cloudcover": [{'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}], + "temp": [{'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}], + "press": [{'id': 102660, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}, + {'id': 26692, 'network_name': 'AIRBASE', 'station_id': 'DENW053', + 'parameter_name': 'press', 'parameter_label': 'PRESS', 'parameter_attribute': 'REA'}]} + + def test_no_network_given(self, caplog, vars): + caplog.set_level(logging.INFO) + res = _select_distinct_network(vars, []) + expected = { + "no2": {'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', + 'parameter_label': 'NO2', 'parameter_attribute': ''}, + "o3": {'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', 'parameter_attribute': ''}, + "cloudcover": {'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}, + "temp": {'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}, + "press": {'id': 102660, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}} + assert check_nested_equality(res, expected) is True + + message = "Could not find a valid match for variable %s and networks []! Therefore, use first answer from JOIN:" + assert message % "no2" in caplog.messages[0] + assert message % "o3" in caplog.messages[1] + assert message % "cloudcover" in caplog.messages[2] + assert message % "temp" in caplog.messages[3] + assert message % "press" in caplog.messages[4] + + def test_single_network_given(self, vars): + res = _select_distinct_network(vars, ["UBA"]) + expected = { + "no2": {'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', + 'parameter_label': 'NO2', 'parameter_attribute': ''}, + "o3": {'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', 'parameter_attribute': ''}, + "cloudcover": {'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}, + "temp": {'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}, + "press": {'id': 102660, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}} + assert check_nested_equality(res, expected) is True + + def test_single_network_given_no_match(self, vars): + with pytest.raises(ValueError) as e: # AIRBASE not avail for all variables + _select_distinct_network(vars, ["AIRBASE"]) + assert e.value.args[-1] == "Cannot find a valid match for requested networks ['AIRBASE'] and variable no2 as " \ + "only following networks are available in JOIN: ['UBA']" + + with pytest.raises(ValueError) as e: # both requested networks are not available for all variables + _select_distinct_network(vars, ["LUBW", "EMEP"]) + assert e.value.args[-1] == "Cannot find a valid match for requested networks ['LUBW', 'EMEP'] and variable " \ + "no2 as only following networks are available in JOIN: ['UBA']" + + def test_multiple_networks_given(self, vars): + res = _select_distinct_network(vars, ["UBA", "AIRBASE"]) + expected = { + "no2": {'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', + 'parameter_label': 'NO2', 'parameter_attribute': ''}, + "o3": {'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', 'parameter_attribute': ''}, + "cloudcover": {'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}, + "temp": {'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}, + "press": {'id': 102660, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}} + assert check_nested_equality(res, expected) is True + + res = _select_distinct_network(vars, ["AIRBASE", "UBA"]) + expected = { + "no2": {'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', + 'parameter_label': 'NO2', 'parameter_attribute': ''}, + "o3": {'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', 'parameter_attribute': ''}, + "cloudcover": {'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}, + "temp": {'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}, + "press": {'id': 26692, 'network_name': 'AIRBASE', 'station_id': 'DENW053', + 'parameter_name': 'press', 'parameter_label': 'PRESS', 'parameter_attribute': 'REA'}} + assert check_nested_equality(res, expected) is True + + def test_multiple_networks_given_by_dict(self, vars): + res = _select_distinct_network(vars, {"no2": "UBA", "o3": ["UBA", "AIRBASE"], "temp": ["AIRBASE", "UBA"], + "press": ["AIRBASE", "UBA"]}) + expected = { + "no2": {'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', + 'parameter_label': 'NO2', 'parameter_attribute': ''}, + "o3": {'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', 'parameter_attribute': ''}, + "cloudcover": {'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}, + "temp": {'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}, + "press": {'id': 26692, 'network_name': 'AIRBASE', 'station_id': 'DENW053', + 'parameter_name': 'press', 'parameter_label': 'PRESS', 'parameter_attribute': 'REA'}} + assert check_nested_equality(res, expected) is True + + +class TestSelectDistinctSeries: + + @pytest.fixture + def vars(self): + return [{'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', + 'parameter_label': 'NO2', 'parameter_attribute': ''}, + {'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', + 'parameter_label': 'O3', 'parameter_attribute': ''}, + {'id': 16692, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS--LANUV', 'parameter_attribute': ''}, + {'id': 16693, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP--LANUV', 'parameter_attribute': ''}, + {'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', + 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}, + {'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', + 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}, + {'id': 102660, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}, + {'id': 26692, 'network_name': 'AIRBASE', 'station_id': 'DENW053', 'parameter_name': 'press', + 'parameter_label': 'PRESS', 'parameter_attribute': 'REA'}] + + def test_no_origin_given(self, vars): + res, orig = _select_distinct_series(vars) + assert res == {"no2": 16686, "o3": 16687, "cloudcover": 54036, "temp": 88491, "press": 102660} + assert orig == {"no2": "", "o3": "", "cloudcover": "REA", "temp": "REA", "press": "REA"} + + def test_different_origins(self, vars): + origin = {"no2": "test", "temp": "", "cloudcover": "REA"} + res, orig = _select_distinct_series(vars, data_origin=origin) + assert res == {"o3": 16687, "press": 102660, "temp": 16693, "cloudcover": 54036} + assert orig == {"no2": "test", "o3": "", "cloudcover": "REA", "temp": "", "press": "REA"} + res, orig = _select_distinct_series(vars, data_origin={}) + assert res == {"cloudcover": 54036, "no2": 16686, "o3": 16687, "press": 102660, "temp": 88491} + assert orig == {"no2": "", "o3": "", "temp": "REA", "press": "REA", "cloudcover": "REA"} + + def test_different_networks(self, vars): + res, orig = _select_distinct_series(vars, network_name="UBA") + assert res == {"no2": 16686, "o3": 16687, "cloudcover": 54036, "temp": 88491, "press": 102660} + assert orig == {"no2": "", "o3": "", "cloudcover": "REA", "temp": "REA", "press": "REA"} + + res, orig = _select_distinct_series(vars, network_name=["UBA", "EMEP", "AIRBASE"]) + assert res == {"no2": 16686, "o3": 16687, "cloudcover": 54036, "temp": 88491, "press": 102660} + assert orig == {"no2": "", "o3": "", "cloudcover": "REA", "temp": "REA", "press": "REA"} + + res, orig = _select_distinct_series(vars, network_name=["EMEP", "AIRBASE", "UBA"]) + assert res == {"no2": 16686, "o3": 16687, "cloudcover": 54036, "temp": 88491, "press": 26692} + assert orig == {"no2": "", "o3": "", "cloudcover": "REA", "temp": "REA", "press": "REA"} + + def test_network_not_available(self, vars): + with pytest.raises(ValueError) as e: + _select_distinct_series(vars, network_name="AIRBASE") + assert e.value.args[-1] == "Cannot find a valid match for requested networks ['AIRBASE'] and variable no2 as " \ + "only following networks are available in JOIN: ['UBA']" + + def test_different_network_and_origin(self, vars): + origin = {"no2": "test", "temp": "", "cloudcover": "REA"} + res, orig = _select_distinct_series(vars, data_origin=origin, network_name=["EMEP", "AIRBASE", "UBA"]) + assert res == {"o3": 16687, "press": 26692, "temp": 16693, "cloudcover": 54036} + assert orig == {"no2": "test", "o3": "", "cloudcover": "REA", "temp": "", "press": "REA"} + + +class TestSaveToPandas: + + @staticmethod + def convert_date(date): + return map(lambda s: dt.datetime.strptime(s, "%Y-%m-%d %H:%M"), date) + + @pytest.fixture + def date(self): + return ['1997-01-01 00:00', '1997-01-02 00:00', '1997-01-03 00:00', '1997-01-04 00:00'] + + @pytest.fixture + def date_len19(self): + return ['1997-01-01 00:00:00', '1997-01-02 00:00:00', '1997-01-03 00:00:00', '1997-01-04 00:00:00'] + + @pytest.fixture + def values(self): + return [86.21, 94.76, 76.96, 99.89] + + @pytest.fixture + def alternative_values(self): + return [20.0, 25.2, 25.1, 23.6] + + @pytest.fixture + def create_df(self, date, values): + return pd.DataFrame(values, index=self.convert_date(date), columns=['cloudcover']) + + def test_empty_df(self, date, values, create_df): + data = {'datetime': date, 'mean': values, 'metadata': None} + assert pd.testing.assert_frame_equal(create_df, _save_to_pandas(None, data, 'mean', 'cloudcover')) is None + + def test_not_empty_df(self, date, alternative_values, create_df): + data = {'datetime': date, 'max': alternative_values, 'metadata': None} + next_df = pd.DataFrame(data["max"], index=self.convert_date(date), columns=['temperature']) + df_concat = pd.concat([create_df, next_df], axis=1) + assert pd.testing.assert_frame_equal(df_concat, _save_to_pandas(create_df, data, 'max', 'temperature')) is None + + def test_alternative_date_format(self, date_len19, values, create_df): + data = {'datetime': date_len19, 'mean': values, 'metadata': None} + assert pd.testing.assert_frame_equal(create_df, _save_to_pandas(None, data, 'mean', 'cloudcover')) is None + + +class TestLowerList: + + def test_string_lowering(self): + list_iterator = _lower_list(["Capitalised", "already_small", "UPPERCASE", "veRyStRaNGe"]) + assert isinstance(list_iterator, Iterable) + assert list(list_iterator) == ["capitalised", "already_small", "uppercase", "verystrange"] + + diff --git a/test/test_helpers/test_data_sources/test_toar_data.py b/test/test_helpers/test_data_sources/test_toar_data.py new file mode 100644 index 0000000000000000000000000000000000000000..abaec10cc580b592d85d7dcc842616c67777f174 --- /dev/null +++ b/test/test_helpers/test_data_sources/test_toar_data.py @@ -0,0 +1,53 @@ +from mlair.configuration.join_settings import join_settings +from mlair.helpers.data_sources.toar_data import get_data, create_url, correct_stat_name + + +class TestGetData: + + def test(self): + opts = {"base": join_settings()[0], "service": "series", "station_id": 'DEBW107', "network_name": "UBA", + "parameter_name": "o3,no2"} + assert get_data(opts, headers={}) == [[17057, 'UBA', 'DEBW107', 'O3'], [17058, 'UBA', 'DEBW107', 'NO2']] + + +class TestCreateUrl: + + def test_minimal_args_given(self): + url = create_url("www.base.edu", "testingservice") + assert url == "www.base.edu/testingservice/" + + def test_given_kwargs(self): + url = create_url("www.base2.edu/", "testingservice", mood="happy", confidence=0.98) + assert url == "www.base2.edu/testingservice/?mood=happy&confidence=0.98" + + def test_single_kwargs(self): + url = create_url("www.base2.edu/", "testingservice", mood="undefined") + assert url == "www.base2.edu/testingservice/?mood=undefined" + + def test_none_kwargs(self): + url = create_url("www.base2.edu/", "testingservice", mood="sad", happiness=None, stress_factor=100) + assert url == "www.base2.edu/testingservice/?mood=sad&stress_factor=100" + + def test_param_id(self): + url = create_url("www.base.edu", "testingservice", param_id="2001") + assert url == "www.base.edu/testingservice/2001" + + def test_param_id_kwargs(self): + url = create_url("www.base.edu", "testingservice", param_id=2001, mood="sad", happiness=None, stress_factor=100) + assert url == "www.base.edu/testingservice/2001?mood=sad&stress_factor=100" + + url = create_url("www.base.edu", "testingservice", param_id=2001, mood="sad", series_id=222) + assert url == "www.base.edu/testingservice/2001?mood=sad&series_id=222" + + +class TestCorrectStatName: + + def test_nothing_to_do(self): + assert correct_stat_name("dma8eu") == "dma8eu" + assert correct_stat_name("max") == "max" + + def test_correct_string(self): + assert correct_stat_name("maximum") == "max" + assert correct_stat_name("minimum") == "min" + assert correct_stat_name("average_values") == "mean" + diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py index 70640be9d56d71e4f68145b3bb68fb835e1e27a5..6f787d5835bd917fcfc55341d93a2d302f2c6e6e 100644 --- a/test/test_helpers/test_helpers.py +++ b/test/test_helpers/test_helpers.py @@ -12,8 +12,9 @@ import mock import pytest import string -from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict, sort_like -from mlair.helpers import PyTestRegex +from mlair.helpers import to_list, dict_to_xarray, float_round, remove_items, extract_value, select_from_dict, \ + sort_like, filter_dict_by_value +from mlair.helpers import PyTestRegex, check_nested_equality from mlair.helpers import Logger, TimeTracking from mlair.helpers.helpers import is_xarray, convert2xrda, relative_round @@ -223,6 +224,10 @@ class TestSelectFromDict: assert select_from_dict(dictionary, ["a", "e"]) == {"a": 1, "e": None} assert select_from_dict(dictionary, ["a", "e"], remove_none=True) == {"a": 1} + def test_select_condition(self, dictionary): + assert select_from_dict(dictionary, ["a", "e"], filter_cond=False) == {"b": 23, "c": "last"} + assert select_from_dict(dictionary, ["a", "c"], filter_cond=False, remove_none=True) == {"b": 23} + class TestRemoveItems: @@ -487,3 +492,22 @@ class TestSortLike: l_obj = [1, 2, 3, 8, 4] with pytest.raises(AssertionError) as e: sort_like(l_obj, [1, 2, 3, 5, 6, 7, 8]) + + +class TestFilterDictByValue: + + def test_filter_dict_by_value(self): + data_origin = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA', 'temp': 'era5'} + expected = {'temp': 'era5'} + assert check_nested_equality(filter_dict_by_value(data_origin, "era5", True), expected) is True + expected = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA'} + assert check_nested_equality(filter_dict_by_value(data_origin, "era5", False), expected) is True + expected = {'o3': '', 'no': '', 'no2': ''} + assert check_nested_equality(filter_dict_by_value(data_origin, "", True), expected) is True + + def test_filter_dict_by_value_not_avail(self): + data_origin = {'o3': '', 'no': '', 'no2': '', 'relhum': 'REA', 'u': 'REA', 'cloudcover': 'REA', 'temp': 'era5'} + expected = {} + assert check_nested_equality(filter_dict_by_value(data_origin, "not_avail", True), expected) is True + assert check_nested_equality(filter_dict_by_value(data_origin, "EA", True), expected) is True + diff --git a/test/test_helpers/test_join.py b/test/test_helpers/test_join.py deleted file mode 100644 index e903669bf63f4056a8278401b07818d31a09616d..0000000000000000000000000000000000000000 --- a/test/test_helpers/test_join.py +++ /dev/null @@ -1,179 +0,0 @@ -from typing import Iterable - -import pytest - -from mlair.helpers.join import * -from mlair.helpers.join import _save_to_pandas, _correct_stat_name, _lower_list, _select_distinct_series -from mlair.configuration.join_settings import join_settings - - -class TestDownloadJoin: - - def test_download_single_var(self): - data, meta = download_join("DEBW107", {"o3": "dma8eu"}) - assert data.columns == "o3" - assert meta.columns == "DEBW107" - - def test_download_empty(self): - with pytest.raises(EmptyQueryResult) as e: - download_join("DEBW107", {"o3": "dma8eu"}, "traffic") - assert e.value.args[-1] == "No data found for variables {'o3'} and options station=['DEBW107'], type=traffic," \ - " network=None, origin={} in JOIN." - - def test_download_incomplete(self): - with pytest.raises(EmptyQueryResult) as e: - download_join("DEBW107", {"o3": "dma8eu", "o10": "maximum"}, "background") - assert e.value.args[-1] == "No data found for variables {'o10'} and options station=['DEBW107'], " \ - "type=background, network=None, origin={} in JOIN." - with pytest.raises(EmptyQueryResult) as e: - download_join("DEBW107", {"o3": "dma8eu", "o10": "maximum"}, "background", data_origin={"o10": ""}) - assert e.value.args[-1] == "No data found for variables {'o10'} and options station=['DEBW107'], " \ - "type=background, network=None, origin={'o10': ''} in JOIN." - - -class TestCorrectDataFormat: - - def test_correct_data_format(self): - list_data = [["2020-01-01 06:00:01", 23.], ["2020-01-01 06:00:11", 24.], ["2020-01-01 06:00:21", 25.], - ["2020-01-01 06:00:31", 26.], ["2020-01-01 06:00:41", 27.], ["2020-01-01 06:00:51", 23.], - {"station": "test_station_001", "author": "ME", "success": True}] - dict_data = correct_data_format(list_data) - assert dict_data == {"datetime": ["2020-01-01 06:00:01", "2020-01-01 06:00:11", "2020-01-01 06:00:21", - "2020-01-01 06:00:31", "2020-01-01 06:00:41", "2020-01-01 06:00:51"], - "values": [23., 24., 25., 26., 27., 23.], - "metadata": {"station": "test_station_001", "author": "ME", "success": True}} - - -class TestGetData: - - def test(self): - opts = {"base": join_settings()[0], "service": "series", "station_id": 'DEBW107', "network_name": "UBA", - "parameter_name": "o3,no2"} - assert get_data(opts, headers={}) == [[17057, 'UBA', 'DEBW107', 'O3'], [17058, 'UBA', 'DEBW107', 'NO2']] - - -class TestLoadSeriesInformation: - - def test_standard_query(self): - expected_subset = {'o3': 23031, 'no2': 39002, 'temp': 85584, 'wspeed': 17060} - res, orig = load_series_information(['DEBW107'], None, None, join_settings()[0], {}) - assert expected_subset.items() <= res.items() - - def test_empty_result(self): - res, orig = load_series_information(['DEBW107'], "traffic", None, join_settings()[0], {}) - assert res == {} - - -class TestSelectDistinctSeries: - - @pytest.fixture - def vars(self): - return [{'id': 16686, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'no2', - 'parameter_label': 'NO2', 'parameter_attribute': ''}, - {'id': 16687, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'o3', - 'parameter_label': 'O3', - 'parameter_attribute': ''}, - {'id': 16692, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', - 'parameter_label': 'PRESS--LANUV', 'parameter_attribute': ''}, - {'id': 16693, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', - 'parameter_label': 'TEMP--LANUV', 'parameter_attribute': ''}, - {'id': 54036, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'cloudcover', - 'parameter_label': 'CLOUDCOVER', 'parameter_attribute': 'REA'}, - {'id': 88491, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'temp', - 'parameter_label': 'TEMP-REA-MIUB', 'parameter_attribute': 'REA'}, - {'id': 102660, 'network_name': 'UBA', 'station_id': 'DENW053', 'parameter_name': 'press', - 'parameter_label': 'PRESS-REA-MIUB', 'parameter_attribute': 'REA'}] - - def test_no_origin_given(self, vars): - res, orig = _select_distinct_series(vars) - assert res == {"no2": 16686, "o3": 16687, "cloudcover": 54036, "temp": 88491, "press": 102660} - assert orig == {"no2": "", "o3": "", "cloudcover": "REA", "temp": "REA", "press": "REA"} - - def test_different_origins(self, vars): - origin = {"no2": "test", "temp": "", "cloudcover": "REA"} - res, orig = _select_distinct_series(vars, data_origin=origin) - assert res == {"o3": 16687, "press": 102660, "temp": 16693, "cloudcover": 54036} - assert orig == {"no2": "test", "o3": "", "cloudcover": "REA", "temp": "", "press": "REA"} - res, orig = _select_distinct_series(vars, data_origin={}) - assert res == {"cloudcover": 54036, "no2": 16686, "o3": 16687, "press": 102660, "temp": 88491} - assert orig == {"no2": "", "o3": "", "temp": "REA", "press": "REA", "cloudcover": "REA"} - - -class TestSaveToPandas: - - @staticmethod - def convert_date(date): - return map(lambda s: dt.datetime.strptime(s, "%Y-%m-%d %H:%M"), date) - - @pytest.fixture - def date(self): - return ['1997-01-01 00:00', '1997-01-02 00:00', '1997-01-03 00:00', '1997-01-04 00:00'] - - @pytest.fixture - def date_len19(self): - return ['1997-01-01 00:00:00', '1997-01-02 00:00:00', '1997-01-03 00:00:00', '1997-01-04 00:00:00'] - - @pytest.fixture - def values(self): - return [86.21, 94.76, 76.96, 99.89] - - @pytest.fixture - def alternative_values(self): - return [20.0, 25.2, 25.1, 23.6] - - @pytest.fixture - def create_df(self, date, values): - return pd.DataFrame(values, index=self.convert_date(date), columns=['cloudcover']) - - def test_empty_df(self, date, values, create_df): - data = {'datetime': date, 'mean': values, 'metadata': None} - assert pd.testing.assert_frame_equal(create_df, _save_to_pandas(None, data, 'mean', 'cloudcover')) is None - - def test_not_empty_df(self, date, alternative_values, create_df): - data = {'datetime': date, 'max': alternative_values, 'metadata': None} - next_df = pd.DataFrame(data["max"], index=self.convert_date(date), columns=['temperature']) - df_concat = pd.concat([create_df, next_df], axis=1) - assert pd.testing.assert_frame_equal(df_concat, _save_to_pandas(create_df, data, 'max', 'temperature')) is None - - def test_alternative_date_format(self, date_len19, values, create_df): - data = {'datetime': date_len19, 'mean': values, 'metadata': None} - assert pd.testing.assert_frame_equal(create_df, _save_to_pandas(None, data, 'mean', 'cloudcover')) is None - - -class TestCorrectStatName: - - def test_nothing_to_do(self): - assert _correct_stat_name("dma8eu") == "dma8eu" - assert _correct_stat_name("max") == "max" - - def test_correct_string(self): - assert _correct_stat_name("maximum") == "max" - assert _correct_stat_name("minimum") == "min" - assert _correct_stat_name("average_values") == "mean" - - -class TestLowerList: - - def test_string_lowering(self): - list_iterator = _lower_list(["Capitalised", "already_small", "UPPERCASE", "veRyStRaNGe"]) - assert isinstance(list_iterator, Iterable) - assert list(list_iterator) == ["capitalised", "already_small", "uppercase", "verystrange"] - - -class TestCreateUrl: - - def test_minimal_args_given(self): - url = create_url("www.base.edu", "testingservice") - assert url == "www.base.edu/testingservice/?" - - def test_given_kwargs(self): - url = create_url("www.base2.edu/", "testingservice", mood="happy", confidence=0.98) - assert url == "www.base2.edu/testingservice/?mood=happy&confidence=0.98" - - def test_single_kwargs(self): - url = create_url("www.base2.edu/", "testingservice", mood="undefined") - assert url == "www.base2.edu/testingservice/?mood=undefined" - - def test_none_kwargs(self): - url = create_url("www.base2.edu/", "testingservice", mood="sad", happiness=None, stress_factor=100) - assert url == "www.base2.edu/testingservice/?mood=sad&stress_factor=100" diff --git a/test/test_model_modules/test_abstract_model_class.py b/test/test_model_modules/test_abstract_model_class.py index a1ec4c63a2b3b44c26bbf722a3d4d84aec112bec..2a1578aa28c061fce40be2e3f2f2a29306663463 100644 --- a/test/test_model_modules/test_abstract_model_class.py +++ b/test/test_model_modules/test_abstract_model_class.py @@ -147,16 +147,16 @@ class TestAbstractModelClass: with pytest.raises(ValueError) as einfo: amc.compile_options = {"optimizer": keras.optimizers.Adam()} assert "Got different values or arguments for same argument: self.optimizer=<class" \ - " 'tensorflow.python.keras.optimizer_v2.gradient_descent.SGD'> and " \ - "'optimizer': <class 'tensorflow.python.keras.optimizer_v2.adam.Adam'>" in str(einfo.value) + " 'keras.optimizer_v2.gradient_descent.SGD'> and " \ + "'optimizer': <class 'keras.optimizer_v2.adam.Adam'>" in str(einfo.value) def test_compile_options_setter_as_mix_attr_dict_invalid_duplicates_same_optimizer_other_args(self, amc): amc.optimizer = keras.optimizers.SGD(lr=0.1) with pytest.raises(ValueError) as einfo: amc.compile_options = {"optimizer": keras.optimizers.SGD(lr=0.001)} assert "Got different values or arguments for same argument: self.optimizer=<class" \ - " 'tensorflow.python.keras.optimizer_v2.gradient_descent.SGD'> and " \ - "'optimizer': <class 'tensorflow.python.keras.optimizer_v2.gradient_descent.SGD'>" in str(einfo.value) + " 'keras.optimizer_v2.gradient_descent.SGD'> and " \ + "'optimizer': <class 'keras.optimizer_v2.gradient_descent.SGD'>" in str(einfo.value) def test_compile_options_setter_as_dict_invalid_keys(self, amc): with pytest.raises(ValueError) as einfo: diff --git a/test/test_model_modules/test_flatten_tail.py b/test/test_model_modules/test_flatten_tail.py index 83861be561fbe164d09048f1b748b51977b2fc27..b53e381ea1cecc1d6dfbe019264c726f11946479 100644 --- a/test/test_model_modules/test_flatten_tail.py +++ b/test/test_model_modules/test_flatten_tail.py @@ -27,7 +27,7 @@ class TestGetActivation: def test_layer_act(self, model_input): x_in = get_activation(model_input, activation=ELU, name='adv_layer') act = x_in._keras_history[0] - assert act.name == 'adv_layer' + assert act.name == 'tf.nn.elu' def test_layer_act_invalid(self, model_input): with pytest.raises(TypeError) as einfo: @@ -62,8 +62,8 @@ class TestFlattenTail: assert final_dense.units == 2 assert final_dense.kernel_regularizer is None inner_act = self.step_in(final_dense) - assert inner_act.name == 'Main_tail_act' - assert inner_act.__class__.__name__ == 'ELU' + assert inner_act.name == 'tf.nn.elu_1' + assert inner_act.__class__.__name__ == 'TFOpLambda' inner_dense = self.step_in(inner_act) assert inner_dense.name == 'Main_tail_inner_Dense' assert inner_dense.units == 64 @@ -112,9 +112,8 @@ class TestFlattenTail: 'dtype': 'float32', 'data_format': 'channels_last'} reduc_act = self.step_in(flatten) - assert reduc_act.get_config() == {'name': 'Main_tail_all_conv_act', 'trainable': True, - 'dtype': 'float32', 'alpha': 1.0} - + assert reduc_act.get_config() == {'name': 'tf.nn.elu_2', 'trainable': True, 'function': 'nn.elu', + 'dtype': 'float32'} reduc_conv = self.step_in(reduc_act) assert reduc_conv.kernel_size == (1, 1) diff --git a/test/test_model_modules/test_inception_model.py b/test/test_model_modules/test_inception_model.py index 0ed975d054841d9d4cfb8b4c964fa0cd2d4e2667..0a0dd38fa9d354c1243127df1ae9e079a6ca88e9 100644 --- a/test/test_model_modules/test_inception_model.py +++ b/test/test_model_modules/test_inception_model.py @@ -43,7 +43,7 @@ class TestInceptionModelBase: assert base.part_of_block == 1 assert tower.name == 'Block_0a_act_2/Relu:0' act_layer = tower._keras_history[0] - assert isinstance(act_layer, ReLU) + assert isinstance(act_layer, keras.layers.ReLU) assert act_layer.name == "Block_0a_act_2" # check previous element of tower (conv2D) conv_layer = self.step_in(act_layer) @@ -60,7 +60,7 @@ class TestInceptionModelBase: assert pad_layer.name == 'Block_0a_Pad' # check previous element of tower (activation) act_layer2 = self.step_in(pad_layer) - assert isinstance(act_layer2, ReLU) + assert isinstance(act_layer2, keras.layers.ReLU) assert act_layer2.name == "Block_0a_act_1" # check previous element of tower (conv2D) conv_layer2 = self.step_in(act_layer2) @@ -80,7 +80,7 @@ class TestInceptionModelBase: # assert tower.name == 'Block_0a_act_2/Relu:0' assert tower.name == 'Block_0a_act_2/Relu:0' act_layer = tower._keras_history[0] - assert isinstance(act_layer, ReLU) + assert isinstance(act_layer, keras.layers.ReLU) assert act_layer.name == "Block_0a_act_2" # check previous element of tower (batch_normal) batch_layer = self.step_in(act_layer) @@ -101,7 +101,7 @@ class TestInceptionModelBase: assert pad_layer.name == 'Block_0a_Pad' # check previous element of tower (activation) act_layer2 = self.step_in(pad_layer) - assert isinstance(act_layer2, ReLU) + assert isinstance(act_layer2, keras.layers.ReLU) assert act_layer2.name == "Block_0a_act_1" # check previous element of tower (conv2D) conv_layer2 = self.step_in(act_layer2) @@ -124,7 +124,7 @@ class TestInceptionModelBase: tower = base.create_conv_tower(activation=keras.layers.LeakyReLU, **opts) assert tower.name == 'Block_0b_act_2/LeakyRelu:0' act_layer = tower._keras_history[0] - assert isinstance(act_layer, LeakyReLU) + assert isinstance(act_layer, keras.layers.LeakyReLU) assert act_layer.name == "Block_0b_act_2" def test_create_conv_tower_1x1(self, base, input_x): @@ -134,7 +134,7 @@ class TestInceptionModelBase: assert base.part_of_block == 1 assert tower.name == 'Block_0a_act_1/Relu:0' act_layer = tower._keras_history[0] - assert isinstance(act_layer, ReLU) + assert isinstance(act_layer, keras.layers.ReLU) assert act_layer.name == "Block_0a_act_1" # check previous element of tower (conv2D) conv_layer = self.step_in(act_layer) @@ -160,7 +160,7 @@ class TestInceptionModelBase: assert base.part_of_block == 1 assert tower.name == 'Block_0a_act_1/Relu:0' act_layer = tower._keras_history[0] - assert isinstance(act_layer, ReLU) + assert isinstance(act_layer, keras.layers.ReLU) assert act_layer.name == "Block_0a_act_1" # check previous element of tower (conv2D) conv_layer = self.step_in(act_layer) diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py index 1dafdbd5c4882932e3d57e726e7a06bea22a745d..4618a5e4f3f5eaf2a419e68a5a0e18156aa7fb0d 100644 --- a/test/test_run_modules/test_pre_processing.py +++ b/test/test_run_modules/test_pre_processing.py @@ -30,7 +30,7 @@ class TestPreProcessing: @pytest.fixture def obj_with_exp_setup(self): - ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], + ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW99X'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background", data_handler=DefaultDataHandler) pre = object.__new__(PreProcessing) @@ -87,7 +87,7 @@ class TestPreProcessing: def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") - message = "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']" + message = "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW99X']" assert ('root', 10, message) in caplog.record_tuples data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection) diff --git a/test/test_workflows/test_default_workflow.py b/test/test_workflows/test_default_workflow.py index c7c198a4821f779329b9f5f19b04e757d8ebc7da..790fb5f5de2fef207c64fdc430028f0739eb20fa 100644 --- a/test/test_workflows/test_default_workflow.py +++ b/test/test_workflows/test_default_workflow.py @@ -23,7 +23,7 @@ class TestDefaultWorkflow: flow = DefaultWorkflow(stations="test", real_kwarg=4) assert flow._registry[0].__name__ == ExperimentSetup.__name__ assert len(flow._registry_kwargs[0].keys()) == 3 - assert list(flow._registry_kwargs[0].keys()) == ["experiment_date", "stations", "real_kwarg"] + assert sorted(list(flow._registry_kwargs[0].keys())) == ["experiment_date", "real_kwarg", "stations"] def test_setup(self): flow = DefaultWorkflow()